diff --git a/gemma/gm/data/_functional.py b/gemma/gm/data/_functional.py index b9e4fef5..d1bbfc47 100644 --- a/gemma/gm/data/_functional.py +++ b/gemma/gm/data/_functional.py @@ -14,6 +14,8 @@ """Functional version of the `gm.data` transforms.""" +import warnings + from etils import enp import flax import jax @@ -134,6 +136,17 @@ def make_seq2seq_fields( Returns: The input, target and mask, all of length `prompt_len + response_len - 1`. """ + # Handle empty prompt: issue warning and use a default BOS token (2) as fallback. + if len(prompt) == 0: + warnings.warn( + 'Empty prompt provided. Using default BOS token (2) as prompt. ' + 'Empty prompts are not recommended for sequence-to-sequence training.', + UserWarning, + stacklevel=2, + ) + # Use BOS token (2) as a default prompt token for seq2seq compatibility. + prompt = np.array([2], dtype=np.int32) + # Concatenate the prompt and response tokens. sequence = np.concatenate([prompt, response])