training

  • Checkpoints are really big
    • optimizer states (2x for Adam) + model (1x) + gradients (1x) = 4x the size
  • take a long time to save and load checkpoints
  • might not fit on the memory of a single node
  • PyTorch has support for Sharded checkpointing