arbor-ai 0.1.14__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,444 @@
1
+ # adapted from Will Brown's verifiers library (https://github.com/willccbb/verifiers)
2
+
3
+ import atexit
4
+ import logging
5
+ import time
6
+ from typing import Optional
7
+
8
+ import httpx
9
+ import requests
10
+ import torch
11
+ from requests import ConnectionError
12
+ from requests.adapters import HTTPAdapter
13
+ from torch import nn
14
+ from trl.import_utils import is_requests_available, is_vllm_available
15
+ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
16
+ from vllm.distributed.utils import StatelessProcessGroup
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Add these new constants near the top of the file
21
+ MAX_INFERENCE_RETRIES = 3
22
+ INFERENCE_RETRY_DELAY = 1.0 # seconds
23
+ INFERENCE_BLOCKED_STATUS = 503 # HTTP status code to use when inference is blocked
24
+
25
+
26
+ # Add this new custom exception
27
+ class InferenceBlockedError(Exception):
28
+ """Raised when inference is blocked due to weight updates in progress."""
29
+
30
+ pass
31
+
32
+
33
+ class VLLMClient:
34
+ """
35
+ A client class to interact with a vLLM server.
36
+
37
+ This class provides methods to generate completions, initialize and manage weight update groups, and update model
38
+ weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
39
+
40
+ Args:
41
+ host (`str`, *optional*, defaults to `"0.0.0.0"`):
42
+ IP address of the vLLM server.
43
+ server_port (`int`, *optional*, defaults to `8000`):
44
+ Port number of the vLLM server.
45
+ group_port (`int`, *optional*, defaults to `51216`):
46
+ Port number for the weight update group.
47
+ connection_timeout (`float`, *optional*, defaults to `0.0`):
48
+ Total timeout duration in seconds to wait for the server to be up. If the server is not up after the
49
+ timeout, a `ConnectionError` is raised.
50
+
51
+ Examples:
52
+ Run the vLLM server with the model `Qwen/Qwen2.5-7B`:
53
+
54
+ ```
55
+ $ trl vllm-serve --model Qwen/Qwen2.5-7B
56
+ ...
57
+ INFO: Application startup complete.
58
+ INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
59
+ ```
60
+
61
+ Use the client to generate completions and update model weights:
62
+
63
+ ```python
64
+ >>> from trl.extras.vllm_client import VLLMClient
65
+ >>> client = VLLMClient()
66
+ >>> client.generate(["Hello, AI!", "Tell me a joke"])
67
+ [[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
68
+ [911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
69
+
70
+ >>> from transformers import AutoModelForCausalLM
71
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
72
+ >>> client.update_model_params(model)
73
+ ```
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ host: str = "0.0.0.0",
79
+ port: int = 8000,
80
+ group_port: int = 51216,
81
+ connection_timeout: float = 0.0,
82
+ ):
83
+ if not is_requests_available():
84
+ raise ImportError(
85
+ "requests is not installed. Please install it with `pip install requests`."
86
+ )
87
+ if not is_vllm_available():
88
+ raise ImportError(
89
+ "vLLM is not installed. Please install it with `pip install vllm`."
90
+ )
91
+
92
+ self.base_url = f"http://{host}:{port}/v1"
93
+ self.session = requests.Session()
94
+ # Configure connection pooling to handle rapid requests better
95
+ adapter = HTTPAdapter(
96
+ pool_connections=10, pool_maxsize=10, max_retries=3, pool_block=False
97
+ )
98
+ self.session.mount("http://", adapter)
99
+ self.session.mount("https://", adapter)
100
+
101
+ self.host = host
102
+ self.server_port = port # Renamed from server_port to port to match super init
103
+ self.group_port = group_port
104
+ self.check_server(connection_timeout) # check server and fail after timeout
105
+
106
+ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
107
+ """
108
+ Check server availability with retries on failure, within a total timeout duration. If the server is not up
109
+ after the total timeout duration, raise a `ConnectionError`.
110
+
111
+ Args:
112
+ retry_interval (`float`, *optional*, defaults to `2.0`):
113
+ Interval in seconds between retries.
114
+ total_timeout (`float`, *optional*, defaults to `0.0`):
115
+ Total timeout duration in seconds.
116
+ """
117
+ url = f"http://{self.host}:{self.server_port}/health/"
118
+ start_time = time.time() # Record the start time
119
+
120
+ while True:
121
+ try:
122
+ response = requests.get(url) # type: ignore
123
+ except requests.exceptions.RequestException as exc: # type: ignore
124
+ # Check if the total timeout duration has passed
125
+ elapsed_time = time.time() - start_time
126
+ if elapsed_time >= total_timeout:
127
+ raise ConnectionError( # type: ignore
128
+ f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
129
+ "seconds. Make sure the server is running by running `trl vllm-serve`."
130
+ ) from exc
131
+ else:
132
+ if response.status_code == 200:
133
+ logger.info("Server is up!")
134
+ return None
135
+
136
+ # Retry logic: wait before trying again
137
+ logger.info(
138
+ f"Server is not up yet. Retrying in {retry_interval} seconds..."
139
+ )
140
+ time.sleep(retry_interval)
141
+
142
+ def init_communicator(self):
143
+ """
144
+ Initializes the weight update group in a distributed setup for model synchronization.
145
+ """
146
+ logger.info(f"[VLLM_CLIENT] Starting init_communicator")
147
+
148
+ # Get the world size from the server
149
+ url = f"http://{self.host}:{self.server_port}/get_world_size/"
150
+ logger.info(f"[VLLM_CLIENT] Getting world size from {url}")
151
+ try:
152
+ response = requests.get(url)
153
+ logger.info(
154
+ f"[VLLM_CLIENT] World size response: status={response.status_code}"
155
+ )
156
+ except Exception as e:
157
+ logger.error(f"[VLLM_CLIENT] Failed to get world size: {e}")
158
+ raise
159
+
160
+ if response.status_code == 200:
161
+ vllm_world_size = response.json()["world_size"]
162
+ logger.info(f"[VLLM_CLIENT] vLLM world size: {vllm_world_size}")
163
+ else:
164
+ raise Exception(f"Request failed: {response.status_code}, {response.text}")
165
+
166
+ world_size = vllm_world_size + 1 # add the client to the world
167
+ self.rank = vllm_world_size # the client's rank is the last process
168
+ logger.info(
169
+ f"[VLLM_CLIENT] Client rank: {self.rank}, total world size: {world_size}"
170
+ )
171
+
172
+ # Initialize weight update group
173
+ url = f"http://{self.host}:{self.server_port}/init_communicator/"
174
+ logger.info(f"[VLLM_CLIENT] Sending init_communicator request to {url}")
175
+ # In the server side, the host is set to 0.0.0.0
176
+ try:
177
+ response = self.session.post(
178
+ url,
179
+ json={
180
+ "host": "0.0.0.0",
181
+ "port": self.group_port,
182
+ "world_size": world_size,
183
+ },
184
+ )
185
+ logger.info(
186
+ f"[VLLM_CLIENT] Init communicator response: status={response.status_code}"
187
+ )
188
+ except Exception as e:
189
+ logger.error(f"[VLLM_CLIENT] Failed to init communicator: {e}")
190
+ raise
191
+
192
+ if response.status_code != 200:
193
+ raise Exception(f"Request failed: {response.status_code}, {response.text}")
194
+
195
+ # Brief delay to allow server initialization. While not strictly required (client socket will retry on
196
+ # connection failure), this prevents log warnings like:
197
+ # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
198
+ time.sleep(0.1)
199
+
200
+ # Set up the communication group for weight broadcasting
201
+ pg = StatelessProcessGroup.create(
202
+ host=self.host, port=self.group_port, rank=self.rank, world_size=world_size
203
+ )
204
+ # Use device 0 like the old code - this seems to work better for multi-GPU setups
205
+ device = 0
206
+ logger.info(
207
+ f"[VLLM_CLIENT] Initializing PyNcclCommunicator on device {device}, rank {self.rank}, world_size {world_size}"
208
+ )
209
+ self.pynccl_comm = PyNcclCommunicator(pg, device=device)
210
+
211
+ # When the client object is deleted, close the weight update group
212
+ atexit.register(self.close_communicator)
213
+
214
+ async def chat(self, json_body: dict) -> dict:
215
+ """
216
+ Send a chat completion request with retry logic for when inference is blocked.
217
+ """
218
+ url = f"http://{self.host}:{self.server_port}/v1/chat/completions"
219
+
220
+ retries = 0
221
+ while retries < MAX_INFERENCE_RETRIES:
222
+ try:
223
+ async with httpx.AsyncClient() as client:
224
+ response = await client.post(url, json=json_body, timeout=300)
225
+
226
+ if response.status_code == INFERENCE_BLOCKED_STATUS:
227
+ retries += 1
228
+ if retries < MAX_INFERENCE_RETRIES:
229
+ logger.warning(
230
+ f"Inference blocked (weight update in progress). Retry {retries}/{MAX_INFERENCE_RETRIES} in {INFERENCE_RETRY_DELAY}s"
231
+ )
232
+ await asyncio.sleep(INFERENCE_RETRY_DELAY)
233
+ continue
234
+ else:
235
+ raise InferenceBlockedError(
236
+ "Inference blocked by weight updates after max retries"
237
+ )
238
+
239
+ response.raise_for_status()
240
+ return response.json()
241
+
242
+ except httpx.TimeoutError:
243
+ logger.error("Request timed out")
244
+ raise
245
+ except InferenceBlockedError:
246
+ raise
247
+ except Exception as e:
248
+ retries += 1
249
+ if retries < MAX_INFERENCE_RETRIES:
250
+ logger.warning(
251
+ f"Request failed. Retry {retries}/{MAX_INFERENCE_RETRIES} in {INFERENCE_RETRY_DELAY}s. Error: {e}"
252
+ )
253
+ await asyncio.sleep(INFERENCE_RETRY_DELAY)
254
+ else:
255
+ logger.error(
256
+ f"Request failed after {MAX_INFERENCE_RETRIES} retries"
257
+ )
258
+ raise
259
+
260
+ def update_named_param(self, name: str, weights: torch.Tensor):
261
+ """
262
+ Updates a specific named parameter in the model and broadcasts it to other processes.
263
+
264
+ Args:
265
+ name (`str`):
266
+ Name of the layer whose weights are being updated.
267
+ weights (`torch.Tensor`):
268
+ Tensor containing the updated weights.
269
+ """
270
+ dtype, shape = str(weights.dtype), tuple(weights.shape)
271
+ url = f"http://{self.host}:{self.server_port}/update_named_param/"
272
+ logger.debug(f"[VLLM_CLIENT] Sending weight update request for {name}")
273
+
274
+ # Add timeout to prevent hanging on HTTP request
275
+ try:
276
+ response = self.session.post(
277
+ url, json={"name": name, "dtype": dtype, "shape": shape}, timeout=300.0
278
+ )
279
+ if response.status_code != 200:
280
+ raise Exception(
281
+ f"Request failed: {response.status_code}, {response.text}"
282
+ )
283
+ except requests.exceptions.Timeout:
284
+ logger.error(
285
+ f"[VLLM_CLIENT] Timeout waiting for server response for {name} after 300s"
286
+ )
287
+ raise Exception(f"Request timeout for {name} after 300s")
288
+ except Exception as e:
289
+ logger.error(f"[VLLM_CLIENT] Error sending request for {name}: {e}")
290
+ raise
291
+
292
+ logger.debug(
293
+ f"[VLLM_CLIENT] Server responded, starting NCCL broadcast for {name}"
294
+ )
295
+
296
+ # Broadcast the weights to the other processes
297
+ self.pynccl_comm.broadcast(weights, src=self.rank)
298
+ logger.debug(
299
+ f"[VLLM_CLIENT] NCCL broadcast complete, waiting at barrier for {name}"
300
+ )
301
+ self.pynccl_comm.group.barrier()
302
+ logger.debug(f"[VLLM_CLIENT] Barrier passed for {name}")
303
+
304
+ def update_model_params(self, model: nn.Module):
305
+ """
306
+ Updates all parameters of the given model by calling `update_named_param` for each parameter in the model.
307
+
308
+ Args:
309
+ model (`nn.Module`):
310
+ Model whose parameters (weights/biases) are to be updated.
311
+ """
312
+ for name, param in model.named_parameters():
313
+ # Update each parameter individually
314
+ self.update_named_param(name, param.data)
315
+
316
+ def batch_update_model_params(self, model: nn.Module, batch_size: int = 50):
317
+ """
318
+ Updates all parameters of the given model in batches to reduce overhead and prevent overwhelming the server.
319
+
320
+ This method coordinates with the server to ensure proper NCCL synchronization:
321
+ 1. Send batch of parameter metadata to server
322
+ 2. Server notifies workers for each parameter
323
+ 3. Client broadcasts each parameter via NCCL after server confirmation
324
+
325
+ Args:
326
+ model (`nn.Module`):
327
+ Model whose parameters (weights/biases) are to be updated.
328
+ batch_size (`int`, *optional*, defaults to 50):
329
+ Number of parameters to update in each batch.
330
+ """
331
+ # Collect all parameters
332
+ all_params = list(model.named_parameters())
333
+ total_params = len(all_params)
334
+
335
+ logger.info(
336
+ f"[VLLM_CLIENT] Starting batch update of {total_params} parameters in batches of {batch_size}"
337
+ )
338
+
339
+ # Process in batches
340
+ for batch_idx, i in enumerate(range(0, total_params, batch_size)):
341
+ batch_params = all_params[i : i + batch_size]
342
+
343
+ # Prepare batch update request
344
+ batch_updates = []
345
+ for name, param in batch_params:
346
+ batch_updates.append(
347
+ {
348
+ "name": name,
349
+ "dtype": str(param.data.dtype),
350
+ "shape": list(param.data.shape),
351
+ }
352
+ )
353
+
354
+ # Send batch update request
355
+ url = f"http://{self.host}:{self.server_port}/batch_update_named_params/"
356
+ logger.debug(
357
+ f"[VLLM_CLIENT] Sending batch {batch_idx + 1} with {len(batch_updates)} parameters"
358
+ )
359
+
360
+ try:
361
+ response = self.session.post(
362
+ url, json={"updates": batch_updates}, timeout=600.0
363
+ )
364
+ if response.status_code not in [200, 207]: # 207 is Multi-Status
365
+ raise Exception(
366
+ f"Batch request failed: {response.status_code}, {response.text}"
367
+ )
368
+
369
+ result = response.json()
370
+
371
+ # Check for partial failures
372
+ if response.status_code == 207:
373
+ logger.warning(
374
+ f"[VLLM_CLIENT] Batch had errors: {result.get('errors', [])}"
375
+ )
376
+
377
+ # Get list of successfully notified parameters
378
+ successful_params = result.get("successful", [])
379
+ if not successful_params:
380
+ logger.error(
381
+ f"[VLLM_CLIENT] No successful parameters in batch response"
382
+ )
383
+ continue
384
+
385
+ except requests.exceptions.Timeout:
386
+ logger.error(
387
+ f"[VLLM_CLIENT] Timeout waiting for batch response after 600s"
388
+ )
389
+ raise Exception(f"Batch request timeout after 600s")
390
+ except Exception as e:
391
+ logger.error(f"[VLLM_CLIENT] Error sending batch request: {e}")
392
+ raise
393
+
394
+ # Now broadcast weights for successfully notified parameters
395
+ logger.debug(
396
+ f"[VLLM_CLIENT] Broadcasting weights for {len(successful_params)} parameters in batch {batch_idx + 1}"
397
+ )
398
+
399
+ for name, param in batch_params:
400
+ if name in successful_params:
401
+ try:
402
+ # Broadcast this specific parameter
403
+ self.pynccl_comm.broadcast(param.data, src=self.rank)
404
+ self.pynccl_comm.group.barrier()
405
+ logger.debug(f"[VLLM_CLIENT] Broadcast complete for {name}")
406
+ except Exception as e:
407
+ logger.error(f"[VLLM_CLIENT] Failed to broadcast {name}: {e}")
408
+ raise
409
+ else:
410
+ logger.warning(
411
+ f"[VLLM_CLIENT] Skipping broadcast for {name} - not in successful list"
412
+ )
413
+
414
+ logger.debug(f"[VLLM_CLIENT] Completed batch {batch_idx + 1}")
415
+
416
+ logger.info(
417
+ f"[VLLM_CLIENT] Batch update complete for {total_params} parameters"
418
+ )
419
+
420
+ def reset_prefix_cache(self):
421
+ """
422
+ Resets the prefix cache for the model.
423
+ """
424
+ url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/"
425
+ response = self.session.post(url)
426
+ if response.status_code != 200:
427
+ raise Exception(f"Request failed: {response.status_code}, {response.text}")
428
+
429
+ def close_communicator(self):
430
+ """
431
+ Closes the weight update group and cleans up the communication group.
432
+ """
433
+ url = f"http://{self.host}:{self.server_port}/close_communicator/"
434
+
435
+ try:
436
+ response = self.session.post(url)
437
+ except ConnectionError:
438
+ # The server might be already down, so we don't need to close the communicator
439
+ pass
440
+ else:
441
+ if response.status_code != 200:
442
+ raise Exception(
443
+ f"Request failed: {response.status_code}, {response.text}"
444
+ )