training Checkpoints are really big optimizer states (2x for Adam) + model (1x) + gradients (1x) = 4x the size take a long time to save and load checkpoints might not fit on the memory of a single node PyTorch has support for Sharded checkpointing