PyTorch 模型的保存与加载机制详解

本文为与 DeepSeek 对话记录,深入探讨 PyTorch 中模型的保存和加载机制,包括不同的保存方式、最佳实践以及常见问题的解决方案。

Q: PyTorch 中所谓的保存模型/加载模型,指的是保存/加载什么?是 state_dict 吗?

Reasoning

在 PyTorch 中,保存和加载模型的核心是 state_dict,但具体操作可以灵活选择。以下是详细说明:


1. state_dict 是什么?

  • state_dict 是一个 Python 字典对象,它将模型的每一层(如卷积层、全连接层等)映射到对应的参数张量(权重和偏置)。
  • 例如,一个简单的神经网络可能包含:
    model.state_dict() = {
        'conv1.weight': tensor(...),
        'conv1.bias': tensor(...),
        'fc1.weight': tensor(...),
        'fc1.bias': tensor(...),
        ...
    }
  • 优化器(如 torch.optim.SGD)也有自己的 state_dict,保存优化器的状态(如动量缓冲等)。

2. 保存模型的两种方式

方式 1:仅保存 state_dict(推荐)

  • 代码示例
    torch.save(model.state_dict(), 'model_weights.pth')
  • 优点
    • 文件体积小,只保存参数,不保存模型结构。
    • 加载灵活:需先实例化模型结构,再加载参数。
    • 避免因代码变更(如类名修改)导致的加载失败。

方式 2:保存整个模型(不推荐)

  • 代码示例
    torch.save(model, 'entire_model.pth')
  • 缺点
    • 文件体积大,包含模型结构和参数。
    • 依赖原始代码环境(模型类的定义必须存在),否则加载时会报错。

3. 加载模型的对应方法

加载 state_dict(需先构建模型结构)

model = MyModelClass()  # 必须与保存时的模型结构一致
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()  # 切换到推理模式(关闭 dropout 等)

加载整个模型(直接使用)

model = torch.load('entire_model.pth')
model.eval()

4. 实际场景中的最佳实践

  • 保存训练检查点(Checkpoint): 通常需要保存 state_dict 之外的信息(如优化器状态、训练步数等):

    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': 100,
        'loss': loss,
    }
    torch.save(checkpoint, 'checkpoint.pth')
  • 加载检查点

    checkpoint = torch.load('checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']

5. 关键总结

  • 保存模型:本质是保存模型的 state_dict(推荐),但也可以保存整个模型。
  • 加载模型
    • 若保存的是 state_dict,需先实例化模型结构,再加载参数。
    • 若保存的是整个模型,直接加载但需保证代码兼容性。
  • 灵活性:使用 state_dict 更安全、更灵活,适合跨项目或模型结构可能变化的场景。

Q: model 的 config 又是什么东西?是超参吗?

Reasoning

在机器学习中,模型的 config(配置) 通常指定义模型结构的所有参数,而超参(Hyperparameters)更多指控制训练过程的参数。这两者有时会有重叠,但核心区别在于作用对象。以下是详细解释:


1. 什么是模型的 config

  • 模型配置(Model Configuration)是定义模型结构的一组参数,例如:

    • 神经网络的层数、每层的神经元数量
    • 卷积核大小、步长(Stride)、填充(Padding)
    • 激活函数类型(如 ReLU、Sigmoid)
    • 注意力机制的头数(如 Transformer 中的 num_heads
    • 词嵌入维度(如 NLP 模型中的 hidden_size
  • 示例(以 Transformer 模型为例):

    config = {
        "num_layers": 12,          # Transformer 层数
        "hidden_size": 768,         # 隐藏层维度
        "num_attention_heads": 12,  # 注意力头数
        "intermediate_size": 3072,  # FeedForward 层维度
        "vocab_size": 30522,        # 词表大小
    }

2. 模型 config 与超参数的区别

类别模型配置(Config)超参数(Hyperparameters)
作用对象模型结构训练过程
典型参数层数、维度、注意力头数、卷积核大小学习率、批量大小(Batch Size)、优化器类型、训练轮数(Epochs)
保存目的重建模型结构(如加载 state_dict 前)复现训练过程(如调参、实验对比)
修改频率低频(模型结构稳定后很少改动)高频(需反复调整优化)

3. 为什么需要单独管理 config

  • 场景 1:重建模型
    若仅保存 state_dict(模型权重),加载时必须先根据 config 重建模型结构,才能加载权重:

    # 保存
    torch.save(model.state_dict(), "model_weights.pth")
    with open("model_config.json", "w") as f:
        json.dump(config, f)
     
    # 加载
    with open("model_config.json", "r") as f:
        config = json.load(f)
    model = MyModel(**config)  # 根据 config 实例化模型
    model.load_state_dict(torch.load("model_weights.pth"))
  • 场景 2:版本控制与实验复现
    config 保存为 JSON/YAML 文件,便于记录不同模型结构的实验效果。

  • 场景 3:共享模型
    开源社区常用模式(如 HuggingFace Transformers):

    • config.json:定义模型结构
    • pytorch_model.bin:保存 state_dict

4. config 与超参数的重叠情况

某些参数可能同时影响模型结构和训练,例如:

  • Dropout 比率:既是模型结构的一部分,也被视为超参数。
  • 词嵌入维度:影响模型结构,但可能因训练资源调整(如显存不足时调小维度)。

此时,可以将两者统一管理,但建议分开存储:

# 统一配置文件(config.yaml)
model:
  num_layers: 12
  hidden_size: 768
training:
  learning_rate: 1e-4
  batch_size: 32

5. 实际代码示例

定义模型时使用 config

class TransformerModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(config["vocab_size"], config["hidden_size"])
        self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config["num_layers"])])
        # ...
 
