arbor-ai 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +12 -0
- arbor/server/api/routes/grpo.py +4 -1
- arbor/server/api/routes/inference.py +11 -16
- arbor/server/services/grpo_manager.py +179 -98
- arbor/server/services/inference/__init__.py +0 -0
- arbor/server/services/inference/vllm_client.py +445 -0
- arbor/server/services/inference/vllm_serve.py +2335 -0
- arbor/server/services/inference_manager.py +149 -219
- arbor/server/services/scripts/dpo_training.py +0 -0
- arbor/server/services/scripts/grpo_training.py +157 -53
- arbor/server/services/scripts/sft_training.py +109 -0
- arbor/server/services/scripts/utils/__init__.py +0 -0
- arbor/server/services/scripts/utils/arg_parser.py +31 -0
- arbor/server/services/scripts/utils/dataset.py +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/METADATA +4 -5
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/RECORD +20 -12
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/WHEEL +1 -1
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,445 @@
|
|
1
|
+
# adapted from trl/extras/vllm_client.py (huggingface/trl)
|
2
|
+
|
3
|
+
import atexit
|
4
|
+
import logging
|
5
|
+
import time
|
6
|
+
from typing import Optional
|
7
|
+
|
8
|
+
import httpx
|
9
|
+
import requests
|
10
|
+
import torch
|
11
|
+
from openai import OpenAI
|
12
|
+
from requests import ConnectionError
|
13
|
+
from requests.adapters import HTTPAdapter
|
14
|
+
from torch import nn
|
15
|
+
from trl.import_utils import is_requests_available, is_vllm_available
|
16
|
+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
17
|
+
from vllm.distributed.utils import StatelessProcessGroup
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
# Add these new constants near the top of the file
|
22
|
+
MAX_INFERENCE_RETRIES = 3
|
23
|
+
INFERENCE_RETRY_DELAY = 1.0 # seconds
|
24
|
+
INFERENCE_BLOCKED_STATUS = 503 # HTTP status code to use when inference is blocked
|
25
|
+
|
26
|
+
|
27
|
+
# Add this new custom exception
|
28
|
+
class InferenceBlockedError(Exception):
|
29
|
+
"""Raised when inference is blocked due to weight updates in progress."""
|
30
|
+
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class VLLMClient(OpenAI):
|
35
|
+
"""
|
36
|
+
A client class to interact with a vLLM server.
|
37
|
+
|
38
|
+
This class provides methods to generate completions, initialize and manage weight update groups, and update model
|
39
|
+
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
43
|
+
IP address of the vLLM server.
|
44
|
+
server_port (`int`, *optional*, defaults to `8000`):
|
45
|
+
Port number of the vLLM server.
|
46
|
+
group_port (`int`, *optional*, defaults to `51216`):
|
47
|
+
Port number for the weight update group.
|
48
|
+
connection_timeout (`float`, *optional*, defaults to `0.0`):
|
49
|
+
Total timeout duration in seconds to wait for the server to be up. If the server is not up after the
|
50
|
+
timeout, a `ConnectionError` is raised.
|
51
|
+
|
52
|
+
Examples:
|
53
|
+
Run the vLLM server with the model `Qwen/Qwen2.5-7B`:
|
54
|
+
|
55
|
+
```
|
56
|
+
$ trl vllm-serve --model Qwen/Qwen2.5-7B
|
57
|
+
...
|
58
|
+
INFO: Application startup complete.
|
59
|
+
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
|
60
|
+
```
|
61
|
+
|
62
|
+
Use the client to generate completions and update model weights:
|
63
|
+
|
64
|
+
```python
|
65
|
+
>>> from trl.extras.vllm_client import VLLMClient
|
66
|
+
>>> client = VLLMClient()
|
67
|
+
>>> client.generate(["Hello, AI!", "Tell me a joke"])
|
68
|
+
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
|
69
|
+
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
|
70
|
+
|
71
|
+
>>> from transformers import AutoModelForCausalLM
|
72
|
+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
|
73
|
+
>>> client.update_model_params(model)
|
74
|
+
```
|
75
|
+
"""
|
76
|
+
|
77
|
+
def __init__(
|
78
|
+
self,
|
79
|
+
host: str = "0.0.0.0",
|
80
|
+
port: int = 8000,
|
81
|
+
group_port: int = 51216,
|
82
|
+
connection_timeout: float = 0.0,
|
83
|
+
):
|
84
|
+
if not is_requests_available():
|
85
|
+
raise ImportError(
|
86
|
+
"requests is not installed. Please install it with `pip install requests`."
|
87
|
+
)
|
88
|
+
if not is_vllm_available():
|
89
|
+
raise ImportError(
|
90
|
+
"vLLM is not installed. Please install it with `pip install vllm`."
|
91
|
+
)
|
92
|
+
|
93
|
+
super().__init__(base_url=f"http://{host}:{port}/v1", api_key="local")
|
94
|
+
self.session = requests.Session()
|
95
|
+
# Configure connection pooling to handle rapid requests better
|
96
|
+
adapter = HTTPAdapter(
|
97
|
+
pool_connections=10, pool_maxsize=10, max_retries=3, pool_block=False
|
98
|
+
)
|
99
|
+
self.session.mount("http://", adapter)
|
100
|
+
self.session.mount("https://", adapter)
|
101
|
+
|
102
|
+
self.host = host
|
103
|
+
self.server_port = port # Renamed from server_port to port to match super init
|
104
|
+
self.group_port = group_port
|
105
|
+
self.check_server(connection_timeout) # check server and fail after timeout
|
106
|
+
|
107
|
+
def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
|
108
|
+
"""
|
109
|
+
Check server availability with retries on failure, within a total timeout duration. If the server is not up
|
110
|
+
after the total timeout duration, raise a `ConnectionError`.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
retry_interval (`float`, *optional*, defaults to `2.0`):
|
114
|
+
Interval in seconds between retries.
|
115
|
+
total_timeout (`float`, *optional*, defaults to `0.0`):
|
116
|
+
Total timeout duration in seconds.
|
117
|
+
"""
|
118
|
+
url = f"http://{self.host}:{self.server_port}/health/"
|
119
|
+
start_time = time.time() # Record the start time
|
120
|
+
|
121
|
+
while True:
|
122
|
+
try:
|
123
|
+
response = requests.get(url) # type: ignore
|
124
|
+
except requests.exceptions.RequestException as exc: # type: ignore
|
125
|
+
# Check if the total timeout duration has passed
|
126
|
+
elapsed_time = time.time() - start_time
|
127
|
+
if elapsed_time >= total_timeout:
|
128
|
+
raise ConnectionError( # type: ignore
|
129
|
+
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
|
130
|
+
"seconds. Make sure the server is running by running `trl vllm-serve`."
|
131
|
+
) from exc
|
132
|
+
else:
|
133
|
+
if response.status_code == 200:
|
134
|
+
logger.info("Server is up!")
|
135
|
+
return None
|
136
|
+
|
137
|
+
# Retry logic: wait before trying again
|
138
|
+
logger.info(
|
139
|
+
f"Server is not up yet. Retrying in {retry_interval} seconds..."
|
140
|
+
)
|
141
|
+
time.sleep(retry_interval)
|
142
|
+
|
143
|
+
def init_communicator(self):
|
144
|
+
"""
|
145
|
+
Initializes the weight update group in a distributed setup for model synchronization.
|
146
|
+
"""
|
147
|
+
logger.info(f"[VLLM_CLIENT] Starting init_communicator")
|
148
|
+
|
149
|
+
# Get the world size from the server
|
150
|
+
url = f"http://{self.host}:{self.server_port}/get_world_size/"
|
151
|
+
logger.info(f"[VLLM_CLIENT] Getting world size from {url}")
|
152
|
+
try:
|
153
|
+
response = requests.get(url)
|
154
|
+
logger.info(
|
155
|
+
f"[VLLM_CLIENT] World size response: status={response.status_code}"
|
156
|
+
)
|
157
|
+
except Exception as e:
|
158
|
+
logger.error(f"[VLLM_CLIENT] Failed to get world size: {e}")
|
159
|
+
raise
|
160
|
+
|
161
|
+
if response.status_code == 200:
|
162
|
+
vllm_world_size = response.json()["world_size"]
|
163
|
+
logger.info(f"[VLLM_CLIENT] vLLM world size: {vllm_world_size}")
|
164
|
+
else:
|
165
|
+
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
166
|
+
|
167
|
+
world_size = vllm_world_size + 1 # add the client to the world
|
168
|
+
self.rank = vllm_world_size # the client's rank is the last process
|
169
|
+
logger.info(
|
170
|
+
f"[VLLM_CLIENT] Client rank: {self.rank}, total world size: {world_size}"
|
171
|
+
)
|
172
|
+
|
173
|
+
# Initialize weight update group
|
174
|
+
url = f"http://{self.host}:{self.server_port}/init_communicator/"
|
175
|
+
logger.info(f"[VLLM_CLIENT] Sending init_communicator request to {url}")
|
176
|
+
# In the server side, the host is set to 0.0.0.0
|
177
|
+
try:
|
178
|
+
response = self.session.post(
|
179
|
+
url,
|
180
|
+
json={
|
181
|
+
"host": "0.0.0.0",
|
182
|
+
"port": self.group_port,
|
183
|
+
"world_size": world_size,
|
184
|
+
},
|
185
|
+
)
|
186
|
+
logger.info(
|
187
|
+
f"[VLLM_CLIENT] Init communicator response: status={response.status_code}"
|
188
|
+
)
|
189
|
+
except Exception as e:
|
190
|
+
logger.error(f"[VLLM_CLIENT] Failed to init communicator: {e}")
|
191
|
+
raise
|
192
|
+
|
193
|
+
if response.status_code != 200:
|
194
|
+
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
195
|
+
|
196
|
+
# Brief delay to allow server initialization. While not strictly required (client socket will retry on
|
197
|
+
# connection failure), this prevents log warnings like:
|
198
|
+
# [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
|
199
|
+
time.sleep(0.1)
|
200
|
+
|
201
|
+
# Set up the communication group for weight broadcasting
|
202
|
+
pg = StatelessProcessGroup.create(
|
203
|
+
host=self.host, port=self.group_port, rank=self.rank, world_size=world_size
|
204
|
+
)
|
205
|
+
# Use device 0 like the old code - this seems to work better for multi-GPU setups
|
206
|
+
device = 0
|
207
|
+
logger.info(
|
208
|
+
f"[VLLM_CLIENT] Initializing PyNcclCommunicator on device {device}, rank {self.rank}, world_size {world_size}"
|
209
|
+
)
|
210
|
+
self.pynccl_comm = PyNcclCommunicator(pg, device=device)
|
211
|
+
|
212
|
+
# When the client object is deleted, close the weight update group
|
213
|
+
atexit.register(self.close_communicator)
|
214
|
+
|
215
|
+
async def chat(self, json_body: dict) -> dict:
|
216
|
+
"""
|
217
|
+
Send a chat completion request with retry logic for when inference is blocked.
|
218
|
+
"""
|
219
|
+
url = f"http://{self.host}:{self.server_port}/v1/chat/completions"
|
220
|
+
|
221
|
+
retries = 0
|
222
|
+
while retries < MAX_INFERENCE_RETRIES:
|
223
|
+
try:
|
224
|
+
async with httpx.AsyncClient() as client:
|
225
|
+
response = await client.post(url, json=json_body, timeout=300)
|
226
|
+
|
227
|
+
if response.status_code == INFERENCE_BLOCKED_STATUS:
|
228
|
+
retries += 1
|
229
|
+
if retries < MAX_INFERENCE_RETRIES:
|
230
|
+
logger.warning(
|
231
|
+
f"Inference blocked (weight update in progress). Retry {retries}/{MAX_INFERENCE_RETRIES} in {INFERENCE_RETRY_DELAY}s"
|
232
|
+
)
|
233
|
+
await asyncio.sleep(INFERENCE_RETRY_DELAY)
|
234
|
+
continue
|
235
|
+
else:
|
236
|
+
raise InferenceBlockedError(
|
237
|
+
"Inference blocked by weight updates after max retries"
|
238
|
+
)
|
239
|
+
|
240
|
+
response.raise_for_status()
|
241
|
+
return response.json()
|
242
|
+
|
243
|
+
except httpx.TimeoutError:
|
244
|
+
logger.error("Request timed out")
|
245
|
+
raise
|
246
|
+
except InferenceBlockedError:
|
247
|
+
raise
|
248
|
+
except Exception as e:
|
249
|
+
retries += 1
|
250
|
+
if retries < MAX_INFERENCE_RETRIES:
|
251
|
+
logger.warning(
|
252
|
+
f"Request failed. Retry {retries}/{MAX_INFERENCE_RETRIES} in {INFERENCE_RETRY_DELAY}s. Error: {e}"
|
253
|
+
)
|
254
|
+
await asyncio.sleep(INFERENCE_RETRY_DELAY)
|
255
|
+
else:
|
256
|
+
logger.error(
|
257
|
+
f"Request failed after {MAX_INFERENCE_RETRIES} retries"
|
258
|
+
)
|
259
|
+
raise
|
260
|
+
|
261
|
+
def update_named_param(self, name: str, weights: torch.Tensor):
|
262
|
+
"""
|
263
|
+
Updates a specific named parameter in the model and broadcasts it to other processes.
|
264
|
+
|
265
|
+
Args:
|
266
|
+
name (`str`):
|
267
|
+
Name of the layer whose weights are being updated.
|
268
|
+
weights (`torch.Tensor`):
|
269
|
+
Tensor containing the updated weights.
|
270
|
+
"""
|
271
|
+
dtype, shape = str(weights.dtype), tuple(weights.shape)
|
272
|
+
url = f"http://{self.host}:{self.server_port}/update_named_param/"
|
273
|
+
logger.debug(f"[VLLM_CLIENT] Sending weight update request for {name}")
|
274
|
+
|
275
|
+
# Add timeout to prevent hanging on HTTP request
|
276
|
+
try:
|
277
|
+
response = self.session.post(
|
278
|
+
url, json={"name": name, "dtype": dtype, "shape": shape}, timeout=300.0
|
279
|
+
)
|
280
|
+
if response.status_code != 200:
|
281
|
+
raise Exception(
|
282
|
+
f"Request failed: {response.status_code}, {response.text}"
|
283
|
+
)
|
284
|
+
except requests.exceptions.Timeout:
|
285
|
+
logger.error(
|
286
|
+
f"[VLLM_CLIENT] Timeout waiting for server response for {name} after 300s"
|
287
|
+
)
|
288
|
+
raise Exception(f"Request timeout for {name} after 300s")
|
289
|
+
except Exception as e:
|
290
|
+
logger.error(f"[VLLM_CLIENT] Error sending request for {name}: {e}")
|
291
|
+
raise
|
292
|
+
|
293
|
+
logger.debug(
|
294
|
+
f"[VLLM_CLIENT] Server responded, starting NCCL broadcast for {name}"
|
295
|
+
)
|
296
|
+
|
297
|
+
# Broadcast the weights to the other processes
|
298
|
+
self.pynccl_comm.broadcast(weights, src=self.rank)
|
299
|
+
logger.debug(
|
300
|
+
f"[VLLM_CLIENT] NCCL broadcast complete, waiting at barrier for {name}"
|
301
|
+
)
|
302
|
+
self.pynccl_comm.group.barrier()
|
303
|
+
logger.debug(f"[VLLM_CLIENT] Barrier passed for {name}")
|
304
|
+
|
305
|
+
def update_model_params(self, model: nn.Module):
|
306
|
+
"""
|
307
|
+
Updates all parameters of the given model by calling `update_named_param` for each parameter in the model.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
model (`nn.Module`):
|
311
|
+
Model whose parameters (weights/biases) are to be updated.
|
312
|
+
"""
|
313
|
+
for name, param in model.named_parameters():
|
314
|
+
# Update each parameter individually
|
315
|
+
self.update_named_param(name, param.data)
|
316
|
+
|
317
|
+
def batch_update_model_params(self, model: nn.Module, batch_size: int = 50):
|
318
|
+
"""
|
319
|
+
Updates all parameters of the given model in batches to reduce overhead and prevent overwhelming the server.
|
320
|
+
|
321
|
+
This method coordinates with the server to ensure proper NCCL synchronization:
|
322
|
+
1. Send batch of parameter metadata to server
|
323
|
+
2. Server notifies workers for each parameter
|
324
|
+
3. Client broadcasts each parameter via NCCL after server confirmation
|
325
|
+
|
326
|
+
Args:
|
327
|
+
model (`nn.Module`):
|
328
|
+
Model whose parameters (weights/biases) are to be updated.
|
329
|
+
batch_size (`int`, *optional*, defaults to 50):
|
330
|
+
Number of parameters to update in each batch.
|
331
|
+
"""
|
332
|
+
# Collect all parameters
|
333
|
+
all_params = list(model.named_parameters())
|
334
|
+
total_params = len(all_params)
|
335
|
+
|
336
|
+
logger.info(
|
337
|
+
f"[VLLM_CLIENT] Starting batch update of {total_params} parameters in batches of {batch_size}"
|
338
|
+
)
|
339
|
+
|
340
|
+
# Process in batches
|
341
|
+
for batch_idx, i in enumerate(range(0, total_params, batch_size)):
|
342
|
+
batch_params = all_params[i : i + batch_size]
|
343
|
+
|
344
|
+
# Prepare batch update request
|
345
|
+
batch_updates = []
|
346
|
+
for name, param in batch_params:
|
347
|
+
batch_updates.append(
|
348
|
+
{
|
349
|
+
"name": name,
|
350
|
+
"dtype": str(param.data.dtype),
|
351
|
+
"shape": list(param.data.shape),
|
352
|
+
}
|
353
|
+
)
|
354
|
+
|
355
|
+
# Send batch update request
|
356
|
+
url = f"http://{self.host}:{self.server_port}/batch_update_named_params/"
|
357
|
+
logger.debug(
|
358
|
+
f"[VLLM_CLIENT] Sending batch {batch_idx + 1} with {len(batch_updates)} parameters"
|
359
|
+
)
|
360
|
+
|
361
|
+
try:
|
362
|
+
response = self.session.post(
|
363
|
+
url, json={"updates": batch_updates}, timeout=600.0
|
364
|
+
)
|
365
|
+
if response.status_code not in [200, 207]: # 207 is Multi-Status
|
366
|
+
raise Exception(
|
367
|
+
f"Batch request failed: {response.status_code}, {response.text}"
|
368
|
+
)
|
369
|
+
|
370
|
+
result = response.json()
|
371
|
+
|
372
|
+
# Check for partial failures
|
373
|
+
if response.status_code == 207:
|
374
|
+
logger.warning(
|
375
|
+
f"[VLLM_CLIENT] Batch had errors: {result.get('errors', [])}"
|
376
|
+
)
|
377
|
+
|
378
|
+
# Get list of successfully notified parameters
|
379
|
+
successful_params = result.get("successful", [])
|
380
|
+
if not successful_params:
|
381
|
+
logger.error(
|
382
|
+
f"[VLLM_CLIENT] No successful parameters in batch response"
|
383
|
+
)
|
384
|
+
continue
|
385
|
+
|
386
|
+
except requests.exceptions.Timeout:
|
387
|
+
logger.error(
|
388
|
+
f"[VLLM_CLIENT] Timeout waiting for batch response after 600s"
|
389
|
+
)
|
390
|
+
raise Exception(f"Batch request timeout after 600s")
|
391
|
+
except Exception as e:
|
392
|
+
logger.error(f"[VLLM_CLIENT] Error sending batch request: {e}")
|
393
|
+
raise
|
394
|
+
|
395
|
+
# Now broadcast weights for successfully notified parameters
|
396
|
+
logger.debug(
|
397
|
+
f"[VLLM_CLIENT] Broadcasting weights for {len(successful_params)} parameters in batch {batch_idx + 1}"
|
398
|
+
)
|
399
|
+
|
400
|
+
for name, param in batch_params:
|
401
|
+
if name in successful_params:
|
402
|
+
try:
|
403
|
+
# Broadcast this specific parameter
|
404
|
+
self.pynccl_comm.broadcast(param.data, src=self.rank)
|
405
|
+
self.pynccl_comm.group.barrier()
|
406
|
+
logger.debug(f"[VLLM_CLIENT] Broadcast complete for {name}")
|
407
|
+
except Exception as e:
|
408
|
+
logger.error(f"[VLLM_CLIENT] Failed to broadcast {name}: {e}")
|
409
|
+
raise
|
410
|
+
else:
|
411
|
+
logger.warning(
|
412
|
+
f"[VLLM_CLIENT] Skipping broadcast for {name} - not in successful list"
|
413
|
+
)
|
414
|
+
|
415
|
+
logger.debug(f"[VLLM_CLIENT] Completed batch {batch_idx + 1}")
|
416
|
+
|
417
|
+
logger.info(
|
418
|
+
f"[VLLM_CLIENT] Batch update complete for {total_params} parameters"
|
419
|
+
)
|
420
|
+
|
421
|
+
def reset_prefix_cache(self):
|
422
|
+
"""
|
423
|
+
Resets the prefix cache for the model.
|
424
|
+
"""
|
425
|
+
url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/"
|
426
|
+
response = self.session.post(url)
|
427
|
+
if response.status_code != 200:
|
428
|
+
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
429
|
+
|
430
|
+
def close_communicator(self):
|
431
|
+
"""
|
432
|
+
Closes the weight update group and cleans up the communication group.
|
433
|
+
"""
|
434
|
+
url = f"http://{self.host}:{self.server_port}/close_communicator/"
|
435
|
+
|
436
|
+
try:
|
437
|
+
response = self.session.post(url)
|
438
|
+
except ConnectionError:
|
439
|
+
# The server might be already down, so we don't need to close the communicator
|
440
|
+
pass
|
441
|
+
else:
|
442
|
+
if response.status_code != 200:
|
443
|
+
raise Exception(
|
444
|
+
f"Request failed: {response.status_code}, {response.text}"
|
445
|
+
)
|