Skip to content

Commit ac76708

Browse files
committed
Add VGG model implementation and hub configuration
1 parent bdda63f commit ac76708

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

hubconf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
dependencies = ["torch"]
2+
3+
from vgg import cifar10_vgg9_bn

vgg.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
import torch.nn as nn
3+
from collections import OrderedDict, defaultdict
4+
5+
class VGG(nn.Module):
6+
ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
7+
8+
def __init__(self, state_dict=None) -> None:
9+
super().__init__()
10+
11+
layers = []
12+
counts = defaultdict(int)
13+
14+
def add(name: str, layer: nn.Module) -> None:
15+
layers.append((f"{name}{counts[name]}", layer))
16+
counts[name] += 1
17+
18+
in_channels = 3
19+
for x in self.ARCH:
20+
if x != 'M':
21+
# conv-bn-relu
22+
add("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
23+
add("bn", nn.BatchNorm2d(x))
24+
add("relu", nn.ReLU(True))
25+
in_channels = x
26+
else:
27+
# maxpool
28+
add("pool", nn.MaxPool2d(2))
29+
30+
self.backbone = nn.Sequential(OrderedDict(layers))
31+
self.classifier = nn.Linear(512, 10)
32+
33+
self.state_dict = state_dict
34+
if state_dict is not None:
35+
self.recover_model()
36+
37+
def forward(self, x: torch.Tensor) -> torch.Tensor:
38+
# backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
39+
x = self.backbone(x)
40+
41+
# avgpool: [N, 512, 2, 2] => [N, 512]
42+
x = x.mean([2, 3])
43+
44+
# classifier: [N, 512] => [N, 10]
45+
x = self.classifier(x)
46+
return x
47+
48+
def recover_model(self):
49+
if self.state_dict is not None:
50+
self.load_state_dict(self.state_dict)
51+
52+
53+
def cifar10_vgg9_bn(pretrained=False, **kwargs):
54+
if pretrained:
55+
state_dict = torch.hub.load_state_dict_from_url(
56+
'https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth',
57+
progress=True)
58+
state_dict = state_dict['state_dict']
59+
else:
60+
state_dict = None
61+
model = VGG(state_dict)
62+
return model

0 commit comments

Comments
 (0)