kubetorch 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. kubetorch/__init__.py +59 -0
  2. kubetorch/cli.py +1939 -0
  3. kubetorch/cli_utils.py +967 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +269 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +159 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +140 -0
  30. kubetorch/resources/callables/module.py +1315 -0
  31. kubetorch/resources/callables/utils.py +203 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +253 -0
  34. kubetorch/resources/compute/compute.py +2414 -0
  35. kubetorch/resources/compute/decorators.py +137 -0
  36. kubetorch/resources/compute/utils.py +1026 -0
  37. kubetorch/resources/compute/websocket.py +135 -0
  38. kubetorch/resources/images/__init__.py +1 -0
  39. kubetorch/resources/images/image.py +412 -0
  40. kubetorch/resources/images/images.py +64 -0
  41. kubetorch/resources/secrets/__init__.py +2 -0
  42. kubetorch/resources/secrets/kubernetes_secrets_client.py +377 -0
  43. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  44. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  45. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  46. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  47. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  48. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  49. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  50. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  51. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  52. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  53. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  54. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  55. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  56. kubetorch/resources/secrets/provider_secrets/providers.py +92 -0
  57. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  58. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  59. kubetorch/resources/secrets/secret.py +224 -0
  60. kubetorch/resources/secrets/secret_factory.py +64 -0
  61. kubetorch/resources/secrets/utils.py +222 -0
  62. kubetorch/resources/volumes/__init__.py +0 -0
  63. kubetorch/resources/volumes/volume.py +340 -0
  64. kubetorch/servers/__init__.py +0 -0
  65. kubetorch/servers/http/__init__.py +0 -0
  66. kubetorch/servers/http/distributed_utils.py +2968 -0
  67. kubetorch/servers/http/http_client.py +802 -0
  68. kubetorch/servers/http/http_server.py +1622 -0
  69. kubetorch/servers/http/server_metrics.py +255 -0
  70. kubetorch/servers/http/utils.py +722 -0
  71. kubetorch/serving/__init__.py +0 -0
  72. kubetorch/serving/autoscaling.py +153 -0
  73. kubetorch/serving/base_service_manager.py +344 -0
  74. kubetorch/serving/constants.py +77 -0
  75. kubetorch/serving/deployment_service_manager.py +431 -0
  76. kubetorch/serving/knative_service_manager.py +487 -0
  77. kubetorch/serving/raycluster_service_manager.py +526 -0
  78. kubetorch/serving/service_manager.py +18 -0
  79. kubetorch/serving/templates/deployment_template.yaml +17 -0
  80. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  81. kubetorch/serving/templates/kt_setup_template.sh.j2 +91 -0
  82. kubetorch/serving/templates/pod_template.yaml +198 -0
  83. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  84. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  85. kubetorch/serving/templates/service_template.yaml +21 -0
  86. kubetorch/serving/templates/workerset_template.yaml +36 -0
  87. kubetorch/serving/utils.py +344 -0
  88. kubetorch/utils.py +263 -0
  89. kubetorch-0.2.5.dist-info/METADATA +75 -0
  90. kubetorch-0.2.5.dist-info/RECORD +92 -0
  91. kubetorch-0.2.5.dist-info/WHEEL +4 -0
  92. kubetorch-0.2.5.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,2968 @@
