Jupyter Notebook

Memory-efficient Model Weight Loading

  • This notebook provides tips for loading larger pretrained or finetuned models when GPU (or CPU) memory is limited
  • Specifically, it focuses on cases where you saved the model using torch.save(model.state_dict(), "model.pth") (for example, in chapters 5-7) and want to load it in a new session later for continued pretraining or additional finetuning
  • While the example uses an LLM, the methods explained in this notebook are general and apply to loading any PyTorch model, not just LLMs

from importlib.metadata import version
pkgs = [
for p in pkgs:
    print(f"{p} version: {version(p)}")
memory_profiler version: 0.61.0
torch version: 2.4.1+cu121

1. Benchmark utilities

  • First, let’s define some utility code to track VRAM (GPU memory)
  • Later, we will also introduce a tool to track the main system RAM (CPU memory)
  • The purpose of these functions will become clear when we apply them later
import gc
import time
import torch
def start_memory_tracking():
    """Initialize GPU memory tracking."""
    if torch.cuda.is_available():
        print("This notebook is intended for CUDA GPUs but CUDA is not available.")
def print_memory_usage():
    max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert bytes to GB
    print(f"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB")
def cleanup():
    time.sleep(3)  # some buffer time to allow memory to clear
    max_memory_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
    print(f"Maximum GPU memory allocated: {max_memory_allocated:.1f} GB")

2. Model setup

  • This code section sets up the model itself
  • Here, we use the “large” GPT-2 model to make things more interesting (you may use the “gpt2-small (124M)” to lower the memory requirements and execution time of this notebook)
from previous_chapters import GPTModel
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "drop_rate": 0.0,        # Dropout rate
    "qkv_bias": True         # Query-key-value bias
model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
CHOOSE_MODEL = "gpt2-xl (1558M)"
  • Now, let’s see the GPU memory functions in action:
model = GPTModel(BASE_CONFIG)
device = torch.device("cuda")
Maximum GPU memory allocated: 6.4 GB
  • Additionally, let’s make sure that the model runs okay by passing in some example tensor
# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
with torch.no_grad():
  • Next, imagine we were pretraining the model and saving it for later use
  • We skip the actual pretraining here for simplicity and just save the initialized model (but the same concept applies)
# Training code would go here...
torch.save(model.state_dict(), "model.pth")
  • Lastly, we delete the model and example tensor in the Python session to reset the GPU memory
del model, test_input
Maximum GPU memory allocated: 0.0 GB

3. Weight loading

  • Now begins the interesting part where we load the pretrained model weights
  • Let’s see how much GPU memory is required to load the previously saved model
# Then load pretrained weights
model = GPTModel(BASE_CONFIG)
    torch.load("model.pth", map_location=device, weights_only=True)
Maximum GPU memory allocated: 12.8 GB
  • Notice that the memory is 2x as large as in the previous session
  • This is because we have the same model in memory twice, for a short period of time:
    • The first time via model.to(device)
    • The second time via the code line model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True)); eventually, the loaded model weights will be copied into the model, and the state_dict will be discarded, but for a brief amount of time, we have both the main model and the loaded state_dict in memory
  • The remaining sections focus on addressing this
  • But first, let’s test the model and reset the GPU memory
# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
with torch.no_grad():
del model, test_input
Maximum GPU memory allocated: 0.0 GB

4. Loading weights sequentially

  • One workaround for the problem of having the model weights in GPU memory twice, as highlighted in the previous section, is to load the model sequentially
  • Below, we:
    • first load the model into GPU memory
    • then load the model weights into CPU memory
    • and finally copy each parameter one by one into GPU memory
model = GPTModel(BASE_CONFIG).to(device)
state_dict = torch.load("model.pth", map_location="cpu", weights_only=True)
# Sequentially copy weights to the model's parameters
with torch.no_grad():
    for name, param in model.named_parameters():
        if name in state_dict:
            print(f"Warning: {name} not found in state_dict.")
Maximum GPU memory allocated: 6.4 GB
Maximum GPU memory allocated: 6.7 GB
  • As we can see above, the memory usage is much lower than before
  • Notice that the memory increases from 6.4 to 6.7 GB because initially, we only have the model in memory, and then we have the model plus 1 parameter tensor in memory (we temporarily move the parameter tensor to the GPU so we can assign it using ".to" the model)
  • Overall, this is a significant improvement
  • Again, let’s briefly test the model and then reset the GPU memory for the next section
# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
with torch.no_grad():
del model, test_input, state_dict, param
Maximum GPU memory allocated: 0.0 GB

5. Loading the model with low CPU memory

  • In the previous session, we reduced GPU memory use by loading the weights (state_dict) into CPU memory first before copying them one-by-one into the model
  • However, what do we do if we have limited CPU memory?
  • This section uses PyTorch’s so-called "meta" device approach to load a model on machines with large GPU memory but small CPU memory
  • But first, let’s define a convenience function to monitor CPU memory
