Skip to content

Commit ddaf46b

Browse files
committed
first commit
1 parent ead5aef commit ddaf46b

File tree

5 files changed

+338
-0
lines changed

5 files changed

+338
-0
lines changed

config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from attrdict import AttrDict
2+
import os
3+
4+
cfg = AttrDict({
5+
# 'exp_name': 'test-len10-delta',
6+
# 'exp_name': 'test-len1-fixedscale-aggre-super',
7+
# 'exp_name': 'test-aggre-super',
8+
# 'exp_name': 'test-mask',
9+
'exp_name': 'test-proposal',
10+
'resume': True,
11+
'device': 'cuda:0',
12+
# 'device': 'cpu',
13+
14+
'train': {
15+
'batch_size': 100,
16+
'model_lr': 1e-4,
17+
'max_epochs': 1000
18+
},
19+
'valid': {
20+
'batch_size': 64
21+
},
22+
'num_train': 60000,
23+
'logdir': 'logs/',
24+
'checkpointdir': 'checkpoints/',
25+
})

generator.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import numpy as np
2+
import random
3+
import cPickle as pickle
4+
5+
num_train = 60000
6+
num_val = 10000
7+
num_test = 10000
8+
9+
step_num = 4
10+
elem_num = 26 + 10 + 1
11+
12+
x_train = np.zeros([num_train, step_num * 2 + 3, elem_num], dtype=np.float32)
13+
x_val = np.zeros([num_val, step_num * 2 + 3, elem_num], dtype=np.float32)
14+
x_test = np.zeros([num_test, step_num * 2 + 3, elem_num], dtype=np.float32)
15+
16+
y_train = np.zeros([num_train, elem_num], dtype=np.float32)
17+
y_val = np.zeros([num_val, elem_num], dtype=np.float32)
18+
y_test = np.zeros([num_test, elem_num], dtype=np.float32)
19+
20+
21+
def get_one_hot(c):
22+
a = np.zeros([elem_num])
23+
if ord('a') <= ord(c) <= ord('z'):
24+
a[ord(c) - ord('a')] = 1
25+
elif ord('0') <= ord(c) <= ord('9'):
26+
a[ord(c) - ord('0') + 26] = 1
27+
else:
28+
a[-1] = 1
29+
return a
30+
31+
32+
def generate_one():
33+
a = np.zeros([step_num * 2 + 3, elem_num])
34+
d = {}
35+
st = ''
36+
37+
for i in range(0, step_num):
38+
c = random.randint(0, 25)
39+
while d.has_key(c):
40+
c = random.randint(0, 25)
41+
b = random.randint(0, 9)
42+
d[c] = b
43+
s, t = chr(c + ord('a')), chr(b + ord('0'))
44+
st += s + t
45+
a[i*2] = get_one_hot(s)
46+
a[i*2+1] = get_one_hot(t)
47+
48+
s = random.choice(d.keys())
49+
t = chr(s + ord('a'))
50+
r = chr(d[s] + ord('0'))
51+
a[step_num * 2] = get_one_hot('?')
52+
a[step_num * 2 + 1] = get_one_hot('?')
53+
a[step_num * 2 + 2] = get_one_hot(t)
54+
st += '??' + t + r
55+
e = get_one_hot(r)
56+
return a, e
57+
58+
if __name__ == '__main__':
59+
for i in range(0, num_train):
60+
x_train[i], y_train[i] = generate_one()
61+
62+
for i in range(0, num_test):
63+
x_test[i], y_test[i] = generate_one()
64+
65+
for i in range(0, num_val):
66+
x_val[i], y_val[i] = generate_one()
67+
68+
d = {
69+
'x_train': x_train,
70+
'x_test': x_test,
71+
'x_val': x_val,
72+
'y_train': y_train,
73+
'y_test': y_test,
74+
'y_val': y_val
75+
}
76+
with open('associative-retrieval.pkl', 'wb') as f:
77+
pickle.dump(d, f, protocol=2)

