博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch专题 --- load模型
阅读量:4005 次
发布时间:2019-05-24

本文共 16617 字,大约阅读时间需要 55 分钟。

转载自

博客最后加了一段重载模型的另一种方法

一般来说,保存模型是把参数全部用model.cpu().state_dict(), 然后加载模型时一般用 model.load_state_dict(torch.load(model_path))。 值得注意的是:torch.load 返回的是一个 OrderedDict.

import torchimport torch.nn as nnclass Net_old(nn.Module):    def __init__(self):        super(Net_old, self).__init__()        self.nets = nn.Sequential(            torch.nn.Conv2d(1, 2, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(2, 1, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(1, 1, 3)        )    def forward(self, x):        return self.nets(x)class Net_new(nn.Module):    def __init__(self):        super(Net_old, self).__init__()        self.conv1 = torch.nn.Conv2d(1, 2, 3)        self.r1 = torch.nn.ReLU(True)        self.conv2 = torch.nn.Conv2d(2, 1, 3)        self.r2 = torch.nn.ReLU(True)        self.conv3 = torch.nn.Conv2d(1, 1, 3)    def forward(self, x):        x = self.conv1(x)        x = self.r1(x)        x = self.conv2(x)        x = self.r2(x)        x = self.conv3(x)        return xnetwork = Net_old()torch.save(network.cpu().state_dict(), 't.pth')pretrained_net = torch.load('t.pth')print(pretrained_net)for key, v in enumerate(pretrained_net):    print key, v 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

可以看到

OrderedDict([('nets.0.weight',(0 ,0 ,.,.) = -0.2436  0.2523  0.3097 -0.0315 -0.1307  0.0759  0.0750  0.1894 -0.0761(1 ,0 ,.,.) =  0.0280 -0.2178  0.0914  0.3227 -0.0121 -0.0016 -0.0654 -0.0584 -0.1655[torch.FloatTensor of size 2x1x3x3]), ('nets.0.bias',-0.0507-0.2836[torch.FloatTensor of size 2]), ('nets.2.weight',(0 ,0 ,.,.) = -0.2233  0.0279 -0.0511 -0.0242 -0.1240 -0.0511  0.2266  0.1385 -0.1070(0 ,1 ,.,.) = -0.0943 -0.1403  0.0979 -0.2163  0.1906 -0.2269 -0.1984  0.0843 -0.0719[torch.FloatTensor of size 1x2x3x3]), ('nets.2.bias',-0.1420[torch.FloatTensor of size 1]), ('nets.4.weight',(0 ,0 ,.,.) =  0.1981 -0.0250  0.2429  0.3012  0.2428 -0.0114  0.2878 -0.2134  0.1173[torch.FloatTensor of size 1x1x3x3]), ('nets.4.bias',1.00000e-02 * -5.8426[torch.FloatTensor of size 1])])0 nets.0.weight1 nets.0.bias2 nets.2.weight3 nets.2.bias4 nets.4.weight5 nets.4.bias 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

说明.state_dict()只是把所有模型的参数都以OrderedDict的形式存下来。通过

for key, v in enumerate(pretrained_net):    print key, v 
1
2

得知这些参数的顺序!,当然要看具体的值

for key, v in pretrained_net.items():    print key, v 
1
2
nets.0.weight(0 ,0 ,.,.) = -0.2444 -0.3148  0.1626  0.2531 -0.0859 -0.0236  0.1635  0.1113 -0.1110(1 ,0 ,.,.) =  0.2374 -0.2931 -0.1806 -0.1456  0.2264 -0.0114  0.1813  0.1134 -0.2095[torch.FloatTensor of size 2x1x3x3]nets.0.bias-0.3087-0.2407[torch.FloatTensor of size 2]nets.2.weight(0 ,0 ,.,.) = -0.2206 -0.1151 -0.0783  0.0723 -0.2008  0.0568 -0.0964 -0.1505 -0.1203(0 ,1 ,.,.) =  0.0131  0.1329 -0.1763  0.1276 -0.2025 -0.0075 -0.1167 -0.1833  0.1103[torch.FloatTensor of size 1x2x3x3]nets.2.bias-0.1858[torch.FloatTensor of size 1]nets.4.weight(0 ,0 ,.,.) = -0.1019  0.0534  0.2018 -0.0600 -0.1389 -0.0275  0.0696  0.0360  0.1560[torch.FloatTensor of size 1x1x3x3]nets.4.bias1.00000e-03 * -5.6003[torch.FloatTensor of size 1] 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

如果哪一天我们需要重新写这个网络的,比如使用Net_new,这个网络是将每一层都作为类的一个属性。如果直接load

import torchimport torch.nn as nnclass Net_old(nn.Module):    def __init__(self):        super(Net_old, self).__init__()        self.nets = nn.Sequential(            torch.nn.Conv2d(1, 2, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(2, 1, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(1, 1, 3)        )    def forward(self, x):        return self.nets(x)class Net_new(nn.Module):    def __init__(self):        super(Net_new, self).__init__()        self.conv1 = torch.nn.Conv2d(1, 2, 3)        self.r1 = torch.nn.ReLU(True)        self.conv2 = torch.nn.Conv2d(2, 1, 3)        self.r2 = torch.nn.ReLU(True)        self.conv3 = torch.nn.Conv2d(1, 1, 3)    def forward(self, x):        x = self.conv1(x)        x = self.r1(x)        x = self.conv2(x)        x = self.r2(x)        x = self.conv3(x)        return xnetwork = Net_old()torch.save(network.cpu().state_dict(), 't.pth')pretrained_net = torch.load('t.pth')# Show keys of pretrained modelfor key, v in pretrained_net.items():    print key# Define new network, and directly load the state_dictnew_network = Net_new()new_network.load_state_dict(pretrained_net) 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

会出现unexpected key

nets.0.weightnets.0.biasnets.2.weightnets.2.biasnets.4.weightnets.4.biasTraceback (most recent call last):  File "Blog.py", line 44, in 
new_network.load_state_dict(pretrained_net) File "/home/vis/xxx/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict .format(name))KeyError: 'unexpected key "nets.0.weight" in state_dict'
1
2
3
4
5
6
7
8
9
10
11
12

这是因为,我们新的网络,都是“属性形式的”,查看新网络的state_dict

conv1.weightconv1.biasconv2.weightconv2.biasconv3.weightconv3.bias 
1
2
3
4
5
6

strict=False加载模型的正确解读

你可能会决定

import torchimport torch.nn as nnclass Net_old(nn.Module):    def __init__(self):        super(Net_old, self).__init__()        self.nets = nn.Sequential(            torch.nn.Conv2d(1, 2, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(2, 1, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(1, 1, 3)        )    def forward(self, x):        return self.nets(x)class Net_new(nn.Module):    def __init__(self):        super(Net_new, self).__init__()        self.conv1 = torch.nn.Conv2d(1, 2, 3)        self.r1 = torch.nn.ReLU(True)        self.conv2 = torch.nn.Conv2d(2, 1, 3)        self.r2 = torch.nn.ReLU(True)        self.conv3 = torch.nn.Conv2d(1, 1, 3)    def forward(self, x):        x = self.conv1(x)        x = self.r1(x)        x = self.conv2(x)        x = self.r2(x)        x = self.conv3(x)        return xold_network = Net_old()torch.save(old_network.cpu().state_dict(), 't.pth')pretrained_net = torch.load('t.pth')# Show keys of pretrained modelfor key, v in pretrained_net.items():    print keyprint('****Before loading********')new_network = Net_new()print(torch.sum(old_network.nets[0].weight.data))print(torch.sum(new_network.conv1.weight.data))for key, _ in new_network.state_dict().items():    print keyprint('-----After loading------')new_network.load_state_dict(pretrained_net, strict=False)# So you think that this two values are the same?? Hah!print(torch.sum(old_network.nets[0].weight.data))print(torch.sum(new_network.conv1.weight.data))for key, _ in new_network.state_dict().items():    print key 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

输出

nets.0.weightnets.0.biasnets.2.weightnets.2.biasnets.4.weightnets.4.bias****Before loading********-0.8826888054610.34207585454conv1.weightconv1.biasconv2.weightconv2.biasconv3.weightconv3.bias-----After loading-------0.8826888054610.34207585454conv1.weightconv1.biasconv2.weightconv2.biasconv3.weightconv3.bias 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

数值一点变化都没有,说明“strict=False”没有那么智能! 它直接忽略那些没有的dict,有相同的就复制,没有就直接放弃赋值!

import torchimport torch.nn as nnclass Net_old(nn.Module):    def __init__(self):        super(Net_old, self).__init__()        self.nets = nn.Sequential(            torch.nn.Conv2d(1, 2, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(2, 1, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(1, 1, 3)        )    def forward(self, x):        return self.nets(x)class Net_new(nn.Module):    def __init__(self):        super(Net_new, self).__init__()        self.conv1 = torch.nn.Conv2d(1, 2, 3)        self.r1 = torch.nn.ReLU(True)        self.conv2 = torch.nn.Conv2d(2, 1, 3)        self.r2 = torch.nn.ReLU(True)##### 在Net_new也加入了一个'nets'属性        self.nets = nn.Sequential(            torch.nn.Conv2d(1, 2, 3)        )    def forward(self, x):        x = self.conv1(x)        x = self.r1(x)        x = self.conv2(x)        x = self.r2(x)        x = self.conv3(x)        x = self.nets(x)        return xold_network = Net_old()torch.save(old_network.cpu().state_dict(), 't.pth')pretrained_net = torch.load('t.pth')# Show keys of pretrained modelfor key, v in pretrained_net.items():    print keyprint('****Before loading********')new_network = Net_new()print(torch.sum(old_network.nets[0].weight.data))print(torch.sum(new_network.conv1.weight.data))print(torch.sum(new_network.nets[0].weight.data))for key, _ in new_network.state_dict().items():    print keyprint('-----After loading------')new_network.load_state_dict(pretrained_net, strict=False)print(torch.sum(old_network.nets[0].weight.data))print(torch.sum(new_network.conv1.weight.data))# Hopefully, this value equals to 'old_network.nets[0].weight'print(torch.sum(new_network.nets[0].weight.data))for key, _ in new_network.state_dict().items():    print key 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

结果:

nets.0.weightnets.0.biasnets.2.weightnets.2.biasnets.4.weightnets.4.bias****Before loading********-0.1976437717680.8625088036061.21658478677conv1.weightconv1.biasconv2.weightconv2.biasconv3.weightconv3.biasnets.0.weightnets.0.bias-----After loading-------0.1976437717680.862508803606-0.197643771768conv1.weightconv1.biasconv2.weightconv2.biasconv3.weightconv3.biasnets.0.weightnets.0.bias 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

发现After loading之后,预期的两个值一致。

总结:用strict=False进行加载模型,则“能塞则塞,不能塞则丢”。load一般是依据key来加载的,一旦有key不匹配则出错。如果设置strict=False,则直接忽略不匹配的key,对于匹配的key则进行正常的赋值。

Strict=False的用途

所以说,当你一个模型训练好之后,你想往里面加几层,那么strict=False可以很容易的加载预训练的参数(注意检查key是否匹配)。只要key能让其进性匹配则可以进行正确的赋值。

出现unexpected key module.xxx.weight问题

有时候你的模型保存时含有 nn.DataParallel时,就会发现所有的dict都会有 module的前缀。

这时候加载含有module前缀的模型时,可能会出错。其实你只要移除这些前缀即可

pretrained_net = Net_OLD()  pretrained_net_dict = torch.load(save_path)  new_state_dict = OrderedDict()  for k, v in pretrained_net_dict.items():      name = k[7:] # remove `module.`      new_state_dict[name] = v  # load params  pretrained_net.load_state_dict(new_state_dict) 
1
2
3
4
5
6
7
8

总结

  • 保存的Dict是按照net.属性.weight来存储的。如果这个属性是一个Sequential,我们可以类似这样net.seqConvs.0.weight来获得。
    当然在定义的类中,拿到Sequential的某一层用[], 比如self.seqConvs[0].weight.
  • strict=False是没有那么智能,遵循有相同的key则赋值,否则直接丢弃。

附加

由于第一段的问题还没解决,即如何将Sequential定义的网络的模型参数,加载到用“属性一层层”定义的网络中?

下面是一种比较ugly的方法:

import torchimport torch.nn as nnclass Net_old(nn.Module):    def __init__(self):        super(Net_old, self).__init__()        self.nets = nn.Sequential(            torch.nn.Conv2d(1, 2, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(2, 1, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(1, 1, 3)        )    def forward(self, x):        return self.nets(x)class Net_new(nn.Module):    def __init__(self):        super(Net_new, self).__init__()        self.conv1 = torch.nn.Conv2d(1, 2, 3)        self.r1 = torch.nn.ReLU(True)        self.conv2 = torch.nn.Conv2d(2, 1, 3)        self.r2 = torch.nn.ReLU(True)        self.conv3 = torch.nn.Conv2d(1, 1, 3)    def forward(self, x):        x = self.conv1(x)        x = self.r1(x)        x = self.conv2(x)        x = self.r2(x)        x = self.conv3(x)        x = self.nets(x)        return x    def _initialize_weights_from_net(self):        save_path = 't.pth'        print('Successfully load model '+save_path)        # First load the net.        pretrained_net = Net_old()        pretrained_net_dict = torch.load(save_path)        # load params        pretrained_net.load_state_dict(pretrained_net_dict)        new_convs = self.get_convs()        cnt = 0        # Because sequential is a generator.                for i, name in enumerate(pretrained_net.nets):            if isinstance(name, torch.nn.Conv2d):                print('Assign weight of pretrained model layer : ', name, ' to layer: ', new_convs[cnt])                new_convs[cnt].weight.data = name.weight.data                new_convs[cnt].bias.data = name.bias.data                cnt += 1    def get_convs(self):        return [self.conv1, self.conv2, self.conv3]old_network = Net_old()torch.save(old_network.cpu().state_dict(), 't.pth')pretrained_net = torch.load('t.pth')# Show keys of pretrained modelfor key, v in pretrained_net.items():    print keyprint('****Before loading********')new_network = Net_new()print(torch.sum(old_network.nets[0].weight.data))print(torch.sum(new_network.conv1.weight.data))print('-----New loading method------')new_network._initialize_weights_from_net()print(torch.sum(old_network.nets[0].weight.data))print(torch.sum(new_network.conv1.weight.data)) 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

输出:

nets.0.weightnets.0.biasnets.2.weightnets.2.biasnets.4.weightnets.4.bias****Before loading********0.5103135854010.198701560497-----New loading method------Successfully load model t.pth('Assign weight of pretrained model layer : ', Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)), ' to layer: ', Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)))('Assign weight of pretrained model layer : ', Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1)), ' to layer: ', Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1)))('Assign weight of pretrained model layer : ', Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)), ' to layer: ', Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)))0.5103135854010.510313585401 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

