diff --git a/trl/experimental/online_dpo/online_dpo_config.py b/trl/experimental/online_dpo/online_dpo_config.py index bb728e37b92..2e8986afbae 100644 --- a/trl/experimental/online_dpo/online_dpo_config.py +++ b/trl/experimental/online_dpo/online_dpo_config.py @@ -121,6 +121,9 @@ class may differ from those in [`~transformers.TrainingArguments`]. vllm_server_timeout (`float`, *optional*, defaults to `240.0`): Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the timeout, a `ConnectionError` is raised. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port + is occupied, there is no need to change it. > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) @@ -347,6 +350,13 @@ class may differ from those in [`~transformers.TrainingArguments`]. "after the timeout, a `ConnectionError` is raised.", }, ) + vllm_group_port: int = field( + default=51216, + metadata={ + "help": "Port number for the weight update group. This is used to communicate with the vLLM server. " + "Unless the port is occupied, there is no need to change it.", + }, + ) vllm_tensor_parallel_size: int = field( default=1, metadata={ diff --git a/trl/experimental/online_dpo/online_dpo_trainer.py b/trl/experimental/online_dpo/online_dpo_trainer.py index ebc984744d9..b8d2311a938 100644 --- a/trl/experimental/online_dpo/online_dpo_trainer.py +++ b/trl/experimental/online_dpo/online_dpo_trainer.py @@ -454,7 +454,9 @@ def __init__( base_url = args.vllm_server_base_url else: base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + self.vllm_client = VLLMClient( + base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout + ) # Determine device type (supports cuda, xpu, etc.) accelerator_type = torch.accelerator.current_accelerator().type diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 2fb833309d1..5df991ddf11 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -137,6 +137,9 @@ class GRPOConfig(TrainingArguments): vllm_server_timeout (`float`, *optional*, defaults to `240.0`): Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the timeout, a `ConnectionError` is raised. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port + is occupied, there is no need to change it. > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) @@ -551,6 +554,13 @@ class GRPOConfig(TrainingArguments): "after the timeout, a `ConnectionError` is raised." }, ) + vllm_group_port: int = field( + default=51216, + metadata={ + "help": "Port number for the weight update group. This is used to communicate with the vLLM server. " + "Unless the port is occupied, there is no need to change it.", + }, + ) # Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) vllm_gpu_memory_utilization: float = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 08e2f49888d..36dc0ce443a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -620,7 +620,9 @@ def cast_outputs_to_original_dtype(module, args, output): base_url = args.vllm_server_base_url else: base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + self.vllm_client = VLLMClient( + base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout + ) self.vllm_client.init_communicator(device=torch.cuda.current_device()) elif self.vllm_mode == "colocate": diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index 32f768f550f..8f487074ce3 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -134,6 +134,9 @@ class RLOOConfig(TrainingArguments): vllm_server_timeout (`float`, *optional*, defaults to `240.0`): Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the timeout, a `ConnectionError` is raised. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port + is occupied, there is no need to change it. > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) @@ -449,6 +452,13 @@ class RLOOConfig(TrainingArguments): "after the timeout, a `ConnectionError` is raised." }, ) + vllm_group_port: int = field( + default=51216, + metadata={ + "help": "Port number for the weight update group. This is used to communicate with the vLLM server. " + "Unless the port is occupied, there is no need to change it.", + }, + ) # Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) vllm_gpu_memory_utilization: float = field( diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 34858c22917..a8b883260c7 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -476,7 +476,9 @@ def __init__( base_url = args.vllm_server_base_url else: base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + self.vllm_client = VLLMClient( + base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout + ) self.vllm_client.init_communicator(device=torch.cuda.current_device()) elif self.vllm_mode == "colocate":