# 初始化模型
config = {"vocab_size": 30000, "hidden_size": 512, "num_layers": 6}
model = TransformerModel(config)

保存与加载完整模型(含结构和权重):

# 保存(不推荐,但某些场景需要)
full_model = {"config": config, "state_dict": model.state_dict()}
torch.save(full_model, "full_model.pth")
 
# 加载
loaded = torch.load("full_model.pth")
model = TransformerModel(loaded["config"])
model.load_state_dict(loaded["state_dict"])

6. 总结

  • 模型 config:定义模型结构的参数(如层数、维度),是重建模型的必需信息。
  • 超参数:控制训练过程的参数(如学习率、批量大小),影响模型如何学习。
  • 最佳实践
    • 始终保存 config(如 JSON 文件),确保模型可复现。
    • 分开管理模型配置和训练超参数,避免混淆。
亮曰
  • Config 主要是定义模型结构。会根据 Config 决定有多少层等。
  • 超参主要是训练过程中的学习率、 Epochs 等,复现训练过程时需要,但与模型结构无关,只与训练过程有关。

Q: 所谓"权重"就是指 state_dict 中的内容吗?

Reasoning

在 PyTorch 中,"权重"确实是 state_dict 中的核心内容,但 state_dict 包含的信息可能比"权重"更广泛。以下是详细解释:


1. "权重"是什么?

  • 权重(Weights) 是模型中可学习的参数,例如:
    • 全连接层(nn.Linear)中的权重矩阵(weight)和偏置(bias)。
    • 卷积层(nn.Conv2d)的卷积核参数。
    • 嵌入层(nn.Embedding)的词向量矩阵。
  • 这些参数通过反向传播和优化器(如 SGD、Adam)在训练过程中逐步更新。

2. state_dict 中的内容

state_dict 是 PyTorch 中模型或优化器的状态字典,具体包含:

  • 对模型而言
    • 所有权重(weight)、偏置(bias)。
    • 其他与层相关的参数,例如:
      • BatchNorm 层的 running_meanrunning_var(在推理时使用的均值和方差)。
      • LSTMGRU 中的隐藏状态缓存(如果模型设计时显式存储了这些状态)。
    # 示例:一个简单模型的 state_dict
    model = nn.Sequential(
        nn.Linear(10, 20),
        nn.ReLU(),
        nn.BatchNorm1d(20)
    )
    print(model.state_dict().keys())
    # 输出:
    # ['0.weight', '0.bias', '2.weight', '2.bias', '2.running_mean', '2.running_var', '2.num_batches_tracked']
  • 对优化器而言
    • 优化器的状态(如动量缓存、历史梯度等)。
    optimizer = torch.optim.Adam(model.parameters())
    print(optimizer.state_dict().keys())
    # 输出:
    # ['state', 'param_groups']

3. "权重"与 state_dict 的关系

  • state_dict 包含权重,但不仅限于权重
    • 权重是 state_dict 中的一部分(如 conv1.weight, fc.bias)。
    • state_dict 还可能包含非可学习的参数(如 BatchNorm 的统计量 running_mean)。
    • 对于复杂模型(如 Transformer),还可能包含位置编码、注意力掩码等参数。
  • 总结
    • 权重是 state_dict 的核心内容,但 state_dict 是更全面的模型状态表示。

4. 为什么需要 state_dict 而不仅仅是权重?

  • 完整恢复模型状态: 如果仅保存权重,忽略 BatchNormrunning_mean 等参数,在推理时会导致统计量错误,影响模型性能。
  • 兼容性: 加载模型时需确保 state_dict 的键(如层名称)与当前模型结构严格匹配,否则会报错 Missing keysUnexpected keys
  • 灵活性: 可以手动修改 state_dict,例如迁移学习中仅加载部分层的权重。

5. 操作 state_dict 的常见场景

(1) 查看模型权重

for name, param in model.named_parameters():
    print(f"层名: {name}, 参数形状: {param.shape}")

(2) 选择性加载部分权重(迁移学习)

pretrained_dict = torch.load('pretrained.pth')
model_dict = model.state_dict()
 
# 筛选出名称和形状匹配的权重
pretrained_dict = {
    k: v for k, v in pretrained_dict.items()
    if k in model_dict and v.shape == model_dict[k].shape
}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

(3) 手动修改权重

# 将第一层卷积的权重置零
with torch.no_grad():
    model.conv1.weight.copy_(torch.zeros_like(model.conv1.weight))

6. 总结

  • 权重是模型的可学习参数,是 state_dict 的核心内容。
  • state_dict 是模型状态的完整表示,包括:
    • 所有权重、偏置。
    • 非可学习的运行时参数(如 BatchNorm 的统计量)。
  • 最佳实践
    • 保存模型时使用 torch.save(model.state_dict(), ...)
    • 加载时需先构建相同结构的模型,再调用 model.load_state_dict()
    • 若需调整模型结构,可选择性加载匹配的权重。
亮曰

state_dict 包括 Parameter 与 Buffer ,模型的可学习参数——权重主要指其中的 Parameter 。