speedy-utils 1.0.4__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. llm_utils/__init__.py +31 -0
  2. llm_utils/chat_format/__init__.py +34 -0
  3. llm_utils/chat_format/display.py +274 -0
  4. llm_utils/chat_format/transform.py +149 -0
  5. llm_utils/chat_format/utils.py +43 -0
  6. llm_utils/group_messages.py +120 -0
  7. llm_utils/lm/__init__.py +8 -0
  8. llm_utils/lm/lm.py +304 -0
  9. llm_utils/lm/utils.py +130 -0
  10. llm_utils/scripts/vllm_load_balancer.py +435 -0
  11. llm_utils/scripts/vllm_serve.py +416 -0
  12. speedy_utils/__init__.py +85 -0
  13. speedy_utils/all.py +159 -0
  14. {speedy → speedy_utils}/common/__init__.py +0 -0
  15. speedy_utils/common/clock.py +215 -0
  16. speedy_utils/common/function_decorator.py +66 -0
  17. speedy_utils/common/logger.py +207 -0
  18. speedy_utils/common/report_manager.py +112 -0
  19. speedy_utils/common/utils_cache.py +264 -0
  20. {speedy → speedy_utils}/common/utils_io.py +66 -19
  21. {speedy → speedy_utils}/common/utils_misc.py +25 -11
  22. speedy_utils/common/utils_print.py +216 -0
  23. speedy_utils/multi_worker/__init__.py +0 -0
  24. speedy_utils/multi_worker/process.py +198 -0
  25. speedy_utils/multi_worker/thread.py +327 -0
  26. speedy_utils/scripts/mpython.py +108 -0
  27. speedy_utils-1.0.9.dist-info/METADATA +287 -0
  28. speedy_utils-1.0.9.dist-info/RECORD +30 -0
  29. {speedy_utils-1.0.4.dist-info → speedy_utils-1.0.9.dist-info}/WHEEL +1 -2
  30. speedy_utils-1.0.9.dist-info/entry_points.txt +5 -0
  31. speedy/__init__.py +0 -53
  32. speedy/common/clock.py +0 -68
  33. speedy/common/utils_cache.py +0 -170
  34. speedy/common/utils_print.py +0 -138
  35. speedy/multi_worker.py +0 -121
  36. speedy_utils-1.0.4.dist-info/METADATA +0 -22
  37. speedy_utils-1.0.4.dist-info/RECORD +0 -12
  38. speedy_utils-1.0.4.dist-info/top_level.txt +0 -1
