1+ import torch
2+ import torch .nn as nn
3+ from collections import OrderedDict , defaultdict
4+
5+ class VGG (nn .Module ):
6+ ARCH = [64 , 128 , 'M' , 256 , 256 , 'M' , 512 , 512 , 'M' , 512 , 512 , 'M' ]
7+
8+ def __init__ (self , state_dict = None ) -> None :
9+ super ().__init__ ()
10+
11+ layers = []
12+ counts = defaultdict (int )
13+
14+ def add (name : str , layer : nn .Module ) -> None :
15+ layers .append ((f"{ name } { counts [name ]} " , layer ))
16+ counts [name ] += 1
17+
18+ in_channels = 3
19+ for x in self .ARCH :
20+ if x != 'M' :
21+ # conv-bn-relu
22+ add ("conv" , nn .Conv2d (in_channels , x , 3 , padding = 1 , bias = False ))
23+ add ("bn" , nn .BatchNorm2d (x ))
24+ add ("relu" , nn .ReLU (True ))
25+ in_channels = x
26+ else :
27+ # maxpool
28+ add ("pool" , nn .MaxPool2d (2 ))
29+
30+ self .backbone = nn .Sequential (OrderedDict (layers ))
31+ self .classifier = nn .Linear (512 , 10 )
32+
33+ self .state_dict = state_dict
34+ if state_dict is not None :
35+ self .recover_model ()
36+
37+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
38+ # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
39+ x = self .backbone (x )
40+
41+ # avgpool: [N, 512, 2, 2] => [N, 512]
42+ x = x .mean ([2 , 3 ])
43+
44+ # classifier: [N, 512] => [N, 10]
45+ x = self .classifier (x )
46+ return x
47+
48+ def recover_model (self ):
49+ if self .state_dict is not None :
50+ self .load_state_dict (self .state_dict )
51+
52+
53+ def cifar10_vgg9_bn (pretrained = False , ** kwargs ):
54+ if pretrained :
55+ state_dict = torch .hub .load_state_dict_from_url (
56+ 'https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth' ,
57+ progress = True )
58+ state_dict = state_dict ['state_dict' ]
59+ else :
60+ state_dict = None
61+ model = VGG (state_dict )
62+ return model
0 commit comments