[翻译]PyTorch官方教程中文版:保存和加载模型

datastream.jpg

本文是《PyTorch官方教程中文版》系列文章之一,目录链接:[翻译]PyTorch官方教程中文版:目录

本文翻译自PyTorch官方网站,链接地址:Save and Load the Model

保存和加载模型

本文将介绍如何通过保存和加载模型来保持模型状态,以及如何运行模型。

import torch
import torchvision.models as models

保存和加载模型权重

PyTorch 模型将学习得到的参数存储在名为 state_dict 的内部状态字典中,这些可以通过 torch.save 方法进行保存:

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

上述代码输出:

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/vgg16-397923af.pth

  0%|          | 0.00/528M [00:00<?, ?B/s]
  4%|3         | 18.7M/528M [00:00<00:02, 196MB/s]
  8%|7         | 39.7M/528M [00:00<00:02, 210MB/s]
 12%|#1        | 60.8M/528M [00:00<00:02, 215MB/s]
 15%|#5        | 81.7M/528M [00:00<00:02, 217MB/s]
 20%|#9        | 103M/528M [00:00<00:02, 219MB/s]
 25%|##4       | 131M/528M [00:00<00:01, 246MB/s]
 31%|###       | 164M/528M [00:00<00:01, 276MB/s]
 37%|###7      | 196M/528M [00:00<00:01, 296MB/s]
 42%|####2     | 224M/528M [00:00<00:01, 274MB/s]
 47%|####7     | 251M/528M [00:01<00:01, 255MB/s]
 52%|#####2    | 275M/528M [00:01<00:01, 244MB/s]
 57%|#####6    | 299M/528M [00:01<00:01, 231MB/s]
 61%|######    | 321M/528M [00:01<00:00, 228MB/s]
 66%|######5   | 346M/528M [00:01<00:00, 237MB/s]
 72%|#######1  | 378M/528M [00:01<00:00, 264MB/s]
 77%|#######6  | 406M/528M [00:01<00:00, 273MB/s]
 82%|########1 | 432M/528M [00:01<00:00, 256MB/s]
 87%|########6 | 457M/528M [00:01<00:00, 245MB/s]
 91%|#########1| 481M/528M [00:02<00:00, 220MB/s]
 95%|#########5| 502M/528M [00:02<00:00, 220MB/s]
 99%|#########9| 524M/528M [00:02<00:00, 217MB/s]
100%|##########| 528M/528M [00:02<00:00, 240MB/s]

要加载模型的权重,首先要创建一个该模型的实例,然后使用 load_state_dict() 方法加载权重。

model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

上述代码输出:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意:确保在开始推理之前调用 model.eval() 方法,把模型设置为推理模式,否则将导致推理结果不一致。

保存和加载模型的结构

如果只加载模型权重,那么需要先知道模型使用的类(class),并首先创建模型类的实例,因为模型类定义了神经网络的结构。当我们需要把模型的结构和权重一起保存时,可以将模型传递给保存函数:

torch.save(model, 'model.pth')

并像这样加载模型:

model = torch.load('model.pth')

注意:此方法依赖 Python 的 pickle 模块,并且依赖实际类定义(译者注:即加载模型时,定义模型结构的类要可用)。

相关教程


芸芸小站首发,阅读原文:


最后编辑:2023年08月11日 ©版权所有,转载须保留原文链接