arbor-ai 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,445 @@
1
+ # adapted from trl/extras/vllm_client.py (huggingface/trl)
2
+
3
+ import atexit
4
+ import logging
5
+ import time
6
+ from typing import Optional
7
+
8
+ import httpx
9
+ import requests
10
+ import torch
11
+ from openai import OpenAI
12
+ from requests import ConnectionError
13
+ from requests.adapters import HTTPAdapter
14
+ from torch import nn
15
+ from trl.import_utils import is_requests_available, is_vllm_available
16
+ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
17
+ from vllm.distributed.utils import StatelessProcessGroup
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Add these new constants near the top of the file
22
+ MAX_INFERENCE_RETRIES = 3
23
+ INFERENCE_RETRY_DELAY = 1.0 # seconds
24
+ INFERENCE_BLOCKED_STATUS = 503 # HTTP status code to use when inference is blocked
25
+
26
+
27
+ # Add this new custom exception
28
+ class InferenceBlockedError(Exception):
29
+ """Raised when inference is blocked due to weight updates in progress."""
30
+
31
+ pass
32
+
33
+
34
+ class VLLMClient(OpenAI):
35
+ """
36
+ A client class to interact with a vLLM server.
37
+
38
+ This class provides methods to generate completions, initialize and manage weight update groups, and update model
39
+ weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
40
+
41
+ Args:
42
+ host (`str`, *optional*, defaults to `"0.0.0.0"`):
43
+ IP address of the vLLM server.
44
+ server_port (`int`, *optional*, defaults to `8000`):
45
+ Port number of the vLLM server.
46
+ group_port (`int`, *optional*, defaults to `51216`):
47
+ Port number for the weight update group.
48
+ connection_timeout (`float`, *optional*, defaults to `0.0`):
49
+ Total timeout duration in seconds to wait for the server to be up. If the server is not up after the
50
+ timeout, a `ConnectionError` is raised.
51
+
52
+ Examples:
53
+ Run the vLLM server with the model `Qwen/Qwen2.5-7B`:
54
+
55
+ ```
56
+ $ trl vllm-serve --model Qwen/Qwen2.5-7B
57
+ ...
58
+ INFO: Application startup complete.
59
+ INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
60
+ ```
61
+
62
+ Use the client to generate completions and update model weights:
63
+
64
+ ```python
65
+ >>> from trl.extras.vllm_client import VLLMClient
66
+ >>> client = VLLMClient()
67
+ >>> client.generate(["Hello, AI!", "Tell me a joke"])
68
+ [[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
69
+ [911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
70
+
71
+ >>> from transformers import AutoModelForCausalLM
72
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
73
+ >>> client.update_model_params(model)
74
+ ```
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ host: str = "0.0.0.0",
80
+ port: int = 8000,
81
+ group_port: int = 51216,
82
+ connection_timeout: float = 0.0,
83
+ ):
84
+ if not is_requests_available():
85
+ raise ImportError(
86
+ "requests is not installed. Please install it with `pip install requests`."
87
+ )
88
+ if not is_vllm_available():
89
+ raise ImportError(
90
+ "vLLM is not installed. Please install it with `pip install vllm`."
91
+ )
92
+
93
+ super().__init__(base_url=f"http://{host}:{port}/v1", api_key="local")
94
+ self.session = requests.Session()
95
+ # Configure connection pooling to handle rapid requests better
96
+ adapter = HTTPAdapter(
97
+ pool_connections=10, pool_maxsize=10, max_retries=3, pool_block=False
98
+ )
99
+ self.session.mount("http://", adapter)
100
+ self.session.mount("https://", adapter)
101
+
102
+ self.host = host
103
+ self.server_port = port # Renamed from server_port to port to match super init
104
+ self.group_port = group_port
105
+ self.check_server(connection_timeout) # check server and fail after timeout
106
+
107
+ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
108
+ """
109
+ Check server availability with retries on failure, within a total timeout duration. If the server is not up
110
+ after the total timeout duration, raise a `ConnectionError`.
111
+
112
+ Args:
113
+ retry_interval (`float`, *optional*, defaults to `2.0`):
114
+ Interval in seconds between retries.
115
+ total_timeout (`float`, *optional*, defaults to `0.0`):
116
+ Total timeout duration in seconds.
117
+ """
118
+ url = f"http://{self.host}:{self.server_port}/health/"
119
+ start_time = time.time() # Record the start time
120
+
121
+ while True:
122
+ try:
123
+ response = requests.get(url) # type: ignore
124
+ except requests.exceptions.RequestException as exc: # type: ignore
125
+ # Check if the total timeout duration has passed
126
+ elapsed_time = time.time() - start_time
127
+ if elapsed_time >= total_timeout:
128
+ raise ConnectionError( # type: ignore
129
+ f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
130
+ "seconds. Make sure the server is running by running `trl vllm-serve`."
131
+ ) from exc
132
+ else:
133
+ if response.status_code == 200:
134
+ logger.info("Server is up!")
135
+ return None
136
+
137
+ # Retry logic: wait before trying again
138
+ logger.info(
139
+ f"Server is not up yet. Retrying in {retry_interval} seconds..."
140
+ )
141
+ time.sleep(retry_interval)
142
+
143
+ def init_communicator(self):
144
+ """
145
+ Initializes the weight update group in a distributed setup for model synchronization.
146
+ """
147
+ logger.info(f"[VLLM_CLIENT] Starting init_communicator")
148
+
149
+ # Get the world size from the server
150
+ url = f"http://{self.host}:{self.server_port}/get_world_size/"
151
+ logger.info(f"[VLLM_CLIENT] Getting world size from {url}")
152
+ try:
153
+ response = requests.get(url)
154
+ logger.info(
155
+ f"[VLLM_CLIENT] World size response: status={response.status_code}"
156
+ )
157
+ except Exception as e:
158
+ logger.error(f"[VLLM_CLIENT] Failed to get world size: {e}")
159
+ raise
160
+
161
+ if response.status_code == 200:
162
+ vllm_world_size = response.json()["world_size"]
163
+ logger.info(f"[VLLM_CLIENT] vLLM world size: {vllm_world_size}")
164
+ else:
165
+ raise Exception(f"Request failed: {response.status_code}, {response.text}")
166
+
167
+ world_size = vllm_world_size + 1 # add the client to the world
168
+ self.rank = vllm_world_size # the client's rank is the last process
169
+ logger.info(
170
+ f"[VLLM_CLIENT] Client rank: {self.rank}, total world size: {world_size}"
171
+ )
172
+
173
+ # Initialize weight update group
174
+ url = f"http://{self.host}:{self.server_port}/init_communicator/"
175
+ logger.info(f"[VLLM_CLIENT] Sending init_communicator request to {url}")
176
+ # In the server side, the host is set to 0.0.0.0
177
+ try:
178
+ response = self.session.post(
179
+ url,
180
+ json={
181
+ "host": "0.0.0.0",
182
+ "port": self.group_port,
183
+ "world_size": world_size,
184
+ },
185
+ )
186
+ logger.info(
187
+ f"[VLLM_CLIENT] Init communicator response: status={response.status_code}"
188
+ )
189
+ except Exception as e:
190
+ logger.error(f"[VLLM_CLIENT] Failed to init communicator: {e}")
191
+ raise
192
+
193
+ if response.status_code != 200:
194
+ raise Exception(f"Request failed: {response.status_code}, {response.text}")
195
+
196
+ # Brief delay to allow server initialization. While not strictly required (client socket will retry on
197
+ # connection failure), this prevents log warnings like:
198
+ # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
199
+ time.sleep(0.1)
200
+
201
+ # Set up the communication group for weight broadcasting
202
+ pg = StatelessProcessGroup.create(
203
+ host=self.host, port=self.group_port, rank=self.rank, world_size=world_size
204
+ )
205
+ # Use device 0 like the old code - this seems to work better for multi-GPU setups
206
+ device = 0
207
+ logger.info(
208
+ f"[VLLM_CLIENT] Initializing PyNcclCommunicator on device {device}, rank {self.rank}, world_size {world_size}"
209
+ )
210
+ self.pynccl_comm = PyNcclCommunicator(pg, device=device)
211
+
212
+ # When the client object is deleted, close the weight update group
213
+ atexit.register(self.close_communicator)
214
+
215
+ async def chat(self, json_body: dict) -> dict:
216
+ """
217
+ Send a chat completion request with retry logic for when inference is blocked.
218
+ """
219
+ url = f"http://{self.host}:{self.server_port}/v1/chat/completions"
220
+
221
+ retries = 0
222
+ while retries < MAX_INFERENCE_RETRIES:
223
+ try:
224
+ async with httpx.AsyncClient() as client:
225
+ response = await client.post(url, json=json_body, timeout=300)
226
+
227
+ if response.status_code == INFERENCE_BLOCKED_STATUS:
228
+ retries += 1
229
+ if retries < MAX_INFERENCE_RETRIES:
230
+ logger.warning(
231
+ f"Inference blocked (weight update in progress). Retry {retries}/{MAX_INFERENCE_RETRIES} in {INFERENCE_RETRY_DELAY}s"
232
+ )
233
+ await asyncio.sleep(INFERENCE_RETRY_DELAY)
234
+ continue
235
+ else:
236
+ raise InferenceBlockedError(
237
+ "Inference blocked by weight updates after max retries"
238
+ )
239
+
240
+ response.raise_for_status()
241
+ return response.json()
242
+
243
+ except httpx.TimeoutError:
244
+ logger.error("Request timed out")
245
+ raise
246
+ except InferenceBlockedError:
247
+ raise
248
+ except Exception as e:
249
+ retries += 1
250
+ if retries < MAX_INFERENCE_RETRIES:
251
+ logger.warning(
252
+ f"Request failed. Retry {retries}/{MAX_INFERENCE_RETRIES} in {INFERENCE_RETRY_DELAY}s. Error: {e}"
253
+ )
254
+ await asyncio.sleep(INFERENCE_RETRY_DELAY)
255
+ else:
256
+ logger.error(
257
+ f"Request failed after {MAX_INFERENCE_RETRIES} retries"
258
+ )
259
+ raise
260
+
261
+ def update_named_param(self, name: str, weights: torch.Tensor):
262
+ """
263
+ Updates a specific named parameter in the model and broadcasts it to other processes.
264
+
265
+ Args:
266
+ name (`str`):
267
+ Name of the layer whose weights are being updated.
268
+ weights (`torch.Tensor`):
269
+ Tensor containing the updated weights.
270
+ """
271
+ dtype, shape = str(weights.dtype), tuple(weights.shape)
272
+ url = f"http://{self.host}:{self.server_port}/update_named_param/"
273
+ logger.debug(f"[VLLM_CLIENT] Sending weight update request for {name}")
274
+
275
+ # Add timeout to prevent hanging on HTTP request
276
+ try:
277
+ response = self.session.post(
278
+ url, json={"name": name, "dtype": dtype, "shape": shape}, timeout=300.0
279
+ )
280
+ if response.status_code != 200:
281
+ raise Exception(
282
+ f"Request failed: {response.status_code}, {response.text}"
283
+ )
284
+ except requests.exceptions.Timeout:
285
+ logger.error(
286
+ f"[VLLM_CLIENT] Timeout waiting for server response for {name} after 300s"
287
+ )
288
+ raise Exception(f"Request timeout for {name} after 300s")
289
+ except Exception as e:
290
+ logger.error(f"[VLLM_CLIENT] Error sending request for {name}: {e}")
291
+ raise
292
+
293
+ logger.debug(
294
+ f"[VLLM_CLIENT] Server responded, starting NCCL broadcast for {name}"
295
+ )
296
+
297
+ # Broadcast the weights to the other processes
298
+ self.pynccl_comm.broadcast(weights, src=self.rank)
299
+ logger.debug(
300
+ f"[VLLM_CLIENT] NCCL broadcast complete, waiting at barrier for {name}"
301
+ )
302
+ self.pynccl_comm.group.barrier()
303
+ logger.debug(f"[VLLM_CLIENT] Barrier passed for {name}")
304
+
305
+ def update_model_params(self, model: nn.Module):
306
+ """
307
+ Updates all parameters of the given model by calling `update_named_param` for each parameter in the model.
308
+
309
+ Args:
310
+ model (`nn.Module`):
311
+ Model whose parameters (weights/biases) are to be updated.
312
+ """
313
+ for name, param in model.named_parameters():
314
+ # Update each parameter individually
315
+ self.update_named_param(name, param.data)
316
+
317
+ def batch_update_model_params(self, model: nn.Module, batch_size: int = 50):
318
+ """
319
+ Updates all parameters of the given model in batches to reduce overhead and prevent overwhelming the server.
320
+
321
+ This method coordinates with the server to ensure proper NCCL synchronization:
322
+ 1. Send batch of parameter metadata to server
323
+ 2. Server notifies workers for each parameter
324
+ 3. Client broadcasts each parameter via NCCL after server confirmation
325
+
326
+ Args:
327
+ model (`nn.Module`):
328
+ Model whose parameters (weights/biases) are to be updated.
329
+ batch_size (`int`, *optional*, defaults to 50):
330
+ Number of parameters to update in each batch.
331
+ """
332
+ # Collect all parameters
333
+ all_params = list(model.named_parameters())
334
+ total_params = len(all_params)
335
+
336
+ logger.info(
337
+ f"[VLLM_CLIENT] Starting batch update of {total_params} parameters in batches of {batch_size}"
338
+ )
339
+
340
+ # Process in batches
341
+ for batch_idx, i in enumerate(range(0, total_params, batch_size)):
342
+ batch_params = all_params[i : i + batch_size]
343
+
344
+ # Prepare batch update request
345
+ batch_updates = []
346
+ for name, param in batch_params:
347
+ batch_updates.append(
348
+ {
349
+ "name": name,
350
+ "dtype": str(param.data.dtype),
351
+ "shape": list(param.data.shape),
352
+ }
353
+ )
354
+
355
+ # Send batch update request
356
+ url = f"http://{self.host}:{self.server_port}/batch_update_named_params/"
357
+ logger.debug(
358
+ f"[VLLM_CLIENT] Sending batch {batch_idx + 1} with {len(batch_updates)} parameters"
359
+ )
360
+
361
+ try:
362
+ response = self.session.post(
363
+ url, json={"updates": batch_updates}, timeout=600.0
364
+ )
365
+ if response.status_code not in [200, 207]: # 207 is Multi-Status
366
+ raise Exception(
367
+ f"Batch request failed: {response.status_code}, {response.text}"
368
+ )
369
+
370
+ result = response.json()
371
+
372
+ # Check for partial failures
373
+ if response.status_code == 207:
374
+ logger.warning(
375
+ f"[VLLM_CLIENT] Batch had errors: {result.get('errors', [])}"
376
+ )
377
+
378
+ # Get list of successfully notified parameters
379
+ successful_params = result.get("successful", [])
380
+ if not successful_params:
381
+ logger.error(
382
+ f"[VLLM_CLIENT] No successful parameters in batch response"
383
+ )
384
+ continue
385
+
386
+ except requests.exceptions.Timeout:
387
+ logger.error(
388
+ f"[VLLM_CLIENT] Timeout waiting for batch response after 600s"
389
+ )
390
+ raise Exception(f"Batch request timeout after 600s")
391
+ except Exception as e:
392
+ logger.error(f"[VLLM_CLIENT] Error sending batch request: {e}")
393
+ raise
394
+
395
+ # Now broadcast weights for successfully notified parameters
396
+ logger.debug(
397
+ f"[VLLM_CLIENT] Broadcasting weights for {len(successful_params)} parameters in batch {batch_idx + 1}"
398
+ )
399
+
400
+ for name, param in batch_params:
401
+ if name in successful_params:
402
+ try:
403
+ # Broadcast this specific parameter
404
+ self.pynccl_comm.broadcast(param.data, src=self.rank)
405
+ self.pynccl_comm.group.barrier()
406
+ logger.debug(f"[VLLM_CLIENT] Broadcast complete for {name}")
407
+ except Exception as e:
408
+ logger.error(f"[VLLM_CLIENT] Failed to broadcast {name}: {e}")
409
+ raise
410
+ else:
411
+ logger.warning(
412
+ f"[VLLM_CLIENT] Skipping broadcast for {name} - not in successful list"
413
+ )
414
+
415
+ logger.debug(f"[VLLM_CLIENT] Completed batch {batch_idx + 1}")
416
+
417
+ logger.info(
418
+ f"[VLLM_CLIENT] Batch update complete for {total_params} parameters"
419
+ )
420
+
421
+ def reset_prefix_cache(self):
422
+ """
423
+ Resets the prefix cache for the model.
424
+ """
425
+ url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/"
426
+ response = self.session.post(url)
427
+ if response.status_code != 200:
428
+ raise Exception(f"Request failed: {response.status_code}, {response.text}")
429
+
430
+ def close_communicator(self):
431
+ """
432
+ Closes the weight update group and cleans up the communication group.
433
+ """
434
+ url = f"http://{self.host}:{self.server_port}/close_communicator/"
435
+
436
+ try:
437
+ response = self.session.post(url)
438
+ except ConnectionError:
439
+ # The server might be already down, so we don't need to close the communicator
440
+ pass
441
+ else:
442
+ if response.status_code != 200:
443
+ raise Exception(
444
+ f"Request failed: {response.status_code}, {response.text}"
445
+ )