以下是一个保存和加载PyTorch模型的示例代码:
保存模型:
import torch
import torch.nn as nn
# 创建模型
model = nn.Linear(10, 2)
# 保存模型的state_dict
torch.save(model.state_dict(), 'model.pth')
加载模型:
import torch
import torch.nn as nn
# 创建模型
model = nn.Linear(10, 2)
# 加载模型的state_dict
model.load_state_dict(torch.load('model.pth'))
如果需要保存整个模型(包括模型的结构和参数),可以使用torch.save()
和torch.load()
函数来保存和加载整个模型:
保存整个模型:
import torch
import torch.nn as nn
# 创建模型
model = nn.Linear(10, 2)
# 保存整个模型
torch.save(model, 'model.pth')
加载整个模型:
import torch
import torch.nn as nn
# 加载整个模型
model = torch.load('model.pth')
请注意,加载整个模型时,需要确保加载的模型与保存的模型具有相同的代码结构。