Skip to content

Commit cef7963

Browse files
committed
bugs fixed
1 parent b7731c4 commit cef7963

File tree

6 files changed

+72
-34
lines changed

6 files changed

+72
-34
lines changed

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,34 @@
11
# fast-weights-pytorch
22
PyTorch Implementation of the paper [Using Fast Weights to Attend to the Recent Past]
3+
Code for generating sequential data is forked from [jiamings/fast-weights](https://github.com/jiamings/fast-weights/tree/master)
4+
5+
## Dependencies
6+
Python >= 3.6
7+
Pytorch
8+
TensorboardX
9+
Numpy
10+
Pickle
11+
12+
## Usage
13+
Generate a dataset
14+
15+
```
16+
$ python generator.py
17+
```
18+
19+
Train the model of fast-weights
20+
21+
```
22+
$ python fast_weights.py
23+
```
24+
25+
## Training Result
26+
![](fig/acc.png)
27+
28+
![](fig/loss.png)
29+
30+
### References
31+
32+
[Using Fast Weights to Attend to the Recent Past](https://arxiv.org/abs/1610.06258). Jimmy Ba, Geoffrey Hinton, Volodymyr Mnih, Joel Z. Leibo, Catalin Ionescu.
33+
34+
[Layer Normalization](https://arxiv.org/abs/1607.06450). Jimmy Ba, Ryan Kiros, Geoffery Hinton.

model.py renamed to fast_weights.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import print_function
22

3+
import os
34
import torch
45
import numpy as np
56
import torch.nn as nn
@@ -25,42 +26,46 @@ class fast_weights_model(nn.Module):
2526
"""docstring for fast_weights_model"""
2627
def __init__(self, batch_size, step_num, elem_num, hidden_num):
2728
super(fast_weights_model, self).__init__()
29+
self.batch_size = batch_size
2830
self.x = Variable(torch.randn(batch_size, step_num, elem_num).type(torch.float32))
2931
self.y = Variable(torch.randn(batch_size, elem_num).type(torch.float32))
30-
self.l = torch.tensor([0.9], dtype=torch.float32)
31-
self.e = torch.tensor([0.5], dtype=torch.float32)
32+
self.l = nn.Parameter(torch.tensor([0.9], dtype=torch.float32))
33+
self.e = nn.Parameter(torch.tensor([0.5], dtype=torch.float32))
3234

33-
self.w1 = Variable(torch.empty(elem_num, 50).uniform_(-np.sqrt(0.02), np.sqrt(0.02)))
34-
self.b1 = Variable(torch.zeros([1, 50]).type(torch.float32))
35-
self.w2 = Variable(torch.empty(500, 100).uniform_(-np.sqrt(0.01), np.sqrt(0.01)))
36-
self.b2 = Variable(torch.zeros([1, 100]).type(torch.float32))
37-
self.w3 = Variable(torch.empty(hidden_num, 100).uniform_(-np.sqrt(0.01), np.sqrt(0.01)))
38-
self.b3 = Variable(torch.zeros([1, 100]).type(torch.float32))
39-
self.w4 = Variable(torch.empty(100, elem_num).uniform_(-np.sqrt(1.0 / elem_num), np.sqrt(1.0 / elem_num)))
40-
self.b4 = Variable(torch.zeros([1, elem_num]).type(torch.float32))
35+
self.w1 = nn.Parameter(torch.empty(elem_num, 50).uniform_(-np.sqrt(0.02), np.sqrt(0.02)), requires_grad=True)
36+
self.b1 = nn.Parameter(torch.zeros([1, 50]).type(torch.float32), requires_grad=True)
37+
self.w2 = nn.Parameter(torch.empty(50, 100).uniform_(-np.sqrt(0.01), np.sqrt(0.01)), requires_grad=True)
38+
self.b2 = nn.Parameter(torch.zeros([1, 100]).type(torch.float32), requires_grad=True)
39+
self.w3 = nn.Parameter(torch.empty(hidden_num, 100).uniform_(-np.sqrt(0.01), np.sqrt(0.01)), requires_grad=True)
40+
self.b3 = nn.Parameter(torch.zeros([1, 100]).type(torch.float32), requires_grad=True)
41+
self.w4 = nn.Parameter(torch.empty(100, elem_num).uniform_(-np.sqrt(1.0 / elem_num), np.sqrt(1.0 / elem_num)), requires_grad=True)
42+
self.b4 = nn.Parameter(torch.zeros([1, elem_num]).type(torch.float32), requires_grad=True)
4143

42-
self.w = Variable(torch.tensor(0.05 * np.identity(hidden_num)).type(torch.float32))
44+
self.w = nn.Parameter(torch.tensor(0.05 * np.identity(hidden_num)).type(torch.float32), requires_grad=True)
4345

44-
self.c = Variable(torch.empty(100, hidden_num).uniform_(-np.sqrt(hidden_num), np.sqrt(hidden_num)))
46+
self.c = nn.Parameter(torch.empty(100, hidden_num).uniform_(-np.sqrt(hidden_num), np.sqrt(hidden_num)), requires_grad=True)
4547

46-
self.g = Variable(torch.ones([1, hidden_num]).type(torch.float32))
47-
self.b = Variable(torch.ones([1, hidden_num]).type(torch.float32))
48+
self.g = nn.Parameter(torch.ones([1, hidden_num]).type(torch.float32), requires_grad=True)
49+
self.b = nn.Parameter(torch.ones([1, hidden_num]).type(torch.float32), requires_grad=True)
4850

49-
def forward(self, bx, by)
50-
self.x = bx
51-
self.y = by
52-
a = torch.zeros([batch_size, hidden_num, hidden_num]).type(torch.float32)
53-
h = torch.zeros([batch_size, hidden_num]).type(torch.float32)
51+
def forward(self, bx, by):
52+
self.x = torch.tensor(bx)
53+
self.y = torch.tensor(by)
54+
#print(bx.size)
55+
#print(by.size)
56+
a = torch.zeros([self.batch_size, HIDDEN_NUM, HIDDEN_NUM]).type(torch.float32)
57+
h = torch.zeros([self.batch_size, HIDDEN_NUM]).type(torch.float32)
5458

5559
la = []
5660

57-
for i in range(0, step_num):
58-
s1 = torch.relu(torch.matmul(self.x[:, t, :], self.w1) + self.b1)
61+
for i in range(0, STEP_NUM):
62+
s1 = torch.relu(torch.matmul(self.x[:, i, :], self.w1) + self.b1)
63+
#print(s1.shape, self.w2.shape)
5964
z = torch.relu(torch.matmul(s1, self.w2) + self.b2)
6065

6166
h = torch.relu(torch.matmul(h, self.w) + torch.matmul(z, self.c))
6267

63-
hs = torch.reshape(h, [batch_size, 1, hidden_num])
68+
hs = torch.reshape(h, [self.batch_size, 1, HIDDEN_NUM])
6469

6570
hh = hs
6671

@@ -75,7 +80,7 @@ def forward(self, bx, by)
7580
sig = torch.sqrt(torch.mean(torch.pow((hs - mu), 2), 0))
7681
hs = torch.relu(torch.div(torch.mul(self.g, (hs - mu)), sig) + self.b)
7782

78-
h = torch.reshape(hs, [batch_size, hidden_num])
83+
h = torch.reshape(hs, [self.batch_size, HIDDEN_NUM])
7984

8085
h = torch.relu(torch.matmul(h, self.w3) + self.b3)
8186
logits = torch.matmul(h, self.w4) + self.b4
@@ -85,12 +90,12 @@ def forward(self, bx, by)
8590

8691
return self.loss, self.acc
8792

88-
def train(self, save = 0, verbose = 0):
89-
model = fast_weights_model(STEP_NUM, ELEM_NUM, HIDDEN_NUM)
90-
model.train()
93+
def train(save = 0, verbose = 0):
9194
batch_size = cfg.train.batch_size
95+
model = fast_weights_model(batch_size, STEP_NUM, ELEM_NUM, HIDDEN_NUM)
9296
start_time = time.time()
93-
optimizer = torch.optim.Adam(model.paramters(), lr=cfg.train.model_lr)
97+
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.train.model_lr)
98+
model.train()
9499
writer = SummaryWriter(logdir=os.path.join(cfg.logdir, cfg.exp_name), flush_secs=30)
95100
checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name))
96101
start_epoch = 0
@@ -99,7 +104,8 @@ def train(self, save = 0, verbose = 0):
99104
for epoch in range(start_epoch, cfg.train.max_epochs):
100105
for idx in range(batch_idxs):
101106
gloabl_step = epoch * cfg.num_train + idx + 1
102-
bx, by = ar_data.train.next_batch(batch_size=cfg.batch_size)
107+
#print(ar_data.train._x)
108+
bx, by = ar_data.train.next_batch(batch_size=100)
103109
loss, acc = model(bx, by)
104110
optimizer.zero_grad()
105111
loss.backward()

fig/acc.png

54.9 KB
Loading

fig/loss.png

54.8 KB
Loading

generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import random
3-
import cPickle as pickle
3+
import pickle
44

55
num_train = 60000
66
num_val = 10000

retrieval.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import numpy as np
22
import collections
3-
try:
4-
import cPickle as pickle
5-
except ImportError:
6-
import pickle
3+
import pickle
74

85

96
Datasets = collections.namedtuple('Datasets', ['train', 'val', 'test'])
@@ -40,10 +37,13 @@ def next_batch(self, batch_size):
4037
start = 0
4138
self._index_in_epoch = batch_size
4239
end = self._index_in_epoch
40+
#print(end)
41+
#print(self._x[self.perm[start:end]], self._x[self.perm[start:end]].type)
4342
return self._x[self.perm[start:end]], self._y[self.perm[start:end]]
4443

4544

46-
def read_data(data_path='associative-retrieval.pkl'):
45+
def read_data():
46+
data_path='associative-retrieval.pkl'
4747
with open(data_path, 'rb') as f:
4848
d = pickle.load(f)
4949
x_train = d['x_train']

0 commit comments

Comments
 (0)