This notebook provides tips for loading larger pretrained or finetuned models when GPU (or CPU) memory is limited
Specifically, it focuses on cases where you saved the model using torch.save(model.state_dict(), "model.pth") (for example, in chapters 5-7) and want to load it in a new session later for continued pretraining or additional finetuning
While the example uses an LLM, the methods explained in this notebook are general and apply to loading any PyTorch model, not just LLMs
1. Benchmark utilities
First, let’s define some utility code to track VRAM (GPU memory)
Later, we will also introduce a tool to track the main system RAM (CPU memory)
The purpose of these functions will become clear when we apply them later
2. Model setup
This code section sets up the model itself
Here, we use the “large” GPT-2 model to make things more interesting (you may use the “gpt2-small (124M)” to lower the memory requirements and execution time of this notebook)
Now, let’s see the GPU memory functions in action:
Additionally, let’s make sure that the model runs okay by passing in some example tensor
Next, imagine we were pretraining the model and saving it for later use
We skip the actual pretraining here for simplicity and just save the initialized model (but the same concept applies)
Lastly, we delete the model and example tensor in the Python session to reset the GPU memory
3. Weight loading
Now begins the interesting part where we load the pretrained model weights
Let’s see how much GPU memory is required to load the previously saved model
Notice that the memory is 2x as large as in the previous session
This is because we have the same model in memory twice, for a short period of time:
The first time via model.to(device)
The second time via the code line model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True)); eventually, the loaded model weights will be copied into the model, and the state_dict will be discarded, but for a brief amount of time, we have both the main model and the loaded state_dict in memory
The remaining sections focus on addressing this
But first, let’s test the model and reset the GPU memory
4. Loading weights sequentially
One workaround for the problem of having the model weights in GPU memory twice, as highlighted in the previous section, is to load the model sequentially
Below, we:
first load the model into GPU memory
then load the model weights into CPU memory
and finally copy each parameter one by one into GPU memory
As we can see above, the memory usage is much lower than before
Notice that the memory increases from 6.4 to 6.7 GB because initially, we only have the model in memory, and then we have the model plus 1 parameter tensor in memory (we temporarily move the parameter tensor to the GPU so we can assign it using ".to" the model)
Overall, this is a significant improvement
Again, let’s briefly test the model and then reset the GPU memory for the next section
5. Loading the model with low CPU memory
In the previous session, we reduced GPU memory use by loading the weights (state_dict) into CPU memory first before copying them one-by-one into the model
However, what do we do if we have limited CPU memory?
This section uses PyTorch’s so-called "meta" device approach to load a model on machines with large GPU memory but small CPU memory
But first, let’s define a convenience function to monitor CPU memory
To start with, let’s track the CPU memory of the sequential weight loading approach from the previous section
Now, suppose we have a machine with low CPU memory but large GPU memory
We can trade off CPU memory and GPU memory usage by introducing PyTorch’s so-called “meta” device
PyTorch’s meta device is a special device type that allows you to create tensors without allocating actual memory for their data, effectively creating “meta” tensors
This is useful for tasks like model analysis or architecture definition, where you need tensor shapes and types without the overhead of memory allocation
As we can see above, by creating the model on the meta-device and loading the weights directly into GPU memory, we effectively reduced the CPU memory requirements
One might ask: “Is the sequential weight loading still necessary then, and how does that compare to the original approach?”
Let’s check the simple PyTorch weight loading approach for comparison (from the first weight loading section in this notebook):
As we can see above, the “simple” weight loading without the meta device uses more memory
In other words, if you have a machine with limited CPU memory, you can use the meta device approach to directly load the model weights into GPU memory to reduce peak CPU memory usage
6. Using mmap=True (recommmended)
As an intermediate or advanced torch.load user, you may wonder how these approaches compare to the mmap=True setting in PyTorch
The mmap=True setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM if RAM is limited