Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions trl/experimental/online_dpo/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class may differ from those in [`~transformers.TrainingArguments`].
vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
timeout, a `ConnectionError` is raised.
vllm_group_port (`int`, *optional*, defaults to `51216`):
Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port
is occupied, there is no need to change it.

> Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)

Expand Down Expand Up @@ -347,6 +350,13 @@ class may differ from those in [`~transformers.TrainingArguments`].
"after the timeout, a `ConnectionError` is raised.",
},
)
vllm_group_port: int = field(
default=51216,
metadata={
"help": "Port number for the weight update group. This is used to communicate with the vLLM server. "
"Unless the port is occupied, there is no need to change it.",
},
)
vllm_tensor_parallel_size: int = field(
default=1,
metadata={
Expand Down
4 changes: 3 additions & 1 deletion trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,9 @@ def __init__(
base_url = args.vllm_server_base_url
else:
base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
self.vllm_client = VLLMClient(
base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout
)

# Determine device type (supports cuda, xpu, etc.)
accelerator_type = torch.accelerator.current_accelerator().type
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class GRPOConfig(TrainingArguments):
vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
timeout, a `ConnectionError` is raised.
vllm_group_port (`int`, *optional*, defaults to `51216`):
Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port
is occupied, there is no need to change it.

> Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)

Expand Down Expand Up @@ -551,6 +554,13 @@ class GRPOConfig(TrainingArguments):
"after the timeout, a `ConnectionError` is raised."
},
)
vllm_group_port: int = field(
default=51216,
metadata={
"help": "Port number for the weight update group. This is used to communicate with the vLLM server. "
"Unless the port is occupied, there is no need to change it.",
},
)

# Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
vllm_gpu_memory_utilization: float = field(
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,9 @@ def cast_outputs_to_original_dtype(module, args, output):
base_url = args.vllm_server_base_url
else:
base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
self.vllm_client = VLLMClient(
base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout
)
self.vllm_client.init_communicator(device=torch.cuda.current_device())

elif self.vllm_mode == "colocate":
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class RLOOConfig(TrainingArguments):
vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
timeout, a `ConnectionError` is raised.
vllm_group_port (`int`, *optional*, defaults to `51216`):
Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port
is occupied, there is no need to change it.

> Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)

Expand Down Expand Up @@ -449,6 +452,13 @@ class RLOOConfig(TrainingArguments):
"after the timeout, a `ConnectionError` is raised."
},
)
vllm_group_port: int = field(
default=51216,
metadata={
"help": "Port number for the weight update group. This is used to communicate with the vLLM server. "
"Unless the port is occupied, there is no need to change it.",
},
)

# Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
vllm_gpu_memory_utilization: float = field(
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,9 @@ def __init__(
base_url = args.vllm_server_base_url
else:
base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
self.vllm_client = VLLMClient(
base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout
)
self.vllm_client.init_communicator(device=torch.cuda.current_device())

elif self.vllm_mode == "colocate":
Expand Down
Loading