11from __future__ import print_function
22
3+ import os
34import torch
45import numpy as np
56import torch .nn as nn
@@ -25,42 +26,46 @@ class fast_weights_model(nn.Module):
2526 """docstring for fast_weights_model"""
2627 def __init__ (self , batch_size , step_num , elem_num , hidden_num ):
2728 super (fast_weights_model , self ).__init__ ()
29+ self .batch_size = batch_size
2830 self .x = Variable (torch .randn (batch_size , step_num , elem_num ).type (torch .float32 ))
2931 self .y = Variable (torch .randn (batch_size , elem_num ).type (torch .float32 ))
30- self .l = torch .tensor ([0.9 ], dtype = torch .float32 )
31- self .e = torch .tensor ([0.5 ], dtype = torch .float32 )
32+ self .l = nn . Parameter ( torch .tensor ([0.9 ], dtype = torch .float32 ) )
33+ self .e = nn . Parameter ( torch .tensor ([0.5 ], dtype = torch .float32 ) )
3234
33- self .w1 = Variable (torch .empty (elem_num , 50 ).uniform_ (- np .sqrt (0.02 ), np .sqrt (0.02 )))
34- self .b1 = Variable (torch .zeros ([1 , 50 ]).type (torch .float32 ))
35- self .w2 = Variable (torch .empty (500 , 100 ).uniform_ (- np .sqrt (0.01 ), np .sqrt (0.01 )))
36- self .b2 = Variable (torch .zeros ([1 , 100 ]).type (torch .float32 ))
37- self .w3 = Variable (torch .empty (hidden_num , 100 ).uniform_ (- np .sqrt (0.01 ), np .sqrt (0.01 )))
38- self .b3 = Variable (torch .zeros ([1 , 100 ]).type (torch .float32 ))
39- self .w4 = Variable (torch .empty (100 , elem_num ).uniform_ (- np .sqrt (1.0 / elem_num ), np .sqrt (1.0 / elem_num )))
40- self .b4 = Variable (torch .zeros ([1 , elem_num ]).type (torch .float32 ))
35+ self .w1 = nn . Parameter (torch .empty (elem_num , 50 ).uniform_ (- np .sqrt (0.02 ), np .sqrt (0.02 )), requires_grad = True )
36+ self .b1 = nn . Parameter (torch .zeros ([1 , 50 ]).type (torch .float32 ), requires_grad = True )
37+ self .w2 = nn . Parameter (torch .empty (50 , 100 ).uniform_ (- np .sqrt (0.01 ), np .sqrt (0.01 )), requires_grad = True )
38+ self .b2 = nn . Parameter (torch .zeros ([1 , 100 ]).type (torch .float32 ), requires_grad = True )
39+ self .w3 = nn . Parameter (torch .empty (hidden_num , 100 ).uniform_ (- np .sqrt (0.01 ), np .sqrt (0.01 )), requires_grad = True )
40+ self .b3 = nn . Parameter (torch .zeros ([1 , 100 ]).type (torch .float32 ), requires_grad = True )
41+ self .w4 = nn . Parameter (torch .empty (100 , elem_num ).uniform_ (- np .sqrt (1.0 / elem_num ), np .sqrt (1.0 / elem_num )), requires_grad = True )
42+ self .b4 = nn . Parameter (torch .zeros ([1 , elem_num ]).type (torch .float32 ), requires_grad = True )
4143
42- self .w = Variable (torch .tensor (0.05 * np .identity (hidden_num )).type (torch .float32 ))
44+ self .w = nn . Parameter (torch .tensor (0.05 * np .identity (hidden_num )).type (torch .float32 ), requires_grad = True )
4345
44- self .c = Variable (torch .empty (100 , hidden_num ).uniform_ (- np .sqrt (hidden_num ), np .sqrt (hidden_num )))
46+ self .c = nn . Parameter (torch .empty (100 , hidden_num ).uniform_ (- np .sqrt (hidden_num ), np .sqrt (hidden_num )), requires_grad = True )
4547
46- self .g = Variable (torch .ones ([1 , hidden_num ]).type (torch .float32 ))
47- self .b = Variable (torch .ones ([1 , hidden_num ]).type (torch .float32 ))
48+ self .g = nn . Parameter (torch .ones ([1 , hidden_num ]).type (torch .float32 ), requires_grad = True )
49+ self .b = nn . Parameter (torch .ones ([1 , hidden_num ]).type (torch .float32 ), requires_grad = True )
4850
49- def forward (self , bx , by )
50- self .x = bx
51- self .y = by
52- a = torch .zeros ([batch_size , hidden_num , hidden_num ]).type (torch .float32 )
53- h = torch .zeros ([batch_size , hidden_num ]).type (torch .float32 )
51+ def forward (self , bx , by ):
52+ self .x = torch .tensor (bx )
53+ self .y = torch .tensor (by )
54+ #print(bx.size)
55+ #print(by.size)
56+ a = torch .zeros ([self .batch_size , HIDDEN_NUM , HIDDEN_NUM ]).type (torch .float32 )
57+ h = torch .zeros ([self .batch_size , HIDDEN_NUM ]).type (torch .float32 )
5458
5559 la = []
5660
57- for i in range (0 , step_num ):
58- s1 = torch .relu (torch .matmul (self .x [:, t , :], self .w1 ) + self .b1 )
61+ for i in range (0 , STEP_NUM ):
62+ s1 = torch .relu (torch .matmul (self .x [:, i , :], self .w1 ) + self .b1 )
63+ #print(s1.shape, self.w2.shape)
5964 z = torch .relu (torch .matmul (s1 , self .w2 ) + self .b2 )
6065
6166 h = torch .relu (torch .matmul (h , self .w ) + torch .matmul (z , self .c ))
6267
63- hs = torch .reshape (h , [batch_size , 1 , hidden_num ])
68+ hs = torch .reshape (h , [self . batch_size , 1 , HIDDEN_NUM ])
6469
6570 hh = hs
6671
@@ -75,7 +80,7 @@ def forward(self, bx, by)
7580 sig = torch .sqrt (torch .mean (torch .pow ((hs - mu ), 2 ), 0 ))
7681 hs = torch .relu (torch .div (torch .mul (self .g , (hs - mu )), sig ) + self .b )
7782
78- h = torch .reshape (hs , [batch_size , hidden_num ])
83+ h = torch .reshape (hs , [self . batch_size , HIDDEN_NUM ])
7984
8085 h = torch .relu (torch .matmul (h , self .w3 ) + self .b3 )
8186 logits = torch .matmul (h , self .w4 ) + self .b4
@@ -85,12 +90,12 @@ def forward(self, bx, by)
8590
8691 return self .loss , self .acc
8792
88- def train (self , save = 0 , verbose = 0 ):
89- model = fast_weights_model (STEP_NUM , ELEM_NUM , HIDDEN_NUM )
90- model .train ()
93+ def train (save = 0 , verbose = 0 ):
9194 batch_size = cfg .train .batch_size
95+ model = fast_weights_model (batch_size , STEP_NUM , ELEM_NUM , HIDDEN_NUM )
9296 start_time = time .time ()
93- optimizer = torch .optim .Adam (model .paramters (), lr = cfg .train .model_lr )
97+ optimizer = torch .optim .Adam (model .parameters (), lr = cfg .train .model_lr )
98+ model .train ()
9499 writer = SummaryWriter (logdir = os .path .join (cfg .logdir , cfg .exp_name ), flush_secs = 30 )
95100 checkpointer = Checkpointer (os .path .join (cfg .checkpointdir , cfg .exp_name ))
96101 start_epoch = 0
@@ -99,7 +104,8 @@ def train(self, save = 0, verbose = 0):
99104 for epoch in range (start_epoch , cfg .train .max_epochs ):
100105 for idx in range (batch_idxs ):
101106 gloabl_step = epoch * cfg .num_train + idx + 1
102- bx , by = ar_data .train .next_batch (batch_size = cfg .batch_size )
107+ #print(ar_data.train._x)
108+ bx , by = ar_data .train .next_batch (batch_size = 100 )
103109 loss , acc = model (bx , by )
104110 optimizer .zero_grad ()
105111 loss .backward ()
0 commit comments