diff --git a/trl/environment/__init__.py b/trl/environment/__init__.py index 0ab3e0d628d..4f3bfa6ecb8 100644 --- a/trl/environment/__init__.py +++ b/trl/environment/__init__.py @@ -19,10 +19,12 @@ _import_structure = { "base_environment": ["TextEnvironment", "TextHistory"], + "env_protocol": ["Environment"], } if TYPE_CHECKING: from .base_environment import TextEnvironment, TextHistory + from .env_protocol import Environment else: import sys diff --git a/trl/environment/env_protocol.py b/trl/environment/env_protocol.py new file mode 100644 index 00000000000..4897c587be2 --- /dev/null +++ b/trl/environment/env_protocol.py @@ -0,0 +1,12 @@ +from typing import Any, List, Protocol + + +class Environment(Protocol): + """ + A protocol describing the minimal interface needed for integration + with the trainer. Your environment can run any multi-step logic, + but must ultimately return token sequences akin to selecting token_ids from + vllm.LLM's generate() output. https://docs.vllm.ai/en/stable/api/offline_inference/llm.html + """ + + def generate(self, vllm_inputs, processing_class, vlm, sampling_params) -> List[Any]: ... diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 9d3ceeb6cf2..6730793810d 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -15,6 +15,7 @@ import textwrap import warnings from collections import defaultdict +from copy import deepcopy from typing import Any, Callable, Optional, Sized, Union from unittest.mock import patch @@ -30,7 +31,6 @@ from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, - GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, Qwen2_5_VLForConditionalGeneration, @@ -43,6 +43,7 @@ from transformers.utils import is_peft_available from ..data_utils import apply_chat_template, is_conversational +from ..environment.env_protocol import Environment from ..import_utils import is_vllm_available from ..models import ( create_reference_model, @@ -218,7 +219,7 @@ def __init__( model: PreTrainedModel, reward_funcs: Union[RewardFunc, list[RewardFunc]], processing_class: PreTrainedTokenizerBase, - tokenize_and_inject_images: Callable, + env: Environment, args: GRPOConfig = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, @@ -228,9 +229,6 @@ def __init__( peft_config: Optional["PeftConfig"] = None, shuffle_dataset: bool = True, ): - # Add shuffle_dataset to instance variables - self.shuffle_dataset = shuffle_dataset - # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path @@ -327,8 +325,6 @@ def __init__( reward_processing_classes[i] = reward_processing_class self.reward_processing_classes = reward_processing_classes - self.tokenize_and_inject_images = tokenize_and_inject_images - # Data collator def data_collator(features): # No data collation is needed in GRPO return features @@ -338,7 +334,7 @@ def data_collator(features): # No data collation is needed in GRPO self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.use_vllm = args.use_vllm - print(f"use_vllm: {self.use_vllm}") + self.shuffle_dataset = shuffle_dataset self.beta = args.beta @@ -457,12 +453,9 @@ def data_collator(features): # No data collation is needed in GRPO # synchronize all processes after vLLM has been fully initialized. self.accelerator.wait_for_everyone() else: - self.generation_config = GenerationConfig( - max_new_tokens=self.max_completion_length, - do_sample=True, - temperature=args.temperature, - pad_token_id=processing_class.tokenizer.pad_token_id, - ) + raise ValueError("use_vllm must be True") + + self.env = env # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set @@ -575,12 +568,35 @@ def _move_model_to_vllm(self): def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device - prompt_inputs, vllm_inputs, prompts_text, prompts = self.tokenize_and_inject_images( + if not self.env: + raise ValueError("No environment provided. Only supporting envs now. ") + + # TODO: This is a hack that we should probably fix. + # without this, each gpu receives different inputs, screwing up the advantage computation. + # Simple synchronization of inputs across processes + if self.accelerator.num_processes > 1: + # Make sure all processes have a non-None value to gather + # Use an empty list for non-main processes + local_inputs = inputs if self.accelerator.process_index == 0 else [] + + # Gather from all processes using torch.distributed.gather_object + all_inputs = gather_object(local_inputs) + + # each process takes the inputs from process 0 as its inputs + inputs = deepcopy(all_inputs) + + self.accelerator.wait_for_everyone() + + # conversations: list of conversations + # prompts_text: list of prompts as strings + # prompt_inputs: tokenized data (with image tokens injected) that we will use to compute log probs on the base model. + # env_inputs: data in the format our env/vllm expects + conversations, prompts_text, prompt_inputs, env_inputs = self.env.prepare_data( inputs=inputs, processing_class=self.processing_class ) + # unpack prompt_inputs prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask, pixel_values, image_grid_thw = ( prompt_inputs["input_ids"], prompt_inputs["attention_mask"], @@ -589,70 +605,51 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s ) if self.max_prompt_length is not None: - if self.use_vllm: - raise ValueError( - "max_prompt_length is not supported when using vLLM. Please set it to None if vLLM is used. This is because we don't control tokenization when using vLLM." - ) - - prompt_ids = prompt_ids[:, -self.max_prompt_length :] - prompt_mask = prompt_mask[:, -self.max_prompt_length :] + raise ValueError("max_prompt_length is not supported.") - # Generate completions using either vLLM or regular generation + # Generate completions using vLLM if self.use_vllm: # First, have main process load weights if needed if self.state.global_step != self._last_loaded_step: self._move_model_to_vllm() self._last_loaded_step = self.state.global_step - # Generate completions using vLLM: gather all prompt inputs and use them in a single call in the main process - all_vllm_inputs = gather_object(vllm_inputs) - all_prompts_text = gather_object(prompts_text) + all_env_inputs = gather_object(env_inputs) + all_conversations = gather_object(conversations) if self.accelerator.is_main_process: - outputs = self.vlm.generate( - all_vllm_inputs, - sampling_params=self.sampling_params, - use_tqdm=False, - ) - completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] + if self.env is None: + raise ValueError("No environment provided. Only supporting envs now.") + else: + completion_ids = self.env.generate( + conversations=all_conversations, + vlm_inputs=all_env_inputs, + vlm=self.vlm, + sampling_params=self.sampling_params, + ) + else: - completion_ids = [None] * len(all_prompts_text) + completion_ids = [None] * len(all_env_inputs) # Broadcast the completions from the main process to all processes, ensuring each process receives its # corresponding slice. completion_ids = broadcast_object_list(completion_ids, from_process=0) process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), ) completion_ids = completion_ids[process_slice] - # Pad the completions, and concatenate them with the prompts + eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device) + + # Pad completion_ids to uniform length, mask from last output token (EOS) completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.processing_class.tokenizer.pad_token_id) + sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) else: - # Regular generation path - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - prompt_completion_ids = unwrapped_model.generate( - prompt_ids, - attention_mask=prompt_mask, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - generation_config=self.generation_config, - ) - - # Compute prompt length and extract completion ids - prompt_length = prompt_ids.size(1) - prompt_ids = prompt_completion_ids[:, :prompt_length] - completion_ids = prompt_completion_ids[:, prompt_length:] - - # Mask everything after the first EOS token - is_eos = completion_ids == self.processing_class.tokenizer.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + raise ValueError("Attempted to generate with HF. Only supporting vllm now.") # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) @@ -684,7 +681,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] - for prompt, completion in zip(prompts, completions_text): + for prompt, completion in zip(conversations, completions_text): bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" if isinstance(bootstrap, list): if len(bootstrap) > 1: @@ -695,16 +692,16 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s else: completions = completions_text - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + rewards_per_func = torch.zeros(len(conversations), len(self.reward_funcs), device=device) for i, (reward_func, reward_processing_class) in enumerate( zip(self.reward_funcs, self.reward_processing_classes) ): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models if is_conversational(inputs[0]): - messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + messages = [{"messages": p + c} for p, c in zip(conversations, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: - texts = [p + c for p, c in zip(prompts, completions)] + texts = [p + c for p, c in zip(conversations, completions)] reward_inputs = reward_processing_class( texts, return_tensors="pt", @@ -720,13 +717,29 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] reward_kwargs = {key: [example[key] for example in inputs] for key in keys} reward_kwargs["prompts_text"] = prompts_text - output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) + output_reward_func = reward_func(prompts=conversations, completions=completions, **reward_kwargs) rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the # completions may be distributed across processes rewards_per_func = gather(rewards_per_func) + # # DEBUG: Verify prompt consistency across completions in each group + # TODO: remove this probably? + # if self.accelerator.is_main_process: + # all_prompts = gather_object(prompts_text) + + # if not len(all_prompts) == self.num_generations: + # raise ValueError( + # f"We should have one prompt per generation, but we have {len(all_prompts)} prompts and {self.num_generations} generations" + # ) + # if not len(set(all_prompts)) == 1: + # raise ValueError(f"All prompts should be the same. {all_prompts=}") + # print("PASSED PROMPT CONSISTENCY CHECK") + + # # Add synchronization point to prevent processes from getting out of sync + # self.accelerator.wait_for_everyone() + # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) @@ -741,8 +754,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # Slice to keep only the local part of the data process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), + self.accelerator.process_index * len(conversations), + (self.accelerator.process_index + 1) * len(conversations), ) advantages = advantages[process_slice]