fix support for multi-dim observations#22
Conversation
|
Thanks for pointing out the bug! |
|
Hmm, no, the code is working correctly in my project. I didn't test it with the examples tho. Will do it and fix the errors. |
|
Ok, so 2 more dimensions are present in the >>> obs_buffer.shape
(1, 64, 1024, 28, 28, 1)Their meaning is: While >>> running_mean.shape
(784,)My code handles the last dimensions, they are expected. But the first 2 are causing the error.
|
|
hey @lerrytang! how any update on this issue? |
|
for example, take a look at the https://github.com/google/brax/blob/main/brax/training/normalization.py They have a |
Hey! I found a bug in the observations normalization code. The bug occurs when the observations are not a flat array, but a multi-dim array. This happens because the obs_normalizer params are stored as a flat array. The code fails in this case. Here is the fix for this bug.