@@ -27,8 +27,8 @@ def __init__(self, batch_size, step_num, elem_num, hidden_num):
2727 super (fast_weights_model , self ).__init__ ()
2828 self .x = Variable (torch .randn (batch_size , step_num , elem_num ).type (torch .float32 ))
2929 self .y = Variable (torch .randn (batch_size , elem_num ).type (torch .float32 ))
30- self .l = torch .zeros ( 1 , dtype = torch .float32 )
31- self .e = torch .zeros ( 1 , dtype = torch .float32 )
30+ self .l = torch .tensor ([ 0.9 ] , dtype = torch .float32 )
31+ self .e = torch .tensor ([ 0.5 ] , dtype = torch .float32 )
3232
3333 self .w1 = Variable (torch .empty (elem_num , 50 ).uniform_ (- np .sqrt (0.02 ), np .sqrt (0.02 )))
3434 self .b1 = Variable (torch .zeros ([1 , 50 ]).type (torch .float32 ))
@@ -47,6 +47,8 @@ def __init__(self, batch_size, step_num, elem_num, hidden_num):
4747 self .b = Variable (torch .ones ([1 , hidden_num ]).type (torch .float32 ))
4848
4949 def forward (self , bx , by )
50+ self .x = bx
51+ self .y = by
5052 a = torch .zeros ([batch_size , hidden_num , hidden_num ]).type (torch .float32 )
5153 h = torch .zeros ([batch_size , hidden_num ]).type (torch .float32 )
5254
@@ -92,6 +94,7 @@ def train(self, save = 0, verbose = 0):
9294 writer = SummaryWriter (logdir = os .path .join (cfg .logdir , cfg .exp_name ), flush_secs = 30 )
9395 checkpointer = Checkpointer (os .path .join (cfg .checkpointdir , cfg .exp_name ))
9496 start_epoch = 0
97+ start_epoch = checkpointer .load (model , optimizer )
9598 batch_idxs = 600
9699 for epoch in range (start_epoch , cfg .train .max_epochs ):
97100 for idx in range (batch_idxs ):
0 commit comments