-
Notifications
You must be signed in to change notification settings - Fork 68
Description
Hi,
First of all, I am really not that familiar with Jax. My conda environment was built with distributed yml file and thus got the Jax 0.3.24 as shown below.
jax 0.3.24 pypi_0 pypi
jaxlib 0.3.24 pypi_0 pypi
However, when running the af2_interface_metrics with the silent files, I am getting the following error. Any thoughts on this? I am getting the same error when using both AF 2.3.1 and AF 2.2.4 versions. Also, af2_metrics.py works without any issue.
Traceback (most recent call last):
File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 597, in
predict_structure(tag_buffer, feature_dict_dict, binderlen_dict, initial_guess_dict, sfd_out, scorefilename)
File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 431, in predict_structure
prediction_result = jax.vmap(model_runner.apply, in_axes=(None,None,0,0))(model_runner.params,
File "/software/RFDesign2/envs/.conda/envs/rfdesign2/lib/python3.9/site-packages/haiku/_src/transform.py", line 128, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/software/RFDesign2/envs/.conda/envs/rfdesign2/lib/python3.9/site-packages/haiku/_src/transform.py", line 357, in apply_fn
out = f(*args, **kwargs)
TypeError: _forward_fn() takes 1 positional argument but 2 were given
Appreciate it a lot.
Thanks!