model.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from __future__ import print_function
2+
3+
import torch
4+
import numpy as np
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from config import cfg
8+
from tensorboardX import SummaryWriter
9+
from torch.autograd import Variable
10+
import time
11+
from retrieval import read_data
12+
from util import Checkpointer
13+
14+
ar_data = read_data()
15+
16+
STEP_NUM = 11
17+
ELEM_NUM = 26 + 10 + 1
18+
HIDDEN_NUM = 20
19+
20+
def softmax_cross_entropy_with_logits(logits, labels):
21+
loss = torch.sum(-labels * F.log_softmax(logits, -1), -1)
22+
return loss
23+
24+
class fast_weights_model(nn.Module):
25+
"""docstring for fast_weights_model"""
26+
def __init__(self, batch_size, step_num, elem_num, hidden_num):
27+
super(fast_weights_model, self).__init__()
28+
self.x = Variable(torch.randn(batch_size, step_num, elem_num).type(torch.float32))
29+
self.y = Variable(torch.randn(batch_size, elem_num).type(torch.float32))
30+
self.l = torch.zeros(1, dtype=torch.float32)
31+
self.e = torch.zeros(1, dtype=torch.float32)
32+
33+
self.w1 = Variable(torch.empty(elem_num, 50).uniform_(-np.sqrt(0.02), np.sqrt(0.02)))
34+
self.b1 = Variable(torch.zeros([1, 50]).type(torch.float32))
35+
self.w2 = Variable(torch.empty(500, 100).uniform_(-np.sqrt(0.01), np.sqrt(0.01)))
36+
self.b2 = Variable(torch.zeros([1, 100]).type(torch.float32))
37+
self.w3 = Variable(torch.empty(hidden_num, 100).uniform_(-np.sqrt(0.01), np.sqrt(0.01)))
38+
self.b3 = Variable(torch.zeros([1, 100]).type(torch.float32))
39+
self.w4 = Variable(torch.empty(100, elem_num).uniform_(-np.sqrt(1.0 / elem_num), np.sqrt(1.0 / elem_num)))
40+
self.b4 = Variable(torch.zeros([1, elem_num]).type(torch.float32))
41+
42+
self.w = Variable(torch.tensor(0.05 * np.identity(hidden_num)).type(torch.float32))
43+
44+
self.c = Variable(torch.empty(100, hidden_num).uniform_(-np.sqrt(hidden_num), np.sqrt(hidden_num)))
45+
46+
self.g = Variable(torch.ones([1, hidden_num]).type(torch.float32))
47+
self.b = Variable(torch.ones([1, hidden_num]).type(torch.float32))
48+
49+
def forward(self, bx, by)
50+
a = torch.zeros([batch_size, hidden_num, hidden_num]).type(torch.float32)
51+
h = torch.zeros([batch_size, hidden_num]).type(torch.float32)
52+
53+
la = []
54+
55+
for i in range(0, step_num):
56+
s1 = torch.relu(torch.matmul(self.x[:, t, :], self.w1) + self.b1)
57+
z = torch.relu(torch.matmul(s1, self.w2) + self.b2)
58+
59+
h = torch.relu(torch.matmul(h, self.w) + torch.matmul(z, self.c))
60+
61+
hs = torch.reshape(h, [batch_size, 1, hidden_num])
62+
63+
hh = hs
64+
65+
a = self.l * a + self.e * torch.matmul(hs.transpose(1,2), hs)
66+
67+
la.append(torch.mean(torch.pow(a,2)))
68+
69+
for s in range(1):
70+
hs = torch.reshape(torch.matmul(h, self.w), hh.shape) + \
71+
torch.reshape(torch.matmul(z, self.c), hh.shape) + torch.matmul(hs, a)
72+
mu = torch.mean(hs, 0)
73+
sig = torch.sqrt(torch.mean(torch.pow((hs - mu), 2), 0))
74+
hs = torch.relu(torch.div(torch.mul(self.g, (hs - mu)), sig) + self.b)
75+
76+
h = torch.reshape(hs, [batch_size, hidden_num])
77+
78+
h = torch.relu(torch.matmul(h, self.w3) + self.b3)
79+
logits = torch.matmul(h, self.w4) + self.b4
80+
correct = torch.argmax(logits, dim=1).eq(torch.argmax(self.y, dim=1))
81+
self.loss = softmax_cross_entropy_with_logits(logits, self.y).mean()
82+
self.acc = torch.mean(correct.type(torch.float32))
83+
84+
return self.loss, self.acc
85+
86+
def train(self, save = 0, verbose = 0):
87+
model = fast_weights_model(STEP_NUM, ELEM_NUM, HIDDEN_NUM)
88+
model.train()
89+
batch_size = cfg.train.batch_size
90+
start_time = time.time()
91+
optimizer = torch.optim.Adam(model.paramters(), lr=cfg.train.model_lr)
92+
writer = SummaryWriter(logdir=os.path.join(cfg.logdir, cfg.exp_name), flush_secs=30)
93+
checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name))
94+
start_epoch = 0
95+
batch_idxs = 600
96+
for epoch in range(start_epoch, cfg.train.max_epochs):
97+
for idx in range(batch_idxs):
98+
gloabl_step = epoch * cfg.num_train + idx + 1
99+
bx, by = ar_data.train.next_batch(batch_size=cfg.batch_size)
100+
loss, acc = model(bx, by)
101+
optimizer.zero_grad()
102+
loss.backward()
103+
optimizer.step()
104+
writer.add_scalar('loss/loss', loss, gloabl_step)
105+
writer.add_scalar('acc/acc', acc, gloabl_step)
106+
if verbose > 0 and idx % verbose == 0:
107+
print('Epoch: [{:4d}] [{:4d}/{:4d}] time: {:.4f}, loss: {:.8f}, acc: {:.2f}'.format(
108+
epoch, idx, batch_idxs, time.time() - start_time, loss, acc
109+
))
110+
checkpointer.save(model, optimizer, epoch+1)
111+
112+
113+
if __name__ == "__main__":
114+
train(verbose = 10)
115+
116+
117+
118+
119+

