diff --git a/interact.py b/interact.py index d368204..5e3cb06 100644 --- a/interact.py +++ b/interact.py @@ -66,7 +66,7 @@ def sample_sequence(personality, history, tokenizer, model, args, current_output input_ids = torch.tensor(instance["input_ids"], device=args.device).unsqueeze(0) token_type_ids = torch.tensor(instance["token_type_ids"], device=args.device).unsqueeze(0) - logits = model(input_ids, token_type_ids=token_type_ids) + logits = model(input_ids, token_type_ids=token_type_ids).logits if isinstance(logits, tuple): # for gpt2 and maybe others logits = logits[0] logits = logits[0, -1, :] / args.temperature