When running the command
python unllama_token_clf.py conll2003 7b
I get the following:
RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and query.dtype: c10::BFloat16 instead.
I am running on an A100, with cuda 12.1, transformers 4.37.2, and torch 2.1.2.