搞定!

以上都是原作者的博客,在此感谢作者的分享,下面给出另一种较为方便的加载模型的方法:

import torchfrom collections import OrderedDictimport torch.nn as nnclass Net_old(nn.Module):    def __init__(self):        super(Net_old, self).__init__()        self.nets = nn.Sequential(            torch.nn.Conv2d(1, 2, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(2, 1, 3),            torch.nn.ReLU(True),            torch.nn.Conv2d(1, 1, 3)        )    def forward(self, x):        return self.nets(x)class Net_new(nn.Module):    def __init__(self):        super(Net_new, self).__init__()        self.conv1 = torch.nn.Conv2d(1, 2, 3)        self.r1 = torch.nn.ReLU(True)        self.conv2 = torch.nn.Conv2d(2, 1, 3)        self.r2 = torch.nn.ReLU(True)        self.conv3 = torch.nn.Conv2d(1, 1, 3)    def forward(self, x):        x = self.conv1(x)        x = self.r1(x)        x = self.conv2(x)        x = self.r2(x)        x = self.conv3(x)        x = self.nets(x)        return xold_network = Net_old()torch.save(old_network.cpu().state_dict(), 't.pth')new_network = Net_new()pretrained_net = torch.load('t.pth')new_dict = ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias']new_state_dict = OrderedDict()for i, (k, v) in enumerate(pretrained_net.items()):    new_state_dict[new_dict[i]] = v    print('****Before loading********')print(torch.sum(old_network.nets[0].weight.data))print(torch.sum(new_network.conv1.weight.data))print('-----New loading method------')new_network.load_state_dict(new_state_dict)print(torch.sum(old_network.nets[0].weight.data))print(torch.sum(new_network.conv1.weight.data))

输出:

****Before loading********tensor(1.3759)tensor(0.3301)-----New loading method------tensor(1.3759)tensor(1.3759)

结果完全OK!

你可能感兴趣的文章
[leetCode By Python] 14. Longest Common Prefix
查看>>
[LeetCode By Python]107. Binary Tree Level Order Traversal II
查看>>
[LeetCode By Python]108. Convert Sorted Array to Binary Search Tree
查看>>
[leetCode By Python]111. Minimum Depth of Binary Tree
查看>>
[LeetCode By Python]118. Pascal's Triangle
查看>>
[LeetCode By Python]121. Best Time to Buy and Sell Stock
查看>>
[LeetCode By Python]122. Best Time to Buy and Sell Stock II
查看>>
[LeetCode By Python]125. Valid Palindrome
查看>>
[LeetCode By Python]136. Single Number
查看>>
[LeetCode By Python]167. Two Sum II - Input array is sorted
查看>>
[LeetCode BY Python]169. Majority Element
查看>>
[LeetCode By Python]172. Factorial Trailing Zeroes
查看>>
[LeetCode By MYSQL] Combine Two Tables
查看>>
python jieba分词模块的基本用法
查看>>
[CCF BY C++]2017.12 最小差值
查看>>
[CCF BY C++]2017-12 游戏
查看>>
如何打开ipynb文件
查看>>
[Leetcode BY python ]190. Reverse Bits
查看>>
面试---刷牛客算法题
查看>>
Android下调用收发短信邮件等(转载)
查看>>