diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py index 9180d493..512a37a3 100644 --- a/gpt_oss/torch/model.py +++ b/gpt_oss/torch/model.py @@ -2,6 +2,7 @@ import math import os from dataclasses import dataclass +from typing import tuple import torch import torch.distributed as dist @@ -82,7 +83,7 @@ def __init__( self.ntk_beta = ntk_beta self.device = device - def _compute_concentration_and_inv_freq(self) -> torch.Tensor: + def _compute_concentration_and_inv_freq(self) -> tuple[float, torch.Tensor]: """See YaRN paper: https://arxiv.org/abs/2309.00071""" freq = self.base ** ( torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device)