@@ -0,0 +1,435 @@
1
+ import asyncio
2
+ import random
3
+ from collections import defaultdict
4
+ import time
5
+ from tabulate import tabulate
6
+ import contextlib
7
+ import aiohttp # <-- Import aiohttp
8
+ from loguru import logger
9
+
10
+ # --- Configuration ---
11
+ LOAD_BALANCER_HOST = "0.0.0.0"
12
+ LOAD_BALANCER_PORT = 8008
13
+
14
+ SCAN_TARGET_HOST = "localhost"
15
+ SCAN_PORT_START = 8150
16
+ SCAN_PORT_END = 8170 # Inclusive
17
+ SCAN_INTERVAL = 30
18
+ # Timeout applies to the HTTP health check request now
19
+ HEALTH_CHECK_TIMEOUT = 2.0 # Increased slightly for HTTP requests
20
+
21
+ STATUS_PRINT_INTERVAL = 5
22
+ BUFFER_SIZE = 4096
23
+
24
+ # --- Global Shared State ---
25
+ available_servers = []
26
+ connection_counts = defaultdict(int)
27
+ state_lock = asyncio.Lock()
28
+
29
+
30
+ # --- Helper Functions --- (relay_data and safe_close_writer remain the same)
31
+ async def relay_data(reader, writer, direction):
32
+ """Reads data from reader and writes to writer until EOF or error."""
33
+ try:
34
+ while True:
35
+ data = await reader.read(BUFFER_SIZE)
36
+ if not data:
37
+ logger.debug(f"EOF received on {direction} stream.")
38
+ break
39
+ writer.write(data)
40
+ await writer.drain()
41
+ except ConnectionResetError:
42
+ logger.warning(f"Connection reset on {direction} stream.")
43
+ except asyncio.CancelledError:
44
+ logger.debug(f"Relay task cancelled for {direction}.")
45
+ raise
46
+ except Exception as e:
47
+ logger.warning(f"Error during data relay ({direction}): {e}")
48
+ finally:
49
+ if not writer.is_closing():
50
+ try:
51
+ writer.close()
52
+ await writer.wait_closed()
53
+ logger.debug(f"Closed writer for {direction}")
54
+ except Exception as close_err:
55
+ logger.debug(
56
+ f"Error closing writer for {direction} (might be expected): {close_err}"
57
+ )
58
+
59
+
60
+ @contextlib.asynccontextmanager
61
+ async def safe_close_writer(writer):
62
+ """Async context manager to safely close an asyncio StreamWriter."""
63
+ try:
64
+ yield writer
65
+ finally:
66
+ if writer and not writer.is_closing():
67
+ try:
68
+ writer.close()
69
+ await writer.wait_closed()
70
+ except Exception as e:
71
+ logger.debug(f"Error closing writer in context manager: {e}")
72
+
73
+
74
+ # --- Server Scanning and Health Check (Modified) ---
75
+
76
+
77
+ async def check_server_health(session, host, port):
78
+ """Performs an HTTP GET request to the /health endpoint."""
79
+ url = f"http://{host}:{port}/health"
80
+ try:
81
+ # Use the provided aiohttp session to make the GET request
82
+ async with session.get(url, timeout=HEALTH_CHECK_TIMEOUT) as response:
83
+ # Check for a successful status code (2xx range)
84
+ if 200 <= response.status < 300:
85
+ logger.debug(
86
+ f"Health check success for {url} (Status: {response.status})"
87
+ )
88
+ # Ensure the connection is released back to the pool
89
+ await response.release()
90
+ return True
91
+ else:
92
+ logger.debug(
93
+ f"Health check failed for {url} (Status: {response.status})"
94
+ )
95
+ await response.release()
96
+ return False
97
+ except asyncio.TimeoutError:
98
+ logger.debug(f"Health check HTTP request timeout for {url}")
99
+ return False
100
+ except aiohttp.ClientConnectorError as e:
101
+ # Handles connection refused, DNS errors etc. - server likely down
102
+ logger.debug(f"Health check connection error for {url}: {e}")
103
+ return False
104
+ except aiohttp.ClientError as e:
105
+ # Catch other potential client errors (e.g., invalid URL structure, too many redirects)
106
+ logger.warning(f"Health check client error for {url}: {e}")
107
+ return False
108
+ except Exception as e:
109
+ # Catch any other unexpected errors during the check
110
+ logger.error(f"Unexpected health check error for {url}: {e}")
111
+ return False
112
+
113
+
114
+ async def scan_and_update_servers():
115
+ """Periodically scans ports using HTTP /health check and updates available servers."""
116
+ global available_servers
117
+ logger.debug(
118
+ f"Starting server scan task (HTTP GET /health on Ports {SCAN_PORT_START}-{SCAN_PORT_END} every {SCAN_INTERVAL}s)"
119
+ )
120
+ while True:
121
+ try:
122
+ current_scan_results = []
123
+ scan_tasks = []
124
+ ports_to_scan = range(SCAN_PORT_START, SCAN_PORT_END + 1)
125
+
126
+ # Create ONE aiohttp session for all checks within this scan cycle for efficiency
127
+ async with aiohttp.ClientSession() as session:
128
+ # Create health check tasks for all ports, passing the shared session
129
+ for port in ports_to_scan:
130
+ task = asyncio.create_task(
131
+ check_server_health(session, SCAN_TARGET_HOST, port)
132
+ )
133
+ scan_tasks.append((task, port))
134
+
135
+ # Wait for all health checks to complete
136
+ # return_exceptions=True prevents gather from stopping if one check fails
137
+ await asyncio.gather(
138
+ *(task for task, port in scan_tasks), return_exceptions=True
139
+ )
140
+
141
+ # Collect results from completed tasks
142
+ for task, port in scan_tasks:
143
+ try:
144
+ # Check if task finished, wasn't cancelled, and returned True
145
+ if (
146
+ task.done()
147
+ and not task.cancelled()
148
+ and task.result() is True
149
+ ):
150
+ current_scan_results.append((SCAN_TARGET_HOST, port))
151
+ except Exception as e:
152
+ logger.error(
153
+ f"Error retrieving health check result for port {port}: {e}"
154
+ )
155
+ # --- Update Shared State (Locked) ---
156
+ async with state_lock:
157
+ previous_servers = set(available_servers)
158
+ current_set = set(current_scan_results)
159
+
160
+ added = current_set - previous_servers
161
+ removed = previous_servers - current_set
162
+
163
+ if added:
164
+ logger.info(
165
+ f"Servers added (passed /health check): {sorted(list(added))}"
166
+ )
167
+ if removed:
168
+ logger.info(
169
+ f"Servers removed (failed /health check or stopped): {sorted(list(removed))}"
170
+ )
171
+ for server in removed:
172
+ if server in connection_counts:
173
+ del connection_counts[server]
174
+ logger.debug(
175
+ f"Removed connection count entry for unavailable server {server}"
176
+ )
177
+
178
+ available_servers = sorted(list(current_set))
179
+ for server in available_servers:
180
+ if server not in connection_counts:
181
+ connection_counts[server] = 0
182
+
183
+ logger.debug(f"Scan complete. Active servers: {available_servers}")
184
+
185
+ except asyncio.CancelledError:
186
+ logger.info("Server scan task cancelled.")
187
+ break
188
+ except Exception as e:
189
+ logger.error(f"Error in scan_and_update_servers loop: {e}")
190
+ await asyncio.sleep(SCAN_INTERVAL / 2) # Avoid tight loop on error
191
+
192
+ await asyncio.sleep(SCAN_INTERVAL)
193
+
194
+
195
+ # --- Core Load Balancer Logic (handle_client remains the same) ---
196
+ async def handle_client(client_reader, client_writer):
197
+ """Handles a single client connection."""
198
+ client_addr = client_writer.get_extra_info("peername")
199
+ logger.info(f"Accepted connection from {client_addr}")
200
+
201
+ backend_server = None
202
+ backend_reader = None
203
+ backend_writer = None
204
+ server_selected = False
205
+
206
+ try:
207
+ # --- Select Backend Server (Least Connections from Available) ---
208
+ selected_server = None
209
+ async with (
210
+ state_lock
211
+ ): # Lock to safely access available_servers and connection_counts
212
+ if not available_servers:
213
+ logger.warning(
214
+ f"No backend servers available (failed health checks?) for client {client_addr}. Closing connection."
215
+ )
216
+ async with safe_close_writer(client_writer):
217
+ pass
218
+ return
219
+
220
+ min_connections = float("inf")
221
+ least_used_available_servers = []
222
+ for (
223
+ server
224
+ ) in (
225
+ available_servers
226
+ ): # Iterate only over servers that passed health check
227
+ count = connection_counts.get(server, 0)
228
+ if count < min_connections:
229
+ min_connections = count
230
+ least_used_available_servers = [server]
231
+ elif count == min_connections:
232
+ least_used_available_servers.append(server)
233
+
234
+ if least_used_available_servers:
235
+ selected_server = random.choice(least_used_available_servers)
236
+ connection_counts[selected_server] += 1
237
+ backend_server = selected_server
238
+ server_selected = True
239
+ logger.info(
240
+ f"Routing {client_addr} to {backend_server} (Current connections: {connection_counts[backend_server]})"
241
+ )
242
+ else:
243
+ logger.error(
244
+ f"Logic error: No server chosen despite available servers list not being empty for {client_addr}."
245
+ )
246
+ async with safe_close_writer(client_writer):
247
+ pass
248
+ return
249
+
250
+ # --- Connect to Backend Server ---
251
+ if not backend_server:
252
+ logger.error(
253
+ f"No backend server selected for {client_addr} before connection attempt."
254
+ )
255
+ async with safe_close_writer(client_writer):
256
+ pass
257
+ server_selected = False
258
+ return
259
+ try:
260
+ logger.debug(
261
+ f"Attempting connection to backend {backend_server} for {client_addr}"
262
+ )
263
+ backend_reader, backend_writer = await asyncio.open_connection(
264
+ backend_server[0], backend_server[1]
265
+ )
266
+ logger.debug(
267
+ f"Successfully connected to backend {backend_server} for {client_addr}"
268
+ )
269
+
270
+ # Handle connection failure AFTER selection (server might go down between health check and selection)
271
+ except ConnectionRefusedError:
272
+ logger.error(
273
+ f"Connection refused by selected backend server {backend_server} for {client_addr}"
274
+ )
275
+ async with state_lock: # Decrement count under lock
276
+ if (
277
+ backend_server in connection_counts
278
+ and connection_counts[backend_server] > 0
279
+ ):
280
+ connection_counts[backend_server] -= 1
281
+ server_selected = False # Mark failure
282
+ async with safe_close_writer(client_writer):
283
+ pass
284
+ return
285
+ except Exception as e:
286
+ logger.error(
287
+ f"Failed to connect to selected backend server {backend_server} for {client_addr}: {e}"
288
+ )
289
+ async with state_lock: # Decrement count under lock
290
+ if (
291
+ backend_server in connection_counts
292
+ and connection_counts[backend_server] > 0
293
+ ):
294
+ connection_counts[backend_server] -= 1
295
+ server_selected = False # Mark failure
296
+ async with safe_close_writer(client_writer):
297
+ pass
298
+ return
299
+
300
+ # --- Relay Data Bidirectionally ---
301
+ async with safe_close_writer(backend_writer): # Ensure backend writer is closed
302
+ client_to_backend = asyncio.create_task(
303
+ relay_data(
304
+ client_reader, backend_writer, f"{client_addr} -> {backend_server}"
305
+ )
306
+ )
307
+ backend_to_client = asyncio.create_task(
308
+ relay_data(
309
+ backend_reader, client_writer, f"{backend_server} -> {client_addr}"
310
+ )
311
+ )
312
+ done, pending = await asyncio.wait(
313
+ [client_to_backend, backend_to_client],
314
+ return_when=asyncio.FIRST_COMPLETED,
315
+ )
316
+ for task in pending:
317
+ task.cancel()
318
+ for task in done:
319
+ with contextlib.suppress(asyncio.CancelledError):
320
+ if task.exception():
321
+ logger.warning(
322
+ f"Relay task finished with error: {task.exception()}"
323
+ )
324
+
325
+ except asyncio.CancelledError:
326
+ logger.info(f"Client handler for {client_addr} cancelled.")
327
+ except Exception as e:
328
+ logger.error(f"Error handling client {client_addr}: {e}")
329
+ finally:
330
+ logger.info(f"Closing connection for {client_addr}")
331
+ # Decrement connection count only if we successfully selected/incremented
332
+ if backend_server and server_selected:
333
+ async with state_lock:
334
+ if backend_server in connection_counts:
335
+ if connection_counts[backend_server] > 0:
336
+ connection_counts[backend_server] -= 1
337
+ logger.info(
338
+ f"Connection closed for {client_addr}. Backend {backend_server} connections: {connection_counts[backend_server]}"
339
+ )
340
+ else:
341
+ logger.warning(
342
+ f"Attempted to decrement count below zero for {backend_server} on close"
343
+ )
344
+ connection_counts[backend_server] = 0
345
+
346
+
347
+ # --- Status Reporting Task (print_status_periodically remains the same) ---
348
+ async def print_status_periodically():
349
+ """Periodically prints the connection status based on available servers."""
350
+ while True:
351
+ await asyncio.sleep(STATUS_PRINT_INTERVAL)
352
+ async with state_lock:
353
+ headers = ["Backend Server", "Host", "Port", "Active Connections", "Status"]
354
+ table_data = []
355
+ total_connections = 0
356
+ current_available = available_servers[:]
357
+ current_counts = connection_counts.copy()
358
+
359
+ if not current_available:
360
+ # clear terminal and print status
361
+ print("\033[H\033[J", end="") # Clear terminal
362
+ print("\n----- Load Balancer Status -----")
363
+ print("No backend servers currently available (failed /health check).")
364
+ print("------------------------------\n")
365
+ continue
366
+
367
+ for server in current_available:
368
+ host, port = server
369
+ count = current_counts.get(server, 0)
370
+ table_data.append([f"{host}:{port}", host, port, count, "Available"])
371
+ total_connections += count
372
+
373
+ table_data.sort(key=lambda row: (row[1], row[2]))
374
+
375
+ try:
376
+ table = tabulate(table_data, headers=headers, tablefmt="grid")
377
+ print("\n----- Load Balancer Status -----")
378
+ print(
379
+ f"Scanning Ports: {SCAN_PORT_START}-{SCAN_PORT_END} on {SCAN_TARGET_HOST} (using /health endpoint)"
380
+ )
381
+ print(
382
+ f"Scan Interval: {SCAN_INTERVAL}s | Health Check Timeout: {HEALTH_CHECK_TIMEOUT}s"
383
+ )
384
+ print(table)
385
+ print(
386
+ f"Total Active Connections (on available servers): {total_connections}"
387
+ )
388
+ print("------------------------------\n")
389
+ except Exception as e:
390
+ logger.error(f"Error printing status table: {e}")
391
+
392
+
393
+ # --- Main Execution (main remains the same) ---
394
+ async def main():
395
+ scan_task = asyncio.create_task(scan_and_update_servers())
396
+ status_task = asyncio.create_task(print_status_periodically())
397
+
398
+ server = await asyncio.start_server(
399
+ handle_client, LOAD_BALANCER_HOST, LOAD_BALANCER_PORT
400
+ )
401
+
402
+ addrs = ", ".join(str(sock.getsockname()) for sock in server.sockets)
403
+ logger.info(f"Load balancer serving on {addrs}")
404
+ logger.info(
405
+ f"Dynamically discovering servers via HTTP /health on {SCAN_TARGET_HOST}:{SCAN_PORT_START}-{SCAN_PORT_END}"
406
+ )
407
+
408
+ async with server:
409
+ try:
410
+ await server.serve_forever()
411
+ except asyncio.CancelledError:
412
+ logger.info("Load balancer server shutting down.")
413
+ finally:
414
+ logger.info("Cancelling background tasks...")
415
+ scan_task.cancel()
416
+ status_task.cancel()
417
+ try:
418
+ await asyncio.gather(scan_task, status_task, return_exceptions=True)
419
+ except asyncio.CancelledError:
420
+ pass
421
+ logger.info("Background tasks finished.")
422
+
423
+
424
+ def run_load_balancer():
425
+ # Make sure to install aiohttp: pip install aiohttp
426
+ try:
427
+ asyncio.run(main())
428
+ except KeyboardInterrupt:
429
+ logger.info("Shutdown requested by user.")
430
+ except Exception as e:
431
+ logger.critical(f"Critical error in main execution: {e}")
432
+
433
+
434
+ if __name__ == "__main__":
435
+ run_load_balancer()