Skip to content

Commit b7731c4

Browse files
committed
debug forward
1 parent ddaf46b commit b7731c4

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def __init__(self, batch_size, step_num, elem_num, hidden_num):
2727
super(fast_weights_model, self).__init__()
2828
self.x = Variable(torch.randn(batch_size, step_num, elem_num).type(torch.float32))
2929
self.y = Variable(torch.randn(batch_size, elem_num).type(torch.float32))
30-
self.l = torch.zeros(1, dtype=torch.float32)
31-
self.e = torch.zeros(1, dtype=torch.float32)
30+
self.l = torch.tensor([0.9], dtype=torch.float32)
31+
self.e = torch.tensor([0.5], dtype=torch.float32)
3232

3333
self.w1 = Variable(torch.empty(elem_num, 50).uniform_(-np.sqrt(0.02), np.sqrt(0.02)))
3434
self.b1 = Variable(torch.zeros([1, 50]).type(torch.float32))
@@ -47,6 +47,8 @@ def __init__(self, batch_size, step_num, elem_num, hidden_num):
4747
self.b = Variable(torch.ones([1, hidden_num]).type(torch.float32))
4848

4949
def forward(self, bx, by)
50+
self.x = bx
51+
self.y = by
5052
a = torch.zeros([batch_size, hidden_num, hidden_num]).type(torch.float32)
5153
h = torch.zeros([batch_size, hidden_num]).type(torch.float32)
5254

@@ -92,6 +94,7 @@ def train(self, save = 0, verbose = 0):
9294
writer = SummaryWriter(logdir=os.path.join(cfg.logdir, cfg.exp_name), flush_secs=30)
9395
checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name))
9496
start_epoch = 0
97+
start_epoch = checkpointer.load(model, optimizer)
9598
batch_idxs = 600
9699
for epoch in range(start_epoch, cfg.train.max_epochs):
97100
for idx in range(batch_idxs):

0 commit comments

Comments
 (0)