你好,想请教个问题。
我的运行下来,报错在
y_reshaped = tf.reshape(y_one_hot, self.logits.get_shape())
因为y_one_hot和self.logits的总元素数量不同,所以不能reshape。
我推算了一下:
-
inputs的shape是(num_seqs, num_steps),经过tf.one_hot以后,lstm_inputs的shape变成(num_seqs, num_steps, num_classes)
-
我用的是cell是一层的lstm,lstm_inputs经过tf.nn.dynamic(cell, lstm_inputs, initial_state=self.initial_state)后,lstm_outputs的shape是(num_seqs, num_steps, lstm_size)
-
lstm_outputs经过tf.concat(lstm_outputs, 1)以后,shape没有任何变化,再经过一些列运算后,shape就会有问题。
所以想问一下tf.concat(lstm_outputs, 1)这一步是做什么的?
感谢~