1
+ import copy
2
+ import multiprocessing
3
+ import os
4
+ import queue
5
+ import subprocess
6
+ import threading
7
+ import time
8
+ import uuid
9
+ from bdb import BdbQuit
10
+ from concurrent.futures import as_completed, ThreadPoolExecutor
11
+ from typing import Dict, Optional
12
+
13
+ import httpx
14
+ from starlette.responses import JSONResponse
15
+
16
+ from kubetorch.servers.http.http_server import (
17
+ load_callable,
18
+ logger,
19
+ package_exception,
20
+ patch_sys_path,
21
+ request_id_ctx_var,
22
+ run_callable_internal_sync,
23
+ )
24
+
25
+ from .utils import clear_debugging_sessions, is_running_in_kubernetes
26
+
27
+ # Try to import Monarch components at module level if available
28
+ # This helps avoid threading issues with Monarch's Rust bindings
29
+ try:
30
+ from monarch._src.actor.allocator import RemoteAllocator, StaticRemoteAllocInitializer
31
+
32
+ MONARCH_AVAILABLE = True
33
+ except ImportError:
34
+ MONARCH_AVAILABLE = False
35
+ RemoteAllocator = None
36
+ StaticRemoteAllocInitializer = None
37
+ except Exception:
38
+ # Catch any other exceptions during import
39
+ MONARCH_AVAILABLE = False
40
+ RemoteAllocator = None
41
+ StaticRemoteAllocInitializer = None
42
+
43
+
44
+ class RemoteWorkerPool:
45
+ """Manages async HTTP calls to remote workers in a separate process."""
46
+
47
+ def __init__(self, quorum_timeout=300):
48
+ self.quorum_timeout = quorum_timeout
49
+ self.request_queue = multiprocessing.Queue()
50
+ self.response_queue = multiprocessing.Queue()
51
+ self.process = None
52
+ self._running = False
53
+ # Thread-safe response routing for concurrent calls
54
+ self._response_events = {}
55
+ self._response_lock = threading.Lock()
56
+ self._router_thread = None
57
+
58
+ def start(self, max_workers=2000):
59
+ """Start the worker process and router thread."""
60
+ if self.process:
61
+ raise RuntimeError("WorkerPool already started")
62
+
63
+ self._running = True
64
+ # Pass necessary data as arguments to avoid pickling issues
65
+ self.process = multiprocessing.Process(
66
+ target=self._run_async_worker,
67
+ args=(
68
+ self.request_queue,
69
+ self.response_queue,
70
+ self.quorum_timeout,
71
+ max_workers,
72
+ ),
73
+ daemon=True,
74
+ )
75
+ self.process.start()
76
+
77
+ # Start router thread for handling responses
78
+ self._router_thread = threading.Thread(target=self._response_router, daemon=True)
79
+ self._router_thread.start()
80
+
81
+ logger.debug("Started RemoteWorkerPool process and router thread")
82
+
83
+ def stop(self):
84
+ """Stop the worker process and router thread."""
85
+ self._running = False
86
+
87
+ # Stop router thread
88
+ if self._router_thread:
89
+ self.response_queue.put({"request_id": "STOP_ROUTER"})
90
+ self._router_thread.join(timeout=1)
91
+
92
+ # Stop worker process
93
+ if self.process:
94
+ self.request_queue.put(("SHUTDOWN", None))
95
+ self.process.join(timeout=5)
96
+ if self.process and self.process.is_alive():
97
+ self.process.terminate()
98
+ self.process.join(timeout=1)
99
+ if self.process and self.process.is_alive():
100
+ self.process.kill()
101
+ self.process = None
102
+
103
+ # Clear response events
104
+ with self._response_lock:
105
+ self._response_events.clear()
106
+
107
+ logger.debug("Stopped RemoteWorkerPool process and router thread")
108
+
109
+ def call_workers(
110
+ self,
111
+ worker_ips,
112
+ cls_or_fn_name,
113
+ method_name,
114
+ params,
115
+ request_headers,
116
+ workers_arg="all",
117
+ ):
118
+ """Call remote workers and return responses."""
119
+ if not self.process or not self.process.is_alive():
120
+ raise RuntimeError("RemoteWorkerPool not running")
121
+
122
+ # Generate unique request ID
123
+ request_id = str(uuid.uuid4())
124
+
125
+ # Submit request
126
+ request_data = {
127
+ "request_id": request_id,
128
+ "worker_ips": worker_ips,
129
+ "cls_or_fn_name": cls_or_fn_name,
130
+ "method_name": method_name,
131
+ "params": params,
132
+ "request_headers": request_headers,
133
+ "workers_arg": workers_arg,
134
+ }
135
+
136
+ # Register event for this request
137
+ event = threading.Event()
138
+ with self._response_lock:
139
+ self._response_events[request_id] = (event, None)
140
+
141
+ logger.debug(f"RemoteWorkerPool: Submitting request {request_id} to queue for {len(worker_ips)} workers")
142
+ self.request_queue.put(("CALL", request_data))
143
+
144
+ # Wait for response with a timeout to prevent hanging forever
145
+ logger.debug(f"RemoteWorkerPool: Waiting for response for request {request_id} from {len(worker_ips)} workers")
146
+ # Wait indefinitely for the response (no timeout for long-running jobs)
147
+ event.wait()
148
+
149
+ # Get and cleanup response
150
+ with self._response_lock:
151
+ _, result = self._response_events.pop(request_id)
152
+
153
+ logger.debug(f"RemoteWorkerPool: Got response for request {request_id}, type: {type(result).__name__}")
154
+
155
+ if isinstance(result, Exception):
156
+ raise result
157
+ return result
158
+
159
+ def _response_router(self):
160
+ """Router thread that distributes responses to waiting threads."""
161
+ logger.debug("RemoteWorkerPool response router thread started")
162
+ while self._running:
163
+ try:
164
+ response = self.response_queue.get(timeout=0.1)
165
+ if response.get("request_id") == "STOP_ROUTER":
166
+ break
167
+
168
+ request_id = response.get("request_id")
169
+ logger.debug(f"Response router received response for request {request_id}")
170
+ with self._response_lock:
171
+ if request_id in self._response_events:
172
+ event, _ = self._response_events[request_id]
173
+ # Store the result (either results list or exception)
174
+ if "error" in response:
175
+ logger.debug(f"Response router: Setting error for request {request_id}")
176
+ self._response_events[request_id] = (
177
+ event,
178
+ response["error"],
179
+ )
180
+ else:
181
+ logger.debug(
182
+ f"Response router: Setting {len(response.get('results', []))} results for request {request_id}"
183
+ )
184
+ self._response_events[request_id] = (
185
+ event,
186
+ response["results"],
187
+ )
188
+ event.set()
189
+ logger.debug(f"Response router: Event set for request {request_id}")
190
+ else:
191
+ logger.warning(
192
+ f"Response router: No event found for request {request_id}, registered events: {list(self._response_events.keys())}"
193
+ )
194
+ except queue.Empty:
195
+ continue # Queue timeout, continue checking
196
+ except Exception as e:
197
+ logger.error(f"Error in response router: {e}")
198
+ continue
199
+
200
+ @staticmethod
201
+ def _run_async_worker(request_queue, response_queue, quorum_timeout, max_workers=2000):
202
+ """Worker process that handles async HTTP calls.
203
+
204
+ Architecture:
205
+ - Runs in a separate process with its own event loop
206
+ - Maintains a single shared httpx.AsyncClient for connection pooling
207
+ - Processes requests from queue concurrently (multiple requests in flight)
208
+ - Each request involves parallel async HTTP calls to multiple workers
209
+
210
+ Nested functions (share queue/timeout state):
211
+ - wait_for_worker_health: Health check with retries
212
+ - call_single_worker: Make HTTP call to one worker
213
+ - call_workers_async: Orchestrate health checks + calls for all workers
214
+ - process_requests: Main loop pulling from queue and dispatching tasks
215
+ """
216
+ import asyncio
217
+ import queue
218
+
219
+ import httpx
220
+
221
+ async def wait_for_worker_health(client, worker_ip, workers_arg, quorum_timeout):
222
+ """Wait for a worker to become healthy within timeout."""
223
+ port = os.environ["KT_SERVER_PORT"]
224
+ worker_url = f"http://{worker_ip}:{port}"
225
+
226
+ start_time = asyncio.get_event_loop().time()
227
+
228
+ # Keep trying until timeout
229
+ while (asyncio.get_event_loop().time() - start_time) < quorum_timeout:
230
+ try:
231
+ resp = await client.get(f"{worker_url}/health", timeout=10.0)
232
+ if resp.status_code == 200:
233
+ return (worker_ip, True) # Return IP and success
234
+ except (httpx.RequestError, httpx.TimeoutException):
235
+ pass
236
+
237
+ # No waiting for quorum if user just wants to call ready workers
238
+ if workers_arg == "ready":
239
+ return worker_ip, False
240
+
241
+ # Wait before retry (1 second, same as original)
242
+ await asyncio.sleep(1.0)
243
+
244
+ # Timeout reached
245
+ return worker_ip, False
246
+
247
+ async def call_single_worker(
248
+ client,
249
+ worker_ip,
250
+ cls_or_fn_name,
251
+ method_name,
252
+ params,
253
+ request_headers,
254
+ workers_arg,
255
+ ):
256
+ """Call a single worker (assumes health already checked)."""
257
+ port = os.environ["KT_SERVER_PORT"]
258
+ worker_url = f"http://{worker_ip}:{port}"
259
+
260
+ # Make the actual call
261
+ call_url = f"{worker_url}/{cls_or_fn_name}"
262
+ if method_name:
263
+ call_url += f"/{method_name}"
264
+ call_url += "?distributed_subcall=true"
265
+
266
+ logger.debug(f"Async worker: Making POST to {call_url}")
267
+
268
+ # Retry logic for transient failures
269
+ max_retries = 3
270
+ for attempt in range(max_retries):
271
+ try:
272
+ resp = await client.post(call_url, json=params, headers=request_headers)
273
+ result = resp.json()
274
+ break # Success, exit retry loop
275
+ except httpx.ReadError as e:
276
+ # Check if this is due to server shutdown (connection reset)
277
+ if "Connection reset" in str(e) or "Connection closed" in str(e):
278
+ logger.warning(f"Worker {worker_ip} appears to be shutting down: {e}")
279
+ raise # Don't retry on shutdown
280
+ if attempt < max_retries - 1:
281
+ wait_time = (attempt + 1) * 2 # Exponential backoff: 2s, 4s
282
+ logger.warning(
283
+ f"ReadError calling {worker_ip} (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
284
+ )
285
+ await asyncio.sleep(wait_time)
286
+ else:
287
+ logger.error(
288
+ f"ReadError calling {worker_ip} after {max_retries} attempts: {e}. Worker may be crashed/overloaded."
289
+ )
290
+ raise
291
+ except httpx.TimeoutException as e:
292
+ if attempt < max_retries - 1:
293
+ logger.warning(
294
+ f"Timeout calling {worker_ip} (attempt {attempt + 1}/{max_retries}): {e}. Retrying..."
295
+ )
296
+ await asyncio.sleep(1)
297
+ else:
298
+ logger.error(f"Timeout calling {worker_ip} after {max_retries} attempts: {e}")
299
+ raise
300
+ except Exception as e:
301
+ logger.error(f"Unexpected error calling {worker_ip}: {e}")
302
+ raise
303
+ logger.debug(
304
+ f"Async worker: Got response from {worker_ip}, type: {type(result).__name__}, "
305
+ f"length: {len(result) if hasattr(result, '__len__') else 'N/A'}"
306
+ )
307
+ # In tree topology, intermediate nodes return aggregated results from their subtree
308
+ # We should preserve the flat list structure
309
+ return result
310
+
311
+ async def call_workers_async(client, data):
312
+ """Make async calls to all workers using shared client."""
313
+ worker_ips = data["worker_ips"]
314
+ cls_or_fn_name = data["cls_or_fn_name"]
315
+ method_name = data["method_name"]
316
+ params = data["params"]
317
+ request_headers = data["request_headers"]
318
+ workers_arg = data["workers_arg"]
319
+
320
+ # With tree topology limiting fanout, we don't need to batch health checks
321
+ # Each node only calls its direct children in the tree
322
+ logger.info(f"Waiting for {len(worker_ips)} workers to become ready (timeout={quorum_timeout}s)")
323
+ health_tasks = [wait_for_worker_health(client, ip, workers_arg, quorum_timeout) for ip in worker_ips]
324
+ health_results = await asyncio.gather(*health_tasks)
325
+
326
+ # Process results
327
+ healthy_workers = []
328
+ unhealthy_workers = []
329
+ for worker_ip, is_healthy in health_results:
330
+ if is_healthy:
331
+ healthy_workers.append(worker_ip)
332
+ else:
333
+ unhealthy_workers.append(worker_ip)
334
+
335
+ if unhealthy_workers:
336
+ if workers_arg == "ready":
337
+ # For "ready" mode, just skip unhealthy workers
338
+ logger.info(f"Skipping {len(unhealthy_workers)} workers that didn't respond (ready mode)")
339
+ else:
340
+ # For normal mode, fail if any worker didn't become ready
341
+ logger.error(
342
+ f"{len(unhealthy_workers)} workers failed to become ready after {quorum_timeout}s: {unhealthy_workers[:5]}..."
343
+ )
344
+ raise TimeoutError(
345
+ f"{len(unhealthy_workers)} of {len(worker_ips)} workers did not become ready within {quorum_timeout} seconds. "
346
+ f"This may indicate the pods are still starting or there's a resource constraint. "
347
+ f"Consider increasing quorum_timeout in .distribute() call."
348
+ )
349
+
350
+ logger.info(f"All {len(healthy_workers)} workers are ready, making distributed calls")
351
+
352
+ # Now make the actual calls to ready workers
353
+ # Create tasks (not just coroutines)
354
+ tasks = []
355
+ for worker_ip in healthy_workers:
356
+ coro = call_single_worker(
357
+ client,
358
+ worker_ip,
359
+ cls_or_fn_name,
360
+ method_name,
361
+ params,
362
+ request_headers,
363
+ workers_arg,
364
+ )
365
+ # Create actual task from coroutine
366
+ task = asyncio.create_task(coro)
367
+ tasks.append(task)
368
+
369
+ # Use as_completed for fast failure propagation
370
+ responses = []
371
+ pending_tasks = set(tasks)
372
+
373
+ for future in asyncio.as_completed(tasks):
374
+ try:
375
+ result = await future
376
+ if result is not None: # Skip None results
377
+ responses.append(result)
378
+ # Remove completed task from pending set
379
+ pending_tasks.discard(future)
380
+ except Exception as e:
381
+ # Fast failure - immediately propagate the exception
382
+ # Cancel remaining tasks to avoid unnecessary work
383
+ for task in pending_tasks:
384
+ if not task.done():
385
+ task.cancel()
386
+ raise e
387
+
388
+ return responses
389
+
390
+ # Create and run event loop with shared AsyncClient
391
+ async def main():
392
+ # Set up signal handler for graceful shutdown
393
+ import signal
394
+
395
+ shutdown_event = asyncio.Event()
396
+
397
+ # Save existing handlers to chain to them
398
+ original_sigterm_handler = signal.getsignal(signal.SIGTERM)
399
+ original_sigint_handler = signal.getsignal(signal.SIGINT)
400
+
401
+ def signal_handler(signum, frame):
402
+ logger.info(f"Received signal {signum}, initiating graceful shutdown")
403
+ shutdown_event.set()
404
+
405
+ # Chain to original handler if it exists
406
+ if signum == signal.SIGTERM:
407
+ if original_sigterm_handler and original_sigterm_handler not in (
408
+ signal.SIG_DFL,
409
+ signal.SIG_IGN,
410
+ ):
411
+ original_sigterm_handler(signum, frame)
412
+ elif signum == signal.SIGINT:
413
+ if original_sigint_handler and original_sigint_handler not in (
414
+ signal.SIG_DFL,
415
+ signal.SIG_IGN,
416
+ ):
417
+ original_sigint_handler(signum, frame)
418
+
419
+ signal.signal(signal.SIGTERM, signal_handler)
420
+ signal.signal(signal.SIGINT, signal_handler)
421
+
422
+ # Create a single AsyncClient to be shared across all requests
423
+ # Set limits based on max expected workers (passed from parent)
424
+ # We need exactly max_workers connections since health checks and work calls
425
+ # should reuse the same connection per worker (HTTP keep-alive)
426
+ # Add small buffer for edge cases
427
+ buffer = max(100, int(max_workers * 0.1)) # 10% buffer, min 100
428
+ max_conn = max_workers + buffer
429
+
430
+ limits = httpx.Limits(
431
+ max_keepalive_connections=max_conn, # Keep all connections alive
432
+ max_connections=max_conn, # One connection per worker + buffer
433
+ keepalive_expiry=300.0, # Keep connections alive for 5 minutes
434
+ )
435
+ timeout = httpx.Timeout(
436
+ connect=10.0,
437
+ read=None, # No read timeout for long-running jobs
438
+ write=10.0,
439
+ pool=60.0, # Time to wait for a connection from pool
440
+ )
441
+
442
+ logger.debug(
443
+ f"AsyncClient configured with max_connections={max_conn} "
444
+ f"(workers={max_workers} + buffer={buffer}) with 5min keepalive"
445
+ )
446
+
447
+ # Note: http2=True would enable HTTP/2 if server supports it
448
+ async with httpx.AsyncClient(timeout=timeout, limits=limits) as client:
449
+ logger.debug("Async worker started with shared httpx.AsyncClient")
450
+
451
+ active_tasks = set()
452
+
453
+ async def process_call(request_data):
454
+ """Process a CALL request asynchronously."""
455
+ request_id = request_data["request_id"]
456
+ worker_ips = request_data["worker_ips"]
457
+ logger.debug(
458
+ f"Async worker: Processing request {request_id} with {len(worker_ips)} workers: {worker_ips}"
459
+ )
460
+ try:
461
+ results = await call_workers_async(client, request_data)
462
+ logger.debug(
463
+ f"Async worker: Successfully got {len(results)} results for request {request_id}, sending response"
464
+ )
465
+ response_queue.put({"request_id": request_id, "results": results})
466
+ except Exception as e:
467
+ logger.error(f"Error processing request {request_id}: {e}")
468
+ response_queue.put({"request_id": request_id, "error": e})
469
+
470
+ # Main request processing loop
471
+ while not shutdown_event.is_set():
472
+ try:
473
+ # Check queue without blocking event loop
474
+ cmd, request_data = await asyncio.to_thread(request_queue.get, block=True, timeout=0.1)
475
+
476
+ if cmd == "SHUTDOWN" or shutdown_event.is_set():
477
+ # Cancel all active tasks immediately for quick cleanup
478
+ if active_tasks:
479
+ logger.debug(f"Cancelling {len(active_tasks)} active tasks for shutdown")
480
+ for task in active_tasks:
481
+ task.cancel()
482
+ break
483
+
484
+ elif cmd == "CALL":
485
+ # Create a task to handle this request concurrently
486
+ task = asyncio.create_task(process_call(request_data))
487
+ active_tasks.add(task)
488
+
489
+ # Clean up completed tasks periodically
490
+ active_tasks = {t for t in active_tasks if not t.done()}
491
+
492
+ except queue.Empty:
493
+ # Clean up completed tasks while waiting
494
+ active_tasks = {t for t in active_tasks if not t.done()}
495
+ except Exception as e:
496
+ logger.error(f"Error in async worker loop: {e}")
497
+
498
+ loop = asyncio.new_event_loop()
499
+ asyncio.set_event_loop(loop)
500
+ try:
501
+ loop.run_until_complete(main())
502
+ finally:
503
+ loop.close()
504
+
505
+
506
+ class DistributedProcessPool:
507
+ """Unified pool managing distributed processes with single router thread."""
508
+
509
+ def __init__(self, process_class, num_processes, max_threads_per_proc=10, **process_kwargs):
510
+ self.process_class = process_class
511
+ self.num_processes = num_processes
512
+ self.max_threads_per_proc = max_threads_per_proc
513
+ self.process_kwargs = process_kwargs # Additional kwargs to pass to process constructor
514
+
515
+ # Processes and queues
516
+ self.processes = []
517
+ self.request_queues = [] # One request queue per process
518
+ self.response_queue = multiprocessing.Queue() # Single shared response queue
519
+
520
+ # Response routing with single router thread
521
+ self._router_thread = None
522
+ self._response_events = {} # Maps request_id to (threading.Event, response)
523
+ self._response_lock = threading.Lock()
524
+ self._running = False
525
+
526
+ def start(self):
527
+ """Start all processes in the pool and the single router thread."""
528
+ if self.processes:
529
+ raise RuntimeError("Pool already started")
530
+
531
+ # Create and start all processes in parallel
532
+ with ThreadPoolExecutor(max_workers=self.num_processes) as executor:
533
+ futures = []
534
+ for i in range(self.num_processes):
535
+ future = executor.submit(self._create_and_start_process, i)
536
+ futures.append(future)
537
+
538
+ # Wait for all processes to start
539
+ for future in futures:
540
+ future.result()
541
+
542
+ # Start single response router thread for entire pool
543
+ self._running = True
544
+ self._router_thread = threading.Thread(target=self._response_router, daemon=True, name="PoolResponseRouter")
545
+ self._router_thread.start()
546
+ logger.debug(f"Started {self.num_processes} processes with single router thread")
547
+
548
+ def _create_and_start_process(self, local_rank):
549
+ """Helper to create and start a single process."""
550
+ request_queue = multiprocessing.Queue()
551
+ self.request_queues.append(request_queue)
552
+
553
+ process = self.process_class(
554
+ local_rank=local_rank,
555
+ request_queue=request_queue,
556
+ response_queue=self.response_queue, # Shared response queue
557
+ max_threads=self.max_threads_per_proc,
558
+ **self.process_kwargs, # Pass additional framework-specific settings
559
+ )
560
+ process.start()
561
+ self.processes.append(process)
562
+
563
+ def stop(self):
564
+ """Stop all processes and the router thread."""
565
+ self._running = False
566
+
567
+ # Send shutdown signal to all processes (use put_nowait to avoid blocking)
568
+ for q in self.request_queues:
569
+ try:
570
+ q.put_nowait("SHUTDOWN")
571
+ except Exception:
572
+ pass
573
+
574
+ # Stop router thread
575
+ try:
576
+ self.response_queue.put_nowait("STOP_ROUTER")
577
+ except Exception:
578
+ pass
579
+
580
+ # Wait briefly for router to stop
581
+ if self._router_thread:
582
+ self._router_thread.join(timeout=0.5)
583
+
584
+ # Terminate all processes immediately without waiting
585
+ for process in self.processes:
586
+ if process.is_alive():
587
+ process.terminate()
588
+
589
+ # Give processes a brief chance to terminate gracefully (reduced timeout)
590
+ for process in self.processes:
591
+ process.join(timeout=0.5)
592
+
593
+ # Force kill any remaining processes
594
+ for process in self.processes:
595
+ if process.is_alive():
596
+ logger.warning(f"Force killing process {process.pid}")
597
+ process.kill()
598
+ process.join(timeout=0.1) # Brief wait to confirm kill
599
+
600
+ # Clear all queues
601
+ self._clear_queues()
602
+
603
+ # Reset state
604
+ self.processes.clear()
605
+ self.request_queues.clear()
606
+ with self._response_lock:
607
+ self._response_events.clear()
608
+
609
+ def call(
610
+ self,
611
+ idx,
612
+ method_name,
613
+ params,
614
+ deployed_as_of,
615
+ request_id,
616
+ distributed_env_vars,
617
+ debug_port,
618
+ serialization,
619
+ ):
620
+ """Call a specific process by index."""
621
+ if idx >= len(self.processes):
622
+ raise ValueError(f"Process index {idx} out of range (have {len(self.processes)} processes)")
623
+
624
+ request_unique_id = str(uuid.uuid4())
625
+
626
+ # Register this request for response routing
627
+ event = threading.Event()
628
+ with self._response_lock:
629
+ self._response_events[request_unique_id] = (event, None)
630
+
631
+ try:
632
+ # Send request to specific process
633
+ self.request_queues[idx].put(
634
+ {
635
+ "request_unique_id": request_unique_id,
636
+ "method_name": method_name,
637
+ "params": params,
638
+ "deployed_as_of": deployed_as_of,
639
+ "request_id": request_id,
640
+ "distributed_env_vars": distributed_env_vars,
641
+ "debug_port": debug_port,
642
+ "serialization": serialization,
643
+ "process_idx": idx, # Include process index for debugging
644
+ }
645
+ )
646
+
647
+ # Wait for response
648
+ event.wait()
649
+
650
+ # Get and return the response
651
+ with self._response_lock:
652
+ _, result = self._response_events.pop(request_unique_id)
653
+
654
+ return result
655
+
656
+ except Exception:
657
+ # Clean up on error
658
+ with self._response_lock:
659
+ self._response_events.pop(request_unique_id, None)
660
+ raise
661
+
662
+ def call_all(
663
+ self,
664
+ method_name,
665
+ params_list,
666
+ deployed_as_of,
667
+ request_id,
668
+ distributed_env_vars_list,
669
+ debug_ports,
670
+ serialization,
671
+ ):
672
+ """Call all processes in parallel and return results."""
673
+ if len(params_list) != self.num_processes:
674
+ raise ValueError(f"Expected {self.num_processes} param sets, got {len(params_list)}")
675
+
676
+ with ThreadPoolExecutor(max_workers=self.num_processes) as executor:
677
+ futures = []
678
+ for idx in range(self.num_processes):
679
+ future = executor.submit(
680
+ self.call,
681
+ idx=idx,
682
+ method_name=method_name,
683
+ params=params_list[idx],
684
+ deployed_as_of=deployed_as_of,
685
+ request_id=request_id,
686
+ distributed_env_vars=distributed_env_vars_list[idx],
687
+ debug_port=debug_ports[idx] if debug_ports else None,
688
+ serialization=serialization,
689
+ )
690
+ futures.append(future)
691
+
692
+ results = []
693
+ for future in futures:
694
+ results.append(future.result())
695
+
696
+ return results
697
+
698
+ def _response_router(self):
699
+ """Single router thread handling responses from all processes."""
700
+ while self._running:
701
+ try:
702
+ response = self.response_queue.get(timeout=1)
703
+ if response == "STOP_ROUTER":
704
+ break
705
+
706
+ request_id = response.get("request_unique_id")
707
+ with self._response_lock:
708
+ if request_id in self._response_events:
709
+ event, _ = self._response_events[request_id]
710
+ self._response_events[request_id] = (event, response["result"])
711
+ event.set()
712
+ else:
713
+ logger.warning(f"Received response for unknown request: {request_id}")
714
+
715
+ except Exception as e:
716
+ if "Empty" not in str(e.__class__.__name__):
717
+ logger.debug(f"Response router error: {e}")
718
+ continue
719
+
720
+ def _clear_queues(self):
721
+ """Clear all pending items from queues."""
722
+ for q in self.request_queues:
723
+ try:
724
+ while not q.empty():
725
+ q.get_nowait()
726
+ except Exception:
727
+ pass
728
+
729
+ try:
730
+ while not self.response_queue.empty():
731
+ self.response_queue.get_nowait()
732
+ except Exception:
733
+ pass
734
+
735
+ def __len__(self):
736
+ return len(self.processes)
737
+
738
+ def __getitem__(self, idx):
739
+ return self.processes[idx]
740
+
741
+
742
+ class DistributedSupervisor:
743
+ def __init__(self, quorum_workers=None, quorum_timeout=300, monitor_members=True):
744
+ """
745
+ Base class for distributed supervisors. This class should be subclassed for specific distributed
746
+ environments like PyTorch or Ray.
747
+
748
+ Args:
749
+ config: Optional configuration object for the distributed environment.
750
+ """
751
+ # Set after creation by the factory function
752
+ self.quorum_workers = quorum_workers
753
+ self.quorum_timeout = quorum_timeout
754
+ self.monitor_members = monitor_members
755
+
756
+ self.config_hash = None
757
+
758
+ # DNS monitoring state
759
+ self._dns_monitor_thread = None
760
+ self._dns_monitor_running = False
761
+ self._current_workers = set()
762
+ self._workers_lock = threading.Lock()
763
+ self._membership_changes = queue.Queue()
764
+ self._change_subscribers = []
765
+ self._last_dns_check = 0
766
+ self._dns_check_interval = 5 # seconds
767
+
768
+ def pod_ips(self):
769
+ """Get pod IPs from DNS, waiting for quorum if specified.
770
+
771
+ Will wait up to quorum_timeout seconds for quorum_workers to appear in DNS.
772
+ If quorum_workers is not specified, returns immediately after first DNS query.
773
+ """
774
+ # Primarily for testing
775
+ if not is_running_in_kubernetes():
776
+ return os.environ["LOCAL_IPS"].split(",")
777
+
778
+ # Use DNS-based service discovery instead of Kubernetes API
779
+ # Check if pre-computed DNS name is available (should point to headless service for distributed)
780
+ service_dns = os.environ.get("KT_SERVICE_DNS")
781
+
782
+ if not service_dns:
783
+ # Fall back to computing DNS name from service and namespace
784
+ service_name = os.environ.get("KT_SERVICE_NAME")
785
+ namespace = os.environ.get("POD_NAMESPACE")
786
+
787
+ if not service_name:
788
+ raise RuntimeError("KT_SERVICE environment variable not found")
789
+ if not namespace:
790
+ raise RuntimeError("POD_NAMESPACE environment variable not found")
791
+
792
+ # Kubernetes headless service DNS name for distributed pod discovery
793
+ # Format: <service-name>-headless.<namespace>.svc.cluster.local
794
+ service_dns = f"{service_name}-headless.{namespace}.svc.cluster.local"
795
+
796
+ import socket
797
+ import time
798
+
799
+ start_time = time.time()
800
+ max_wait = self.quorum_timeout if self.quorum_timeout else 0
801
+ expected_workers = self.quorum_workers
802
+
803
+ pod_ips = []
804
+ last_count = 0
805
+
806
+ while True:
807
+ try:
808
+ # DNS lookup returns all pod IPs for the headless service
809
+ # getaddrinfo returns list of (family, type, proto, canonname, sockaddr)
810
+ addr_info = socket.getaddrinfo(service_dns, None, socket.AF_INET)
811
+
812
+ # Extract unique IP addresses from the results
813
+ pod_ips = sorted(list(set([addr[4][0] for addr in addr_info])))
814
+
815
+ if not pod_ips:
816
+ logger.debug(f"No pod IPs found for service {service_dns}")
817
+ else:
818
+ logger.debug(f"Found {len(pod_ips)} pod IPs via DNS for {service_dns}: {pod_ips}")
819
+
820
+ except socket.gaierror as e:
821
+ logger.debug(f"DNS lookup failed for {service_dns}: {e}")
822
+ pod_ips = []
823
+
824
+ # Check if we should wait for more workers
825
+ elapsed = time.time() - start_time
826
+
827
+ # If we have the expected count, we're done
828
+ if expected_workers and len(pod_ips) >= expected_workers:
829
+ logger.info(f"Found {len(pod_ips)}/{expected_workers} workers after {elapsed:.1f}s")
830
+ return pod_ips
831
+
832
+ # If we don't have expected count or timeout is reached, decide what to do
833
+ if elapsed >= max_wait:
834
+ if expected_workers:
835
+ logger.warning(f"Only found {len(pod_ips)}/{expected_workers} workers after {elapsed:.1f}s timeout")
836
+ else:
837
+ logger.info(f"Found {len(pod_ips)} workers after {elapsed:.1f}s")
838
+ return pod_ips
839
+
840
+ # Log progress if count changed
841
+ if len(pod_ips) != last_count:
842
+ if expected_workers:
843
+ logger.info(f"{len(pod_ips)}/{expected_workers} workers found, waiting for quorum...")
844
+ else:
845
+ logger.debug(f"{len(pod_ips)} workers found, no quorum set")
846
+ last_count = len(pod_ips)
847
+
848
+ # Wait before retrying
849
+ time.sleep(2)
850
+
851
+ def _get_pod_ips_fast(self):
852
+ """Get pod IPs from DNS without waiting for quorum - for monitoring only."""
853
+ # Primarily for testing
854
+ if not is_running_in_kubernetes():
855
+ return os.environ["LOCAL_IPS"].split(",")
856
+
857
+ # Use DNS-based service discovery
858
+ service_dns = os.environ.get("KT_SERVICE_DNS")
859
+
860
+ if not service_dns:
861
+ service_name = os.environ.get("KT_SERVICE")
862
+ namespace = os.environ.get("POD_NAMESPACE")
863
+
864
+ if not service_name or not namespace:
865
+ return []
866
+
867
+ service_dns = f"{service_name}-headless.{namespace}.svc.cluster.local"
868
+
869
+ import socket
870
+
871
+ try:
872
+ # Single DNS lookup, no retries, no waiting
873
+ addr_info = socket.getaddrinfo(service_dns, None, socket.AF_INET)
874
+ # Extract unique IP addresses
875
+ pod_ips = sorted(list(set([addr[4][0] for addr in addr_info])))
876
+ return pod_ips
877
+ except socket.gaierror:
878
+ # DNS lookup failed, return current known workers
879
+ with self._workers_lock:
880
+ return list(self._current_workers)
881
+
882
+ def start_dns_monitoring(self):
883
+ """Start DNS monitoring if not already running.
884
+ Should be called by coordinator nodes only."""
885
+ # Skip if monitoring is disabled (e.g., for Ray)
886
+ if not self.monitor_members:
887
+ logger.debug("DNS monitoring disabled for this supervisor")
888
+ return
889
+
890
+ with self._workers_lock:
891
+ if self._dns_monitor_thread and self._dns_monitor_thread.is_alive():
892
+ return # Already running
893
+
894
+ # Initialize with current workers
895
+ self._current_workers = set(self.pod_ips())
896
+ logger.debug(f"Starting DNS monitor with {len(self._current_workers)} workers")
897
+
898
+ self._dns_monitor_running = True
899
+ self._dns_monitor_thread = threading.Thread(
900
+ target=self._monitor_worker_membership, daemon=True, name="DNSMonitor"
901
+ )
902
+ self._dns_monitor_thread.start()
903
+
904
+ def stop_dns_monitoring(self):
905
+ """Stop DNS monitoring thread."""
906
+ self._dns_monitor_running = False
907
+ if self._dns_monitor_thread:
908
+ self._dns_monitor_thread.join(timeout=2)
909
+ self._dns_monitor_thread = None
910
+
911
+ def _monitor_worker_membership(self):
912
+ """Monitor DNS for worker membership changes."""
913
+ check_interval = 3 # Start with 3 second checks (faster initial detection)
914
+
915
+ while self._dns_monitor_running:
916
+ try:
917
+ # Note that we start this after the delay, because we're doing a DNS check at
918
+ # the start of call_distributed anyway. This thread is only for the recurring checks
919
+ # as the call runs.
920
+ time.sleep(check_interval)
921
+
922
+ # Query DNS for current workers - use a faster version
923
+ current_ips = set(self._get_pod_ips_fast())
924
+
925
+ with self._workers_lock:
926
+ if current_ips != self._current_workers:
927
+ added = current_ips - self._current_workers
928
+ removed = self._current_workers - current_ips
929
+
930
+ change = {
931
+ "timestamp": time.time(),
932
+ "added": added,
933
+ "removed": removed,
934
+ "previous": self._current_workers.copy(),
935
+ "current": current_ips.copy(),
936
+ }
937
+
938
+ if removed:
939
+ logger.error(f"Workers REMOVED from cluster: {removed}")
940
+ if added:
941
+ logger.warning(f"Workers ADDED to cluster: {added}")
942
+
943
+ # Queue change and notify subscribers
944
+ self._membership_changes.put(change)
945
+ for event in self._change_subscribers:
946
+ event.set()
947
+
948
+ self._current_workers = current_ips
949
+
950
+ time.sleep(check_interval)
951
+
952
+ except Exception as e:
953
+ logger.error(f"DNS monitor error: {e}")
954
+ time.sleep(3)
955
+
956
+ def subscribe_to_membership_changes(self):
957
+ """Subscribe to worker membership changes.
958
+ Returns an event that will be set when changes occur."""
959
+ event = threading.Event()
960
+ with self._workers_lock:
961
+ self._change_subscribers.append(event)
962
+ return event
963
+
964
+ def unsubscribe_from_membership_changes(self, event):
965
+ """Unsubscribe from worker membership changes."""
966
+ with self._workers_lock:
967
+ if event in self._change_subscribers:
968
+ self._change_subscribers.remove(event)
969
+
970
+ def check_for_membership_changes(self, force_dns_check=False):
971
+ """Check for membership changes and raise exception if any occurred.
972
+
973
+ Args:
974
+ force_dns_check: If True, immediately query DNS to check for changes
975
+ instead of relying on the monitoring thread
976
+ """
977
+ # Skip if monitoring is disabled (e.g., for Ray)
978
+ if not self.monitor_members:
979
+ return
980
+ # Force an immediate DNS check if requested
981
+ if force_dns_check:
982
+ # Use fast DNS query for immediate check
983
+ current_ips = set(self._get_pod_ips_fast())
984
+ with self._workers_lock:
985
+ if current_ips != self._current_workers:
986
+ added = current_ips - self._current_workers
987
+ removed = self._current_workers - current_ips
988
+
989
+ # Import here to avoid circular dependency
990
+ from kubetorch.servers.http.utils import WorkerMembershipChanged
991
+
992
+ # Update current workers
993
+ self._current_workers = current_ips
994
+
995
+ # Log the change
996
+ if removed:
997
+ logger.error(f"Workers REMOVED from cluster (forced check): {removed}")
998
+ if added:
999
+ logger.warning(f"Workers ADDED to cluster (forced check): {added}")
1000
+
1001
+ raise WorkerMembershipChanged(
1002
+ added_ips=added,
1003
+ removed_ips=removed,
1004
+ previous_ips=self._current_workers.copy(),
1005
+ current_ips=current_ips,
1006
+ )
1007
+
1008
+ # Check queued changes from monitoring thread
1009
+ try:
1010
+ change = self._membership_changes.get_nowait()
1011
+
1012
+ # Import here to avoid circular dependency
1013
+ from kubetorch.servers.http.utils import WorkerMembershipChanged
1014
+
1015
+ raise WorkerMembershipChanged(
1016
+ added_ips=change["added"],
1017
+ removed_ips=change["removed"],
1018
+ previous_ips=change["previous"],
1019
+ current_ips=change["current"],
1020
+ )
1021
+ except queue.Empty:
1022
+ pass # No changes
1023
+
1024
+ def setup(self, deployed_as_of: Optional[str] = None):
1025
+ # This method should be overridden by subclasses to set up the distributed environment
1026
+ raise NotImplementedError("setup() must be implemented by subclasses")
1027
+
1028
+ def cleanup(self):
1029
+ """Base cleanup - stop DNS monitoring. Subclasses should call super().cleanup()"""
1030
+ self.stop_dns_monitoring()
1031
+ # Subclasses should override and call super().cleanup() to add their own cleanup
1032
+
1033
+ def intercept_call(self):
1034
+ # This method should be overridden by subclasses to indicate whether to intercept calls
1035
+ raise NotImplementedError("intercept_call() must be implemented by subclasses")
1036
+
1037
+ def call_distributed(
1038
+ self,
1039
+ request,
1040
+ cls_or_fn_name: str,
1041
+ method_name: Optional[str] = None,
1042
+ params: Optional[Dict] = None,
1043
+ distributed_subcall: bool = False,
1044
+ debug_port: int = False,
1045
+ deployed_as_of: Optional[str] = None,
1046
+ ):
1047
+ # if intercept_call is True, this method should be overridden by subclasses to handle distributing and/or
1048
+ # supervising the distributed execution
1049
+ raise NotImplementedError("call_distributed() must be implemented by subclasses")
1050
+
1051
+
1052
+ class DistributedProcess(multiprocessing.Process):
1053
+ """Base class for distributed processes that run callables in subprocesses."""
1054
+
1055
+ def __init__(self, local_rank, request_queue, response_queue, max_threads=4, **kwargs):
1056
+ super().__init__()
1057
+ # We don't need the cache miss / reload here because these processes are destroyed and recreated
1058
+ # with each .to call.
1059
+ os.environ["LOCAL_RANK"] = str(local_rank)
1060
+ self._request_queue = request_queue
1061
+ self._response_queue = response_queue
1062
+ self._max_threads = max_threads
1063
+ self._executor = None
1064
+ # Store any additional framework-specific settings
1065
+ self._settings = kwargs
1066
+
1067
+ def proc_cleanup(self):
1068
+ """Override this method to provide framework-specific cleanup."""
1069
+ logger.info("Cleaning up debugging sessions...")
1070
+ clear_debugging_sessions()
1071
+ logger.info("Debugging sessions cleaned up.")
1072
+
1073
+ # Cleanup thread pool
1074
+ if self._executor:
1075
+ self._executor.shutdown(wait=False)
1076
+ self._executor = None
1077
+
1078
+ @classmethod
1079
+ def get_distributed_env_vars(cls, worker_ips, node_rank, local_rank, num_local_procs, **settings):
1080
+ """Get framework-specific distributed environment variables.
1081
+
1082
+ Args:
1083
+ worker_ips: List of all worker IPs
1084
+ node_rank: Rank of this node (0-indexed)
1085
+ local_rank: Local rank on this node (0-indexed)
1086
+ num_local_procs: Number of processes on this node
1087
+ **settings: Additional framework-specific settings (e.g., port)
1088
+
1089
+ Returns:
1090
+ Dict of environment variables to set
1091
+ """
1092
+ # Base implementation - no special env vars needed
1093
+ return {
1094
+ "WORLD_SIZE": str(len(worker_ips) * num_local_procs),
1095
+ "RANK": str(node_rank * num_local_procs + local_rank),
1096
+ "LOCAL_RANK": str(local_rank),
1097
+ "NODE_RANK": str(node_rank),
1098
+ "POD_IPS": ",".join(worker_ips),
1099
+ }
1100
+
1101
+ @classmethod
1102
+ def get_auto_num_processes(cls):
1103
+ """Auto-detect the number of processes to use."""
1104
+ return 1
1105
+
1106
+ def handle_request(self, request):
1107
+ """Handle a single request in a thread."""
1108
+ try:
1109
+ request_unique_id = request["request_unique_id"]
1110
+ method_name = request["method_name"]
1111
+ params = request["params"]
1112
+ deployed_as_of = request["deployed_as_of"]
1113
+ request_id = request["request_id"]
1114
+ distributed_env_vars = request["distributed_env_vars"]
1115
+ debug_port = request["debug_port"]
1116
+ serialization = request["serialization"]
1117
+
1118
+ # Set the request ID in the context for this thread
1119
+ token = request_id_ctx_var.set(request_id)
1120
+
1121
+ # Set the environment variables for this thread (note: os.environ is process-wide, might need thread-local storage)
1122
+ # For distributed PyTorch calls, these should already be set at process level
1123
+ for key, value in distributed_env_vars.items():
1124
+ os.environ[key] = value
1125
+
1126
+ try:
1127
+ # Load callable if not already loaded or if deployed_as_of changed
1128
+ callable_obj = load_callable(
1129
+ deployed_as_of=deployed_as_of,
1130
+ distributed_subprocess=True,
1131
+ reload_cleanup_fn=self.proc_cleanup,
1132
+ )
1133
+
1134
+ result = run_callable_internal_sync(
1135
+ callable_obj=callable_obj,
1136
+ cls_or_fn_name=os.environ["KT_CLS_OR_FN_NAME"],
1137
+ method_name=method_name,
1138
+ params=params,
1139
+ serialization=serialization,
1140
+ debug_port=debug_port,
1141
+ )
1142
+
1143
+ # Reset the request ID after the call is complete
1144
+ request_id_ctx_var.reset(token)
1145
+
1146
+ # Send response back with the unique ID
1147
+ self._response_queue.put({"request_unique_id": request_unique_id, "result": result})
1148
+
1149
+ except Exception as e:
1150
+ # Reset the request ID even if there was an error
1151
+ request_id_ctx_var.reset(token)
1152
+
1153
+ # Package the exception
1154
+ try:
1155
+ packaged_exception = package_exception(e)
1156
+ except Exception as f:
1157
+ packaged_exception = f
1158
+
1159
+ self._response_queue.put(
1160
+ {
1161
+ "request_unique_id": request_unique_id,
1162
+ "result": packaged_exception,
1163
+ }
1164
+ )
1165
+
1166
+ except Exception as e:
1167
+ # Last resort error handling
1168
+ logger.error(f"Fatal error handling request: {e}")
1169
+ self._response_queue.put(
1170
+ {
1171
+ "request_unique_id": request.get("request_unique_id", "unknown"),
1172
+ "result": Exception(f"Fatal error in thread: {e}"),
1173
+ }
1174
+ )
1175
+
1176
+ def run(self):
1177
+ """Main process loop with thread pool for concurrent request handling."""
1178
+ # Create thread pool for handling requests
1179
+ self._executor = ThreadPoolExecutor(max_workers=self._max_threads)
1180
+
1181
+ try:
1182
+ while True:
1183
+ try:
1184
+ # Block waiting for next request
1185
+ request = self._request_queue.get(timeout=1)
1186
+
1187
+ # Special sentinel value to signal shutdown
1188
+ if request == "SHUTDOWN":
1189
+ break
1190
+
1191
+ # Submit request to thread pool for concurrent handling
1192
+ # Check executor exists in case we're shutting down
1193
+ if self._executor:
1194
+ self._executor.submit(self.handle_request, request)
1195
+ else:
1196
+ logger.warning("Executor is None, skipping request (likely shutting down)")
1197
+
1198
+ except Exception as e:
1199
+ # Timeout is normal, continue loop
1200
+ if "Empty" not in str(e.__class__.__name__):
1201
+ logger.error(f"Error getting request from queue: {e}")
1202
+ continue
1203
+
1204
+ except (KeyboardInterrupt, BdbQuit):
1205
+ logger.info("Process interrupted, shutting down...")
1206
+
1207
+ finally:
1208
+ # Cleanup
1209
+ logger.info("Received shutdown signal, cleaning up distributed environment...")
1210
+ self.proc_cleanup()
1211
+ logger.info("Exiting gracefully.")
1212
+
1213
+
1214
+ class PyTorchProcess(DistributedProcess):
1215
+ """PyTorch-specific distributed process."""
1216
+
1217
+ def proc_cleanup(self):
1218
+ import torch.distributed as dist
1219
+
1220
+ try:
1221
+ dist.destroy_process_group()
1222
+ logger.info("Destroyed PyTorch process group.")
1223
+ except Exception:
1224
+ logger.info("Failed to destroy PyTorch process group, it may not have been initialized: {e}")
1225
+ pass
1226
+ # Call parent cleanup for debugging sessions
1227
+ super().proc_cleanup()
1228
+
1229
+ @classmethod
1230
+ def get_distributed_env_vars(cls, worker_ips, node_rank, local_rank, num_local_procs, **settings):
1231
+ """Get PyTorch-specific distributed environment variables."""
1232
+ port = settings.get("port") or 12345
1233
+ env_vars = super().get_distributed_env_vars(worker_ips, node_rank, local_rank, num_local_procs, **settings)
1234
+ env_vars.update(
1235
+ {
1236
+ "MASTER_ADDR": worker_ips[0],
1237
+ "MASTER_PORT": str(port),
1238
+ }
1239
+ )
1240
+ return env_vars
1241
+
1242
+ @classmethod
1243
+ def get_auto_num_processes(cls):
1244
+ """Auto-detect based on GPU availability for PyTorch."""
1245
+ try:
1246
+ import torch
1247
+
1248
+ if torch.cuda.is_available():
1249
+ return torch.cuda.device_count()
1250
+ except ImportError:
1251
+ pass
1252
+ return 1 # Could use os.cpu_count() for CPU-only training
1253
+
1254
+
1255
+ class RayProcess(DistributedProcess):
1256
+ """Ray-specific distributed process."""
1257
+
1258
+ def proc_cleanup(self):
1259
+ try:
1260
+ import ray
1261
+
1262
+ if ray.is_initialized():
1263
+ ray.shutdown()
1264
+ logger.info("Ray shutdown completed.")
1265
+ except ImportError:
1266
+ logger.info("Ray not available for cleanup")
1267
+ except Exception as e:
1268
+ logger.info(f"Failed to shutdown Ray: {e}")
1269
+ # Call parent cleanup for debugging sessions
1270
+ super().proc_cleanup()
1271
+
1272
+
1273
+ class SPMDDistributedSupervisor(DistributedSupervisor):
1274
+ """Base class for SPMD (Single Program Multiple Data) distributed supervisors.
1275
+
1276
+ This class provides common functionality for frameworks that follow the SPMD pattern
1277
+ where the same program runs on multiple processes with different data partitions.
1278
+ """
1279
+
1280
+ def __init__(
1281
+ self,
1282
+ process_class=None,
1283
+ num_proc=None,
1284
+ port=None,
1285
+ restart_procs=True,
1286
+ max_threads_per_proc=10,
1287
+ quorum_timeout=300,
1288
+ quorum_workers=None,
1289
+ monitor_members=True,
1290
+ tree_fanout=50,
1291
+ tree_minimum=100,
1292
+ **process_kwargs,
1293
+ ):
1294
+ super().__init__(
1295
+ quorum_workers=quorum_workers,
1296
+ quorum_timeout=quorum_timeout,
1297
+ monitor_members=monitor_members,
1298
+ )
1299
+ self.process_class = process_class or DistributedProcess
1300
+ self.num_proc = num_proc or "auto"
1301
+ self.port = port
1302
+ self.restart_procs = restart_procs
1303
+ self.max_threads_per_proc = max_threads_per_proc
1304
+ self.process_pool = None
1305
+ self.remote_worker_pool = None # Pool for async HTTP calls to remote workers
1306
+ self.process_kwargs = process_kwargs # Additional settings to pass to process class
1307
+ self.tree_fanout = tree_fanout
1308
+ self.tree_minimum = tree_minimum
1309
+
1310
+ def setup(self, deployed_as_of: Optional[str] = None):
1311
+ # Set multiprocessing to spawn if not already
1312
+ if multiprocessing.get_start_method() != "spawn":
1313
+ multiprocessing.set_start_method("spawn", force=True)
1314
+
1315
+ # Get number of processes
1316
+ if self.num_proc == "auto":
1317
+ num_proc = self.process_class.get_auto_num_processes()
1318
+ else:
1319
+ num_proc = self.num_proc
1320
+
1321
+ if self.restart_procs:
1322
+ logger.debug("restart_procs is True, restarting distributed processes")
1323
+ self.cleanup()
1324
+
1325
+ # If the number of processes has changed, we need to clean up the old ones and recreate them
1326
+ if self.process_pool is None or len(self.process_pool) != num_proc:
1327
+ if self.process_pool:
1328
+ logger.debug(
1329
+ f"Number of processes changed from {len(self.process_pool)} to {num_proc}, restarting processes."
1330
+ )
1331
+ self.cleanup()
1332
+
1333
+ logger.debug("Setting up distributed environment")
1334
+ self.process_pool = DistributedProcessPool(
1335
+ process_class=self.process_class,
1336
+ num_processes=num_proc,
1337
+ max_threads_per_proc=self.max_threads_per_proc,
1338
+ **self.process_kwargs, # Pass any additional settings
1339
+ )
1340
+
1341
+ # Start all processes (now handled internally by the pool)
1342
+ self.process_pool.start()
1343
+
1344
+ self.remote_worker_pool = RemoteWorkerPool(quorum_timeout=self.quorum_timeout)
1345
+ self.remote_worker_pool.start()
1346
+ logger.debug("Finished setting up distributed processes")
1347
+
1348
+ def cleanup(self):
1349
+ # Cleanup the processes
1350
+ logger.debug(f"Cleaning up {self.__class__.__name__} distributed processes")
1351
+
1352
+ # Stop DNS monitoring first
1353
+ super().cleanup()
1354
+
1355
+ if self.process_pool:
1356
+ self.process_pool.stop()
1357
+ self.process_pool = None
1358
+
1359
+ if self.remote_worker_pool:
1360
+ self.remote_worker_pool.stop()
1361
+ self.remote_worker_pool = None
1362
+
1363
+ logger.debug(f"Finished cleaning up {self.__class__.__name__} distributed processes")
1364
+
1365
+ @staticmethod
1366
+ def intercept_call():
1367
+ return True
1368
+
1369
+ def get_tree_children(self, sorted_ips: list, my_ip: str, fanout: int = 100):
1370
+ """Calculate children nodes in a self-organizing tree based on IP indexing.
1371
+
1372
+ Args:
1373
+ sorted_ips: List of all worker IPs sorted deterministically
1374
+ my_ip: This node's IP address
1375
+ fanout: Maximum number of children per node (default 100)
1376
+
1377
+ Returns:
1378
+ List of IP addresses that are children of this node
1379
+ """
1380
+ try:
1381
+ my_index = sorted_ips.index(my_ip)
1382
+ except ValueError:
1383
+ # If not found in list, this node has no children
1384
+ return []
1385
+
1386
+ # Calculate the range of children indices
1387
+ # In a tree with fanout F, node at index i has children at indices:
1388
+ # [i*F + 1, i*F + 2, ..., i*F + F]
1389
+ first_child_idx = my_index * fanout + 1
1390
+ last_child_idx = min(first_child_idx + fanout, len(sorted_ips))
1391
+
1392
+ if first_child_idx >= len(sorted_ips):
1393
+ # No children (leaf node)
1394
+ return []
1395
+
1396
+ children = sorted_ips[first_child_idx:last_child_idx]
1397
+ if len(children) > 0:
1398
+ logger.debug(
1399
+ f"Tree topology: Node {my_ip} (index {my_index}) has {len(children)} children "
1400
+ f"(indices {first_child_idx}-{last_child_idx-1})"
1401
+ )
1402
+ return children
1403
+
1404
+ def call_distributed(
1405
+ self,
1406
+ request,
1407
+ cls_or_fn_name: str,
1408
+ method_name: Optional[str] = None,
1409
+ params: Optional[Dict] = None,
1410
+ distributed_subcall: bool = False,
1411
+ debug_port: int = False,
1412
+ deployed_as_of: Optional[str] = None,
1413
+ ):
1414
+ # Get the request ID from the headers
1415
+ request_id = request.headers.get("X-Request-ID", "-")
1416
+ serialization = request.headers.get("X-Serialization", "json")
1417
+ params = params or {}
1418
+
1419
+ # If deployed_as_of is None and we're the coordinator, generate a consistent timestamp
1420
+ # to use across all workers to prevent reload inconsistencies
1421
+ if not distributed_subcall and deployed_as_of is None:
1422
+ from datetime import datetime, timezone
1423
+
1424
+ deployed_as_of = datetime.now(timezone.utc).isoformat()
1425
+
1426
+ # Get all the pods in the service, and use the first one as the master.
1427
+ # Set the env vars based on whether this is a master or worker
1428
+ logger.debug(f"Configuring distributed environment, distributed_subcall={distributed_subcall}")
1429
+ this_pod_ip = os.environ["POD_IP"]
1430
+ logger.debug(f"This pod IP: {this_pod_ip}")
1431
+
1432
+ # Start DNS monitoring for coordinator nodes
1433
+ change_event = None
1434
+ if not distributed_subcall:
1435
+ # First wait for quorum before starting monitoring
1436
+ worker_ips = self.pod_ips()
1437
+ # sort the worker IPs to ensure generally consistent tree ordering
1438
+ # (avoids thrashing connections due to reordering)
1439
+ worker_ips.sort()
1440
+ # For tree topology, coordinator is always root (index 0)
1441
+ # Move coordinator to front if not already there
1442
+ if this_pod_ip in worker_ips:
1443
+ worker_ips.remove(this_pod_ip)
1444
+ worker_ips.insert(0, this_pod_ip) # Move coordinator to root of tree
1445
+ logger.debug(f"Acting as COORDINATOR - discovered worker IPs: {worker_ips}")
1446
+ logger.debug(f"Pod IPs: {worker_ips}")
1447
+
1448
+ # Check if this call uses flexible worker selection (workers="ready")
1449
+ # If so, don't start DNS monitoring as the worker set is expected to be flexible
1450
+ workers_arg = params.get("workers") if params else None
1451
+ should_monitor = workers_arg not in ["ready", "any"]
1452
+
1453
+ if should_monitor:
1454
+ # Now that we have quorum, start DNS monitoring
1455
+ # Start monitoring (idempotent - won't start if already running)
1456
+ self.start_dns_monitoring()
1457
+
1458
+ # Subscribe to membership changes
1459
+ change_event = self.subscribe_to_membership_changes()
1460
+
1461
+ # Check for any pending changes after starting monitor
1462
+ self.check_for_membership_changes(force_dns_check=True)
1463
+ else:
1464
+ logger.debug("Skipping DNS monitoring for workers='ready' call")
1465
+
1466
+ # Update distributed env vars to use the tree-ordered IPs
1467
+ distributed_env_vars = {
1468
+ "POD_IPS": ",".join(worker_ips),
1469
+ }
1470
+ else:
1471
+ logger.debug(f"Acting as WORKER (distributed_subcall=True) at {this_pod_ip}")
1472
+ logger.debug(f"Worker received params keys: {list(params.keys()) if params else 'None'}")
1473
+ distributed_env_vars = params.pop("distributed_env_vars", None) if params else None
1474
+ logger.debug(f"Using distributed_env_vars: {distributed_env_vars}")
1475
+ if not distributed_env_vars:
1476
+ logger.error(f"No distributed_env_vars found in params: {params}")
1477
+ raise RuntimeError("distributed_env_vars must be provided for distributed subcalls")
1478
+ worker_ips = distributed_env_vars["POD_IPS"].split(",")
1479
+
1480
+ # Don't debug for subcalls, we only want to debug one process
1481
+ debug_port = None
1482
+
1483
+ # Decide topology based on cluster size
1484
+ subcall_ips = []
1485
+ num_workers = len(worker_ips)
1486
+ tree_mode = num_workers >= self.tree_minimum
1487
+ if tree_mode:
1488
+ # Use tree topology for large clusters
1489
+ if distributed_subcall:
1490
+ logger.debug(
1491
+ f"Using tree topology for {num_workers} workers (> {self.tree_minimum} threshold) "
1492
+ f"with fanout {self.tree_fanout}"
1493
+ )
1494
+
1495
+ # Calculate direct children in the tree
1496
+ subcall_ips = self.get_tree_children(
1497
+ sorted_ips=worker_ips,
1498
+ my_ip=this_pod_ip,
1499
+ fanout=self.tree_fanout, # Each node can have up to 50 children
1500
+ )
1501
+ logger.debug(f"Tree node {this_pod_ip} will call {len(subcall_ips)} direct children")
1502
+ elif not distributed_subcall:
1503
+ # Use worker ip list as is for coordiantor node in flat topology
1504
+ # Leave subcall_ips = [] for workers in flat topology
1505
+ subcall_ips = copy.deepcopy(worker_ips)
1506
+ if this_pod_ip in subcall_ips:
1507
+ subcall_ips.remove(this_pod_ip)
1508
+ logger.debug(f"Removed self ({this_pod_ip}) from subcall list, will call: {subcall_ips}")
1509
+ else:
1510
+ # This can happen with headless services where POD_IP might not exactly match DNS results
1511
+ # Try to match by partial IP or hostname
1512
+ logger.warning(
1513
+ f"This pod IP {this_pod_ip} not found in DNS-discovered IPs {worker_ips}. "
1514
+ f"This may indicate DNS propagation delay or hostname/IP mismatch."
1515
+ )
1516
+ # Still proceed as coordinator but will call all discovered pods
1517
+ logger.debug(f"Will call all discovered workers: {subcall_ips}")
1518
+
1519
+ # Always pass the distributed environment variables to workers
1520
+ params["distributed_env_vars"] = distributed_env_vars
1521
+
1522
+ # "workers" is passed through as a regular kwarg, not a special extracted one like pdb or serialization
1523
+ # because it only applies to distributed calls, not regular ones.
1524
+ call_local_procs = True
1525
+ workers_arg = params.get("workers", None)
1526
+ if workers_arg:
1527
+ logger.debug(f"Filtering workers by argument: {workers_arg}")
1528
+ if isinstance(workers_arg, list) and workers_arg:
1529
+ # Build a set of IPs to include based on the list items
1530
+ target_ips = set()
1531
+
1532
+ for item in workers_arg:
1533
+ if isinstance(item, str) and "." in item:
1534
+ # It's an IP address
1535
+ if item not in worker_ips:
1536
+ raise ValueError(f"Worker IP '{item}' not found in available workers: {worker_ips}")
1537
+ target_ips.add(item)
1538
+ elif isinstance(item, int) or (isinstance(item, str) and item.isdigit()):
1539
+ # It's an index
1540
+ idx = int(item) if isinstance(item, str) else item
1541
+ if idx < 0 or idx >= len(worker_ips):
1542
+ raise ValueError(f"Worker index {idx} out of range. Valid range: 0-{len(worker_ips)-1}")
1543
+ target_ips.add(worker_ips[idx])
1544
+ else:
1545
+ raise ValueError(
1546
+ f"Invalid worker specification: {item}. Must be an IP address, "
1547
+ f"integer index, or numeric string."
1548
+ )
1549
+
1550
+ # Filter subcall_ips to only those in the target set
1551
+ subcall_ips = [ip for ip in subcall_ips if ip in target_ips]
1552
+
1553
+ # Check if current pod should participate in local processing
1554
+ if this_pod_ip not in target_ips:
1555
+ call_local_procs = False
1556
+
1557
+ elif workers_arg == "any":
1558
+ # Only call one worker (this one)
1559
+ subcall_ips = []
1560
+ elif workers_arg == "ready":
1561
+ # Filter below in call_worker to only those that respond to /health
1562
+ pass
1563
+ elif isinstance(workers_arg, str) and workers_arg:
1564
+ # Filter the subcall_ips to only those matching the workers_arg string
1565
+ subcall_ips = [ip for ip in subcall_ips if workers_arg in ip]
1566
+ logger.debug(f"Subcall IPs after filtering: {subcall_ips}")
1567
+
1568
+ # "restart_procs" is passed through as a regular kwarg, not a special extracted one like pdb or serialization
1569
+ # because it only applies to distributed calls, not regular ones.
1570
+ if params.get("restart_procs", False):
1571
+ logger.info("restart_procs parameter is True, restarting processes")
1572
+ self.cleanup()
1573
+ self.setup(deployed_as_of)
1574
+
1575
+ try:
1576
+ node_rank = worker_ips.index(this_pod_ip)
1577
+ except ValueError:
1578
+ # This pod IP not found in DNS results - may be external service IP
1579
+ # Fall back to using POD_IP directly and assume node_rank based on position
1580
+ logger.warning(f"Pod IP {this_pod_ip} not found in DNS results {worker_ips}. Using fallback logic.")
1581
+ # For now, assume we're the first worker if not found
1582
+ node_rank = 0
1583
+
1584
+ # Call the workers using RemoteWorkerPool for async operations
1585
+ def call_worker(worker_ip):
1586
+ # Keep this function for backward compatibility but it won't be used
1587
+ # when RemoteWorkerPool is available
1588
+ with httpx.Client(timeout=None) as client:
1589
+ port = os.environ["KT_SERVER_PORT"]
1590
+ worker_url = f"http://{worker_ip}:{port}"
1591
+ # First check that the worker is alive, replicas don't finish setup at exactly the same moment
1592
+ # Use quorum_timeout to control how long to wait for workers
1593
+ for i in range(int(self.quorum_timeout)):
1594
+ try:
1595
+ resp = client.get(f"{worker_url}/health")
1596
+ if resp.status_code == 200:
1597
+ break
1598
+ except httpx.RequestError:
1599
+ if workers_arg == "ready":
1600
+ logger.debug(f"Worker {worker_ip} not ready, skipping as per 'ready' workers argument")
1601
+ return None
1602
+ time.sleep(1)
1603
+ else:
1604
+ # Timeout reached without successful health check
1605
+ logger.warning(f"Worker {worker_ip} failed to respond after {self.quorum_timeout}s timeout")
1606
+ if workers_arg != "ready":
1607
+ raise TimeoutError(
1608
+ f"Worker {worker_ip} did not become ready within {self.quorum_timeout} seconds. "
1609
+ "This may indicate the pod is still starting or there's a resource constraint. "
1610
+ "Consider increasing quorum_timeout in .distribute() call."
1611
+ )
1612
+
1613
+ call_url = (
1614
+ f"{worker_url}/{cls_or_fn_name}/{method_name}?distributed_subcall=true"
1615
+ if method_name is not None
1616
+ else f"{worker_url}/{cls_or_fn_name}?distributed_subcall=true"
1617
+ )
1618
+
1619
+ # Clean headers to avoid potential Content-Length issues
1620
+ clean_headers = {}
1621
+ if request.headers:
1622
+ for key, value in request.headers.items():
1623
+ # Skip headers that could interfere with httpx's automatic handling
1624
+ if key.lower() not in [
1625
+ "content-length",
1626
+ "transfer-encoding",
1627
+ "connection",
1628
+ ]:
1629
+ clean_headers[key] = value
1630
+
1631
+ try:
1632
+ logger.debug(f"Making distributed call to {worker_url}")
1633
+ resp = client.post(
1634
+ url=call_url,
1635
+ json=params,
1636
+ headers=clean_headers, # Includes deployed_as_of and request_id
1637
+ )
1638
+ return resp
1639
+ except (httpx.RequestError, httpx.HTTPError) as e:
1640
+ logger.error(f"Failed to call worker {worker_url}: {e}")
1641
+ raise
1642
+
1643
+ # Prepare per-process parameters
1644
+ num_procs = len(self.process_pool)
1645
+ params_list = [params] * num_procs
1646
+ distributed_env_vars_list = []
1647
+ debug_ports = []
1648
+
1649
+ for idx in range(num_procs):
1650
+ # Get framework-specific env vars from the process class
1651
+ env_vars = self.process_class.get_distributed_env_vars(
1652
+ worker_ips=worker_ips
1653
+ if "worker_ips" in locals()
1654
+ else distributed_env_vars.get("POD_IPS", "").split(","),
1655
+ node_rank=node_rank,
1656
+ local_rank=idx,
1657
+ num_local_procs=num_procs,
1658
+ port=self.port,
1659
+ )
1660
+ # Add any base env vars
1661
+ env_vars.update(distributed_env_vars)
1662
+ distributed_env_vars_list.append(env_vars)
1663
+
1664
+ # Only debug one process and if debug_port is set
1665
+ debug = debug_port and idx == num_procs - 1
1666
+ if debug:
1667
+ time.sleep(0.25)
1668
+ debug_ports.append(debug_port if debug else None)
1669
+
1670
+ # Execute distributed calls in parallel with local processes
1671
+ worker_responses = []
1672
+ worker_exception = None
1673
+ local_exception = None
1674
+
1675
+ # Start both remote and local calls in parallel
1676
+ executor = ThreadPoolExecutor(max_workers=2)
1677
+
1678
+ # Submit remote worker calls if needed
1679
+ worker_future = None
1680
+ if subcall_ips:
1681
+ logger.debug(f"Have {len(subcall_ips)} remote workers to call")
1682
+ if not self.remote_worker_pool:
1683
+ raise RuntimeError("RemoteWorkerPool not initialized. This is required for distributed execution.")
1684
+ logger.debug(f"Using existing RemoteWorkerPool to call {len(subcall_ips)} workers")
1685
+
1686
+ def call_remote_workers():
1687
+ nonlocal worker_exception
1688
+ try:
1689
+ # Prepare headers for remote workers
1690
+ clean_headers = {}
1691
+ if request.headers:
1692
+ for key, value in request.headers.items():
1693
+ if key.lower() not in [
1694
+ "content-length",
1695
+ "transfer-encoding",
1696
+ "connection",
1697
+ ]:
1698
+ clean_headers[key] = value
1699
+ # Always include deployed_as_of in headers for consistency
1700
+ if deployed_as_of:
1701
+ clean_headers["X-Deployed-As-Of"] = deployed_as_of
1702
+
1703
+ # Call remote workers asynchronously through the pool
1704
+ logger.debug(f"Calling {len(subcall_ips)} remote workers via RemoteWorkerPool: {subcall_ips}")
1705
+ results = self.remote_worker_pool.call_workers(
1706
+ worker_ips=subcall_ips,
1707
+ cls_or_fn_name=cls_or_fn_name,
1708
+ method_name=method_name,
1709
+ params=params,
1710
+ request_headers=clean_headers,
1711
+ workers_arg=workers_arg,
1712
+ )
1713
+ logger.warning(
1714
+ f"RemoteWorkerPool returned {len(results) if results else 0} results from {len(subcall_ips)} workers"
1715
+ )
1716
+ return results
1717
+ except Exception as e:
1718
+ # Check if this is a connection error - might indicate worker removal
1719
+ if any(
1720
+ err_type in str(e)
1721
+ for err_type in [
1722
+ "ReadError",
1723
+ "TimeoutException",
1724
+ "RequestError",
1725
+ "HTTPError",
1726
+ "ConnectionError",
1727
+ "Connection reset",
1728
+ "Connection closed",
1729
+ ]
1730
+ ):
1731
+ logger.debug(f"Connection error detected: {e}, checking for membership changes")
1732
+ # Force DNS check to see if workers were removed
1733
+ self.check_for_membership_changes(force_dns_check=True)
1734
+ worker_exception = e
1735
+ raise
1736
+
1737
+ worker_future = executor.submit(call_remote_workers)
1738
+ logger.debug(f"Submitted worker_future for {len(subcall_ips)} remote workers")
1739
+
1740
+ else:
1741
+ logger.debug(f"No remote workers to call (subcall_ips is empty or None: {subcall_ips})")
1742
+
1743
+ # Check if we need to initialize RemoteWorkerPool for tree topology workers
1744
+ if subcall_ips:
1745
+ if not self.remote_worker_pool:
1746
+ # Initialize RemoteWorkerPool if not already done (needed for tree topology workers)
1747
+ logger.warning(
1748
+ f"INITIALIZING RemoteWorkerPool for tree worker at {this_pod_ip} to call {len(subcall_ips)} children"
1749
+ )
1750
+ self.remote_worker_pool = RemoteWorkerPool(quorum_timeout=self.quorum_timeout)
1751
+ self.remote_worker_pool.start(
1752
+ max_workers=min(len(subcall_ips) + 50, 200)
1753
+ ) # Size for expected children plus buffer
1754
+ logger.warning(f"RemoteWorkerPool initialized successfully for {this_pod_ip}")
1755
+ elif (
1756
+ not hasattr(self.remote_worker_pool, "process")
1757
+ or not self.remote_worker_pool.process
1758
+ or not self.remote_worker_pool.process.is_alive()
1759
+ ):
1760
+ # Pool exists but not started/alive
1761
+ logger.warning(f"RemoteWorkerPool exists but not running for {this_pod_ip}, starting it now")
1762
+ self.remote_worker_pool.start(max_workers=min(len(subcall_ips) + 50, 200))
1763
+
1764
+ if subcall_ips and not self.remote_worker_pool:
1765
+ # RemoteWorkerPool should always be initialized at this point
1766
+ raise RuntimeError(
1767
+ f"RemoteWorkerPool not available for worker at {this_pod_ip}. "
1768
+ "This is required for distributed execution with subcall_ips."
1769
+ )
1770
+
1771
+ # Submit local process calls
1772
+ def call_local_processes():
1773
+ logger.debug(f"Processing {num_procs} local process responses")
1774
+ return self.process_pool.call_all(
1775
+ method_name=method_name,
1776
+ params_list=params_list,
1777
+ deployed_as_of=deployed_as_of,
1778
+ request_id=request_id,
1779
+ distributed_env_vars_list=distributed_env_vars_list,
1780
+ debug_ports=debug_ports,
1781
+ serialization=serialization,
1782
+ )
1783
+
1784
+ # We may not be calling the locally processes if the user specified workers and didn't include this node
1785
+ if call_local_procs:
1786
+ local_future = executor.submit(call_local_processes)
1787
+ else:
1788
+ local_future = None
1789
+
1790
+ # Wait for both to complete with fast failure propagation
1791
+ try:
1792
+ # Use as_completed to get results as they arrive
1793
+ futures = [f for f in [worker_future, local_future] if f is not None]
1794
+ local_responses = []
1795
+
1796
+ # If we only have local processes (no remote workers), handle that case
1797
+ if not futures:
1798
+ logger.error("No futures to wait for - this shouldn't happen")
1799
+ raise RuntimeError("No distributed work to execute")
1800
+
1801
+ logger.debug(
1802
+ f"Waiting for {len(futures)} futures to complete (worker_future={worker_future is not None}, local_future={local_future is not None})"
1803
+ )
1804
+
1805
+ # Process futures with periodic membership checks
1806
+ from concurrent.futures import FIRST_COMPLETED, wait
1807
+
1808
+ pending_futures = set(futures)
1809
+ while pending_futures:
1810
+ # Check for membership changes even while waiting
1811
+ if change_event and change_event.is_set():
1812
+ logger.debug("Membership change detected, checking...")
1813
+ try:
1814
+ self.check_for_membership_changes()
1815
+ except Exception as e:
1816
+ # Cancel all pending futures immediately
1817
+ logger.error(f"Membership change detected, cancelling futures: {e}")
1818
+ for f in pending_futures:
1819
+ if not f.done():
1820
+ f.cancel()
1821
+ raise e
1822
+ finally:
1823
+ change_event.clear() # Reset for next change
1824
+
1825
+ # Wait for next future with a short timeout to allow membership checks
1826
+ done, pending_futures = wait(pending_futures, timeout=1.0, return_when=FIRST_COMPLETED)
1827
+
1828
+ for future in done:
1829
+ logger.debug(
1830
+ f"Future completed: is_worker={worker_future and future == worker_future}, is_local={local_future and future == local_future}"
1831
+ )
1832
+ try:
1833
+ if worker_future and future == worker_future:
1834
+ logger.debug("Getting results from remote workers future")
1835
+ results = future.result()
1836
+ logger.debug(
1837
+ f"Remote worker future returned: {type(results).__name__} with {len(results) if hasattr(results, '__len__') else 'N/A'} items"
1838
+ )
1839
+ # Process results - they're already JSON-decoded
1840
+ logger.debug(f"Processing {len(results)} results from RemoteWorkerPool")
1841
+ for i, result in enumerate(results):
1842
+ logger.debug(
1843
+ f"Result {i}: type={type(result).__name__}, "
1844
+ f"length={len(result) if hasattr(result, '__len__') else 'N/A'}"
1845
+ )
1846
+ if isinstance(result, dict) and "error_type" in result:
1847
+ # Fast failure - return error immediately
1848
+ executor.shutdown(wait=False)
1849
+ return JSONResponse(status_code=500, content=result)
1850
+ # Results from RemoteWorkerPool are already lists (aggregated from subtree in tree topology)
1851
+ if isinstance(result, list):
1852
+ worker_responses.extend(result)
1853
+ else:
1854
+ worker_responses.append(result)
1855
+ logger.debug(f"Got {len(worker_responses)} total responses from remote workers")
1856
+ else: # local_future
1857
+ logger.debug("Getting results from local processes future")
1858
+ local_responses = future.result()
1859
+ logger.debug(f"Got {len(local_responses)} responses from local processes")
1860
+ # Check for errors in local responses
1861
+ for response in local_responses:
1862
+ if isinstance(response, JSONResponse):
1863
+ # Fast failure - return error immediately
1864
+ executor.shutdown(wait=False)
1865
+ return response
1866
+ except Exception as e:
1867
+ # Fast failure - propagate exception immediately
1868
+ executor.shutdown(wait=False)
1869
+ logger.error(f"Error in distributed execution: {e}")
1870
+ raise
1871
+
1872
+ finally:
1873
+ # Unsubscribe from membership changes
1874
+ if change_event:
1875
+ self.unsubscribe_from_membership_changes(change_event)
1876
+
1877
+ logger.debug("Shutting down executor")
1878
+ executor.shutdown(wait=False)
1879
+
1880
+ total = len(local_responses) + len(worker_responses)
1881
+ if tree_mode and total > 10: # Only log for tree mode with significant results
1882
+ logger.debug(
1883
+ f"TREE RESULT AGGREGATION at {this_pod_ip}: {len(local_responses)} local + {len(worker_responses)} remote = {total} total"
1884
+ )
1885
+ else:
1886
+ logger.debug(
1887
+ f"Combining {len(local_responses)} local + {len(worker_responses)} remote = {total} total responses"
1888
+ )
1889
+ logger.debug(f"Distributed_subcall={distributed_subcall}, tree topology={tree_mode}")
1890
+ # Log sample of what we're returning for debugging
1891
+ if worker_responses:
1892
+ logger.debug(f"Sample worker response type: {type(worker_responses[0]).__name__}")
1893
+ responses = local_responses + worker_responses
1894
+ for response in responses:
1895
+ # If the response is a JSONResponse, we need to check if it contains an exception,
1896
+ # and "raise" it if so - essentially just returning it immediately rather than the full result list.
1897
+ if isinstance(response, JSONResponse):
1898
+ return response
1899
+ # This is primarily to handle exceptions while packaging an exception, which will cause the server to hang.
1900
+ if isinstance(response, Exception):
1901
+ raise response
1902
+ logger.debug(f"Returning {len(responses)} responses from execute_call")
1903
+ return responses
1904
+
1905
+
1906
+ class JaxProcess(DistributedProcess):
1907
+ """JAX-specific distributed process."""
1908
+
1909
+ @classmethod
1910
+ def get_distributed_env_vars(cls, worker_ips, node_rank, local_rank, num_local_procs, **settings):
1911
+ """Get JAX-specific distributed environment variables.
1912
+
1913
+ JAX uses a coordinator address and process ID for distributed setup.
1914
+ """
1915
+ port = settings.get("port") or 1234 # JAX default coordinator port
1916
+ env_vars = super().get_distributed_env_vars(worker_ips, node_rank, local_rank, num_local_procs, **settings)
1917
+
1918
+ # JAX distributed environment variables
1919
+ env_vars.update(
1920
+ {
1921
+ # Coordinator is the first worker
1922
+ "JAX_COORDINATOR_ADDRESS": f"{worker_ips[0]}:{port}",
1923
+ # Process ID is global rank
1924
+ "JAX_PROCESS_ID": str(node_rank * num_local_procs + local_rank),
1925
+ # Total number of processes
1926
+ "JAX_NUM_PROCESSES": str(len(worker_ips) * num_local_procs),
1927
+ # Local device IDs (for GPU/TPU)
1928
+ "JAX_LOCAL_DEVICE_IDS": str(local_rank),
1929
+ }
1930
+ )
1931
+ return env_vars
1932
+
1933
+ @classmethod
1934
+ def get_auto_num_processes(cls):
1935
+ """Auto-detect based on available accelerators for JAX."""
1936
+ try:
1937
+ import jax
1938
+
1939
+ # JAX can use TPUs, GPUs, or CPUs
1940
+ devices = jax.devices()
1941
+ return len(devices)
1942
+ except Exception:
1943
+ return 1
1944
+
1945
+ # JAX doesn't have a global process group to destroy like PyTorch
1946
+ # Cleanup is mostly handled automatically
1947
+ # def proc_cleanup(self):
1948
+
1949
+
1950
+ class TensorflowProcess(DistributedProcess):
1951
+ """TensorFlow-specific distributed process."""
1952
+
1953
+ def proc_cleanup(self):
1954
+ """TensorFlow-specific cleanup."""
1955
+ try:
1956
+ import tensorflow as tf
1957
+
1958
+ # Clear the default graph and reset the session
1959
+ tf.keras.backend.clear_session()
1960
+ logger.info("TensorFlow process cleanup completed.")
1961
+ except ImportError:
1962
+ logger.info("TensorFlow not available for cleanup")
1963
+ except Exception as e:
1964
+ logger.info(f"Failed during TensorFlow cleanup: {e}")
1965
+ # Call parent cleanup for debugging sessions
1966
+ super().proc_cleanup()
1967
+
1968
+ @classmethod
1969
+ def get_distributed_env_vars(cls, worker_ips, node_rank, local_rank, num_local_procs, **settings):
1970
+ """Get TensorFlow-specific distributed environment variables.
1971
+
1972
+ TensorFlow uses TF_CONFIG for distributed training configuration.
1973
+ """
1974
+ import json
1975
+
1976
+ port = settings.get("port") or 2222 # TensorFlow default port
1977
+ env_vars = super().get_distributed_env_vars(worker_ips, node_rank, local_rank, num_local_procs, **settings)
1978
+
1979
+ # Build TF_CONFIG for MultiWorkerMirroredStrategy
1980
+ worker_addresses = [f"{ip}:{port}" for ip in worker_ips]
1981
+
1982
+ tf_config = {
1983
+ "cluster": {"worker": worker_addresses},
1984
+ "task": {"type": "worker", "index": node_rank},
1985
+ }
1986
+
1987
+ env_vars.update(
1988
+ {
1989
+ "TF_CONFIG": json.dumps(tf_config),
1990
+ # Additional TF env vars for performance
1991
+ "TF_FORCE_GPU_ALLOW_GROWTH": "true",
1992
+ "TF_GPU_THREAD_MODE": "gpu_private",
1993
+ }
1994
+ )
1995
+ return env_vars
1996
+
1997
+ @classmethod
1998
+ def get_auto_num_processes(cls):
1999
+ """Auto-detect based on available GPUs for TensorFlow."""
2000
+ try:
2001
+ import tensorflow as tf
2002
+
2003
+ gpus = tf.config.list_physical_devices("GPU")
2004
+ if gpus:
2005
+ return len(gpus)
2006
+ except Exception:
2007
+ pass
2008
+ return 1
2009
+
2010
+
2011
+ class MonarchProcess(DistributedProcess):
2012
+ """Monarch-specific distributed process for single-controller actor framework.
2013
+
2014
+ Similar to Ray, Monarch uses a single controller (rank 0) that manages distributed
2015
+ actors across worker nodes. Each node runs a process_allocator service.
2016
+ """
2017
+
2018
+ def __init__(self, local_rank, request_queue, response_queue, max_threads=4, **kwargs):
2019
+ super().__init__(local_rank, request_queue, response_queue, max_threads, **kwargs)
2020
+ self.allocator = None
2021
+
2022
+ # Monarch imports will be done in run() on the main thread
2023
+ self.RemoteAllocator = None
2024
+ self.StaticRemoteAllocInitializer = None
2025
+
2026
+ def _create_allocator_for_controller(self):
2027
+ """Create a RemoteAllocator for the controller (rank 0)."""
2028
+
2029
+ try:
2030
+ # Try to import if not already available
2031
+ if self.RemoteAllocator is None or self.StaticRemoteAllocInitializer is None:
2032
+ try:
2033
+ from monarch._src.actor.allocator import RemoteAllocator, StaticRemoteAllocInitializer
2034
+
2035
+ self.RemoteAllocator = RemoteAllocator
2036
+ self.StaticRemoteAllocInitializer = StaticRemoteAllocInitializer
2037
+ logger.debug("Monarch components imported")
2038
+ except ImportError as e:
2039
+ logger.error(f"Failed to import Monarch: {e}")
2040
+ logger.error("Make sure torchmonarch is installed: pip install torchmonarch")
2041
+ import traceback
2042
+
2043
+ logger.error(traceback.format_exc())
2044
+ return None
2045
+ except Exception as e:
2046
+ logger.error(f"Unexpected error importing Monarch: {e}")
2047
+ import traceback
2048
+
2049
+ logger.error(traceback.format_exc())
2050
+ return None
2051
+
2052
+ if self.RemoteAllocator is None or self.StaticRemoteAllocInitializer is None:
2053
+ logger.error("Monarch components not available. Cannot create allocator.")
2054
+ return None
2055
+
2056
+ # Get worker addresses from POD_IPS
2057
+ pod_ips = os.environ.get("POD_IPS", "").split(",")
2058
+ if not pod_ips or pod_ips == [""]:
2059
+ logger.warning("No POD_IPS found, using localhost")
2060
+ pod_ips = ["127.0.0.1"]
2061
+
2062
+ # Use tcp! format for channel addresses (Monarch's format)
2063
+ # Format: tcp!{ip}:{port} not tcp://{ip}:{port}
2064
+ worker_addresses = [f"tcp!{ip}:26600" for ip in pod_ips]
2065
+ logger.info(f"Creating Monarch allocator with {len(worker_addresses)} workers")
2066
+ logger.debug(f"Worker addresses type: {type(worker_addresses)}")
2067
+ logger.debug(f"First address: {worker_addresses[0] if worker_addresses else 'none'}")
2068
+ logger.debug(f"First address type: {type(worker_addresses[0]) if worker_addresses else 'none'}")
2069
+
2070
+ # Simple check - don't add complex waiting logic
2071
+
2072
+ # Create initializer with all workers using pre-imported classes
2073
+ # StaticRemoteAllocInitializer takes addresses as positional args
2074
+ logger.debug(f"About to create StaticRemoteAllocInitializer with args: {worker_addresses}")
2075
+ try:
2076
+ initializer = self.StaticRemoteAllocInitializer(*worker_addresses)
2077
+ except Exception as e:
2078
+ logger.error(f"Failed to create StaticRemoteAllocInitializer: {e}")
2079
+ import traceback
2080
+
2081
+ logger.debug(f"Traceback: {traceback.format_exc()}")
2082
+ raise
2083
+
2084
+ # Return configured allocator using pre-imported class
2085
+ # RemoteAllocator takes world_id and initializer
2086
+ # Use stable world_id based on service name
2087
+ # This allows coordinator failover and process restarts to work correctly
2088
+ service_name = os.environ.get("KT_SERVICE_NAME", "monarch-default")
2089
+ world_id = service_name
2090
+ try:
2091
+ allocator = self.RemoteAllocator(world_id=world_id, initializer=initializer)
2092
+ logger.info(f"Created allocator with world_id={world_id}")
2093
+ return allocator
2094
+ except Exception as e:
2095
+ logger.error(f"Failed to create RemoteAllocator: {e}")
2096
+ import traceback
2097
+
2098
+ logger.debug(f"Traceback: {traceback.format_exc()}")
2099
+ raise
2100
+
2101
+ except ImportError as e:
2102
+ logger.error(f"Could not import Monarch for allocator creation: {e}")
2103
+ import traceback
2104
+
2105
+ logger.error(traceback.format_exc())
2106
+ return None
2107
+ except Exception as e:
2108
+ logger.error(f"Failed to create Monarch allocator: {e}")
2109
+ import traceback
2110
+
2111
+ logger.error(traceback.format_exc())
2112
+ return None
2113
+
2114
+ def run_user_function(self, callable_obj, method_name, params):
2115
+ """Run user function with Monarch-specific setup."""
2116
+ import asyncio
2117
+ import inspect
2118
+
2119
+ # Get the rank from environment
2120
+ rank = int(os.environ.get("NODE_RANK", "0"))
2121
+ logger.debug(f"Running user function on rank {rank}")
2122
+
2123
+ # Get the method to call
2124
+ if method_name and hasattr(callable_obj, method_name):
2125
+ user_method = getattr(callable_obj, method_name)
2126
+ else:
2127
+ user_method = callable_obj
2128
+
2129
+ logger.info(f"User method: {user_method}")
2130
+
2131
+ # Prepare arguments
2132
+ args = params.get("args", []) if params else []
2133
+ kwargs = params.get("kwargs", {}) if params else {}
2134
+
2135
+ # Only create and inject allocator for controller (rank 0)
2136
+ # Workers will run the user function with allocator=None
2137
+ if rank == 0:
2138
+ logger.info("Rank 0 (controller) - will create allocator")
2139
+ # Controller (rank 0) - create allocator if needed
2140
+ if self.allocator is None:
2141
+ logger.debug("Creating allocator...")
2142
+ self.allocator = self._create_allocator_for_controller()
2143
+ if self.allocator is None:
2144
+ logger.error("Failed to create allocator - returned None!")
2145
+ else:
2146
+ logger.debug("Allocator created successfully")
2147
+
2148
+ # Inject allocator if function accepts it
2149
+ try:
2150
+ sig = inspect.signature(user_method)
2151
+ if "allocator" in sig.parameters:
2152
+ logger.debug("Injecting allocator into controller function")
2153
+ kwargs["allocator"] = self.allocator
2154
+ except Exception as e:
2155
+ logger.warning(f"Could not inspect function signature: {e}")
2156
+ else:
2157
+ # Workers get None for allocator parameter if the function expects it
2158
+ try:
2159
+ sig = inspect.signature(user_method)
2160
+ if "allocator" in sig.parameters:
2161
+ logger.info(f"Worker {rank}: Setting allocator=None")
2162
+ kwargs["allocator"] = None
2163
+ except Exception:
2164
+ pass
2165
+
2166
+ # Run the function (we're already on the main thread of this process)
2167
+ logger.info(f"Rank {rank}: Running user function with args={args}, kwargs keys={list(kwargs.keys())}")
2168
+
2169
+ try:
2170
+ if asyncio.iscoroutinefunction(user_method):
2171
+ result = asyncio.run(user_method(*args, **kwargs))
2172
+ else:
2173
+ result = user_method(*args, **kwargs)
2174
+ if inspect.isawaitable(result):
2175
+ result = asyncio.run(result)
2176
+
2177
+ # If the result is an exception dict (from the user's try/except),
2178
+ # convert it back to a proper exception and raise it
2179
+ if isinstance(result, dict) and "status" in result and result["status"] == "error":
2180
+ error_msg = result.get("error", "Unknown error")
2181
+ traceback_str = result.get("traceback", "")
2182
+ logger.error(f"User function returned error dict: {error_msg}")
2183
+ logger.error(f"Traceback from user function: {traceback_str}")
2184
+ # Raise a RuntimeError with the original error message
2185
+ raise RuntimeError(f"Monarch execution failed: {error_msg}\n{traceback_str}")
2186
+
2187
+ return result
2188
+ except Exception as e:
2189
+ logger.error(f"Exception in run_user_function: {e}")
2190
+ import traceback
2191
+
2192
+ logger.error(f"Full traceback: {traceback.format_exc()}")
2193
+ # Re-raise the exception to be handled by the caller
2194
+ raise
2195
+
2196
+ def run(self):
2197
+ """Override run to handle requests on main thread for Monarch."""
2198
+ logger.debug("MonarchProcess starting on main thread")
2199
+
2200
+ # Import Monarch on the main thread of this subprocess
2201
+ # This is the right place since run() executes on the main thread
2202
+ if self.RemoteAllocator is None or self.StaticRemoteAllocInitializer is None:
2203
+ try:
2204
+ from monarch._src.actor.allocator import RemoteAllocator, StaticRemoteAllocInitializer
2205
+
2206
+ self.RemoteAllocator = RemoteAllocator
2207
+ self.StaticRemoteAllocInitializer = StaticRemoteAllocInitializer
2208
+ except Exception as e:
2209
+ logger.error(f"Failed to import Monarch in run(): {e}")
2210
+ import traceback
2211
+
2212
+ logger.error(traceback.format_exc())
2213
+
2214
+ # Monarch requires main thread execution, so we don't use ThreadPoolExecutor
2215
+ try:
2216
+ while True:
2217
+ try:
2218
+ # Block waiting for next request
2219
+ request = self._request_queue.get(timeout=1)
2220
+
2221
+ # Special sentinel value to signal shutdown
2222
+ if request == "SHUTDOWN":
2223
+ break
2224
+
2225
+ # Handle request directly on main thread (not in thread pool)
2226
+ self.handle_request(request)
2227
+
2228
+ except queue.Empty:
2229
+ continue
2230
+ except Exception as e:
2231
+ if "Empty" not in str(e.__class__.__name__):
2232
+ logger.error(f"Error getting request from queue: {e}")
2233
+ continue
2234
+
2235
+ except (KeyboardInterrupt, BdbQuit):
2236
+ logger.debug("MonarchProcess interrupted")
2237
+ finally:
2238
+ logger.debug("MonarchProcess shutting down")
2239
+ self.proc_cleanup()
2240
+
2241
+ def handle_request(self, request):
2242
+ """Handle request using Monarch-specific logic."""
2243
+ try:
2244
+ # Use parent's handle_request but override the actual function execution
2245
+ request_unique_id = request["request_unique_id"]
2246
+ method_name = request["method_name"]
2247
+ params = request["params"]
2248
+ deployed_as_of = request["deployed_as_of"]
2249
+ request_id = request["request_id"]
2250
+ distributed_env_vars = request["distributed_env_vars"]
2251
+
2252
+ # Set environment variables
2253
+ for key, value in distributed_env_vars.items():
2254
+ os.environ[key] = value
2255
+
2256
+ # Set request context
2257
+ token = request_id_ctx_var.set(request_id)
2258
+
2259
+ try:
2260
+ # Load callable
2261
+ callable_obj = load_callable(
2262
+ deployed_as_of=deployed_as_of,
2263
+ distributed_subprocess=True,
2264
+ reload_cleanup_fn=self.proc_cleanup,
2265
+ )
2266
+
2267
+ # Run with our simplified Monarch logic
2268
+ result = self.run_user_function(callable_obj, method_name, params)
2269
+
2270
+ # Send response
2271
+ self._response_queue.put({"request_unique_id": request_unique_id, "result": result})
2272
+
2273
+ except Exception as e:
2274
+ # Package and send error
2275
+ packaged_exception = package_exception(e)
2276
+ self._response_queue.put(
2277
+ {
2278
+ "request_unique_id": request_unique_id,
2279
+ "result": packaged_exception,
2280
+ }
2281
+ )
2282
+ finally:
2283
+ request_id_ctx_var.reset(token)
2284
+
2285
+ except Exception as e:
2286
+ logger.error(f"Error in Monarch request handling: {e}")
2287
+ self._response_queue.put(
2288
+ {
2289
+ "request_unique_id": request.get("request_unique_id", "unknown"),
2290
+ "result": Exception(f"Fatal error: {e}"),
2291
+ }
2292
+ )
2293
+
2294
+ def proc_cleanup(self):
2295
+ """Monarch-specific cleanup."""
2296
+ try:
2297
+ # Stop allocator service
2298
+ if self.allocator_proc:
2299
+ self.allocator_proc.terminate()
2300
+ try:
2301
+ self.allocator_proc.wait(timeout=5)
2302
+ except subprocess.TimeoutExpired:
2303
+ self.allocator_proc.kill()
2304
+ self.allocator_proc = None
2305
+ logger.info("Stopped process_allocator service")
2306
+
2307
+ # Cleanup any Monarch resources
2308
+ # Monarch doesn't have a global shutdown like Ray
2309
+ logger.debug("Monarch process cleanup completed")
2310
+
2311
+ except Exception as e:
2312
+ logger.error(f"Error during Monarch cleanup: {e}")
2313
+
2314
+ # Call parent cleanup
2315
+ super().proc_cleanup()
2316
+
2317
+ @classmethod
2318
+ def get_distributed_env_vars(cls, worker_ips, node_rank, local_rank, num_local_procs, **settings):
2319
+ """Get Monarch-specific environment variables."""
2320
+ env_vars = super().get_distributed_env_vars(worker_ips, node_rank, local_rank, num_local_procs, **settings)
2321
+
2322
+ # Monarch uses these for discovery
2323
+ env_vars.update(
2324
+ {
2325
+ "HYPERACTOR_MESH_BOOTSTRAP_ADDR": "tcp://localhost:26600",
2326
+ "HYPERACTOR_MESH_INDEX": str(node_rank),
2327
+ # Keep POD_IPS for allocator creation
2328
+ "POD_IPS": ",".join(worker_ips),
2329
+ }
2330
+ )
2331
+
2332
+ return env_vars
2333
+
2334
+ @classmethod
2335
+ def get_auto_num_processes(cls):
2336
+ """Monarch uses one process per node (like Ray)."""
2337
+ return 1
2338
+
2339
+
2340
+ # Similar to Ray, Monarch needs special handling as a single-controller framework
2341
+ class MonarchDistributed(DistributedSupervisor):
2342
+ """Monarch distributed supervisor for single-controller actor framework."""
2343
+
2344
+ def __init__(
2345
+ self,
2346
+ restart_procs=True,
2347
+ max_threads=4,
2348
+ quorum_timeout=300,
2349
+ quorum_workers=None,
2350
+ **kwargs,
2351
+ ):
2352
+ # Monarch doesn't use DNS monitoring like SPMD frameworks
2353
+ super().__init__(
2354
+ quorum_workers=quorum_workers,
2355
+ quorum_timeout=quorum_timeout,
2356
+ monitor_members=False, # Disable DNS monitoring like Ray
2357
+ )
2358
+ self.restart_procs = restart_procs
2359
+ self.max_threads = max_threads
2360
+ self.process_pool = None
2361
+ self.remote_worker_pool = None
2362
+
2363
+ def setup(self, deployed_as_of: Optional[str] = None):
2364
+ """Setup Monarch distributed environment."""
2365
+ # Set multiprocessing to spawn
2366
+ if multiprocessing.get_start_method() != "spawn":
2367
+ multiprocessing.set_start_method("spawn", force=True)
2368
+
2369
+ # Start process_allocator service (like Ray starts its server)
2370
+ self._start_allocator_service()
2371
+
2372
+ if self.restart_procs:
2373
+ logger.debug("restart_procs is True, restarting Monarch processes")
2374
+ self.cleanup()
2375
+
2376
+ if self.process_pool is None:
2377
+ logger.debug("Setting up Monarch distributed environment")
2378
+
2379
+ # Create process pool with MonarchProcess
2380
+ logger.info("Creating DistributedProcessPool with MonarchProcess class")
2381
+ self.process_pool = DistributedProcessPool(
2382
+ process_class=MonarchProcess,
2383
+ num_processes=1, # One process per node for Monarch
2384
+ max_threads_per_proc=self.max_threads,
2385
+ )
2386
+
2387
+ # Start the process
2388
+ logger.info("Starting MonarchProcess pool...")
2389
+ self.process_pool.start()
2390
+ logger.info(f"Started MonarchProcess pool: {self.process_pool}")
2391
+
2392
+ # Create remote worker pool for coordination
2393
+ self.remote_worker_pool = RemoteWorkerPool(quorum_timeout=self.quorum_timeout)
2394
+ self.remote_worker_pool.start()
2395
+
2396
+ logger.debug("Finished setting up Monarch distributed processes")
2397
+
2398
+ def _start_allocator_service(self):
2399
+ """Start the process_allocator service if available."""
2400
+ try:
2401
+ # Check if process_allocator is already running
2402
+ import subprocess
2403
+
2404
+ check_result = subprocess.run(["pgrep", "-f", "process_allocator"], capture_output=True)
2405
+ if check_result.returncode == 0:
2406
+ logger.info("process_allocator already running")
2407
+ return
2408
+
2409
+ # Try to find process_allocator binary
2410
+ import shutil
2411
+
2412
+ allocator_path = shutil.which("process_allocator")
2413
+
2414
+ if not allocator_path:
2415
+ # Check common installation paths
2416
+ import sys
2417
+
2418
+ possible_paths = [
2419
+ "/opt/conda/bin/process_allocator",
2420
+ os.path.join(sys.prefix, "bin", "process_allocator"),
2421
+ ]
2422
+ for path in possible_paths:
2423
+ if os.path.exists(path) and os.access(path, os.X_OK):
2424
+ allocator_path = path
2425
+ break
2426
+
2427
+ if not allocator_path:
2428
+ logger.warning(
2429
+ "process_allocator binary not found. "
2430
+ "Please ensure torchmonarch is properly installed or "
2431
+ "start process_allocator manually in your Docker image."
2432
+ )
2433
+ return
2434
+
2435
+ # Start process_allocator with the correct arguments
2436
+ # Based on monarch/python/monarch/tools/components/hyperactor.py
2437
+ allocator_cmd = [
2438
+ allocator_path,
2439
+ "--port=26600",
2440
+ "--program=monarch_bootstrap",
2441
+ ]
2442
+
2443
+ logger.info(f"Starting process_allocator: {' '.join(allocator_cmd)}")
2444
+
2445
+ # Start in background
2446
+ self.allocator_proc = subprocess.Popen(
2447
+ allocator_cmd,
2448
+ stdout=subprocess.DEVNULL,
2449
+ stderr=subprocess.PIPE,
2450
+ start_new_session=True,
2451
+ )
2452
+
2453
+ # Give it a moment to start
2454
+ import time
2455
+
2456
+ time.sleep(2)
2457
+
2458
+ # Check if it's still running
2459
+ if self.allocator_proc.poll() is None:
2460
+ logger.info(f"process_allocator started successfully (PID: {self.allocator_proc.pid})")
2461
+ else:
2462
+ stderr = self.allocator_proc.stderr.read().decode() if self.allocator_proc.stderr else ""
2463
+ logger.error(f"process_allocator failed to start: {stderr}")
2464
+ self.allocator_proc = None
2465
+
2466
+ except Exception as e:
2467
+ logger.warning(f"Could not start process_allocator: {e}")
2468
+ # Continue anyway - user may have started it differently
2469
+
2470
+ def cleanup(self):
2471
+ """Cleanup Monarch distributed environment."""
2472
+ logger.debug("Cleaning up Monarch distributed processes")
2473
+
2474
+ # Stop DNS monitoring (though it's disabled for Monarch)
2475
+ super().cleanup()
2476
+
2477
+ # Stop process_allocator if we started it
2478
+ if hasattr(self, "allocator_proc") and self.allocator_proc:
2479
+ try:
2480
+ self.allocator_proc.terminate()
2481
+ self.allocator_proc.wait(timeout=5)
2482
+ logger.info("Stopped process_allocator")
2483
+ except Exception as e:
2484
+ logger.debug(f"Error stopping process_allocator: {e}")
2485
+
2486
+ if self.process_pool:
2487
+ self.process_pool.stop()
2488
+ self.process_pool = None
2489
+
2490
+ if self.remote_worker_pool:
2491
+ self.remote_worker_pool.stop()
2492
+ self.remote_worker_pool = None
2493
+
2494
+ logger.debug("Finished cleaning up Monarch distributed processes")
2495
+
2496
+ @staticmethod
2497
+ def intercept_call():
2498
+ """Monarch intercepts calls like Ray."""
2499
+ return True
2500
+
2501
+ def call_distributed(
2502
+ self,
2503
+ request,
2504
+ cls_or_fn_name: str,
2505
+ method_name: Optional[str] = None,
2506
+ params: Optional[Dict] = None,
2507
+ distributed_subcall: bool = False,
2508
+ debug_port: int = False,
2509
+ deployed_as_of: Optional[str] = None,
2510
+ ):
2511
+ """Monarch distributed call - executes on controller node (rank 0)."""
2512
+ logger.info("MonarchDistributed.call_distributed called")
2513
+
2514
+ # Ensure setup has been called
2515
+ if self.process_pool is None:
2516
+ logger.info("Process pool not initialized, calling setup()")
2517
+ self.setup(deployed_as_of=deployed_as_of)
2518
+
2519
+ request_id = request.headers.get("X-Request-ID", "-")
2520
+ serialization = request.headers.get("X-Serialization", "json")
2521
+
2522
+ # If deployed_as_of is None, generate a consistent timestamp
2523
+ if deployed_as_of is None:
2524
+ from datetime import datetime, timezone
2525
+
2526
+ deployed_as_of = datetime.now(timezone.utc).isoformat()
2527
+
2528
+ # Start DNS monitoring for worker discovery
2529
+ self.start_dns_monitoring()
2530
+
2531
+ # Check for any pending changes before we start
2532
+ self.check_for_membership_changes()
2533
+
2534
+ # Get pod IPs with quorum handling
2535
+ pod_ips = self.pod_ips()
2536
+
2537
+ # Handle case where no pods are found
2538
+ if not pod_ips:
2539
+ logger.error(
2540
+ f"No pods found for service {os.environ.get('KT_SERVICE')}. "
2541
+ "This may indicate the pods aren't ready yet. Consider increasing quorum_timeout in .distribute() call."
2542
+ )
2543
+ raise RuntimeError(
2544
+ "No pods found for Monarch distributed setup. " "Consider increasing quorum_timeout parameter."
2545
+ )
2546
+
2547
+ logger.info(f"Found {len(pod_ips)} pod(s) for Monarch distributed setup: {pod_ips}")
2548
+
2549
+ # Store critical environment variables
2550
+ self.distributed_env_vars = {}
2551
+ critical_env_vars = [
2552
+ "KT_SERVICE",
2553
+ "KT_SERVICE_NAME",
2554
+ "KT_FILE_PATH",
2555
+ "KT_MODULE_NAME",
2556
+ "KT_CLS_OR_FN_NAME",
2557
+ ]
2558
+ for env_var in critical_env_vars:
2559
+ if env_var in os.environ:
2560
+ self.distributed_env_vars[env_var] = os.environ[env_var]
2561
+
2562
+ # Update distributed env vars with current cluster IPs
2563
+ self.distributed_env_vars["POD_IPS"] = ",".join(pod_ips)
2564
+ self.distributed_env_vars["WORLD_SIZE"] = str(len(pod_ips))
2565
+ self.distributed_env_vars["NODE_RANK"] = "0" # Controller is always rank 0
2566
+
2567
+ logger.debug("Sending call to Monarch subprocess (controller)")
2568
+
2569
+ # Monarch uses only one process per node, call index 0
2570
+ result = self.process_pool.call(
2571
+ idx=0,
2572
+ method_name=method_name,
2573
+ params=params,
2574
+ deployed_as_of=deployed_as_of,
2575
+ request_id=request_id,
2576
+ distributed_env_vars=self.distributed_env_vars,
2577
+ debug_port=debug_port,
2578
+ serialization=serialization,
2579
+ )
2580
+
2581
+ # Handle exceptions from subprocess
2582
+ if isinstance(result, JSONResponse):
2583
+ return result
2584
+ if isinstance(result, Exception):
2585
+ raise result
2586
+
2587
+ return result
2588
+
2589
+
2590
+ RAY_START_PROC = None
2591
+
2592
+
2593
+ class RayDistributed(DistributedSupervisor):
2594
+ def __init__(
2595
+ self,
2596
+ restart_procs=True,
2597
+ max_threads=4,
2598
+ quorum_timeout=300,
2599
+ quorum_workers=None,
2600
+ monitor_members=False,
2601
+ ):
2602
+ """Ray distributed supervisor - only runs on head node (single controller).
2603
+
2604
+ Args:
2605
+ restart_procs: Whether to restart processes on each call
2606
+ max_threads: Maximum threads per process
2607
+ quorum_timeout: Timeout in seconds for Ray cluster nodes to become ready (default 300s/5min)
2608
+ """
2609
+ # Ray manages its own membership, so we don't monitor DNS changes
2610
+ super().__init__(
2611
+ quorum_timeout=quorum_timeout,
2612
+ quorum_workers=quorum_workers,
2613
+ monitor_members=monitor_members,
2614
+ )
2615
+ self.restart_procs = restart_procs
2616
+ self.distributed_env_vars = None
2617
+ self.process_pool = None # Using pool even for single process for consistency
2618
+ self.remote_worker_pool = None # Pool for async HTTP calls to remote workers
2619
+ self.max_threads = max_threads
2620
+ self.quorum_timeout = quorum_timeout
2621
+
2622
+ def setup(self, deployed_as_of: Optional[str] = None):
2623
+ # Set multiprocessing to spawn if not already
2624
+ if multiprocessing.get_start_method() != "spawn":
2625
+ multiprocessing.set_start_method("spawn", force=True)
2626
+
2627
+ # Start the Ray server here, if we allow KubeRay to start it in the pod template
2628
+ # it's hard to wait for it start properly and we lose the ability to restart if needed.
2629
+ global RAY_START_PROC
2630
+
2631
+ # Check if Ray is actually running, not just if our global variable is None
2632
+ # (the global variable gets reset when HTTP server restarts)
2633
+ ray_running = self._is_ray_running()
2634
+
2635
+ if not ray_running:
2636
+ patch_sys_path()
2637
+
2638
+ kuberay_start_cmd = os.environ.get("KUBERAY_GEN_RAY_START_CMD")
2639
+ if kuberay_start_cmd:
2640
+ full_cmd = f"ulimit -n 65536; {kuberay_start_cmd}"
2641
+ logger.info(f"Starting Ray server with command: {full_cmd}")
2642
+
2643
+ try:
2644
+ # Start Ray as a non-blocking subprocess
2645
+ RAY_START_PROC = subprocess.Popen(
2646
+ full_cmd,
2647
+ shell=True,
2648
+ stdout=subprocess.PIPE,
2649
+ stderr=subprocess.STDOUT,
2650
+ universal_newlines=True,
2651
+ bufsize=1,
2652
+ env=os.environ.copy(),
2653
+ )
2654
+
2655
+ # Start a thread to stream Ray logs
2656
+ def stream_ray_logs():
2657
+ try:
2658
+ for line in RAY_START_PROC.stdout:
2659
+ logger.info(f"[Ray] {line.strip()}")
2660
+ except Exception as e:
2661
+ logger.error(f"Error streaming Ray logs: {e}")
2662
+
2663
+ import threading
2664
+
2665
+ log_thread = threading.Thread(target=stream_ray_logs, daemon=True)
2666
+ log_thread.start()
2667
+
2668
+ logger.info(f"Ray server started with PID: {RAY_START_PROC.pid}")
2669
+
2670
+ # Give Ray a moment to start
2671
+ time.sleep(2)
2672
+
2673
+ except Exception as e:
2674
+ logger.error(f"Failed to start Ray server: {e}")
2675
+ RAY_START_PROC = None
2676
+ raise
2677
+ else:
2678
+ logger.warning("KUBERAY_GEN_RAY_START_CMD environment variable not found")
2679
+
2680
+ logger.debug("Ray distributed supervisor setup completed (pod discovery will be done lazily)")
2681
+
2682
+ # Only the head node runs the subprocess
2683
+ this_pod_ip = os.environ["POD_IP"]
2684
+ if not os.environ["POD_NAME"].endswith("-head"):
2685
+ logger.info(f"Ray worker node {this_pod_ip}, skipping subprocess setup")
2686
+ return
2687
+
2688
+ logger.info(f"Ray head node {this_pod_ip}, setting up subprocess")
2689
+
2690
+ # Set Ray environment variables
2691
+ self.distributed_env_vars = {"RAY_HEAD_NODE_IP": this_pod_ip}
2692
+
2693
+ # Include critical environment variables so Ray workers can find and load the callable
2694
+ critical_env_vars = [
2695
+ "PYTHONPATH",
2696
+ "KT_FILE_PATH",
2697
+ "KT_MODULE_NAME",
2698
+ "KT_CLS_OR_FN_NAME",
2699
+ ]
2700
+ for env_var in critical_env_vars:
2701
+ if env_var in os.environ:
2702
+ self.distributed_env_vars[env_var] = os.environ[env_var]
2703
+
2704
+ # Cleanup will remove the process pool if found, so we need to check if it was previously initialized
2705
+ previously_initialized = self.remote_worker_pool is not None
2706
+
2707
+ if self.restart_procs:
2708
+ logger.debug("restart_procs is True, restarting Ray distributed process")
2709
+ self.cleanup()
2710
+
2711
+ if previously_initialized:
2712
+ pod_ips = self.pod_ips()
2713
+ this_pod_ip = os.environ["POD_IP"]
2714
+
2715
+ # Send reload requests to other pods if needed
2716
+ self._reload_image_on_other_pods(pod_ips, this_pod_ip, deployed_as_of)
2717
+
2718
+ if self.process_pool is None:
2719
+ logger.debug("Setting up Ray distributed process")
2720
+ self.process_pool = DistributedProcessPool(
2721
+ process_class=RayProcess,
2722
+ num_processes=1, # Ray only needs one process
2723
+ max_threads_per_proc=self.max_threads,
2724
+ )
2725
+ self.process_pool.start()
2726
+
2727
+ # # Start remote worker pool for async HTTP calls if needed
2728
+ # Use a reasonable default max_workers since we don't know cluster size yet
2729
+ self.remote_worker_pool = RemoteWorkerPool(quorum_timeout=self.quorum_timeout)
2730
+ self.remote_worker_pool.start(max_workers=100) # Default size
2731
+
2732
+ logger.debug("Finished setting up Ray distributed process and remote worker pool")
2733
+
2734
+ def cleanup(self):
2735
+ """Clean up Ray distributed process."""
2736
+ logger.debug("Cleaning up Ray distributed process")
2737
+
2738
+ # Stop DNS monitoring first
2739
+ super().cleanup()
2740
+
2741
+ if self.process_pool:
2742
+ self.process_pool.stop()
2743
+ self.process_pool = None
2744
+
2745
+ if self.remote_worker_pool:
2746
+ self.remote_worker_pool.stop()
2747
+ self.remote_worker_pool = None
2748
+
2749
+ logger.debug("Finished cleaning up Ray distributed process")
2750
+
2751
+ @staticmethod
2752
+ def intercept_call():
2753
+ return True
2754
+
2755
+ def call_distributed(
2756
+ self,
2757
+ request,
2758
+ cls_or_fn_name: str,
2759
+ method_name: Optional[str] = None,
2760
+ params: Optional[Dict] = None,
2761
+ distributed_subcall: bool = False,
2762
+ debug_port: int = False,
2763
+ deployed_as_of: Optional[str] = None,
2764
+ ):
2765
+ """Ray distributed call - only executes on head node."""
2766
+ request_id = request.headers.get("X-Request-ID", "-")
2767
+ serialization = request.headers.get("X-Serialization", "json")
2768
+
2769
+ # If deployed_as_of is None, generate a consistent timestamp
2770
+ # to use across all workers to prevent reload inconsistencies
2771
+ if deployed_as_of is None:
2772
+ from datetime import datetime, timezone
2773
+
2774
+ deployed_as_of = datetime.now(timezone.utc).isoformat()
2775
+
2776
+ if not os.environ["POD_NAME"].endswith("-head"):
2777
+ # This should never happen, because the service only points to the head node, Raise an error if it does.
2778
+ raise RuntimeError(
2779
+ f"Ray distributed call attempted on non-head node {os.environ['POD_NAME']}. "
2780
+ "This should only be called on the head node."
2781
+ )
2782
+
2783
+ # Start DNS monitoring for the head node
2784
+ self.start_dns_monitoring()
2785
+
2786
+ # Check for any pending changes before we start
2787
+ self.check_for_membership_changes()
2788
+
2789
+ # The pod_ips() method now handles waiting for quorum
2790
+ pod_ips = self.pod_ips()
2791
+
2792
+ # Handle case where no pods are found
2793
+ if not pod_ips:
2794
+ logger.error(
2795
+ f"No pods found for service {os.environ.get('KT_SERVICE')}. "
2796
+ "This may indicate the pods aren't ready yet. Consider increasing quorum_timeout in .distribute() call."
2797
+ )
2798
+ raise RuntimeError(
2799
+ "No pods found for Ray distributed setup. " "Consider increasing quorum_timeout parameter."
2800
+ )
2801
+
2802
+ logger.info(f"Found {len(pod_ips)} pod(s) for distributed setup: {pod_ips}")
2803
+
2804
+ # Update distributed env vars with current cluster IPs
2805
+ self.distributed_env_vars["POD_IPS"] = ",".join(pod_ips)
2806
+
2807
+ logger.debug("Sending call to Ray subprocess")
2808
+ # Ray uses only one process, so always call index 0
2809
+ result = self.process_pool.call(
2810
+ idx=0,
2811
+ method_name=method_name,
2812
+ params=params,
2813
+ deployed_as_of=deployed_as_of,
2814
+ request_id=request_id,
2815
+ distributed_env_vars=self.distributed_env_vars,
2816
+ debug_port=debug_port,
2817
+ serialization=serialization,
2818
+ )
2819
+
2820
+ # Handle exceptions from subprocess
2821
+ if isinstance(result, JSONResponse):
2822
+ return result
2823
+ if isinstance(result, Exception):
2824
+ raise result
2825
+
2826
+ return result
2827
+
2828
+ def _reload_image_on_other_pods(self, pod_ips, this_pod_ip, deployed_as_of):
2829
+ """Send /_reload_image requests to all other pods in parallel, with retries for pods that aren't ready."""
2830
+ other_pod_ips = [ip for ip in pod_ips if ip != this_pod_ip]
2831
+
2832
+ if not other_pod_ips:
2833
+ logger.debug("No other pods to reload")
2834
+ return
2835
+
2836
+ logger.info(f"Sending reload requests to {len(other_pod_ips)} other pods: {other_pod_ips}")
2837
+
2838
+ server_port = os.environ.get("KT_SERVER_PORT", "32300")
2839
+ total_timeout = self.quorum_timeout # Use configurable quorum timeout
2840
+ retry_interval = 2 # Wait 2 seconds between retry attempts
2841
+ start_time = time.time()
2842
+
2843
+ successful_pods = set()
2844
+ remaining_pods = set(other_pod_ips)
2845
+
2846
+ while remaining_pods and (time.time() - start_time) < total_timeout:
2847
+ logger.debug(f"Attempting to reload {len(remaining_pods)} remaining pods: {list(remaining_pods)}")
2848
+
2849
+ def reload_pod(pod_ip):
2850
+ """Send reload request to a single pod."""
2851
+ try:
2852
+ # Use a proper HTTP client session to avoid Content-Length issues
2853
+ with httpx.Client(timeout=None) as client:
2854
+ url = f"http://{pod_ip}:{server_port}/_reload_image"
2855
+ # First try a quick health check to see if pod is ready
2856
+ health_url = f"http://{pod_ip}:{server_port}/health"
2857
+ health_response = client.get(health_url, timeout=5)
2858
+
2859
+ if health_response.status_code != 200:
2860
+ logger.debug(f"Pod {pod_ip} health check failed, will retry later")
2861
+ return False
2862
+
2863
+ # Pod is healthy, send reload request (no timeout, installs can be long-running)
2864
+ response = client.post(url, headers={"X-Deployed-As-Of": deployed_as_of})
2865
+ if response.status_code == 200:
2866
+ logger.debug(f"Successfully reloaded image on pod {pod_ip}")
2867
+ return True
2868
+ else:
2869
+ logger.warning(f"Pod {pod_ip} reload returned status {response.status_code}")
2870
+ return False
2871
+
2872
+ except Exception as e:
2873
+ logger.debug(f"Failed to reload image on pod {pod_ip}: {e}")
2874
+ raise
2875
+
2876
+ # Try to reload all remaining pods in parallel
2877
+ current_attempt_pods = list(remaining_pods)
2878
+
2879
+ with ThreadPoolExecutor(max_workers=min(len(current_attempt_pods), 10)) as executor:
2880
+ # Submit reload tasks for remaining pods
2881
+ future_to_pod = {executor.submit(reload_pod, pod_ip): pod_ip for pod_ip in current_attempt_pods}
2882
+
2883
+ # Process completed futures
2884
+ for future in as_completed(future_to_pod, timeout=None):
2885
+ pod_ip = future_to_pod[future]
2886
+ try:
2887
+ success = future.result()
2888
+ if success:
2889
+ successful_pods.add(pod_ip)
2890
+ remaining_pods.discard(pod_ip)
2891
+ except Exception as e:
2892
+ logger.debug(f"Reload task for pod {pod_ip} failed: {e}")
2893
+
2894
+ if remaining_pods:
2895
+ elapsed = time.time() - start_time
2896
+ remaining_time = total_timeout - elapsed
2897
+ if remaining_time > retry_interval:
2898
+ logger.info(f"Waiting {retry_interval}s before retrying {len(remaining_pods)} pods...")
2899
+ time.sleep(retry_interval)
2900
+ else:
2901
+ logger.warning("Timeout approaching, stopping retry attempts")
2902
+ break
2903
+
2904
+ # Log final results
2905
+ if successful_pods:
2906
+ logger.info(f"Successfully reloaded {len(successful_pods)} pod images: {list(successful_pods)}")
2907
+
2908
+ if remaining_pods:
2909
+ logger.warning(f"Failed to reload {len(remaining_pods)} pod images after timeout: {list(remaining_pods)}")
2910
+
2911
+ def _is_ray_running(self):
2912
+ """Check if Ray is actually running by trying to connect to the Ray GCS port."""
2913
+ try:
2914
+ import socket
2915
+
2916
+ # Ray GCS runs on port 6379 by default
2917
+ ray_port = 6379
2918
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2919
+ sock.settimeout(1) # 1 second timeout
2920
+ result = sock.connect_ex(("127.0.0.1", ray_port))
2921
+ sock.close()
2922
+
2923
+ if result == 0:
2924
+ logger.debug("Ray GCS port 6379 is accessible, Ray appears to be running")
2925
+ return True
2926
+ else:
2927
+ logger.debug("Ray GCS port 6379 is not accessible, Ray is not running")
2928
+ return False
2929
+
2930
+ except Exception as e:
2931
+ logger.debug(f"Error checking if Ray is running: {e}")
2932
+ return False
2933
+
2934
+
2935
+ def distributed_supervisor_factory(distribution_type, *args, **kwargs):
2936
+ """
2937
+ Factory function to create a distributed supervisor based on the specified type.
2938
+
2939
+ Args:
2940
+ distribution_type (str): The type of distributed supervisor to create.
2941
+ Options include 'ray', 'monarch', 'pytorch', 'jax', 'tensorflow', or None for generic SPMD.
2942
+ *args: Positional arguments to pass to the supervisor constructor.
2943
+ **kwargs: Keyword arguments to pass to the supervisor constructor.
2944
+ Common kwargs include:
2945
+ - quorum_timeout: Timeout in seconds for workers to become ready (default 30 for SPMD, 300 for Ray/Monarch)
2946
+
2947
+ Returns:
2948
+ DistributedSupervisor: An instance of the specified distributed supervisor.
2949
+ """
2950
+ if distribution_type == "ray":
2951
+ # Ray uses its own supervisor, not SPMD
2952
+ return RayDistributed(*args, **kwargs)
2953
+ elif distribution_type == "monarch":
2954
+ # Monarch is similar to Ray - single controller framework
2955
+ return MonarchDistributed(*args, **kwargs)
2956
+
2957
+ # All other types use SPMDDistributedSupervisor with different process classes
2958
+ if distribution_type == "pytorch":
2959
+ return SPMDDistributedSupervisor(process_class=PyTorchProcess, *args, **kwargs)
2960
+ elif distribution_type == "jax":
2961
+ return SPMDDistributedSupervisor(process_class=JaxProcess, *args, **kwargs)
2962
+ elif distribution_type == "tensorflow" or distribution_type == "tf":
2963
+ return SPMDDistributedSupervisor(process_class=TensorflowProcess, *args, **kwargs)
2964
+ elif distribution_type is None or distribution_type == "spmd":
2965
+ # Default to base DistributedProcess - no framework-specific dependencies
2966
+ return SPMDDistributedSupervisor(process_class=DistributedProcess, *args, **kwargs)
2967
+ else:
2968
+ raise ValueError(f"Unsupported distributed type: {distribution_type}")