From fdec98aa5a4188eccf18b0682b2200bdf89afaf6 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Mon, 17 Feb 2025 19:34:10 -0800 Subject: [PATCH 1/5] it is working on digit recog --- trl/environment/__init__.py | 2 + trl/environment/env_protocol.py | 12 ++++ trl/trainer/qwen_grpo_trainer.py | 96 ++++++++++++++++---------------- 3 files changed, 61 insertions(+), 49 deletions(-) create mode 100644 trl/environment/env_protocol.py diff --git a/trl/environment/__init__.py b/trl/environment/__init__.py index 0ab3e0d628d..4f3bfa6ecb8 100644 --- a/trl/environment/__init__.py +++ b/trl/environment/__init__.py @@ -19,10 +19,12 @@ _import_structure = { "base_environment": ["TextEnvironment", "TextHistory"], + "env_protocol": ["Environment"], } if TYPE_CHECKING: from .base_environment import TextEnvironment, TextHistory + from .env_protocol import Environment else: import sys diff --git a/trl/environment/env_protocol.py b/trl/environment/env_protocol.py new file mode 100644 index 00000000000..4897c587be2 --- /dev/null +++ b/trl/environment/env_protocol.py @@ -0,0 +1,12 @@ +from typing import Any, List, Protocol + + +class Environment(Protocol): + """ + A protocol describing the minimal interface needed for integration + with the trainer. Your environment can run any multi-step logic, + but must ultimately return token sequences akin to selecting token_ids from + vllm.LLM's generate() output. https://docs.vllm.ai/en/stable/api/offline_inference/llm.html + """ + + def generate(self, vllm_inputs, processing_class, vlm, sampling_params) -> List[Any]: ... diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 9d3ceeb6cf2..4ffd9394067 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -43,6 +43,7 @@ from transformers.utils import is_peft_available from ..data_utils import apply_chat_template, is_conversational +from ..environment.env_protocol import Environment from ..import_utils import is_vllm_available from ..models import ( create_reference_model, @@ -218,6 +219,7 @@ def __init__( model: PreTrainedModel, reward_funcs: Union[RewardFunc, list[RewardFunc]], processing_class: PreTrainedTokenizerBase, + # TODO: remove this function. tokenize_and_inject_images: Callable, args: GRPOConfig = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None, @@ -227,6 +229,7 @@ def __init__( optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, shuffle_dataset: bool = True, + env: Optional[Environment] = None, ): # Add shuffle_dataset to instance variables self.shuffle_dataset = shuffle_dataset @@ -464,6 +467,12 @@ def data_collator(features): # No data collation is needed in GRPO pad_token_id=processing_class.tokenizer.pad_token_id, ) + self.env = env + if self.env is not None: + if not self.use_vllm: + # env assumes using vllm for generation, so we raise an error if it's not the case + raise ValueError("env is not supported when use_vllm is False - vLLM is required for env") + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set # self.model_accepts_loss_kwargs to False to enable scaling. @@ -575,12 +584,19 @@ def _move_model_to_vllm(self): def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device - prompt_inputs, vllm_inputs, prompts_text, prompts = self.tokenize_and_inject_images( + if not self.env: + raise ValueError("No environment provided. Only supporting envs now. ") + + # conversations: list of conversations + # prompts_text: list of prompts as strings + # env_inputs: data in the format our env/vllm expects + # prompt_inputs: tokenized data (with image tokens injected) that we will use to compute log probs on the base model. + conversations, prompts_text, prompt_inputs, env_inputs = self.env.prepare_data( inputs=inputs, processing_class=self.processing_class ) + # unpack prompt_inputs prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask, pixel_values, image_grid_thw = ( prompt_inputs["input_ids"], prompt_inputs["attention_mask"], @@ -589,13 +605,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s ) if self.max_prompt_length is not None: - if self.use_vllm: - raise ValueError( - "max_prompt_length is not supported when using vLLM. Please set it to None if vLLM is used. This is because we don't control tokenization when using vLLM." - ) - - prompt_ids = prompt_ids[:, -self.max_prompt_length :] - prompt_mask = prompt_mask[:, -self.max_prompt_length :] + raise ValueError("max_prompt_length is not supported.") # Generate completions using either vLLM or regular generation if self.use_vllm: @@ -605,54 +615,42 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s self._last_loaded_step = self.state.global_step # Generate completions using vLLM: gather all prompt inputs and use them in a single call in the main process - all_vllm_inputs = gather_object(vllm_inputs) - all_prompts_text = gather_object(prompts_text) + all_env_inputs = gather_object(env_inputs) + all_conversations = gather_object(conversations) if self.accelerator.is_main_process: - outputs = self.vlm.generate( - all_vllm_inputs, - sampling_params=self.sampling_params, - use_tqdm=False, - ) - completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] + if self.env is None: + raise ValueError("No environment provided. Only supporting envs now. ") + else: + completion_ids = self.env.generate( + conversations=all_conversations, + vlm_inputs=all_env_inputs, + vlm=self.vlm, + sampling_params=self.sampling_params, + ) + else: - completion_ids = [None] * len(all_prompts_text) + completion_ids = [None] * len(all_env_inputs) # Broadcast the completions from the main process to all processes, ensuring each process receives its # corresponding slice. completion_ids = broadcast_object_list(completion_ids, from_process=0) process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), ) completion_ids = completion_ids[process_slice] - # Pad the completions, and concatenate them with the prompts + eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device) + + # Pad completion_ids to uniform length, mask from last output token (EOS) completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.processing_class.tokenizer.pad_token_id) + sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) else: - # Regular generation path - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - prompt_completion_ids = unwrapped_model.generate( - prompt_ids, - attention_mask=prompt_mask, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - generation_config=self.generation_config, - ) - - # Compute prompt length and extract completion ids - prompt_length = prompt_ids.size(1) - prompt_ids = prompt_completion_ids[:, :prompt_length] - completion_ids = prompt_completion_ids[:, prompt_length:] - - # Mask everything after the first EOS token - is_eos = completion_ids == self.processing_class.tokenizer.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + raise ValueError("Attempted to generate with HF. Only supporting vllm now.") # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) @@ -684,7 +682,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] - for prompt, completion in zip(prompts, completions_text): + for prompt, completion in zip(conversations, completions_text): bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" if isinstance(bootstrap, list): if len(bootstrap) > 1: @@ -695,16 +693,16 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s else: completions = completions_text - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + rewards_per_func = torch.zeros(len(conversations), len(self.reward_funcs), device=device) for i, (reward_func, reward_processing_class) in enumerate( zip(self.reward_funcs, self.reward_processing_classes) ): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models if is_conversational(inputs[0]): - messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + messages = [{"messages": p + c} for p, c in zip(conversations, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: - texts = [p + c for p, c in zip(prompts, completions)] + texts = [p + c for p, c in zip(conversations, completions)] reward_inputs = reward_processing_class( texts, return_tensors="pt", @@ -720,7 +718,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] reward_kwargs = {key: [example[key] for example in inputs] for key in keys} reward_kwargs["prompts_text"] = prompts_text - output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) + output_reward_func = reward_func(prompts=conversations, completions=completions, **reward_kwargs) rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the @@ -741,8 +739,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # Slice to keep only the local part of the data process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), + self.accelerator.process_index * len(conversations), + (self.accelerator.process_index + 1) * len(conversations), ) advantages = advantages[process_slice] From 70bb5eea8bac3f6b7126e600731a1b67e96232e5 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 18 Feb 2025 19:54:52 -0800 Subject: [PATCH 2/5] rm tokenize and inject --- trl/trainer/qwen_grpo_trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 4ffd9394067..56342bb12b2 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -219,8 +219,6 @@ def __init__( model: PreTrainedModel, reward_funcs: Union[RewardFunc, list[RewardFunc]], processing_class: PreTrainedTokenizerBase, - # TODO: remove this function. - tokenize_and_inject_images: Callable, args: GRPOConfig = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, @@ -330,8 +328,6 @@ def __init__( reward_processing_classes[i] = reward_processing_class self.reward_processing_classes = reward_processing_classes - self.tokenize_and_inject_images = tokenize_and_inject_images - # Data collator def data_collator(features): # No data collation is needed in GRPO return features From fe01d38a8260be7c086fb1e56967d249420ade1d Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 18 Feb 2025 20:19:21 -0800 Subject: [PATCH 3/5] some cleanup. --- trl/trainer/qwen_grpo_trainer.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 56342bb12b2..e02485b77c2 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -30,7 +30,6 @@ from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, - GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, Qwen2_5_VLForConditionalGeneration, @@ -219,6 +218,7 @@ def __init__( model: PreTrainedModel, reward_funcs: Union[RewardFunc, list[RewardFunc]], processing_class: PreTrainedTokenizerBase, + env: Environment, args: GRPOConfig = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, @@ -227,10 +227,8 @@ def __init__( optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, shuffle_dataset: bool = True, - env: Optional[Environment] = None, ): - # Add shuffle_dataset to instance variables - self.shuffle_dataset = shuffle_dataset + # Args if args is None: @@ -337,7 +335,7 @@ def data_collator(features): # No data collation is needed in GRPO self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.use_vllm = args.use_vllm - print(f"use_vllm: {self.use_vllm}") + self.shuffle_dataset = shuffle_dataset self.beta = args.beta @@ -456,18 +454,10 @@ def data_collator(features): # No data collation is needed in GRPO # synchronize all processes after vLLM has been fully initialized. self.accelerator.wait_for_everyone() else: - self.generation_config = GenerationConfig( - max_new_tokens=self.max_completion_length, - do_sample=True, - temperature=args.temperature, - pad_token_id=processing_class.tokenizer.pad_token_id, - ) + raise ValueError("use_vllm must be True") self.env = env - if self.env is not None: - if not self.use_vllm: - # env assumes using vllm for generation, so we raise an error if it's not the case - raise ValueError("env is not supported when use_vllm is False - vLLM is required for env") + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set @@ -585,8 +575,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # conversations: list of conversations # prompts_text: list of prompts as strings - # env_inputs: data in the format our env/vllm expects # prompt_inputs: tokenized data (with image tokens injected) that we will use to compute log probs on the base model. + # env_inputs: data in the format our env/vllm expects conversations, prompts_text, prompt_inputs, env_inputs = self.env.prepare_data( inputs=inputs, processing_class=self.processing_class ) @@ -603,20 +593,19 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if self.max_prompt_length is not None: raise ValueError("max_prompt_length is not supported.") - # Generate completions using either vLLM or regular generation + # Generate completions using vLLM if self.use_vllm: # First, have main process load weights if needed if self.state.global_step != self._last_loaded_step: self._move_model_to_vllm() self._last_loaded_step = self.state.global_step - # Generate completions using vLLM: gather all prompt inputs and use them in a single call in the main process all_env_inputs = gather_object(env_inputs) all_conversations = gather_object(conversations) if self.accelerator.is_main_process: if self.env is None: - raise ValueError("No environment provided. Only supporting envs now. ") + raise ValueError("No environment provided. Only supporting envs now.") else: completion_ids = self.env.generate( conversations=all_conversations, From 252bdb6b902c199dce4c20992809c0f1c4a165f9 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Mon, 3 Mar 2025 00:30:43 -0800 Subject: [PATCH 4/5] fix bug where different gpus got different data --- trl/trainer/qwen_grpo_trainer.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index e02485b77c2..7de7af04b18 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -15,6 +15,7 @@ import textwrap import warnings from collections import defaultdict +from copy import deepcopy from typing import Any, Callable, Optional, Sized, Union from unittest.mock import patch @@ -228,8 +229,6 @@ def __init__( peft_config: Optional["PeftConfig"] = None, shuffle_dataset: bool = True, ): - - # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path @@ -458,7 +457,6 @@ def data_collator(features): # No data collation is needed in GRPO self.env = env - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set # self.model_accepts_loss_kwargs to False to enable scaling. @@ -573,6 +571,22 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if not self.env: raise ValueError("No environment provided. Only supporting envs now. ") + # TODO: This is a hack that we should probably fix. + # without this, each gpu receives different inputs, screwing up the advantage computation. + # Simple synchronization of inputs across processes + if self.accelerator.num_processes > 1: + # Make sure all processes have a non-None value to gather + # Use an empty list for non-main processes + local_inputs = inputs if self.accelerator.process_index == 0 else [] + + self.accelerator.wait_for_everyone() + + # Gather from all processes using torch.distributed.gather_object + all_inputs = gather_object(local_inputs) + + # each process takes the inputs from process 0 as its inputs + inputs = deepcopy(all_inputs) + # conversations: list of conversations # prompts_text: list of prompts as strings # prompt_inputs: tokenized data (with image tokens injected) that we will use to compute log probs on the base model. @@ -710,6 +724,18 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # completions may be distributed across processes rewards_per_func = gather(rewards_per_func) + # # DEBUG: Verify prompt consistency across completions in each group + # TODO: remove this probably? + if self.accelerator.is_main_process: + all_prompts = gather_object(prompts_text) + + if not len(all_prompts) == self.num_generations: + raise ValueError( + f"We should have one prompt per generation, but we have {len(all_prompts)} prompts and {self.num_generations} generations" + ) + if not len(set(all_prompts)) == 1: + raise ValueError(f"All prompts should be the same. {all_prompts=}") + # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) From 546cab4fcd2f96c36116018f4828dad5654ba26c Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Mon, 3 Mar 2025 00:58:01 -0800 Subject: [PATCH 5/5] fix desync --- trl/trainer/qwen_grpo_trainer.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 7de7af04b18..6730793810d 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -579,14 +579,14 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # Use an empty list for non-main processes local_inputs = inputs if self.accelerator.process_index == 0 else [] - self.accelerator.wait_for_everyone() - # Gather from all processes using torch.distributed.gather_object all_inputs = gather_object(local_inputs) # each process takes the inputs from process 0 as its inputs inputs = deepcopy(all_inputs) + self.accelerator.wait_for_everyone() + # conversations: list of conversations # prompts_text: list of prompts as strings # prompt_inputs: tokenized data (with image tokens injected) that we will use to compute log probs on the base model. @@ -726,15 +726,19 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # # DEBUG: Verify prompt consistency across completions in each group # TODO: remove this probably? - if self.accelerator.is_main_process: - all_prompts = gather_object(prompts_text) - - if not len(all_prompts) == self.num_generations: - raise ValueError( - f"We should have one prompt per generation, but we have {len(all_prompts)} prompts and {self.num_generations} generations" - ) - if not len(set(all_prompts)) == 1: - raise ValueError(f"All prompts should be the same. {all_prompts=}") + # if self.accelerator.is_main_process: + # all_prompts = gather_object(prompts_text) + + # if not len(all_prompts) == self.num_generations: + # raise ValueError( + # f"We should have one prompt per generation, but we have {len(all_prompts)} prompts and {self.num_generations} generations" + # ) + # if not len(set(all_prompts)) == 1: + # raise ValueError(f"All prompts should be the same. {all_prompts=}") + # print("PASSED PROMPT CONSISTENCY CHECK") + + # # Add synchronization point to prevent processes from getting out of sync + # self.accelerator.wait_for_everyone() # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)