retrieval.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import numpy as np
2+
import collections
3+
try:
4+
import cPickle as pickle
5+
except ImportError:
6+
import pickle
7+
8+
9+
Datasets = collections.namedtuple('Datasets', ['train', 'val', 'test'])
10+
11+
12+
class Dataset(object):
13+
def __init__(self, x, y):
14+
self._x = x
15+
self._y = y
16+
self._epoch_completed = 0
17+
self._index_in_epoch = 0
18+
self._num_examples = self.x.shape[0]
19+
self.perm = np.random.permutation(np.arange(self._num_examples))
20+
21+
@property
22+
def x(self):
23+
return self._x
24+
25+
@property
26+
def y(self):
27+
return self._y
28+
29+
@property
30+
def num_examples(self):
31+
return self._num_examples
32+
33+
def next_batch(self, batch_size):
34+
assert batch_size <= self._num_examples
35+
start = self._index_in_epoch
36+
self._index_in_epoch += batch_size
37+
if self._index_in_epoch >= self.num_examples:
38+
self._epoch_completed += 1
39+
np.random.shuffle(self.perm)
40+
start = 0
41+
self._index_in_epoch = batch_size
42+
end = self._index_in_epoch
43+
return self._x[self.perm[start:end]], self._y[self.perm[start:end]]
44+
45+
46+
def read_data(data_path='associative-retrieval.pkl'):
47+
with open(data_path, 'rb') as f:
48+
d = pickle.load(f)
49+
x_train = d['x_train']
50+
x_val = d['x_val']
51+
x_test = d['x_test']
52+
y_train = d['y_train']
53+
y_val = d['y_val']
54+
y_test = d['y_test']
55+
train = Dataset(x_train, y_train)
56+
test = Dataset(x_test, y_test)
57+
val = Dataset(x_val, y_val)
58+
return Datasets(train=train, val=val, test=test)

util.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from collections import defaultdict, deque
2+
import pickle
3+
from attrdict import AttrDict
4+
import os
5+
import numpy as np
6+
import torch
7+
from torch import nn
8+
from torch import optim
9+
from tensorboardX import SummaryWriter
10+
11+
class Checkpointer:
12+
def __init__(self, path, max_num=3):
13+
self.max_num = max_num
14+
self.path = path
15+
if not os.path.exists(path):
16+
os.makedirs(path)
17+
self.listfile = os.path.join(path, 'model_list.pkl')
18+
if not os.path.exists(self.listfile):
19+
with open(self.listfile, 'wb') as f:
20+
model_list = []
21+
pickle.dump(model_list, f)
22+
23+
24+
def save(self, model, optimizer, epoch):
25+
checkpoint = {
26+
'model': model.state_dict(),
27+
'optimizer': optimizer.state_dict(),
28+
'epoch': epoch
29+
}
30+
filename = os.path.join(self.path, 'model_{:05}.pth'.format(epoch))
31+
32+
with open(self.listfile, 'rb+') as f:
33+
model_list = pickle.load(f)
34+
if len(model_list) >= self.max_num:
35+
if os.path.exists(model_list[0]):
36+
os.remove(model_list[0])
37+
del model_list[0]
38+
model_list.append(filename)
39+
with open(self.listfile, 'rb+') as f:
40+
pickle.dump(model_list, f)
41+
42+
with open(filename, 'wb') as f:
43+
torch.save(checkpoint, f)
44+
45+
def load(self, model, optimizer):
46+
"""
47+
Return starting epoch
48+
"""
49+
with open(self.listfile, 'rb') as f:
50+
model_list = pickle.load(f)
51+
if len(model_list) == 0:
52+
print('No checkpoint found. Starting from scratch')
53+
return 0
54+
else:
55+
checkpoint = torch.load(model_list[-1])
56+
model.load_state_dict(checkpoint['model'])
57+
optimizer.load_state_dict(checkpoint['optimizer'])
58+
print('Load checkpoint from {}.'.format(model_list[-1]))
59+
return checkpoint['epoch']

0 commit comments

Comments
 (0)