import os
import psutil
from threading import Thread
def memory_usage_in_gb(func, *args, **kwargs):
    process = psutil.Process(os.getpid())
    # Measure the baseline memory usage before running the function
    baseline_mem = process.memory_info().rss / 1024 ** 3  # in GB
    # Start monitoring memory in a separate thread
    mem_usage = []
    done = False
    def monitor_memory():
        while not done:
            mem_usage.append(process.memory_info().rss / 1024 ** 3)  # Convert to GB
    t = Thread(target=monitor_memory)
    # Run the function
    func(*args, **kwargs)
    # Stop monitoring
    done = True
    peak_mem_usage_gb = max(mem_usage) - baseline_mem
    return peak_mem_usage_gb
  • To start with, let’s track the CPU memory of the sequential weight loading approach from the previous section
def load_sequentially():
    model = GPTModel(BASE_CONFIG).to(device)
    state_dict = torch.load("model.pth", map_location="cpu", weights_only=True)
    # Sequentially copy weights to the model's parameters
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in state_dict:
                print(f"Warning: {name} not found in state_dict.")
peak_memory_used = memory_usage_in_gb(load_sequentially)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Maximum GPU memory allocated: 6.4 GB
Maximum GPU memory allocated: 6.7 GB
-> Maximum CPU memory allocated: 6.3 GB
  • Now, suppose we have a machine with low CPU memory but large GPU memory
  • We can trade off CPU memory and GPU memory usage by introducing PyTorch’s so-called “meta” device
  • PyTorch’s meta device is a special device type that allows you to create tensors without allocating actual memory for their data, effectively creating “meta” tensors
  • This is useful for tasks like model analysis or architecture definition, where you need tensor shapes and types without the overhead of memory allocation
def load_sequentially_with_meta():
    with torch.device("meta"):
        model = GPTModel(BASE_CONFIG)
    model = model.to_empty(device=device)
    state_dict = torch.load("model.pth", map_location=device, weights_only=True)
    # Sequentially copy weights to the model's parameters
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in state_dict:
                print(f"Warning: {name} not found in state_dict.")
peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Maximum GPU memory allocated: 12.8 GB
Maximum GPU memory allocated: 12.8 GB
-> Maximum CPU memory allocated: 1.3 GB
  • As we can see above, by creating the model on the meta-device and loading the weights directly into GPU memory, we effectively reduced the CPU memory requirements
  • One might ask: “Is the sequential weight loading still necessary then, and how does that compare to the original approach?”
  • Let’s check the simple PyTorch weight loading approach for comparison (from the first weight loading section in this notebook):
def baseline():
    model = GPTModel(BASE_CONFIG)
    model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
peak_memory_used = memory_usage_in_gb(baseline)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Maximum GPU memory allocated: 12.8 GB
-> Maximum CPU memory allocated: 4.4 GB
  • As we can see above, the “simple” weight loading without the meta device uses more memory
  • In other words, if you have a machine with limited CPU memory, you can use the meta device approach to directly load the model weights into GPU memory to reduce peak CPU memory usage

6. Using mmap=True (recommmended)

  • As an intermediate or advanced torch.load user, you may wonder how these approaches compare to the mmap=True setting in PyTorch
  • The mmap=True setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM if RAM is limited
  • Also, see the helpful comment by mikaylagawarecki
  • At first glance, it may look less efficient than the sequential approaches above:
def best_practices():
  with torch.device("meta"):
      model = GPTModel(BASE_CONFIG)
      torch.load("model.pth", map_location=device, weights_only=True, mmap=True),
peak_memory_used = memory_usage_in_gb(best_practices)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Maximum GPU memory allocated: 6.4 GB
-> Maximum CPU memory allocated: 5.9 GB
  • The reason why the CPU RAM usage is so high is that there’s enough CPU RAM available on this machine
  • However, if you were to run this on a machine with limited CPU RAM, the mmap approach would use less memory

7. Other methods

  • This notebook is focused on simple, built-in methods for loading weights in PyTorch
  • The recommended approach for limited CPU memory cases is the mmap=True approach explained enough
  • Alternatively, one other option is a brute-force approach that saves and loads each weight tensor separately:
model = GPTModel(BASE_CONFIG)
# Assume `model` is your trained model
state_dict = model.state_dict()
# Create a directory to store individual parameter files
os.makedirs("model_parameters", exist_ok=True)
# Save each parameter tensor separately
for name, param in state_dict.items():
    torch.save(param.cpu(), f"model_parameters/{name}.pt")
del model
def load_individual_weights():
    with torch.device("meta"):
        model = GPTModel(BASE_CONFIG)
    model = model.to_empty(device=device)
    param_dir = "model_parameters"
    with torch.no_grad():
        for name, param in model.named_parameters():
            weight_path = os.path.join(param_dir, f"{name}.pt")
            if os.path.exists(weight_path):
                param_data = torch.load(weight_path, map_location="cpu", weights_only=True)
                del param_data  # Free memory
                print(f"Warning: {name} not found in {param_dir}.")
peak_memory_used = memory_usage_in_gb(load_individual_weights)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Maximum GPU memory allocated: 6.4 GB
Maximum GPU memory allocated: 6.4 GB
-> Maximum CPU memory allocated: 0.3 GB