kubetorch 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kubetorch might be problematic. Click here for more details.

Files changed (93) hide show
  1. kubetorch/__init__.py +60 -0
  2. kubetorch/cli.py +1985 -0
  3. kubetorch/cli_utils.py +1025 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +285 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +157 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +133 -0
  30. kubetorch/resources/callables/module.py +1416 -0
  31. kubetorch/resources/callables/utils.py +174 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +261 -0
  34. kubetorch/resources/compute/compute.py +2596 -0
  35. kubetorch/resources/compute/decorators.py +139 -0
  36. kubetorch/resources/compute/rbac.py +74 -0
  37. kubetorch/resources/compute/utils.py +1114 -0
  38. kubetorch/resources/compute/websocket.py +137 -0
  39. kubetorch/resources/images/__init__.py +1 -0
  40. kubetorch/resources/images/image.py +414 -0
  41. kubetorch/resources/images/images.py +74 -0
  42. kubetorch/resources/secrets/__init__.py +2 -0
  43. kubetorch/resources/secrets/kubernetes_secrets_client.py +412 -0
  44. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  45. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  46. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  47. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  48. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  49. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  50. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  51. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  52. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  53. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  54. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  55. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  56. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  57. kubetorch/resources/secrets/provider_secrets/providers.py +93 -0
  58. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  59. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  60. kubetorch/resources/secrets/secret.py +238 -0
  61. kubetorch/resources/secrets/secret_factory.py +70 -0
  62. kubetorch/resources/secrets/utils.py +209 -0
  63. kubetorch/resources/volumes/__init__.py +0 -0
  64. kubetorch/resources/volumes/volume.py +365 -0
  65. kubetorch/servers/__init__.py +0 -0
  66. kubetorch/servers/http/__init__.py +0 -0
  67. kubetorch/servers/http/distributed_utils.py +3223 -0
  68. kubetorch/servers/http/http_client.py +730 -0
  69. kubetorch/servers/http/http_server.py +1788 -0
  70. kubetorch/servers/http/server_metrics.py +278 -0
  71. kubetorch/servers/http/utils.py +728 -0
  72. kubetorch/serving/__init__.py +0 -0
  73. kubetorch/serving/autoscaling.py +173 -0
  74. kubetorch/serving/base_service_manager.py +363 -0
  75. kubetorch/serving/constants.py +83 -0
  76. kubetorch/serving/deployment_service_manager.py +478 -0
  77. kubetorch/serving/knative_service_manager.py +519 -0
  78. kubetorch/serving/raycluster_service_manager.py +582 -0
  79. kubetorch/serving/service_manager.py +18 -0
  80. kubetorch/serving/templates/deployment_template.yaml +17 -0
  81. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  82. kubetorch/serving/templates/kt_setup_template.sh.j2 +81 -0
  83. kubetorch/serving/templates/pod_template.yaml +194 -0
  84. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  85. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  86. kubetorch/serving/templates/service_template.yaml +21 -0
  87. kubetorch/serving/templates/workerset_template.yaml +36 -0
  88. kubetorch/serving/utils.py +377 -0
  89. kubetorch/utils.py +284 -0
  90. kubetorch-0.2.0.dist-info/METADATA +121 -0
  91. kubetorch-0.2.0.dist-info/RECORD +93 -0
  92. kubetorch-0.2.0.dist-info/WHEEL +4 -0
  93. kubetorch-0.2.0.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,3223 @@
1
+ import copy
2
+ import multiprocessing
3
+ import os
4
+ import queue
5
+ import subprocess
6
+ import threading
7
+ import time
8
+ import uuid
9
+ from bdb import BdbQuit
10
+ from concurrent.futures import as_completed, ThreadPoolExecutor
11
+ from typing import Dict, Optional
12
+
13
+ import httpx
14
+ from starlette.responses import JSONResponse
15
+
16
+ from kubetorch.servers.http.http_server import (
17
+ load_callable,
18
+ logger,
19
+ package_exception,
20
+ patch_sys_path,
21
+ request_id_ctx_var,
22
+ run_callable_internal_sync,
23
+ )
24
+
25
+ from .utils import clear_debugging_sessions, is_running_in_kubernetes
26
+
27
+ # Try to import Monarch components at module level if available
28
+ # This helps avoid threading issues with Monarch's Rust bindings
29
+ try:
30
+ from monarch._src.actor.allocator import (
31
+ RemoteAllocator,
32
+ StaticRemoteAllocInitializer,
33
+ )
34
+
35
+ MONARCH_AVAILABLE = True
36
+ except ImportError:
37
+ MONARCH_AVAILABLE = False
38
+ RemoteAllocator = None
39
+ StaticRemoteAllocInitializer = None
40
+ except Exception:
41
+ # Catch any other exceptions during import
42
+ MONARCH_AVAILABLE = False
43
+ RemoteAllocator = None
44
+ StaticRemoteAllocInitializer = None
45
+
46
+
47
+ class RemoteWorkerPool:
48
+ """Manages async HTTP calls to remote workers in a separate process."""
49
+
50
+ def __init__(self, quorum_timeout=300):
51
+ self.quorum_timeout = quorum_timeout
52
+ self.request_queue = multiprocessing.Queue()
53
+ self.response_queue = multiprocessing.Queue()
54
+ self.process = None
55
+ self._running = False
56
+ # Thread-safe response routing for concurrent calls
57
+ self._response_events = {}
58
+ self._response_lock = threading.Lock()
59
+ self._router_thread = None
60
+
61
+ def start(self, max_workers=2000):
62
+ """Start the worker process and router thread."""
63
+ if self.process:
64
+ raise RuntimeError("WorkerPool already started")
65
+
66
+ self._running = True
67
+ # Pass necessary data as arguments to avoid pickling issues
68
+ self.process = multiprocessing.Process(
69
+ target=self._run_async_worker,
70
+ args=(
71
+ self.request_queue,
72
+ self.response_queue,
73
+ self.quorum_timeout,
74
+ max_workers,
75
+ ),
76
+ daemon=True,
77
+ )
78
+ self.process.start()
79
+
80
+ # Start router thread for handling responses
81
+ self._router_thread = threading.Thread(
82
+ target=self._response_router, daemon=True
83
+ )
84
+ self._router_thread.start()
85
+
86
+ logger.debug("Started RemoteWorkerPool process and router thread")
87
+
88
+ def stop(self):
89
+ """Stop the worker process and router thread."""
90
+ self._running = False
91
+
92
+ # Stop router thread
93
+ if self._router_thread:
94
+ self.response_queue.put({"request_id": "STOP_ROUTER"})
95
+ self._router_thread.join(timeout=1)
96
+
97
+ # Stop worker process
98
+ if self.process:
99
+ self.request_queue.put(("SHUTDOWN", None))
100
+ self.process.join(timeout=5)
101
+ if self.process and self.process.is_alive():
102
+ self.process.terminate()
103
+ self.process.join(timeout=1)
104
+ if self.process and self.process.is_alive():
105
+ self.process.kill()
106
+ self.process = None
107
+
108
+ # Clear response events
109
+ with self._response_lock:
110
+ self._response_events.clear()
111
+
112
+ logger.debug("Stopped RemoteWorkerPool process and router thread")
113
+
114
+ def call_workers(
115
+ self,
116
+ worker_ips,
117
+ cls_or_fn_name,
118
+ method_name,
119
+ params,
120
+ request_headers,
121
+ workers_arg="all",
122
+ ):
123
+ """Call remote workers and return responses."""
124
+ if not self.process or not self.process.is_alive():
125
+ raise RuntimeError("RemoteWorkerPool not running")
126
+
127
+ # Generate unique request ID
128
+ request_id = str(uuid.uuid4())
129
+
130
+ # Submit request
131
+ request_data = {
132
+ "request_id": request_id,
133
+ "worker_ips": worker_ips,
134
+ "cls_or_fn_name": cls_or_fn_name,
135
+ "method_name": method_name,
136
+ "params": params,
137
+ "request_headers": request_headers,
138
+ "workers_arg": workers_arg,
139
+ }
140
+
141
+ # Register event for this request
142
+ event = threading.Event()
143
+ with self._response_lock:
144
+ self._response_events[request_id] = (event, None)
145
+
146
+ logger.debug(
147
+ f"RemoteWorkerPool: Submitting request {request_id} to queue for {len(worker_ips)} workers"
148
+ )
149
+ self.request_queue.put(("CALL", request_data))
150
+
151
+ # Wait for response with a timeout to prevent hanging forever
152
+ logger.debug(
153
+ f"RemoteWorkerPool: Waiting for response for request {request_id} from {len(worker_ips)} workers"
154
+ )
155
+ # Wait indefinitely for the response (no timeout for long-running jobs)
156
+ event.wait()
157
+
158
+ # Get and cleanup response
159
+ with self._response_lock:
160
+ _, result = self._response_events.pop(request_id)
161
+
162
+ logger.debug(
163
+ f"RemoteWorkerPool: Got response for request {request_id}, type: {type(result).__name__}"
164
+ )
165
+
166
+ if isinstance(result, Exception):
167
+ raise result
168
+ return result
169
+
170
+ def _response_router(self):
171
+ """Router thread that distributes responses to waiting threads."""
172
+ logger.debug("RemoteWorkerPool response router thread started")
173
+ while self._running:
174
+ try:
175
+ response = self.response_queue.get(timeout=0.1)
176
+ if response.get("request_id") == "STOP_ROUTER":
177
+ break
178
+
179
+ request_id = response.get("request_id")
180
+ logger.debug(
181
+ f"Response router received response for request {request_id}"
182
+ )
183
+ with self._response_lock:
184
+ if request_id in self._response_events:
185
+ event, _ = self._response_events[request_id]
186
+ # Store the result (either results list or exception)
187
+ if "error" in response:
188
+ logger.debug(
189
+ f"Response router: Setting error for request {request_id}"
190
+ )
191
+ self._response_events[request_id] = (
192
+ event,
193
+ response["error"],
194
+ )
195
+ else:
196
+ logger.debug(
197
+ f"Response router: Setting {len(response.get('results', []))} results for request {request_id}"
198
+ )
199
+ self._response_events[request_id] = (
200
+ event,
201
+ response["results"],
202
+ )
203
+ event.set()
204
+ logger.debug(
205
+ f"Response router: Event set for request {request_id}"
206
+ )
207
+ else:
208
+ logger.warning(
209
+ f"Response router: No event found for request {request_id}, registered events: {list(self._response_events.keys())}"
210
+ )
211
+ except queue.Empty:
212
+ continue # Queue timeout, continue checking
213
+ except Exception as e:
214
+ logger.error(f"Error in response router: {e}")
215
+ continue
216
+
217
+ @staticmethod
218
+ def _run_async_worker(
219
+ request_queue, response_queue, quorum_timeout, max_workers=2000
220
+ ):
221
+ """Worker process that handles async HTTP calls.
222
+
223
+ Architecture:
224
+ - Runs in a separate process with its own event loop
225
+ - Maintains a single shared httpx.AsyncClient for connection pooling
226
+ - Processes requests from queue concurrently (multiple requests in flight)
227
+ - Each request involves parallel async HTTP calls to multiple workers
228
+
229
+ Nested functions (share queue/timeout state):
230
+ - wait_for_worker_health: Health check with retries
231
+ - call_single_worker: Make HTTP call to one worker
232
+ - call_workers_async: Orchestrate health checks + calls for all workers
233
+ - process_requests: Main loop pulling from queue and dispatching tasks
234
+ """
235
+ import asyncio
236
+ import queue
237
+
238
+ import httpx
239
+
240
+ async def wait_for_worker_health(
241
+ client, worker_ip, workers_arg, quorum_timeout
242
+ ):
243
+ """Wait for a worker to become healthy within timeout."""
244
+ port = os.environ["KT_SERVER_PORT"]
245
+ worker_url = f"http://{worker_ip}:{port}"
246
+
247
+ start_time = asyncio.get_event_loop().time()
248
+
249
+ # Keep trying until timeout
250
+ while (asyncio.get_event_loop().time() - start_time) < quorum_timeout:
251
+ try:
252
+ resp = await client.get(f"{worker_url}/health", timeout=10.0)
253
+ if resp.status_code == 200:
254
+ return (worker_ip, True) # Return IP and success
255
+ except (httpx.RequestError, httpx.TimeoutException):
256
+ pass
257
+
258
+ # No waiting for quorum if user just wants to call ready workers
259
+ if workers_arg == "ready":
260
+ return worker_ip, False
261
+
262
+ # Wait before retry (1 second, same as original)
263
+ await asyncio.sleep(1.0)
264
+
265
+ # Timeout reached
266
+ return worker_ip, False
267
+
268
+ async def call_single_worker(
269
+ client,
270
+ worker_ip,
271
+ cls_or_fn_name,
272
+ method_name,
273
+ params,
274
+ request_headers,
275
+ workers_arg,
276
+ ):
277
+ """Call a single worker (assumes health already checked)."""
278
+ port = os.environ["KT_SERVER_PORT"]
279
+ worker_url = f"http://{worker_ip}:{port}"
280
+
281
+ # Make the actual call
282
+ call_url = f"{worker_url}/{cls_or_fn_name}"
283
+ if method_name:
284
+ call_url += f"/{method_name}"
285
+ call_url += "?distributed_subcall=true"
286
+
287
+ logger.debug(f"Async worker: Making POST to {call_url}")
288
+
289
+ # Retry logic for transient failures
290
+ max_retries = 3
291
+ for attempt in range(max_retries):
292
+ try:
293
+ resp = await client.post(
294
+ call_url, json=params, headers=request_headers
295
+ )
296
+ result = resp.json()
297
+ break # Success, exit retry loop
298
+ except httpx.ReadError as e:
299
+ # Check if this is due to server shutdown (connection reset)
300
+ if "Connection reset" in str(e) or "Connection closed" in str(e):
301
+ logger.warning(
302
+ f"Worker {worker_ip} appears to be shutting down: {e}"
303
+ )
304
+ raise # Don't retry on shutdown
305
+ if attempt < max_retries - 1:
306
+ wait_time = (attempt + 1) * 2 # Exponential backoff: 2s, 4s
307
+ logger.warning(
308
+ f"ReadError calling {worker_ip} (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
309
+ )
310
+ await asyncio.sleep(wait_time)
311
+ else:
312
+ logger.error(
313
+ f"ReadError calling {worker_ip} after {max_retries} attempts: {e}. Worker may be crashed/overloaded."
314
+ )
315
+ raise
316
+ except httpx.TimeoutException as e:
317
+ if attempt < max_retries - 1:
318
+ logger.warning(
319
+ f"Timeout calling {worker_ip} (attempt {attempt + 1}/{max_retries}): {e}. Retrying..."
320
+ )
321
+ await asyncio.sleep(1)
322
+ else:
323
+ logger.error(
324
+ f"Timeout calling {worker_ip} after {max_retries} attempts: {e}"
325
+ )
326
+ raise
327
+ except Exception as e:
328
+ logger.error(f"Unexpected error calling {worker_ip}: {e}")
329
+ raise
330
+ logger.debug(
331
+ f"Async worker: Got response from {worker_ip}, type: {type(result).__name__}, "
332
+ f"length: {len(result) if hasattr(result, '__len__') else 'N/A'}"
333
+ )
334
+ # In tree topology, intermediate nodes return aggregated results from their subtree
335
+ # We should preserve the flat list structure
336
+ return result
337
+
338
+ async def call_workers_async(client, data):
339
+ """Make async calls to all workers using shared client."""
340
+ worker_ips = data["worker_ips"]
341
+ cls_or_fn_name = data["cls_or_fn_name"]
342
+ method_name = data["method_name"]
343
+ params = data["params"]
344
+ request_headers = data["request_headers"]
345
+ workers_arg = data["workers_arg"]
346
+
347
+ # With tree topology limiting fanout, we don't need to batch health checks
348
+ # Each node only calls its direct children in the tree
349
+ logger.info(
350
+ f"Waiting for {len(worker_ips)} workers to become ready (timeout={quorum_timeout}s)"
351
+ )
352
+ health_tasks = [
353
+ wait_for_worker_health(client, ip, workers_arg, quorum_timeout)
354
+ for ip in worker_ips
355
+ ]
356
+ health_results = await asyncio.gather(*health_tasks)
357
+
358
+ # Process results
359
+ healthy_workers = []
360
+ unhealthy_workers = []
361
+ for worker_ip, is_healthy in health_results:
362
+ if is_healthy:
363
+ healthy_workers.append(worker_ip)
364
+ else:
365
+ unhealthy_workers.append(worker_ip)
366
+
367
+ if unhealthy_workers:
368
+ if workers_arg == "ready":
369
+ # For "ready" mode, just skip unhealthy workers
370
+ logger.info(
371
+ f"Skipping {len(unhealthy_workers)} workers that didn't respond (ready mode)"
372
+ )
373
+ else:
374
+ # For normal mode, fail if any worker didn't become ready
375
+ logger.error(
376
+ f"{len(unhealthy_workers)} workers failed to become ready after {quorum_timeout}s: {unhealthy_workers[:5]}..."
377
+ )
378
+ raise TimeoutError(
379
+ f"{len(unhealthy_workers)} of {len(worker_ips)} workers did not become ready within {quorum_timeout} seconds. "
380
+ f"This may indicate the pods are still starting or there's a resource constraint. "
381
+ f"Consider increasing quorum_timeout in .distribute() call."
382
+ )
383
+
384
+ logger.info(
385
+ f"All {len(healthy_workers)} workers are ready, making distributed calls"
386
+ )
387
+
388
+ # Now make the actual calls to ready workers
389
+ # Create tasks (not just coroutines)
390
+ tasks = []
391
+ for worker_ip in healthy_workers:
392
+ coro = call_single_worker(
393
+ client,
394
+ worker_ip,
395
+ cls_or_fn_name,
396
+ method_name,
397
+ params,
398
+ request_headers,
399
+ workers_arg,
400
+ )
401
+ # Create actual task from coroutine
402
+ task = asyncio.create_task(coro)
403
+ tasks.append(task)
404
+
405
+ # Use as_completed for fast failure propagation
406
+ responses = []
407
+ pending_tasks = set(tasks)
408
+
409
+ for future in asyncio.as_completed(tasks):
410
+ try:
411
+ result = await future
412
+ if result is not None: # Skip None results
413
+ responses.append(result)
414
+ # Remove completed task from pending set
415
+ pending_tasks.discard(future)
416
+ except Exception as e:
417
+ # Fast failure - immediately propagate the exception
418
+ # Cancel remaining tasks to avoid unnecessary work
419
+ for task in pending_tasks:
420
+ if not task.done():
421
+ task.cancel()
422
+ raise e
423
+
424
+ return responses
425
+
426
+ # Create and run event loop with shared AsyncClient
427
+ async def main():
428
+ # Set up signal handler for graceful shutdown
429
+ import signal
430
+
431
+ shutdown_event = asyncio.Event()
432
+
433
+ # Save existing handlers to chain to them
434
+ original_sigterm_handler = signal.getsignal(signal.SIGTERM)
435
+ original_sigint_handler = signal.getsignal(signal.SIGINT)
436
+
437
+ def signal_handler(signum, frame):
438
+ logger.info(f"Received signal {signum}, initiating graceful shutdown")
439
+ shutdown_event.set()
440
+
441
+ # Chain to original handler if it exists
442
+ if signum == signal.SIGTERM:
443
+ if original_sigterm_handler and original_sigterm_handler not in (
444
+ signal.SIG_DFL,
445
+ signal.SIG_IGN,
446
+ ):
447
+ original_sigterm_handler(signum, frame)
448
+ elif signum == signal.SIGINT:
449
+ if original_sigint_handler and original_sigint_handler not in (
450
+ signal.SIG_DFL,
451
+ signal.SIG_IGN,
452
+ ):
453
+ original_sigint_handler(signum, frame)
454
+
455
+ signal.signal(signal.SIGTERM, signal_handler)
456
+ signal.signal(signal.SIGINT, signal_handler)
457
+
458
+ # Create a single AsyncClient to be shared across all requests
459
+ # Set limits based on max expected workers (passed from parent)
460
+ # We need exactly max_workers connections since health checks and work calls
461
+ # should reuse the same connection per worker (HTTP keep-alive)
462
+ # Add small buffer for edge cases
463
+ buffer = max(100, int(max_workers * 0.1)) # 10% buffer, min 100
464
+ max_conn = max_workers + buffer
465
+
466
+ limits = httpx.Limits(
467
+ max_keepalive_connections=max_conn, # Keep all connections alive
468
+ max_connections=max_conn, # One connection per worker + buffer
469
+ keepalive_expiry=300.0, # Keep connections alive for 5 minutes
470
+ )
471
+ timeout = httpx.Timeout(
472
+ connect=10.0,
473
+ read=None, # No read timeout for long-running jobs
474
+ write=10.0,
475
+ pool=60.0, # Time to wait for a connection from pool
476
+ )
477
+
478
+ logger.debug(
479
+ f"AsyncClient configured with max_connections={max_conn} "
480
+ f"(workers={max_workers} + buffer={buffer}) with 5min keepalive"
481
+ )
482
+
483
+ # Note: http2=True would enable HTTP/2 if server supports it
484
+ async with httpx.AsyncClient(timeout=timeout, limits=limits) as client:
485
+ logger.debug("Async worker started with shared httpx.AsyncClient")
486
+
487
+ active_tasks = set()
488
+
489
+ async def process_call(request_data):
490
+ """Process a CALL request asynchronously."""
491
+ request_id = request_data["request_id"]
492
+ worker_ips = request_data["worker_ips"]
493
+ logger.debug(
494
+ f"Async worker: Processing request {request_id} with {len(worker_ips)} workers: {worker_ips}"
495
+ )
496
+ try:
497
+ results = await call_workers_async(client, request_data)
498
+ logger.debug(
499
+ f"Async worker: Successfully got {len(results)} results for request {request_id}, sending response"
500
+ )
501
+ response_queue.put(
502
+ {"request_id": request_id, "results": results}
503
+ )
504
+ except Exception as e:
505
+ logger.error(f"Error processing request {request_id}: {e}")
506
+ response_queue.put({"request_id": request_id, "error": e})
507
+
508
+ # Main request processing loop
509
+ while not shutdown_event.is_set():
510
+ try:
511
+ # Check queue without blocking event loop
512
+ cmd, request_data = await asyncio.to_thread(
513
+ request_queue.get, block=True, timeout=0.1
514
+ )
515
+
516
+ if cmd == "SHUTDOWN" or shutdown_event.is_set():
517
+ # Cancel all active tasks immediately for quick cleanup
518
+ if active_tasks:
519
+ logger.debug(
520
+ f"Cancelling {len(active_tasks)} active tasks for shutdown"
521
+ )
522
+ for task in active_tasks:
523
+ task.cancel()
524
+ break
525
+
526
+ elif cmd == "CALL":
527
+ # Create a task to handle this request concurrently
528
+ task = asyncio.create_task(process_call(request_data))
529
+ active_tasks.add(task)
530
+
531
+ # Clean up completed tasks periodically
532
+ active_tasks = {t for t in active_tasks if not t.done()}
533
+
534
+ except queue.Empty:
535
+ # Clean up completed tasks while waiting
536
+ active_tasks = {t for t in active_tasks if not t.done()}
537
+ except Exception as e:
538
+ logger.error(f"Error in async worker loop: {e}")
539
+
540
+ loop = asyncio.new_event_loop()
541
+ asyncio.set_event_loop(loop)
542
+ try:
543
+ loop.run_until_complete(main())
544
+ finally:
545
+ loop.close()
546
+
547
+
548
+ class DistributedProcessPool:
549
+ """Unified pool managing distributed processes with single router thread."""
550
+
551
+ def __init__(
552
+ self, process_class, num_processes, max_threads_per_proc=10, **process_kwargs
553
+ ):
554
+ self.process_class = process_class
555
+ self.num_processes = num_processes
556
+ self.max_threads_per_proc = max_threads_per_proc
557
+ self.process_kwargs = (
558
+ process_kwargs # Additional kwargs to pass to process constructor
559
+ )
560
+
561
+ # Processes and queues
562
+ self.processes = []
563
+ self.request_queues = [] # One request queue per process
564
+ self.response_queue = multiprocessing.Queue() # Single shared response queue
565
+
566
+ # Response routing with single router thread
567
+ self._router_thread = None
568
+ self._response_events = {} # Maps request_id to (threading.Event, response)
569
+ self._response_lock = threading.Lock()
570
+ self._running = False
571
+
572
+ def start(self):
573
+ """Start all processes in the pool and the single router thread."""
574
+ if self.processes:
575
+ raise RuntimeError("Pool already started")
576
+
577
+ # Create and start all processes in parallel
578
+ with ThreadPoolExecutor(max_workers=self.num_processes) as executor:
579
+ futures = []
580
+ for i in range(self.num_processes):
581
+ future = executor.submit(self._create_and_start_process, i)
582
+ futures.append(future)
583
+
584
+ # Wait for all processes to start
585
+ for future in futures:
586
+ future.result()
587
+
588
+ # Start single response router thread for entire pool
589
+ self._running = True
590
+ self._router_thread = threading.Thread(
591
+ target=self._response_router, daemon=True, name="PoolResponseRouter"
592
+ )
593
+ self._router_thread.start()
594
+ logger.debug(
595
+ f"Started {self.num_processes} processes with single router thread"
596
+ )
597
+
598
+ def _create_and_start_process(self, local_rank):
599
+ """Helper to create and start a single process."""
600
+ request_queue = multiprocessing.Queue()
601
+ self.request_queues.append(request_queue)
602
+
603
+ process = self.process_class(
604
+ local_rank=local_rank,
605
+ request_queue=request_queue,
606
+ response_queue=self.response_queue, # Shared response queue
607
+ max_threads=self.max_threads_per_proc,
608
+ **self.process_kwargs, # Pass additional framework-specific settings
609
+ )
610
+ process.start()
611
+ self.processes.append(process)
612
+
613
+ def stop(self):
614
+ """Stop all processes and the router thread."""
615
+ self._running = False
616
+
617
+ # Send shutdown signal to all processes (use put_nowait to avoid blocking)
618
+ for q in self.request_queues:
619
+ try:
620
+ q.put_nowait("SHUTDOWN")
621
+ except Exception:
622
+ pass
623
+
624
+ # Stop router thread
625
+ try:
626
+ self.response_queue.put_nowait("STOP_ROUTER")
627
+ except Exception:
628
+ pass
629
+
630
+ # Wait briefly for router to stop
631
+ if self._router_thread:
632
+ self._router_thread.join(timeout=0.5)
633
+
634
+ # Terminate all processes immediately without waiting
635
+ for process in self.processes:
636
+ if process.is_alive():
637
+ process.terminate()
638
+
639
+ # Give processes a brief chance to terminate gracefully (reduced timeout)
640
+ for process in self.processes:
641
+ process.join(timeout=0.5)
642
+
643
+ # Force kill any remaining processes
644
+ for process in self.processes:
645
+ if process.is_alive():
646
+ logger.warning(f"Force killing process {process.pid}")
647
+ process.kill()
648
+ process.join(timeout=0.1) # Brief wait to confirm kill
649
+
650
+ # Clear all queues
651
+ self._clear_queues()
652
+
653
+ # Reset state
654
+ self.processes.clear()
655
+ self.request_queues.clear()
656
+ with self._response_lock:
657
+ self._response_events.clear()
658
+
659
+ def call(
660
+ self,
661
+ idx,
662
+ method_name,
663
+ params,
664
+ deployed_as_of,
665
+ request_id,
666
+ distributed_env_vars,
667
+ debug_port,
668
+ serialization,
669
+ ):
670
+ """Call a specific process by index."""
671
+ if idx >= len(self.processes):
672
+ raise ValueError(
673
+ f"Process index {idx} out of range (have {len(self.processes)} processes)"
674
+ )
675
+
676
+ request_unique_id = str(uuid.uuid4())
677
+
678
+ # Register this request for response routing
679
+ event = threading.Event()
680
+ with self._response_lock:
681
+ self._response_events[request_unique_id] = (event, None)
682
+
683
+ try:
684
+ # Send request to specific process
685
+ self.request_queues[idx].put(
686
+ {
687
+ "request_unique_id": request_unique_id,
688
+ "method_name": method_name,
689
+ "params": params,
690
+ "deployed_as_of": deployed_as_of,
691
+ "request_id": request_id,
692
+ "distributed_env_vars": distributed_env_vars,
693
+ "debug_port": debug_port,
694
+ "serialization": serialization,
695
+ "process_idx": idx, # Include process index for debugging
696
+ }
697
+ )
698
+
699
+ # Wait for response
700
+ event.wait()
701
+
702
+ # Get and return the response
703
+ with self._response_lock:
704
+ _, result = self._response_events.pop(request_unique_id)
705
+
706
+ return result
707
+
708
+ except Exception:
709
+ # Clean up on error
710
+ with self._response_lock:
711
+ self._response_events.pop(request_unique_id, None)
712
+ raise
713
+
714
+ def call_all(
715
+ self,
716
+ method_name,
717
+ params_list,
718
+ deployed_as_of,
719
+ request_id,
720
+ distributed_env_vars_list,
721
+ debug_ports,
722
+ serialization,
723
+ ):
724
+ """Call all processes in parallel and return results."""
725
+ if len(params_list) != self.num_processes:
726
+ raise ValueError(
727
+ f"Expected {self.num_processes} param sets, got {len(params_list)}"
728
+ )
729
+
730
+ with ThreadPoolExecutor(max_workers=self.num_processes) as executor:
731
+ futures = []
732
+ for idx in range(self.num_processes):
733
+ future = executor.submit(
734
+ self.call,
735
+ idx=idx,
736
+ method_name=method_name,
737
+ params=params_list[idx],
738
+ deployed_as_of=deployed_as_of,
739
+ request_id=request_id,
740
+ distributed_env_vars=distributed_env_vars_list[idx],
741
+ debug_port=debug_ports[idx] if debug_ports else None,
742
+ serialization=serialization,
743
+ )
744
+ futures.append(future)
745
+
746
+ results = []
747
+ for future in futures:
748
+ results.append(future.result())
749
+
750
+ return results
751
+
752
+ def _response_router(self):
753
+ """Single router thread handling responses from all processes."""
754
+ while self._running:
755
+ try:
756
+ response = self.response_queue.get(timeout=1)
757
+ if response == "STOP_ROUTER":
758
+ break
759
+
760
+ request_id = response.get("request_unique_id")
761
+ with self._response_lock:
762
+ if request_id in self._response_events:
763
+ event, _ = self._response_events[request_id]
764
+ self._response_events[request_id] = (event, response["result"])
765
+ event.set()
766
+ else:
767
+ logger.warning(
768
+ f"Received response for unknown request: {request_id}"
769
+ )
770
+
771
+ except Exception as e:
772
+ if "Empty" not in str(e.__class__.__name__):
773
+ logger.debug(f"Response router error: {e}")
774
+ continue
775
+
776
+ def _clear_queues(self):
777
+ """Clear all pending items from queues."""
778
+ for q in self.request_queues:
779
+ try:
780
+ while not q.empty():
781
+ q.get_nowait()
782
+ except Exception:
783
+ pass
784
+
785
+ try:
786
+ while not self.response_queue.empty():
787
+ self.response_queue.get_nowait()
788
+ except Exception:
789
+ pass
790
+
791
+ def __len__(self):
792
+ return len(self.processes)
793
+
794
+ def __getitem__(self, idx):
795
+ return self.processes[idx]
796
+
797
+
798
+ class DistributedSupervisor:
799
+ def __init__(self, quorum_workers=None, quorum_timeout=300, monitor_members=True):
800
+ """
801
+ Base class for distributed supervisors. This class should be subclassed for specific distributed
802
+ environments like PyTorch or Ray.
803
+
804
+ Args:
805
+ config: Optional configuration object for the distributed environment.
806
+ """
807
+ # Set after creation by the factory function
808
+ self.quorum_workers = quorum_workers
809
+ self.quorum_timeout = quorum_timeout
810
+ self.monitor_members = monitor_members
811
+
812
+ self.config_hash = None
813
+
814
+ # DNS monitoring state
815
+ self._dns_monitor_thread = None
816
+ self._dns_monitor_running = False
817
+ self._current_workers = set()
818
+ self._workers_lock = threading.Lock()
819
+ self._membership_changes = queue.Queue()
820
+ self._change_subscribers = []
821
+ self._last_dns_check = 0
822
+ self._dns_check_interval = 5 # seconds
823
+
824
+ def pod_ips(self):
825
+ """Get pod IPs from DNS, waiting for quorum if specified.
826
+
827
+ Will wait up to quorum_timeout seconds for quorum_workers to appear in DNS.
828
+ If quorum_workers is not specified, returns immediately after first DNS query.
829
+ """
830
+ # Primarily for testing
831
+ if not is_running_in_kubernetes():
832
+ return os.environ["LOCAL_IPS"].split(",")
833
+
834
+ # Use DNS-based service discovery instead of Kubernetes API
835
+ # Check if pre-computed DNS name is available (should point to headless service for distributed)
836
+ service_dns = os.environ.get("KT_SERVICE_DNS")
837
+
838
+ if not service_dns:
839
+ # Fall back to computing DNS name from service and namespace
840
+ service_name = os.environ.get("KT_SERVICE_NAME")
841
+ namespace = os.environ.get("POD_NAMESPACE")
842
+
843
+ if not service_name:
844
+ raise RuntimeError("KT_SERVICE environment variable not found")
845
+ if not namespace:
846
+ raise RuntimeError("POD_NAMESPACE environment variable not found")
847
+
848
+ # Kubernetes headless service DNS name for distributed pod discovery
849
+ # Format: <service-name>-headless.<namespace>.svc.cluster.local
850
+ service_dns = f"{service_name}-headless.{namespace}.svc.cluster.local"
851
+
852
+ import socket
853
+ import time
854
+
855
+ start_time = time.time()
856
+ max_wait = self.quorum_timeout if self.quorum_timeout else 0
857
+ expected_workers = self.quorum_workers
858
+
859
+ pod_ips = []
860
+ last_count = 0
861
+
862
+ while True:
863
+ try:
864
+ # DNS lookup returns all pod IPs for the headless service
865
+ # getaddrinfo returns list of (family, type, proto, canonname, sockaddr)
866
+ addr_info = socket.getaddrinfo(service_dns, None, socket.AF_INET)
867
+
868
+ # Extract unique IP addresses from the results
869
+ pod_ips = sorted(list(set([addr[4][0] for addr in addr_info])))
870
+
871
+ if not pod_ips:
872
+ logger.debug(f"No pod IPs found for service {service_dns}")
873
+ else:
874
+ logger.debug(
875
+ f"Found {len(pod_ips)} pod IPs via DNS for {service_dns}: {pod_ips}"
876
+ )
877
+
878
+ except socket.gaierror as e:
879
+ logger.debug(f"DNS lookup failed for {service_dns}: {e}")
880
+ pod_ips = []
881
+
882
+ # Check if we should wait for more workers
883
+ elapsed = time.time() - start_time
884
+
885
+ # If we have the expected count, we're done
886
+ if expected_workers and len(pod_ips) >= expected_workers:
887
+ logger.info(
888
+ f"Found {len(pod_ips)}/{expected_workers} workers after {elapsed:.1f}s"
889
+ )
890
+ return pod_ips
891
+
892
+ # If we don't have expected count or timeout is reached, decide what to do
893
+ if elapsed >= max_wait:
894
+ if expected_workers:
895
+ logger.warning(
896
+ f"Only found {len(pod_ips)}/{expected_workers} workers after {elapsed:.1f}s timeout"
897
+ )
898
+ else:
899
+ logger.info(f"Found {len(pod_ips)} workers after {elapsed:.1f}s")
900
+ return pod_ips
901
+
902
+ # Log progress if count changed
903
+ if len(pod_ips) != last_count:
904
+ if expected_workers:
905
+ logger.info(
906
+ f"{len(pod_ips)}/{expected_workers} workers found, waiting for quorum..."
907
+ )
908
+ else:
909
+ logger.debug(f"{len(pod_ips)} workers found, no quorum set")
910
+ last_count = len(pod_ips)
911
+
912
+ # Wait before retrying
913
+ time.sleep(2)
914
+
915
+ def _get_pod_ips_fast(self):
916
+ """Get pod IPs from DNS without waiting for quorum - for monitoring only."""
917
+ # Primarily for testing
918
+ if not is_running_in_kubernetes():
919
+ return os.environ["LOCAL_IPS"].split(",")
920
+
921
+ # Use DNS-based service discovery
922
+ service_dns = os.environ.get("KT_SERVICE_DNS")
923
+
924
+ if not service_dns:
925
+ service_name = os.environ.get("KT_SERVICE")
926
+ namespace = os.environ.get("POD_NAMESPACE")
927
+
928
+ if not service_name or not namespace:
929
+ return []
930
+
931
+ service_dns = f"{service_name}-headless.{namespace}.svc.cluster.local"
932
+
933
+ import socket
934
+
935
+ try:
936
+ # Single DNS lookup, no retries, no waiting
937
+ addr_info = socket.getaddrinfo(service_dns, None, socket.AF_INET)
938
+ # Extract unique IP addresses
939
+ pod_ips = sorted(list(set([addr[4][0] for addr in addr_info])))
940
+ return pod_ips
941
+ except socket.gaierror:
942
+ # DNS lookup failed, return current known workers
943
+ with self._workers_lock:
944
+ return list(self._current_workers)
945
+
946
+ def start_dns_monitoring(self):
947
+ """Start DNS monitoring if not already running.
948
+ Should be called by coordinator nodes only."""
949
+ # Skip if monitoring is disabled (e.g., for Ray)
950
+ if not self.monitor_members:
951
+ logger.debug("DNS monitoring disabled for this supervisor")
952
+ return
953
+
954
+ with self._workers_lock:
955
+ if self._dns_monitor_thread and self._dns_monitor_thread.is_alive():
956
+ return # Already running
957
+
958
+ # Initialize with current workers
959
+ self._current_workers = set(self.pod_ips())
960
+ logger.debug(
961
+ f"Starting DNS monitor with {len(self._current_workers)} workers"
962
+ )
963
+
964
+ self._dns_monitor_running = True
965
+ self._dns_monitor_thread = threading.Thread(
966
+ target=self._monitor_worker_membership, daemon=True, name="DNSMonitor"
967
+ )
968
+ self._dns_monitor_thread.start()
969
+
970
+ def stop_dns_monitoring(self):
971
+ """Stop DNS monitoring thread."""
972
+ self._dns_monitor_running = False
973
+ if self._dns_monitor_thread:
974
+ self._dns_monitor_thread.join(timeout=2)
975
+ self._dns_monitor_thread = None
976
+
977
+ def _monitor_worker_membership(self):
978
+ """Monitor DNS for worker membership changes."""
979
+ check_interval = 3 # Start with 3 second checks (faster initial detection)
980
+
981
+ while self._dns_monitor_running:
982
+ try:
983
+ # Note that we start this after the delay, because we're doing a DNS check at
984
+ # the start of call_distributed anyway. This thread is only for the recurring checks
985
+ # as the call runs.
986
+ time.sleep(check_interval)
987
+
988
+ # Query DNS for current workers - use a faster version
989
+ current_ips = set(self._get_pod_ips_fast())
990
+
991
+ with self._workers_lock:
992
+ if current_ips != self._current_workers:
993
+ added = current_ips - self._current_workers
994
+ removed = self._current_workers - current_ips
995
+
996
+ change = {
997
+ "timestamp": time.time(),
998
+ "added": added,
999
+ "removed": removed,
1000
+ "previous": self._current_workers.copy(),
1001
+ "current": current_ips.copy(),
1002
+ }
1003
+
1004
+ if removed:
1005
+ logger.error(f"Workers REMOVED from cluster: {removed}")
1006
+ if added:
1007
+ logger.warning(f"Workers ADDED to cluster: {added}")
1008
+
1009
+ # Queue change and notify subscribers
1010
+ self._membership_changes.put(change)
1011
+ for event in self._change_subscribers:
1012
+ event.set()
1013
+
1014
+ self._current_workers = current_ips
1015
+
1016
+ time.sleep(check_interval)
1017
+
1018
+ except Exception as e:
1019
+ logger.error(f"DNS monitor error: {e}")
1020
+ time.sleep(3)
1021
+
1022
+ def subscribe_to_membership_changes(self):
1023
+ """Subscribe to worker membership changes.
1024
+ Returns an event that will be set when changes occur."""
1025
+ event = threading.Event()
1026
+ with self._workers_lock:
1027
+ self._change_subscribers.append(event)
1028
+ return event
1029
+
1030
+ def unsubscribe_from_membership_changes(self, event):
1031
+ """Unsubscribe from worker membership changes."""
1032
+ with self._workers_lock:
1033
+ if event in self._change_subscribers:
1034
+ self._change_subscribers.remove(event)
1035
+
1036
+ def check_for_membership_changes(self, force_dns_check=False):
1037
+ """Check for membership changes and raise exception if any occurred.
1038
+
1039
+ Args:
1040
+ force_dns_check: If True, immediately query DNS to check for changes
1041
+ instead of relying on the monitoring thread
1042
+ """
1043
+ # Skip if monitoring is disabled (e.g., for Ray)
1044
+ if not self.monitor_members:
1045
+ return
1046
+ # Force an immediate DNS check if requested
1047
+ if force_dns_check:
1048
+ # Use fast DNS query for immediate check
1049
+ current_ips = set(self._get_pod_ips_fast())
1050
+ with self._workers_lock:
1051
+ if current_ips != self._current_workers:
1052
+ added = current_ips - self._current_workers
1053
+ removed = self._current_workers - current_ips
1054
+
1055
+ # Import here to avoid circular dependency
1056
+ from kubetorch.servers.http.utils import WorkerMembershipChanged
1057
+
1058
+ # Update current workers
1059
+ self._current_workers = current_ips
1060
+
1061
+ # Log the change
1062
+ if removed:
1063
+ logger.error(
1064
+ f"Workers REMOVED from cluster (forced check): {removed}"
1065
+ )
1066
+ if added:
1067
+ logger.warning(
1068
+ f"Workers ADDED to cluster (forced check): {added}"
1069
+ )
1070
+
1071
+ raise WorkerMembershipChanged(
1072
+ added_ips=added,
1073
+ removed_ips=removed,
1074
+ previous_ips=self._current_workers.copy(),
1075
+ current_ips=current_ips,
1076
+ )
1077
+
1078
+ # Check queued changes from monitoring thread
1079
+ try:
1080
+ change = self._membership_changes.get_nowait()
1081
+
1082
+ # Import here to avoid circular dependency
1083
+ from kubetorch.servers.http.utils import WorkerMembershipChanged
1084
+
1085
+ raise WorkerMembershipChanged(
1086
+ added_ips=change["added"],
1087
+ removed_ips=change["removed"],
1088
+ previous_ips=change["previous"],
1089
+ current_ips=change["current"],
1090
+ )
1091
+ except queue.Empty:
1092
+ pass # No changes
1093
+
1094
+ def setup(self, deployed_as_of: Optional[str] = None):
1095
+ # This method should be overridden by subclasses to set up the distributed environment
1096
+ raise NotImplementedError("setup() must be implemented by subclasses")
1097
+
1098
+ def cleanup(self):
1099
+ """Base cleanup - stop DNS monitoring. Subclasses should call super().cleanup()"""
1100
+ self.stop_dns_monitoring()
1101
+ # Subclasses should override and call super().cleanup() to add their own cleanup
1102
+
1103
+ def intercept_call(self):
1104
+ # This method should be overridden by subclasses to indicate whether to intercept calls
1105
+ raise NotImplementedError("intercept_call() must be implemented by subclasses")
1106
+
1107
+ def call_distributed(
1108
+ self,
1109
+ request,
1110
+ cls_or_fn_name: str,
1111
+ method_name: Optional[str] = None,
1112
+ params: Optional[Dict] = None,
1113
+ distributed_subcall: bool = False,
1114
+ debug_port: int = False,
1115
+ deployed_as_of: Optional[str] = None,
1116
+ ):
1117
+ # if intercept_call is True, this method should be overridden by subclasses to handle distributing and/or
1118
+ # supervising the distributed execution
1119
+ raise NotImplementedError(
1120
+ "call_distributed() must be implemented by subclasses"
1121
+ )
1122
+
1123
+
1124
+ class DistributedProcess(multiprocessing.Process):
1125
+ """Base class for distributed processes that run callables in subprocesses."""
1126
+
1127
+ def __init__(
1128
+ self, local_rank, request_queue, response_queue, max_threads=4, **kwargs
1129
+ ):
1130
+ super().__init__()
1131
+ # We don't need the cache miss / reload here because these processes are destroyed and recreated
1132
+ # with each .to call.
1133
+ os.environ["LOCAL_RANK"] = str(local_rank)
1134
+ self._request_queue = request_queue
1135
+ self._response_queue = response_queue
1136
+ self._max_threads = max_threads
1137
+ self._executor = None
1138
+ # Store any additional framework-specific settings
1139
+ self._settings = kwargs
1140
+
1141
+ def proc_cleanup(self):
1142
+ """Override this method to provide framework-specific cleanup."""
1143
+ logger.info("Cleaning up debugging sessions...")
1144
+ clear_debugging_sessions()
1145
+ logger.info("Debugging sessions cleaned up.")
1146
+
1147
+ # Cleanup thread pool
1148
+ if self._executor:
1149
+ self._executor.shutdown(wait=False)
1150
+ self._executor = None
1151
+
1152
+ @classmethod
1153
+ def get_distributed_env_vars(
1154
+ cls, worker_ips, node_rank, local_rank, num_local_procs, **settings
1155
+ ):
1156
+ """Get framework-specific distributed environment variables.
1157
+
1158
+ Args:
1159
+ worker_ips: List of all worker IPs
1160
+ node_rank: Rank of this node (0-indexed)
1161
+ local_rank: Local rank on this node (0-indexed)
1162
+ num_local_procs: Number of processes on this node
1163
+ **settings: Additional framework-specific settings (e.g., port)
1164
+
1165
+ Returns:
1166
+ Dict of environment variables to set
1167
+ """
1168
+ # Base implementation - no special env vars needed
1169
+ return {
1170
+ "WORLD_SIZE": str(len(worker_ips) * num_local_procs),
1171
+ "RANK": str(node_rank * num_local_procs + local_rank),
1172
+ "LOCAL_RANK": str(local_rank),
1173
+ "NODE_RANK": str(node_rank),
1174
+ "POD_IPS": ",".join(worker_ips),
1175
+ }
1176
+
1177
+ @classmethod
1178
+ def get_auto_num_processes(cls):
1179
+ """Auto-detect the number of processes to use."""
1180
+ return 1
1181
+
1182
+ def handle_request(self, request):
1183
+ """Handle a single request in a thread."""
1184
+ try:
1185
+ request_unique_id = request["request_unique_id"]
1186
+ method_name = request["method_name"]
1187
+ params = request["params"]
1188
+ deployed_as_of = request["deployed_as_of"]
1189
+ request_id = request["request_id"]
1190
+ distributed_env_vars = request["distributed_env_vars"]
1191
+ debug_port = request["debug_port"]
1192
+ serialization = request["serialization"]
1193
+
1194
+ # Set the request ID in the context for this thread
1195
+ token = request_id_ctx_var.set(request_id)
1196
+
1197
+ # Set the environment variables for this thread (note: os.environ is process-wide, might need thread-local storage)
1198
+ # For distributed PyTorch calls, these should already be set at process level
1199
+ for key, value in distributed_env_vars.items():
1200
+ os.environ[key] = value
1201
+
1202
+ try:
1203
+ # Load callable if not already loaded or if deployed_as_of changed
1204
+ callable_obj = load_callable(
1205
+ deployed_as_of=deployed_as_of,
1206
+ distributed_subprocess=True,
1207
+ reload_cleanup_fn=self.proc_cleanup,
1208
+ )
1209
+
1210
+ result = run_callable_internal_sync(
1211
+ callable_obj=callable_obj,
1212
+ cls_or_fn_name=os.environ["KT_CLS_OR_FN_NAME"],
1213
+ method_name=method_name,
1214
+ params=params,
1215
+ serialization=serialization,
1216
+ debug_port=debug_port,
1217
+ )
1218
+
1219
+ # Reset the request ID after the call is complete
1220
+ request_id_ctx_var.reset(token)
1221
+
1222
+ # Send response back with the unique ID
1223
+ self._response_queue.put(
1224
+ {"request_unique_id": request_unique_id, "result": result}
1225
+ )
1226
+
1227
+ except Exception as e:
1228
+ # Reset the request ID even if there was an error
1229
+ request_id_ctx_var.reset(token)
1230
+
1231
+ # Package the exception
1232
+ try:
1233
+ packaged_exception = package_exception(e)
1234
+ except Exception as f:
1235
+ packaged_exception = f
1236
+
1237
+ self._response_queue.put(
1238
+ {
1239
+ "request_unique_id": request_unique_id,
1240
+ "result": packaged_exception,
1241
+ }
1242
+ )
1243
+
1244
+ except Exception as e:
1245
+ # Last resort error handling
1246
+ logger.error(f"Fatal error handling request: {e}")
1247
+ self._response_queue.put(
1248
+ {
1249
+ "request_unique_id": request.get("request_unique_id", "unknown"),
1250
+ "result": Exception(f"Fatal error in thread: {e}"),
1251
+ }
1252
+ )
1253
+
1254
+ def run(self):
1255
+ """Main process loop with thread pool for concurrent request handling."""
1256
+ # Create thread pool for handling requests
1257
+ self._executor = ThreadPoolExecutor(max_workers=self._max_threads)
1258
+
1259
+ try:
1260
+ while True:
1261
+ try:
1262
+ # Block waiting for next request
1263
+ request = self._request_queue.get(timeout=1)
1264
+
1265
+ # Special sentinel value to signal shutdown
1266
+ if request == "SHUTDOWN":
1267
+ break
1268
+
1269
+ # Submit request to thread pool for concurrent handling
1270
+ # Check executor exists in case we're shutting down
1271
+ if self._executor:
1272
+ self._executor.submit(self.handle_request, request)
1273
+ else:
1274
+ logger.warning(
1275
+ "Executor is None, skipping request (likely shutting down)"
1276
+ )
1277
+
1278
+ except Exception as e:
1279
+ # Timeout is normal, continue loop
1280
+ if "Empty" not in str(e.__class__.__name__):
1281
+ logger.error(f"Error getting request from queue: {e}")
1282
+ continue
1283
+
1284
+ except (KeyboardInterrupt, BdbQuit):
1285
+ logger.info("Process interrupted, shutting down...")
1286
+
1287
+ finally:
1288
+ # Cleanup
1289
+ logger.info(
1290
+ "Received shutdown signal, cleaning up distributed environment..."
1291
+ )
1292
+ self.proc_cleanup()
1293
+ logger.info("Exiting gracefully.")
1294
+
1295
+
1296
+ class PyTorchProcess(DistributedProcess):
1297
+ """PyTorch-specific distributed process."""
1298
+
1299
+ def proc_cleanup(self):
1300
+ import torch.distributed as dist
1301
+
1302
+ try:
1303
+ dist.destroy_process_group()
1304
+ logger.info("Destroyed PyTorch process group.")
1305
+ except Exception:
1306
+ logger.info(
1307
+ "Failed to destroy PyTorch process group, it may not have been initialized: {e}"
1308
+ )
1309
+ pass
1310
+ # Call parent cleanup for debugging sessions
1311
+ super().proc_cleanup()
1312
+
1313
+ @classmethod
1314
+ def get_distributed_env_vars(
1315
+ cls, worker_ips, node_rank, local_rank, num_local_procs, **settings
1316
+ ):
1317
+ """Get PyTorch-specific distributed environment variables."""
1318
+ port = settings.get("port") or 12345
1319
+ env_vars = super().get_distributed_env_vars(
1320
+ worker_ips, node_rank, local_rank, num_local_procs, **settings
1321
+ )
1322
+ env_vars.update(
1323
+ {
1324
+ "MASTER_ADDR": worker_ips[0],
1325
+ "MASTER_PORT": str(port),
1326
+ }
1327
+ )
1328
+ return env_vars
1329
+
1330
+ @classmethod
1331
+ def get_auto_num_processes(cls):
1332
+ """Auto-detect based on GPU availability for PyTorch."""
1333
+ try:
1334
+ import torch
1335
+
1336
+ if torch.cuda.is_available():
1337
+ return torch.cuda.device_count()
1338
+ except ImportError:
1339
+ pass
1340
+ return 1 # Could use os.cpu_count() for CPU-only training
1341
+
1342
+
1343
+ class RayProcess(DistributedProcess):
1344
+ """Ray-specific distributed process."""
1345
+
1346
+ def proc_cleanup(self):
1347
+ try:
1348
+ import ray
1349
+
1350
+ if ray.is_initialized():
1351
+ ray.shutdown()
1352
+ logger.info("Ray shutdown completed.")
1353
+ except ImportError:
1354
+ logger.info("Ray not available for cleanup")
1355
+ except Exception as e:
1356
+ logger.info(f"Failed to shutdown Ray: {e}")
1357
+ # Call parent cleanup for debugging sessions
1358
+ super().proc_cleanup()
1359
+
1360
+
1361
+ class SPMDDistributedSupervisor(DistributedSupervisor):
1362
+ """Base class for SPMD (Single Program Multiple Data) distributed supervisors.
1363
+
1364
+ This class provides common functionality for frameworks that follow the SPMD pattern
1365
+ where the same program runs on multiple processes with different data partitions.
1366
+ """
1367
+
1368
+ def __init__(
1369
+ self,
1370
+ process_class=None,
1371
+ num_proc=None,
1372
+ port=None,
1373
+ restart_procs=True,
1374
+ max_threads_per_proc=10,
1375
+ quorum_timeout=300,
1376
+ quorum_workers=None,
1377
+ monitor_members=True,
1378
+ tree_fanout=50,
1379
+ tree_minimum=100,
1380
+ **process_kwargs,
1381
+ ):
1382
+ super().__init__(
1383
+ quorum_workers=quorum_workers,
1384
+ quorum_timeout=quorum_timeout,
1385
+ monitor_members=monitor_members,
1386
+ )
1387
+ self.process_class = process_class or DistributedProcess
1388
+ self.num_proc = num_proc or "auto"
1389
+ self.port = port
1390
+ self.restart_procs = restart_procs
1391
+ self.max_threads_per_proc = max_threads_per_proc
1392
+ self.process_pool = None
1393
+ self.remote_worker_pool = None # Pool for async HTTP calls to remote workers
1394
+ self.process_kwargs = (
1395
+ process_kwargs # Additional settings to pass to process class
1396
+ )
1397
+ self.tree_fanout = tree_fanout
1398
+ self.tree_minimum = tree_minimum
1399
+
1400
+ def setup(self, deployed_as_of: Optional[str] = None):
1401
+ # Set multiprocessing to spawn if not already
1402
+ if multiprocessing.get_start_method() != "spawn":
1403
+ multiprocessing.set_start_method("spawn", force=True)
1404
+
1405
+ # Get number of processes
1406
+ if self.num_proc == "auto":
1407
+ num_proc = self.process_class.get_auto_num_processes()
1408
+ else:
1409
+ num_proc = self.num_proc
1410
+
1411
+ if self.restart_procs:
1412
+ logger.debug("restart_procs is True, restarting distributed processes")
1413
+ self.cleanup()
1414
+
1415
+ # If the number of processes has changed, we need to clean up the old ones and recreate them
1416
+ if self.process_pool is None or len(self.process_pool) != num_proc:
1417
+ if self.process_pool:
1418
+ logger.debug(
1419
+ f"Number of processes changed from {len(self.process_pool)} to {num_proc}, restarting processes."
1420
+ )
1421
+ self.cleanup()
1422
+
1423
+ logger.debug("Setting up distributed environment")
1424
+ self.process_pool = DistributedProcessPool(
1425
+ process_class=self.process_class,
1426
+ num_processes=num_proc,
1427
+ max_threads_per_proc=self.max_threads_per_proc,
1428
+ **self.process_kwargs, # Pass any additional settings
1429
+ )
1430
+
1431
+ # Start all processes (now handled internally by the pool)
1432
+ self.process_pool.start()
1433
+
1434
+ self.remote_worker_pool = RemoteWorkerPool(
1435
+ quorum_timeout=self.quorum_timeout
1436
+ )
1437
+ self.remote_worker_pool.start()
1438
+ logger.debug("Finished setting up distributed processes")
1439
+
1440
+ def cleanup(self):
1441
+ # Cleanup the processes
1442
+ logger.debug(f"Cleaning up {self.__class__.__name__} distributed processes")
1443
+
1444
+ # Stop DNS monitoring first
1445
+ super().cleanup()
1446
+
1447
+ if self.process_pool:
1448
+ self.process_pool.stop()
1449
+ self.process_pool = None
1450
+
1451
+ if self.remote_worker_pool:
1452
+ self.remote_worker_pool.stop()
1453
+ self.remote_worker_pool = None
1454
+
1455
+ logger.debug(
1456
+ f"Finished cleaning up {self.__class__.__name__} distributed processes"
1457
+ )
1458
+
1459
+ @staticmethod
1460
+ def intercept_call():
1461
+ return True
1462
+
1463
+ def get_tree_children(self, sorted_ips: list, my_ip: str, fanout: int = 100):
1464
+ """Calculate children nodes in a self-organizing tree based on IP indexing.
1465
+
1466
+ Args:
1467
+ sorted_ips: List of all worker IPs sorted deterministically
1468
+ my_ip: This node's IP address
1469
+ fanout: Maximum number of children per node (default 100)
1470
+
1471
+ Returns:
1472
+ List of IP addresses that are children of this node
1473
+ """
1474
+ try:
1475
+ my_index = sorted_ips.index(my_ip)
1476
+ except ValueError:
1477
+ # If not found in list, this node has no children
1478
+ return []
1479
+
1480
+ # Calculate the range of children indices
1481
+ # In a tree with fanout F, node at index i has children at indices:
1482
+ # [i*F + 1, i*F + 2, ..., i*F + F]
1483
+ first_child_idx = my_index * fanout + 1
1484
+ last_child_idx = min(first_child_idx + fanout, len(sorted_ips))
1485
+
1486
+ if first_child_idx >= len(sorted_ips):
1487
+ # No children (leaf node)
1488
+ return []
1489
+
1490
+ children = sorted_ips[first_child_idx:last_child_idx]
1491
+ if len(children) > 0:
1492
+ logger.debug(
1493
+ f"Tree topology: Node {my_ip} (index {my_index}) has {len(children)} children "
1494
+ f"(indices {first_child_idx}-{last_child_idx-1})"
1495
+ )
1496
+ return children
1497
+
1498
+ def call_distributed(
1499
+ self,
1500
+ request,
1501
+ cls_or_fn_name: str,
1502
+ method_name: Optional[str] = None,
1503
+ params: Optional[Dict] = None,
1504
+ distributed_subcall: bool = False,
1505
+ debug_port: int = False,
1506
+ deployed_as_of: Optional[str] = None,
1507
+ ):
1508
+ # Get the request ID from the headers
1509
+ request_id = request.headers.get("X-Request-ID", "-")
1510
+ serialization = request.headers.get("X-Serialization", "json")
1511
+ params = params or {}
1512
+
1513
+ # If deployed_as_of is None and we're the coordinator, generate a consistent timestamp
1514
+ # to use across all workers to prevent reload inconsistencies
1515
+ if not distributed_subcall and deployed_as_of is None:
1516
+ from datetime import datetime, timezone
1517
+
1518
+ deployed_as_of = datetime.now(timezone.utc).isoformat()
1519
+
1520
+ # Get all the pods in the service, and use the first one as the master.
1521
+ # Set the env vars based on whether this is a master or worker
1522
+ logger.debug(
1523
+ f"Configuring distributed environment, distributed_subcall={distributed_subcall}"
1524
+ )
1525
+ this_pod_ip = os.environ["POD_IP"]
1526
+ logger.debug(f"This pod IP: {this_pod_ip}")
1527
+
1528
+ # Start DNS monitoring for coordinator nodes
1529
+ change_event = None
1530
+ if not distributed_subcall:
1531
+ # First wait for quorum before starting monitoring
1532
+ worker_ips = self.pod_ips()
1533
+ # sort the worker IPs to ensure generally consistent tree ordering
1534
+ # (avoids thrashing connections due to reordering)
1535
+ worker_ips.sort()
1536
+ # For tree topology, coordinator is always root (index 0)
1537
+ # Move coordinator to front if not already there
1538
+ if this_pod_ip in worker_ips:
1539
+ worker_ips.remove(this_pod_ip)
1540
+ worker_ips.insert(0, this_pod_ip) # Move coordinator to root of tree
1541
+ logger.debug(f"Acting as COORDINATOR - discovered worker IPs: {worker_ips}")
1542
+ logger.debug(f"Pod IPs: {worker_ips}")
1543
+
1544
+ # Check if this call uses flexible worker selection (workers="ready")
1545
+ # If so, don't start DNS monitoring as the worker set is expected to be flexible
1546
+ workers_arg = params.get("workers") if params else None
1547
+ should_monitor = workers_arg not in ["ready", "any"]
1548
+
1549
+ if should_monitor:
1550
+ # Now that we have quorum, start DNS monitoring
1551
+ # Start monitoring (idempotent - won't start if already running)
1552
+ self.start_dns_monitoring()
1553
+
1554
+ # Subscribe to membership changes
1555
+ change_event = self.subscribe_to_membership_changes()
1556
+
1557
+ # Check for any pending changes after starting monitor
1558
+ self.check_for_membership_changes(force_dns_check=True)
1559
+ else:
1560
+ logger.debug("Skipping DNS monitoring for workers='ready' call")
1561
+
1562
+ # Update distributed env vars to use the tree-ordered IPs
1563
+ distributed_env_vars = {
1564
+ "POD_IPS": ",".join(worker_ips),
1565
+ }
1566
+ else:
1567
+ logger.debug(
1568
+ f"Acting as WORKER (distributed_subcall=True) at {this_pod_ip}"
1569
+ )
1570
+ logger.debug(
1571
+ f"Worker received params keys: {list(params.keys()) if params else 'None'}"
1572
+ )
1573
+ distributed_env_vars = (
1574
+ params.pop("distributed_env_vars", None) if params else None
1575
+ )
1576
+ logger.debug(f"Using distributed_env_vars: {distributed_env_vars}")
1577
+ if not distributed_env_vars:
1578
+ logger.error(f"No distributed_env_vars found in params: {params}")
1579
+ raise RuntimeError(
1580
+ "distributed_env_vars must be provided for distributed subcalls"
1581
+ )
1582
+ worker_ips = distributed_env_vars["POD_IPS"].split(",")
1583
+
1584
+ # Don't debug for subcalls, we only want to debug one process
1585
+ debug_port = None
1586
+
1587
+ # Decide topology based on cluster size
1588
+ subcall_ips = []
1589
+ num_workers = len(worker_ips)
1590
+ tree_mode = num_workers >= self.tree_minimum
1591
+ if tree_mode:
1592
+ # Use tree topology for large clusters
1593
+ if distributed_subcall:
1594
+ logger.debug(
1595
+ f"Using tree topology for {num_workers} workers (> {self.tree_minimum} threshold) "
1596
+ f"with fanout {self.tree_fanout}"
1597
+ )
1598
+
1599
+ # Calculate direct children in the tree
1600
+ subcall_ips = self.get_tree_children(
1601
+ sorted_ips=worker_ips,
1602
+ my_ip=this_pod_ip,
1603
+ fanout=self.tree_fanout, # Each node can have up to 50 children
1604
+ )
1605
+ logger.debug(
1606
+ f"Tree node {this_pod_ip} will call {len(subcall_ips)} direct children"
1607
+ )
1608
+ elif not distributed_subcall:
1609
+ # Use worker ip list as is for coordiantor node in flat topology
1610
+ # Leave subcall_ips = [] for workers in flat topology
1611
+ subcall_ips = copy.deepcopy(worker_ips)
1612
+ if this_pod_ip in subcall_ips:
1613
+ subcall_ips.remove(this_pod_ip)
1614
+ logger.debug(
1615
+ f"Removed self ({this_pod_ip}) from subcall list, will call: {subcall_ips}"
1616
+ )
1617
+ else:
1618
+ # This can happen with headless services where POD_IP might not exactly match DNS results
1619
+ # Try to match by partial IP or hostname
1620
+ logger.warning(
1621
+ f"This pod IP {this_pod_ip} not found in DNS-discovered IPs {worker_ips}. "
1622
+ f"This may indicate DNS propagation delay or hostname/IP mismatch."
1623
+ )
1624
+ # Still proceed as coordinator but will call all discovered pods
1625
+ logger.debug(f"Will call all discovered workers: {subcall_ips}")
1626
+
1627
+ # Always pass the distributed environment variables to workers
1628
+ params["distributed_env_vars"] = distributed_env_vars
1629
+
1630
+ # "workers" is passed through as a regular kwarg, not a special extracted one like pdb or serialization
1631
+ # because it only applies to distributed calls, not regular ones.
1632
+ call_local_procs = True
1633
+ workers_arg = params.get("workers", None)
1634
+ if workers_arg:
1635
+ logger.debug(f"Filtering workers by argument: {workers_arg}")
1636
+ if isinstance(workers_arg, list) and workers_arg:
1637
+ # Build a set of IPs to include based on the list items
1638
+ target_ips = set()
1639
+
1640
+ for item in workers_arg:
1641
+ if isinstance(item, str) and "." in item:
1642
+ # It's an IP address
1643
+ if item not in worker_ips:
1644
+ raise ValueError(
1645
+ f"Worker IP '{item}' not found in available workers: {worker_ips}"
1646
+ )
1647
+ target_ips.add(item)
1648
+ elif isinstance(item, int) or (
1649
+ isinstance(item, str) and item.isdigit()
1650
+ ):
1651
+ # It's an index
1652
+ idx = int(item) if isinstance(item, str) else item
1653
+ if idx < 0 or idx >= len(worker_ips):
1654
+ raise ValueError(
1655
+ f"Worker index {idx} out of range. Valid range: 0-{len(worker_ips)-1}"
1656
+ )
1657
+ target_ips.add(worker_ips[idx])
1658
+ else:
1659
+ raise ValueError(
1660
+ f"Invalid worker specification: {item}. Must be an IP address, "
1661
+ f"integer index, or numeric string."
1662
+ )
1663
+
1664
+ # Filter subcall_ips to only those in the target set
1665
+ subcall_ips = [ip for ip in subcall_ips if ip in target_ips]
1666
+
1667
+ # Check if current pod should participate in local processing
1668
+ if this_pod_ip not in target_ips:
1669
+ call_local_procs = False
1670
+
1671
+ elif workers_arg == "any":
1672
+ # Only call one worker (this one)
1673
+ subcall_ips = []
1674
+ elif workers_arg == "ready":
1675
+ # Filter below in call_worker to only those that respond to /health
1676
+ pass
1677
+ elif isinstance(workers_arg, str) and workers_arg:
1678
+ # Filter the subcall_ips to only those matching the workers_arg string
1679
+ subcall_ips = [ip for ip in subcall_ips if workers_arg in ip]
1680
+ logger.debug(f"Subcall IPs after filtering: {subcall_ips}")
1681
+
1682
+ # "restart_procs" is passed through as a regular kwarg, not a special extracted one like pdb or serialization
1683
+ # because it only applies to distributed calls, not regular ones.
1684
+ if params.get("restart_procs", False):
1685
+ logger.info("restart_procs parameter is True, restarting processes")
1686
+ self.cleanup()
1687
+ self.setup(deployed_as_of)
1688
+
1689
+ try:
1690
+ node_rank = worker_ips.index(this_pod_ip)
1691
+ except ValueError:
1692
+ # This pod IP not found in DNS results - may be external service IP
1693
+ # Fall back to using POD_IP directly and assume node_rank based on position
1694
+ logger.warning(
1695
+ f"Pod IP {this_pod_ip} not found in DNS results {worker_ips}. Using fallback logic."
1696
+ )
1697
+ # For now, assume we're the first worker if not found
1698
+ node_rank = 0
1699
+
1700
+ # Call the workers using RemoteWorkerPool for async operations
1701
+ def call_worker(worker_ip):
1702
+ # Keep this function for backward compatibility but it won't be used
1703
+ # when RemoteWorkerPool is available
1704
+ with httpx.Client(timeout=None) as client:
1705
+ port = os.environ["KT_SERVER_PORT"]
1706
+ worker_url = f"http://{worker_ip}:{port}"
1707
+ # First check that the worker is alive, replicas don't finish setup at exactly the same moment
1708
+ # Use quorum_timeout to control how long to wait for workers
1709
+ for i in range(int(self.quorum_timeout)):
1710
+ try:
1711
+ resp = client.get(f"{worker_url}/health")
1712
+ if resp.status_code == 200:
1713
+ break
1714
+ except httpx.RequestError:
1715
+ if workers_arg == "ready":
1716
+ logger.debug(
1717
+ f"Worker {worker_ip} not ready, skipping as per 'ready' workers argument"
1718
+ )
1719
+ return None
1720
+ time.sleep(1)
1721
+ else:
1722
+ # Timeout reached without successful health check
1723
+ logger.warning(
1724
+ f"Worker {worker_ip} failed to respond after {self.quorum_timeout}s timeout"
1725
+ )
1726
+ if workers_arg != "ready":
1727
+ raise TimeoutError(
1728
+ f"Worker {worker_ip} did not become ready within {self.quorum_timeout} seconds. "
1729
+ "This may indicate the pod is still starting or there's a resource constraint. "
1730
+ "Consider increasing quorum_timeout in .distribute() call."
1731
+ )
1732
+
1733
+ call_url = (
1734
+ f"{worker_url}/{cls_or_fn_name}/{method_name}?distributed_subcall=true"
1735
+ if method_name is not None
1736
+ else f"{worker_url}/{cls_or_fn_name}?distributed_subcall=true"
1737
+ )
1738
+
1739
+ # Clean headers to avoid potential Content-Length issues
1740
+ clean_headers = {}
1741
+ if request.headers:
1742
+ for key, value in request.headers.items():
1743
+ # Skip headers that could interfere with httpx's automatic handling
1744
+ if key.lower() not in [
1745
+ "content-length",
1746
+ "transfer-encoding",
1747
+ "connection",
1748
+ ]:
1749
+ clean_headers[key] = value
1750
+
1751
+ try:
1752
+ logger.debug(f"Making distributed call to {worker_url}")
1753
+ resp = client.post(
1754
+ url=call_url,
1755
+ json=params,
1756
+ headers=clean_headers, # Includes deployed_as_of and request_id
1757
+ )
1758
+ return resp
1759
+ except (httpx.RequestError, httpx.HTTPError) as e:
1760
+ logger.error(f"Failed to call worker {worker_url}: {e}")
1761
+ raise
1762
+
1763
+ # Prepare per-process parameters
1764
+ num_procs = len(self.process_pool)
1765
+ params_list = [params] * num_procs
1766
+ distributed_env_vars_list = []
1767
+ debug_ports = []
1768
+
1769
+ for idx in range(num_procs):
1770
+ # Get framework-specific env vars from the process class
1771
+ env_vars = self.process_class.get_distributed_env_vars(
1772
+ worker_ips=worker_ips
1773
+ if "worker_ips" in locals()
1774
+ else distributed_env_vars.get("POD_IPS", "").split(","),
1775
+ node_rank=node_rank,
1776
+ local_rank=idx,
1777
+ num_local_procs=num_procs,
1778
+ port=self.port,
1779
+ )
1780
+ # Add any base env vars
1781
+ env_vars.update(distributed_env_vars)
1782
+ distributed_env_vars_list.append(env_vars)
1783
+
1784
+ # Only debug one process and if debug_port is set
1785
+ debug = debug_port and idx == num_procs - 1
1786
+ if debug:
1787
+ time.sleep(0.25)
1788
+ debug_ports.append(debug_port if debug else None)
1789
+
1790
+ # Execute distributed calls in parallel with local processes
1791
+ worker_responses = []
1792
+ worker_exception = None
1793
+ local_exception = None
1794
+
1795
+ # Start both remote and local calls in parallel
1796
+ executor = ThreadPoolExecutor(max_workers=2)
1797
+
1798
+ # Submit remote worker calls if needed
1799
+ worker_future = None
1800
+ if subcall_ips:
1801
+ logger.debug(f"Have {len(subcall_ips)} remote workers to call")
1802
+ if not self.remote_worker_pool:
1803
+ raise RuntimeError(
1804
+ "RemoteWorkerPool not initialized. This is required for distributed execution."
1805
+ )
1806
+ logger.debug(
1807
+ f"Using existing RemoteWorkerPool to call {len(subcall_ips)} workers"
1808
+ )
1809
+
1810
+ def call_remote_workers():
1811
+ nonlocal worker_exception
1812
+ try:
1813
+ # Prepare headers for remote workers
1814
+ clean_headers = {}
1815
+ if request.headers:
1816
+ for key, value in request.headers.items():
1817
+ if key.lower() not in [
1818
+ "content-length",
1819
+ "transfer-encoding",
1820
+ "connection",
1821
+ ]:
1822
+ clean_headers[key] = value
1823
+ # Always include deployed_as_of in headers for consistency
1824
+ if deployed_as_of:
1825
+ clean_headers["X-Deployed-As-Of"] = deployed_as_of
1826
+
1827
+ # Call remote workers asynchronously through the pool
1828
+ logger.debug(
1829
+ f"Calling {len(subcall_ips)} remote workers via RemoteWorkerPool: {subcall_ips}"
1830
+ )
1831
+ results = self.remote_worker_pool.call_workers(
1832
+ worker_ips=subcall_ips,
1833
+ cls_or_fn_name=cls_or_fn_name,
1834
+ method_name=method_name,
1835
+ params=params,
1836
+ request_headers=clean_headers,
1837
+ workers_arg=workers_arg,
1838
+ )
1839
+ logger.warning(
1840
+ f"RemoteWorkerPool returned {len(results) if results else 0} results from {len(subcall_ips)} workers"
1841
+ )
1842
+ return results
1843
+ except Exception as e:
1844
+ # Check if this is a connection error - might indicate worker removal
1845
+ if any(
1846
+ err_type in str(e)
1847
+ for err_type in [
1848
+ "ReadError",
1849
+ "TimeoutException",
1850
+ "RequestError",
1851
+ "HTTPError",
1852
+ "ConnectionError",
1853
+ "Connection reset",
1854
+ "Connection closed",
1855
+ ]
1856
+ ):
1857
+ logger.debug(
1858
+ f"Connection error detected: {e}, checking for membership changes"
1859
+ )
1860
+ # Force DNS check to see if workers were removed
1861
+ self.check_for_membership_changes(force_dns_check=True)
1862
+ worker_exception = e
1863
+ raise
1864
+
1865
+ worker_future = executor.submit(call_remote_workers)
1866
+ logger.debug(
1867
+ f"Submitted worker_future for {len(subcall_ips)} remote workers"
1868
+ )
1869
+
1870
+ else:
1871
+ logger.debug(
1872
+ f"No remote workers to call (subcall_ips is empty or None: {subcall_ips})"
1873
+ )
1874
+
1875
+ # Check if we need to initialize RemoteWorkerPool for tree topology workers
1876
+ if subcall_ips:
1877
+ if not self.remote_worker_pool:
1878
+ # Initialize RemoteWorkerPool if not already done (needed for tree topology workers)
1879
+ logger.warning(
1880
+ f"INITIALIZING RemoteWorkerPool for tree worker at {this_pod_ip} to call {len(subcall_ips)} children"
1881
+ )
1882
+ self.remote_worker_pool = RemoteWorkerPool(
1883
+ quorum_timeout=self.quorum_timeout
1884
+ )
1885
+ self.remote_worker_pool.start(
1886
+ max_workers=min(len(subcall_ips) + 50, 200)
1887
+ ) # Size for expected children plus buffer
1888
+ logger.warning(
1889
+ f"RemoteWorkerPool initialized successfully for {this_pod_ip}"
1890
+ )
1891
+ elif (
1892
+ not hasattr(self.remote_worker_pool, "process")
1893
+ or not self.remote_worker_pool.process
1894
+ or not self.remote_worker_pool.process.is_alive()
1895
+ ):
1896
+ # Pool exists but not started/alive
1897
+ logger.warning(
1898
+ f"RemoteWorkerPool exists but not running for {this_pod_ip}, starting it now"
1899
+ )
1900
+ self.remote_worker_pool.start(
1901
+ max_workers=min(len(subcall_ips) + 50, 200)
1902
+ )
1903
+
1904
+ if subcall_ips and not self.remote_worker_pool:
1905
+ # RemoteWorkerPool should always be initialized at this point
1906
+ raise RuntimeError(
1907
+ f"RemoteWorkerPool not available for worker at {this_pod_ip}. "
1908
+ "This is required for distributed execution with subcall_ips."
1909
+ )
1910
+
1911
+ # Submit local process calls
1912
+ def call_local_processes():
1913
+ logger.debug(f"Processing {num_procs} local process responses")
1914
+ return self.process_pool.call_all(
1915
+ method_name=method_name,
1916
+ params_list=params_list,
1917
+ deployed_as_of=deployed_as_of,
1918
+ request_id=request_id,
1919
+ distributed_env_vars_list=distributed_env_vars_list,
1920
+ debug_ports=debug_ports,
1921
+ serialization=serialization,
1922
+ )
1923
+
1924
+ # We may not be calling the locally processes if the user specified workers and didn't include this node
1925
+ if call_local_procs:
1926
+ local_future = executor.submit(call_local_processes)
1927
+ else:
1928
+ local_future = None
1929
+
1930
+ # Wait for both to complete with fast failure propagation
1931
+ try:
1932
+ # Use as_completed to get results as they arrive
1933
+ futures = [f for f in [worker_future, local_future] if f is not None]
1934
+ local_responses = []
1935
+
1936
+ # If we only have local processes (no remote workers), handle that case
1937
+ if not futures:
1938
+ logger.error("No futures to wait for - this shouldn't happen")
1939
+ raise RuntimeError("No distributed work to execute")
1940
+
1941
+ logger.debug(
1942
+ f"Waiting for {len(futures)} futures to complete (worker_future={worker_future is not None}, local_future={local_future is not None})"
1943
+ )
1944
+
1945
+ # Process futures with periodic membership checks
1946
+ from concurrent.futures import FIRST_COMPLETED, wait
1947
+
1948
+ pending_futures = set(futures)
1949
+ while pending_futures:
1950
+ # Check for membership changes even while waiting
1951
+ if change_event and change_event.is_set():
1952
+ logger.debug("Membership change detected, checking...")
1953
+ try:
1954
+ self.check_for_membership_changes()
1955
+ except Exception as e:
1956
+ # Cancel all pending futures immediately
1957
+ logger.error(
1958
+ f"Membership change detected, cancelling futures: {e}"
1959
+ )
1960
+ for f in pending_futures:
1961
+ if not f.done():
1962
+ f.cancel()
1963
+ raise e
1964
+ finally:
1965
+ change_event.clear() # Reset for next change
1966
+
1967
+ # Wait for next future with a short timeout to allow membership checks
1968
+ done, pending_futures = wait(
1969
+ pending_futures, timeout=1.0, return_when=FIRST_COMPLETED
1970
+ )
1971
+
1972
+ for future in done:
1973
+ logger.debug(
1974
+ f"Future completed: is_worker={worker_future and future == worker_future}, is_local={local_future and future == local_future}"
1975
+ )
1976
+ try:
1977
+ if worker_future and future == worker_future:
1978
+ logger.debug("Getting results from remote workers future")
1979
+ results = future.result()
1980
+ logger.debug(
1981
+ f"Remote worker future returned: {type(results).__name__} with {len(results) if hasattr(results, '__len__') else 'N/A'} items"
1982
+ )
1983
+ # Process results - they're already JSON-decoded
1984
+ logger.debug(
1985
+ f"Processing {len(results)} results from RemoteWorkerPool"
1986
+ )
1987
+ for i, result in enumerate(results):
1988
+ logger.debug(
1989
+ f"Result {i}: type={type(result).__name__}, "
1990
+ f"length={len(result) if hasattr(result, '__len__') else 'N/A'}"
1991
+ )
1992
+ if isinstance(result, dict) and "error_type" in result:
1993
+ # Fast failure - return error immediately
1994
+ executor.shutdown(wait=False)
1995
+ return JSONResponse(status_code=500, content=result)
1996
+ # Results from RemoteWorkerPool are already lists (aggregated from subtree in tree topology)
1997
+ if isinstance(result, list):
1998
+ worker_responses.extend(result)
1999
+ else:
2000
+ worker_responses.append(result)
2001
+ logger.debug(
2002
+ f"Got {len(worker_responses)} total responses from remote workers"
2003
+ )
2004
+ else: # local_future
2005
+ logger.debug("Getting results from local processes future")
2006
+ local_responses = future.result()
2007
+ logger.debug(
2008
+ f"Got {len(local_responses)} responses from local processes"
2009
+ )
2010
+ # Check for errors in local responses
2011
+ for response in local_responses:
2012
+ if isinstance(response, JSONResponse):
2013
+ # Fast failure - return error immediately
2014
+ executor.shutdown(wait=False)
2015
+ return response
2016
+ except Exception as e:
2017
+ # Fast failure - propagate exception immediately
2018
+ executor.shutdown(wait=False)
2019
+ logger.error(f"Error in distributed execution: {e}")
2020
+ raise
2021
+
2022
+ finally:
2023
+ # Unsubscribe from membership changes
2024
+ if change_event:
2025
+ self.unsubscribe_from_membership_changes(change_event)
2026
+
2027
+ logger.debug("Shutting down executor")
2028
+ executor.shutdown(wait=False)
2029
+
2030
+ total = len(local_responses) + len(worker_responses)
2031
+ if tree_mode and total > 10: # Only log for tree mode with significant results
2032
+ logger.debug(
2033
+ f"TREE RESULT AGGREGATION at {this_pod_ip}: {len(local_responses)} local + {len(worker_responses)} remote = {total} total"
2034
+ )
2035
+ else:
2036
+ logger.debug(
2037
+ f"Combining {len(local_responses)} local + {len(worker_responses)} remote = {total} total responses"
2038
+ )
2039
+ logger.debug(
2040
+ f"Distributed_subcall={distributed_subcall}, tree topology={tree_mode}"
2041
+ )
2042
+ # Log sample of what we're returning for debugging
2043
+ if worker_responses:
2044
+ logger.debug(
2045
+ f"Sample worker response type: {type(worker_responses[0]).__name__}"
2046
+ )
2047
+ responses = local_responses + worker_responses
2048
+ for response in responses:
2049
+ # If the response is a JSONResponse, we need to check if it contains an exception,
2050
+ # and "raise" it if so - essentially just returning it immediately rather than the full result list.
2051
+ if isinstance(response, JSONResponse):
2052
+ return response
2053
+ # This is primarily to handle exceptions while packaging an exception, which will cause the server to hang.
2054
+ if isinstance(response, Exception):
2055
+ raise response
2056
+ logger.debug(f"Returning {len(responses)} responses from execute_call")
2057
+ return responses
2058
+
2059
+
2060
+ class JaxProcess(DistributedProcess):
2061
+ """JAX-specific distributed process."""
2062
+
2063
+ @classmethod
2064
+ def get_distributed_env_vars(
2065
+ cls, worker_ips, node_rank, local_rank, num_local_procs, **settings
2066
+ ):
2067
+ """Get JAX-specific distributed environment variables.
2068
+
2069
+ JAX uses a coordinator address and process ID for distributed setup.
2070
+ """
2071
+ port = settings.get("port") or 1234 # JAX default coordinator port
2072
+ env_vars = super().get_distributed_env_vars(
2073
+ worker_ips, node_rank, local_rank, num_local_procs, **settings
2074
+ )
2075
+
2076
+ # JAX distributed environment variables
2077
+ env_vars.update(
2078
+ {
2079
+ # Coordinator is the first worker
2080
+ "JAX_COORDINATOR_ADDRESS": f"{worker_ips[0]}:{port}",
2081
+ # Process ID is global rank
2082
+ "JAX_PROCESS_ID": str(node_rank * num_local_procs + local_rank),
2083
+ # Total number of processes
2084
+ "JAX_NUM_PROCESSES": str(len(worker_ips) * num_local_procs),
2085
+ # Local device IDs (for GPU/TPU)
2086
+ "JAX_LOCAL_DEVICE_IDS": str(local_rank),
2087
+ }
2088
+ )
2089
+ return env_vars
2090
+
2091
+ @classmethod
2092
+ def get_auto_num_processes(cls):
2093
+ """Auto-detect based on available accelerators for JAX."""
2094
+ try:
2095
+ import jax
2096
+
2097
+ # JAX can use TPUs, GPUs, or CPUs
2098
+ devices = jax.devices()
2099
+ return len(devices)
2100
+ except Exception:
2101
+ return 1
2102
+
2103
+ # JAX doesn't have a global process group to destroy like PyTorch
2104
+ # Cleanup is mostly handled automatically
2105
+ # def proc_cleanup(self):
2106
+
2107
+
2108
+ class TensorflowProcess(DistributedProcess):
2109
+ """TensorFlow-specific distributed process."""
2110
+
2111
+ def proc_cleanup(self):
2112
+ """TensorFlow-specific cleanup."""
2113
+ try:
2114
+ import tensorflow as tf
2115
+
2116
+ # Clear the default graph and reset the session
2117
+ tf.keras.backend.clear_session()
2118
+ logger.info("TensorFlow process cleanup completed.")
2119
+ except ImportError:
2120
+ logger.info("TensorFlow not available for cleanup")
2121
+ except Exception as e:
2122
+ logger.info(f"Failed during TensorFlow cleanup: {e}")
2123
+ # Call parent cleanup for debugging sessions
2124
+ super().proc_cleanup()
2125
+
2126
+ @classmethod
2127
+ def get_distributed_env_vars(
2128
+ cls, worker_ips, node_rank, local_rank, num_local_procs, **settings
2129
+ ):
2130
+ """Get TensorFlow-specific distributed environment variables.
2131
+
2132
+ TensorFlow uses TF_CONFIG for distributed training configuration.
2133
+ """
2134
+ import json
2135
+
2136
+ port = settings.get("port") or 2222 # TensorFlow default port
2137
+ env_vars = super().get_distributed_env_vars(
2138
+ worker_ips, node_rank, local_rank, num_local_procs, **settings
2139
+ )
2140
+
2141
+ # Build TF_CONFIG for MultiWorkerMirroredStrategy
2142
+ worker_addresses = [f"{ip}:{port}" for ip in worker_ips]
2143
+
2144
+ tf_config = {
2145
+ "cluster": {"worker": worker_addresses},
2146
+ "task": {"type": "worker", "index": node_rank},
2147
+ }
2148
+
2149
+ env_vars.update(
2150
+ {
2151
+ "TF_CONFIG": json.dumps(tf_config),
2152
+ # Additional TF env vars for performance
2153
+ "TF_FORCE_GPU_ALLOW_GROWTH": "true",
2154
+ "TF_GPU_THREAD_MODE": "gpu_private",
2155
+ }
2156
+ )
2157
+ return env_vars
2158
+
2159
+ @classmethod
2160
+ def get_auto_num_processes(cls):
2161
+ """Auto-detect based on available GPUs for TensorFlow."""
2162
+ try:
2163
+ import tensorflow as tf
2164
+
2165
+ gpus = tf.config.list_physical_devices("GPU")
2166
+ if gpus:
2167
+ return len(gpus)
2168
+ except Exception:
2169
+ pass
2170
+ return 1
2171
+
2172
+
2173
+ class MonarchProcess(DistributedProcess):
2174
+ """Monarch-specific distributed process for single-controller actor framework.
2175
+
2176
+ Similar to Ray, Monarch uses a single controller (rank 0) that manages distributed
2177
+ actors across worker nodes. Each node runs a process_allocator service.
2178
+ """
2179
+
2180
+ def __init__(
2181
+ self, local_rank, request_queue, response_queue, max_threads=4, **kwargs
2182
+ ):
2183
+ super().__init__(
2184
+ local_rank, request_queue, response_queue, max_threads, **kwargs
2185
+ )
2186
+ self.allocator = None
2187
+
2188
+ # Monarch imports will be done in run() on the main thread
2189
+ self.RemoteAllocator = None
2190
+ self.StaticRemoteAllocInitializer = None
2191
+
2192
+ def _create_allocator_for_controller(self):
2193
+ """Create a RemoteAllocator for the controller (rank 0)."""
2194
+
2195
+ try:
2196
+ # Try to import if not already available
2197
+ if (
2198
+ self.RemoteAllocator is None
2199
+ or self.StaticRemoteAllocInitializer is None
2200
+ ):
2201
+ try:
2202
+ from monarch._src.actor.allocator import (
2203
+ RemoteAllocator,
2204
+ StaticRemoteAllocInitializer,
2205
+ )
2206
+
2207
+ self.RemoteAllocator = RemoteAllocator
2208
+ self.StaticRemoteAllocInitializer = StaticRemoteAllocInitializer
2209
+ logger.debug("Monarch components imported")
2210
+ except ImportError as e:
2211
+ logger.error(f"Failed to import Monarch: {e}")
2212
+ logger.error(
2213
+ "Make sure torchmonarch is installed: pip install torchmonarch"
2214
+ )
2215
+ import traceback
2216
+
2217
+ logger.error(traceback.format_exc())
2218
+ return None
2219
+ except Exception as e:
2220
+ logger.error(f"Unexpected error importing Monarch: {e}")
2221
+ import traceback
2222
+
2223
+ logger.error(traceback.format_exc())
2224
+ return None
2225
+
2226
+ if (
2227
+ self.RemoteAllocator is None
2228
+ or self.StaticRemoteAllocInitializer is None
2229
+ ):
2230
+ logger.error(
2231
+ "Monarch components not available. Cannot create allocator."
2232
+ )
2233
+ return None
2234
+
2235
+ # Get worker addresses from POD_IPS
2236
+ pod_ips = os.environ.get("POD_IPS", "").split(",")
2237
+ if not pod_ips or pod_ips == [""]:
2238
+ logger.warning("No POD_IPS found, using localhost")
2239
+ pod_ips = ["127.0.0.1"]
2240
+
2241
+ # Use tcp! format for channel addresses (Monarch's format)
2242
+ # Format: tcp!{ip}:{port} not tcp://{ip}:{port}
2243
+ worker_addresses = [f"tcp!{ip}:26600" for ip in pod_ips]
2244
+ logger.info(
2245
+ f"Creating Monarch allocator with {len(worker_addresses)} workers"
2246
+ )
2247
+ logger.debug(f"Worker addresses type: {type(worker_addresses)}")
2248
+ logger.debug(
2249
+ f"First address: {worker_addresses[0] if worker_addresses else 'none'}"
2250
+ )
2251
+ logger.debug(
2252
+ f"First address type: {type(worker_addresses[0]) if worker_addresses else 'none'}"
2253
+ )
2254
+
2255
+ # Simple check - don't add complex waiting logic
2256
+
2257
+ # Create initializer with all workers using pre-imported classes
2258
+ # StaticRemoteAllocInitializer takes addresses as positional args
2259
+ logger.debug(
2260
+ f"About to create StaticRemoteAllocInitializer with args: {worker_addresses}"
2261
+ )
2262
+ try:
2263
+ initializer = self.StaticRemoteAllocInitializer(*worker_addresses)
2264
+ except Exception as e:
2265
+ logger.error(f"Failed to create StaticRemoteAllocInitializer: {e}")
2266
+ import traceback
2267
+
2268
+ logger.debug(f"Traceback: {traceback.format_exc()}")
2269
+ raise
2270
+
2271
+ # Return configured allocator using pre-imported class
2272
+ # RemoteAllocator takes world_id and initializer
2273
+ # Use stable world_id based on service name
2274
+ # This allows coordinator failover and process restarts to work correctly
2275
+ service_name = os.environ.get("KT_SERVICE_NAME", "monarch-default")
2276
+ world_id = service_name
2277
+ try:
2278
+ allocator = self.RemoteAllocator(
2279
+ world_id=world_id, initializer=initializer
2280
+ )
2281
+ logger.info(f"Created allocator with world_id={world_id}")
2282
+ return allocator
2283
+ except Exception as e:
2284
+ logger.error(f"Failed to create RemoteAllocator: {e}")
2285
+ import traceback
2286
+
2287
+ logger.debug(f"Traceback: {traceback.format_exc()}")
2288
+ raise
2289
+
2290
+ except ImportError as e:
2291
+ logger.error(f"Could not import Monarch for allocator creation: {e}")
2292
+ import traceback
2293
+
2294
+ logger.error(traceback.format_exc())
2295
+ return None
2296
+ except Exception as e:
2297
+ logger.error(f"Failed to create Monarch allocator: {e}")
2298
+ import traceback
2299
+
2300
+ logger.error(traceback.format_exc())
2301
+ return None
2302
+
2303
+ def run_user_function(self, callable_obj, method_name, params):
2304
+ """Run user function with Monarch-specific setup."""
2305
+ import asyncio
2306
+ import inspect
2307
+
2308
+ # Get the rank from environment
2309
+ rank = int(os.environ.get("NODE_RANK", "0"))
2310
+ logger.debug(f"Running user function on rank {rank}")
2311
+
2312
+ # Get the method to call
2313
+ if method_name and hasattr(callable_obj, method_name):
2314
+ user_method = getattr(callable_obj, method_name)
2315
+ else:
2316
+ user_method = callable_obj
2317
+
2318
+ logger.info(f"User method: {user_method}")
2319
+
2320
+ # Prepare arguments
2321
+ args = params.get("args", []) if params else []
2322
+ kwargs = params.get("kwargs", {}) if params else {}
2323
+
2324
+ # Only create and inject allocator for controller (rank 0)
2325
+ # Workers will run the user function with allocator=None
2326
+ if rank == 0:
2327
+ logger.info("Rank 0 (controller) - will create allocator")
2328
+ # Controller (rank 0) - create allocator if needed
2329
+ if self.allocator is None:
2330
+ logger.debug("Creating allocator...")
2331
+ self.allocator = self._create_allocator_for_controller()
2332
+ if self.allocator is None:
2333
+ logger.error("Failed to create allocator - returned None!")
2334
+ else:
2335
+ logger.debug("Allocator created successfully")
2336
+
2337
+ # Inject allocator if function accepts it
2338
+ try:
2339
+ sig = inspect.signature(user_method)
2340
+ if "allocator" in sig.parameters:
2341
+ logger.debug("Injecting allocator into controller function")
2342
+ kwargs["allocator"] = self.allocator
2343
+ except Exception as e:
2344
+ logger.warning(f"Could not inspect function signature: {e}")
2345
+ else:
2346
+ # Workers get None for allocator parameter if the function expects it
2347
+ try:
2348
+ sig = inspect.signature(user_method)
2349
+ if "allocator" in sig.parameters:
2350
+ logger.info(f"Worker {rank}: Setting allocator=None")
2351
+ kwargs["allocator"] = None
2352
+ except Exception:
2353
+ pass
2354
+
2355
+ # Run the function (we're already on the main thread of this process)
2356
+ logger.info(
2357
+ f"Rank {rank}: Running user function with args={args}, kwargs keys={list(kwargs.keys())}"
2358
+ )
2359
+
2360
+ try:
2361
+ if asyncio.iscoroutinefunction(user_method):
2362
+ result = asyncio.run(user_method(*args, **kwargs))
2363
+ else:
2364
+ result = user_method(*args, **kwargs)
2365
+ if inspect.isawaitable(result):
2366
+ result = asyncio.run(result)
2367
+
2368
+ # If the result is an exception dict (from the user's try/except),
2369
+ # convert it back to a proper exception and raise it
2370
+ if (
2371
+ isinstance(result, dict)
2372
+ and "status" in result
2373
+ and result["status"] == "error"
2374
+ ):
2375
+ error_msg = result.get("error", "Unknown error")
2376
+ traceback_str = result.get("traceback", "")
2377
+ logger.error(f"User function returned error dict: {error_msg}")
2378
+ logger.error(f"Traceback from user function: {traceback_str}")
2379
+ # Raise a RuntimeError with the original error message
2380
+ raise RuntimeError(
2381
+ f"Monarch execution failed: {error_msg}\n{traceback_str}"
2382
+ )
2383
+
2384
+ return result
2385
+ except Exception as e:
2386
+ logger.error(f"Exception in run_user_function: {e}")
2387
+ import traceback
2388
+
2389
+ logger.error(f"Full traceback: {traceback.format_exc()}")
2390
+ # Re-raise the exception to be handled by the caller
2391
+ raise
2392
+
2393
+ def run(self):
2394
+ """Override run to handle requests on main thread for Monarch."""
2395
+ logger.debug("MonarchProcess starting on main thread")
2396
+
2397
+ # Import Monarch on the main thread of this subprocess
2398
+ # This is the right place since run() executes on the main thread
2399
+ if self.RemoteAllocator is None or self.StaticRemoteAllocInitializer is None:
2400
+ try:
2401
+ from monarch._src.actor.allocator import (
2402
+ RemoteAllocator,
2403
+ StaticRemoteAllocInitializer,
2404
+ )
2405
+
2406
+ self.RemoteAllocator = RemoteAllocator
2407
+ self.StaticRemoteAllocInitializer = StaticRemoteAllocInitializer
2408
+ except Exception as e:
2409
+ logger.error(f"Failed to import Monarch in run(): {e}")
2410
+ import traceback
2411
+
2412
+ logger.error(traceback.format_exc())
2413
+
2414
+ # Monarch requires main thread execution, so we don't use ThreadPoolExecutor
2415
+ try:
2416
+ while True:
2417
+ try:
2418
+ # Block waiting for next request
2419
+ request = self._request_queue.get(timeout=1)
2420
+
2421
+ # Special sentinel value to signal shutdown
2422
+ if request == "SHUTDOWN":
2423
+ break
2424
+
2425
+ # Handle request directly on main thread (not in thread pool)
2426
+ self.handle_request(request)
2427
+
2428
+ except queue.Empty:
2429
+ continue
2430
+ except Exception as e:
2431
+ if "Empty" not in str(e.__class__.__name__):
2432
+ logger.error(f"Error getting request from queue: {e}")
2433
+ continue
2434
+
2435
+ except (KeyboardInterrupt, BdbQuit):
2436
+ logger.debug("MonarchProcess interrupted")
2437
+ finally:
2438
+ logger.debug("MonarchProcess shutting down")
2439
+ self.proc_cleanup()
2440
+
2441
+ def handle_request(self, request):
2442
+ """Handle request using Monarch-specific logic."""
2443
+ try:
2444
+ # Use parent's handle_request but override the actual function execution
2445
+ request_unique_id = request["request_unique_id"]
2446
+ method_name = request["method_name"]
2447
+ params = request["params"]
2448
+ deployed_as_of = request["deployed_as_of"]
2449
+ request_id = request["request_id"]
2450
+ distributed_env_vars = request["distributed_env_vars"]
2451
+
2452
+ # Set environment variables
2453
+ for key, value in distributed_env_vars.items():
2454
+ os.environ[key] = value
2455
+
2456
+ # Set request context
2457
+ token = request_id_ctx_var.set(request_id)
2458
+
2459
+ try:
2460
+ # Load callable
2461
+ callable_obj = load_callable(
2462
+ deployed_as_of=deployed_as_of,
2463
+ distributed_subprocess=True,
2464
+ reload_cleanup_fn=self.proc_cleanup,
2465
+ )
2466
+
2467
+ # Run with our simplified Monarch logic
2468
+ result = self.run_user_function(callable_obj, method_name, params)
2469
+
2470
+ # Send response
2471
+ self._response_queue.put(
2472
+ {"request_unique_id": request_unique_id, "result": result}
2473
+ )
2474
+
2475
+ except Exception as e:
2476
+ # Package and send error
2477
+ packaged_exception = package_exception(e)
2478
+ self._response_queue.put(
2479
+ {
2480
+ "request_unique_id": request_unique_id,
2481
+ "result": packaged_exception,
2482
+ }
2483
+ )
2484
+ finally:
2485
+ request_id_ctx_var.reset(token)
2486
+
2487
+ except Exception as e:
2488
+ logger.error(f"Error in Monarch request handling: {e}")
2489
+ self._response_queue.put(
2490
+ {
2491
+ "request_unique_id": request.get("request_unique_id", "unknown"),
2492
+ "result": Exception(f"Fatal error: {e}"),
2493
+ }
2494
+ )
2495
+
2496
+ def proc_cleanup(self):
2497
+ """Monarch-specific cleanup."""
2498
+ try:
2499
+ # Stop allocator service
2500
+ if self.allocator_proc:
2501
+ self.allocator_proc.terminate()
2502
+ try:
2503
+ self.allocator_proc.wait(timeout=5)
2504
+ except subprocess.TimeoutExpired:
2505
+ self.allocator_proc.kill()
2506
+ self.allocator_proc = None
2507
+ logger.info("Stopped process_allocator service")
2508
+
2509
+ # Cleanup any Monarch resources
2510
+ # Monarch doesn't have a global shutdown like Ray
2511
+ logger.debug("Monarch process cleanup completed")
2512
+
2513
+ except Exception as e:
2514
+ logger.error(f"Error during Monarch cleanup: {e}")
2515
+
2516
+ # Call parent cleanup
2517
+ super().proc_cleanup()
2518
+
2519
+ @classmethod
2520
+ def get_distributed_env_vars(
2521
+ cls, worker_ips, node_rank, local_rank, num_local_procs, **settings
2522
+ ):
2523
+ """Get Monarch-specific environment variables."""
2524
+ env_vars = super().get_distributed_env_vars(
2525
+ worker_ips, node_rank, local_rank, num_local_procs, **settings
2526
+ )
2527
+
2528
+ # Monarch uses these for discovery
2529
+ env_vars.update(
2530
+ {
2531
+ "HYPERACTOR_MESH_BOOTSTRAP_ADDR": "tcp://localhost:26600",
2532
+ "HYPERACTOR_MESH_INDEX": str(node_rank),
2533
+ # Keep POD_IPS for allocator creation
2534
+ "POD_IPS": ",".join(worker_ips),
2535
+ }
2536
+ )
2537
+
2538
+ return env_vars
2539
+
2540
+ @classmethod
2541
+ def get_auto_num_processes(cls):
2542
+ """Monarch uses one process per node (like Ray)."""
2543
+ return 1
2544
+
2545
+
2546
+ # Similar to Ray, Monarch needs special handling as a single-controller framework
2547
+ class MonarchDistributed(DistributedSupervisor):
2548
+ """Monarch distributed supervisor for single-controller actor framework."""
2549
+
2550
+ def __init__(
2551
+ self,
2552
+ restart_procs=True,
2553
+ max_threads=4,
2554
+ quorum_timeout=300,
2555
+ quorum_workers=None,
2556
+ **kwargs,
2557
+ ):
2558
+ # Monarch doesn't use DNS monitoring like SPMD frameworks
2559
+ super().__init__(
2560
+ quorum_workers=quorum_workers,
2561
+ quorum_timeout=quorum_timeout,
2562
+ monitor_members=False, # Disable DNS monitoring like Ray
2563
+ )
2564
+ self.restart_procs = restart_procs
2565
+ self.max_threads = max_threads
2566
+ self.process_pool = None
2567
+ self.remote_worker_pool = None
2568
+
2569
+ def setup(self, deployed_as_of: Optional[str] = None):
2570
+ """Setup Monarch distributed environment."""
2571
+ # Set multiprocessing to spawn
2572
+ if multiprocessing.get_start_method() != "spawn":
2573
+ multiprocessing.set_start_method("spawn", force=True)
2574
+
2575
+ # Start process_allocator service (like Ray starts its server)
2576
+ self._start_allocator_service()
2577
+
2578
+ if self.restart_procs:
2579
+ logger.debug("restart_procs is True, restarting Monarch processes")
2580
+ self.cleanup()
2581
+
2582
+ if self.process_pool is None:
2583
+ logger.debug("Setting up Monarch distributed environment")
2584
+
2585
+ # Create process pool with MonarchProcess
2586
+ logger.info("Creating DistributedProcessPool with MonarchProcess class")
2587
+ self.process_pool = DistributedProcessPool(
2588
+ process_class=MonarchProcess,
2589
+ num_processes=1, # One process per node for Monarch
2590
+ max_threads_per_proc=self.max_threads,
2591
+ )
2592
+
2593
+ # Start the process
2594
+ logger.info("Starting MonarchProcess pool...")
2595
+ self.process_pool.start()
2596
+ logger.info(f"Started MonarchProcess pool: {self.process_pool}")
2597
+
2598
+ # Create remote worker pool for coordination
2599
+ self.remote_worker_pool = RemoteWorkerPool(
2600
+ quorum_timeout=self.quorum_timeout
2601
+ )
2602
+ self.remote_worker_pool.start()
2603
+
2604
+ logger.debug("Finished setting up Monarch distributed processes")
2605
+
2606
+ def _start_allocator_service(self):
2607
+ """Start the process_allocator service if available."""
2608
+ try:
2609
+ # Check if process_allocator is already running
2610
+ import subprocess
2611
+
2612
+ check_result = subprocess.run(
2613
+ ["pgrep", "-f", "process_allocator"], capture_output=True
2614
+ )
2615
+ if check_result.returncode == 0:
2616
+ logger.info("process_allocator already running")
2617
+ return
2618
+
2619
+ # Try to find process_allocator binary
2620
+ import shutil
2621
+
2622
+ allocator_path = shutil.which("process_allocator")
2623
+
2624
+ if not allocator_path:
2625
+ # Check common installation paths
2626
+ import sys
2627
+
2628
+ possible_paths = [
2629
+ "/opt/conda/bin/process_allocator",
2630
+ os.path.join(sys.prefix, "bin", "process_allocator"),
2631
+ ]
2632
+ for path in possible_paths:
2633
+ if os.path.exists(path) and os.access(path, os.X_OK):
2634
+ allocator_path = path
2635
+ break
2636
+
2637
+ if not allocator_path:
2638
+ logger.warning(
2639
+ "process_allocator binary not found. "
2640
+ "Please ensure torchmonarch is properly installed or "
2641
+ "start process_allocator manually in your Docker image."
2642
+ )
2643
+ return
2644
+
2645
+ # Start process_allocator with the correct arguments
2646
+ # Based on monarch/python/monarch/tools/components/hyperactor.py
2647
+ allocator_cmd = [
2648
+ allocator_path,
2649
+ "--port=26600",
2650
+ "--program=monarch_bootstrap",
2651
+ ]
2652
+
2653
+ logger.info(f"Starting process_allocator: {' '.join(allocator_cmd)}")
2654
+
2655
+ # Start in background
2656
+ self.allocator_proc = subprocess.Popen(
2657
+ allocator_cmd,
2658
+ stdout=subprocess.DEVNULL,
2659
+ stderr=subprocess.PIPE,
2660
+ start_new_session=True,
2661
+ )
2662
+
2663
+ # Give it a moment to start
2664
+ import time
2665
+
2666
+ time.sleep(2)
2667
+
2668
+ # Check if it's still running
2669
+ if self.allocator_proc.poll() is None:
2670
+ logger.info(
2671
+ f"process_allocator started successfully (PID: {self.allocator_proc.pid})"
2672
+ )
2673
+ else:
2674
+ stderr = (
2675
+ self.allocator_proc.stderr.read().decode()
2676
+ if self.allocator_proc.stderr
2677
+ else ""
2678
+ )
2679
+ logger.error(f"process_allocator failed to start: {stderr}")
2680
+ self.allocator_proc = None
2681
+
2682
+ except Exception as e:
2683
+ logger.warning(f"Could not start process_allocator: {e}")
2684
+ # Continue anyway - user may have started it differently
2685
+
2686
+ def cleanup(self):
2687
+ """Cleanup Monarch distributed environment."""
2688
+ logger.debug("Cleaning up Monarch distributed processes")
2689
+
2690
+ # Stop DNS monitoring (though it's disabled for Monarch)
2691
+ super().cleanup()
2692
+
2693
+ # Stop process_allocator if we started it
2694
+ if hasattr(self, "allocator_proc") and self.allocator_proc:
2695
+ try:
2696
+ self.allocator_proc.terminate()
2697
+ self.allocator_proc.wait(timeout=5)
2698
+ logger.info("Stopped process_allocator")
2699
+ except Exception as e:
2700
+ logger.debug(f"Error stopping process_allocator: {e}")
2701
+
2702
+ if self.process_pool:
2703
+ self.process_pool.stop()
2704
+ self.process_pool = None
2705
+
2706
+ if self.remote_worker_pool:
2707
+ self.remote_worker_pool.stop()
2708
+ self.remote_worker_pool = None
2709
+
2710
+ logger.debug("Finished cleaning up Monarch distributed processes")
2711
+
2712
+ @staticmethod
2713
+ def intercept_call():
2714
+ """Monarch intercepts calls like Ray."""
2715
+ return True
2716
+
2717
+ def call_distributed(
2718
+ self,
2719
+ request,
2720
+ cls_or_fn_name: str,
2721
+ method_name: Optional[str] = None,
2722
+ params: Optional[Dict] = None,
2723
+ distributed_subcall: bool = False,
2724
+ debug_port: int = False,
2725
+ deployed_as_of: Optional[str] = None,
2726
+ ):
2727
+ """Monarch distributed call - executes on controller node (rank 0)."""
2728
+ logger.info("MonarchDistributed.call_distributed called")
2729
+
2730
+ # Ensure setup has been called
2731
+ if self.process_pool is None:
2732
+ logger.info("Process pool not initialized, calling setup()")
2733
+ self.setup(deployed_as_of=deployed_as_of)
2734
+
2735
+ request_id = request.headers.get("X-Request-ID", "-")
2736
+ serialization = request.headers.get("X-Serialization", "json")
2737
+
2738
+ # If deployed_as_of is None, generate a consistent timestamp
2739
+ if deployed_as_of is None:
2740
+ from datetime import datetime, timezone
2741
+
2742
+ deployed_as_of = datetime.now(timezone.utc).isoformat()
2743
+
2744
+ # Start DNS monitoring for worker discovery
2745
+ self.start_dns_monitoring()
2746
+
2747
+ # Check for any pending changes before we start
2748
+ self.check_for_membership_changes()
2749
+
2750
+ # Get pod IPs with quorum handling
2751
+ pod_ips = self.pod_ips()
2752
+
2753
+ # Handle case where no pods are found
2754
+ if not pod_ips:
2755
+ logger.error(
2756
+ f"No pods found for service {os.environ.get('KT_SERVICE')}. "
2757
+ "This may indicate the pods aren't ready yet. Consider increasing quorum_timeout in .distribute() call."
2758
+ )
2759
+ raise RuntimeError(
2760
+ "No pods found for Monarch distributed setup. "
2761
+ "Consider increasing quorum_timeout parameter."
2762
+ )
2763
+
2764
+ logger.info(
2765
+ f"Found {len(pod_ips)} pod(s) for Monarch distributed setup: {pod_ips}"
2766
+ )
2767
+
2768
+ # Store critical environment variables
2769
+ self.distributed_env_vars = {}
2770
+ critical_env_vars = [
2771
+ "KT_SERVICE",
2772
+ "KT_SERVICE_NAME",
2773
+ "KT_FILE_PATH",
2774
+ "KT_MODULE_NAME",
2775
+ "KT_CLS_OR_FN_NAME",
2776
+ ]
2777
+ for env_var in critical_env_vars:
2778
+ if env_var in os.environ:
2779
+ self.distributed_env_vars[env_var] = os.environ[env_var]
2780
+
2781
+ # Update distributed env vars with current cluster IPs
2782
+ self.distributed_env_vars["POD_IPS"] = ",".join(pod_ips)
2783
+ self.distributed_env_vars["WORLD_SIZE"] = str(len(pod_ips))
2784
+ self.distributed_env_vars["NODE_RANK"] = "0" # Controller is always rank 0
2785
+
2786
+ logger.debug("Sending call to Monarch subprocess (controller)")
2787
+
2788
+ # Monarch uses only one process per node, call index 0
2789
+ result = self.process_pool.call(
2790
+ idx=0,
2791
+ method_name=method_name,
2792
+ params=params,
2793
+ deployed_as_of=deployed_as_of,
2794
+ request_id=request_id,
2795
+ distributed_env_vars=self.distributed_env_vars,
2796
+ debug_port=debug_port,
2797
+ serialization=serialization,
2798
+ )
2799
+
2800
+ # Handle exceptions from subprocess
2801
+ if isinstance(result, JSONResponse):
2802
+ return result
2803
+ if isinstance(result, Exception):
2804
+ raise result
2805
+
2806
+ return result
2807
+
2808
+
2809
+ RAY_START_PROC = None
2810
+
2811
+
2812
+ class RayDistributed(DistributedSupervisor):
2813
+ def __init__(
2814
+ self,
2815
+ restart_procs=True,
2816
+ max_threads=4,
2817
+ quorum_timeout=300,
2818
+ quorum_workers=None,
2819
+ monitor_members=False,
2820
+ ):
2821
+ """Ray distributed supervisor - only runs on head node (single controller).
2822
+
2823
+ Args:
2824
+ restart_procs: Whether to restart processes on each call
2825
+ max_threads: Maximum threads per process
2826
+ quorum_timeout: Timeout in seconds for Ray cluster nodes to become ready (default 300s/5min)
2827
+ """
2828
+ # Ray manages its own membership, so we don't monitor DNS changes
2829
+ super().__init__(
2830
+ quorum_timeout=quorum_timeout,
2831
+ quorum_workers=quorum_workers,
2832
+ monitor_members=monitor_members,
2833
+ )
2834
+ self.restart_procs = restart_procs
2835
+ self.distributed_env_vars = None
2836
+ self.process_pool = None # Using pool even for single process for consistency
2837
+ self.remote_worker_pool = None # Pool for async HTTP calls to remote workers
2838
+ self.max_threads = max_threads
2839
+ self.quorum_timeout = quorum_timeout
2840
+
2841
+ def setup(self, deployed_as_of: Optional[str] = None):
2842
+ # Set multiprocessing to spawn if not already
2843
+ if multiprocessing.get_start_method() != "spawn":
2844
+ multiprocessing.set_start_method("spawn", force=True)
2845
+
2846
+ # Start the Ray server here, if we allow KubeRay to start it in the pod template
2847
+ # it's hard to wait for it start properly and we lose the ability to restart if needed.
2848
+ global RAY_START_PROC
2849
+
2850
+ # Check if Ray is actually running, not just if our global variable is None
2851
+ # (the global variable gets reset when HTTP server restarts)
2852
+ ray_running = self._is_ray_running()
2853
+
2854
+ if not ray_running:
2855
+ patch_sys_path()
2856
+
2857
+ kuberay_start_cmd = os.environ.get("KUBERAY_GEN_RAY_START_CMD")
2858
+ if kuberay_start_cmd:
2859
+ full_cmd = f"ulimit -n 65536; {kuberay_start_cmd}"
2860
+ logger.info(f"Starting Ray server with command: {full_cmd}")
2861
+
2862
+ try:
2863
+ # Start Ray as a non-blocking subprocess
2864
+ RAY_START_PROC = subprocess.Popen(
2865
+ full_cmd,
2866
+ shell=True,
2867
+ stdout=subprocess.PIPE,
2868
+ stderr=subprocess.STDOUT,
2869
+ universal_newlines=True,
2870
+ bufsize=1,
2871
+ env=os.environ.copy(),
2872
+ )
2873
+
2874
+ # Start a thread to stream Ray logs
2875
+ def stream_ray_logs():
2876
+ try:
2877
+ for line in RAY_START_PROC.stdout:
2878
+ logger.info(f"[Ray] {line.strip()}")
2879
+ except Exception as e:
2880
+ logger.error(f"Error streaming Ray logs: {e}")
2881
+
2882
+ import threading
2883
+
2884
+ log_thread = threading.Thread(target=stream_ray_logs, daemon=True)
2885
+ log_thread.start()
2886
+
2887
+ logger.info(f"Ray server started with PID: {RAY_START_PROC.pid}")
2888
+
2889
+ # Give Ray a moment to start
2890
+ time.sleep(2)
2891
+
2892
+ except Exception as e:
2893
+ logger.error(f"Failed to start Ray server: {e}")
2894
+ RAY_START_PROC = None
2895
+ raise
2896
+ else:
2897
+ logger.warning(
2898
+ "KUBERAY_GEN_RAY_START_CMD environment variable not found"
2899
+ )
2900
+
2901
+ logger.debug(
2902
+ "Ray distributed supervisor setup completed (pod discovery will be done lazily)"
2903
+ )
2904
+
2905
+ # Only the head node runs the subprocess
2906
+ this_pod_ip = os.environ["POD_IP"]
2907
+ if not os.environ["POD_NAME"].endswith("-head"):
2908
+ logger.info(f"Ray worker node {this_pod_ip}, skipping subprocess setup")
2909
+ return
2910
+
2911
+ logger.info(f"Ray head node {this_pod_ip}, setting up subprocess")
2912
+
2913
+ # Set Ray environment variables
2914
+ self.distributed_env_vars = {"RAY_HEAD_NODE_IP": this_pod_ip}
2915
+
2916
+ # Include critical environment variables so Ray workers can find and load the callable
2917
+ critical_env_vars = [
2918
+ "PYTHONPATH",
2919
+ "KT_FILE_PATH",
2920
+ "KT_MODULE_NAME",
2921
+ "KT_CLS_OR_FN_NAME",
2922
+ ]
2923
+ for env_var in critical_env_vars:
2924
+ if env_var in os.environ:
2925
+ self.distributed_env_vars[env_var] = os.environ[env_var]
2926
+
2927
+ # Cleanup will remove the process pool if found, so we need to check if it was previously initialized
2928
+ previously_initialized = self.remote_worker_pool is not None
2929
+
2930
+ if self.restart_procs:
2931
+ logger.debug("restart_procs is True, restarting Ray distributed process")
2932
+ self.cleanup()
2933
+
2934
+ if previously_initialized:
2935
+ pod_ips = self.pod_ips()
2936
+ this_pod_ip = os.environ["POD_IP"]
2937
+
2938
+ # Send reload requests to other pods if needed
2939
+ self._reload_image_on_other_pods(pod_ips, this_pod_ip, deployed_as_of)
2940
+
2941
+ if self.process_pool is None:
2942
+ logger.debug("Setting up Ray distributed process")
2943
+ self.process_pool = DistributedProcessPool(
2944
+ process_class=RayProcess,
2945
+ num_processes=1, # Ray only needs one process
2946
+ max_threads_per_proc=self.max_threads,
2947
+ )
2948
+ self.process_pool.start()
2949
+
2950
+ # # Start remote worker pool for async HTTP calls if needed
2951
+ # Use a reasonable default max_workers since we don't know cluster size yet
2952
+ self.remote_worker_pool = RemoteWorkerPool(
2953
+ quorum_timeout=self.quorum_timeout
2954
+ )
2955
+ self.remote_worker_pool.start(max_workers=100) # Default size
2956
+
2957
+ logger.debug(
2958
+ "Finished setting up Ray distributed process and remote worker pool"
2959
+ )
2960
+
2961
+ def cleanup(self):
2962
+ """Clean up Ray distributed process."""
2963
+ logger.debug("Cleaning up Ray distributed process")
2964
+
2965
+ # Stop DNS monitoring first
2966
+ super().cleanup()
2967
+
2968
+ if self.process_pool:
2969
+ self.process_pool.stop()
2970
+ self.process_pool = None
2971
+
2972
+ if self.remote_worker_pool:
2973
+ self.remote_worker_pool.stop()
2974
+ self.remote_worker_pool = None
2975
+
2976
+ logger.debug("Finished cleaning up Ray distributed process")
2977
+
2978
+ @staticmethod
2979
+ def intercept_call():
2980
+ return True
2981
+
2982
+ def call_distributed(
2983
+ self,
2984
+ request,
2985
+ cls_or_fn_name: str,
2986
+ method_name: Optional[str] = None,
2987
+ params: Optional[Dict] = None,
2988
+ distributed_subcall: bool = False,
2989
+ debug_port: int = False,
2990
+ deployed_as_of: Optional[str] = None,
2991
+ ):
2992
+ """Ray distributed call - only executes on head node."""
2993
+ request_id = request.headers.get("X-Request-ID", "-")
2994
+ serialization = request.headers.get("X-Serialization", "json")
2995
+
2996
+ # If deployed_as_of is None, generate a consistent timestamp
2997
+ # to use across all workers to prevent reload inconsistencies
2998
+ if deployed_as_of is None:
2999
+ from datetime import datetime, timezone
3000
+
3001
+ deployed_as_of = datetime.now(timezone.utc).isoformat()
3002
+
3003
+ if not os.environ["POD_NAME"].endswith("-head"):
3004
+ # This should never happen, because the service only points to the head node, Raise an error if it does.
3005
+ raise RuntimeError(
3006
+ f"Ray distributed call attempted on non-head node {os.environ['POD_NAME']}. "
3007
+ "This should only be called on the head node."
3008
+ )
3009
+
3010
+ # Start DNS monitoring for the head node
3011
+ self.start_dns_monitoring()
3012
+
3013
+ # Check for any pending changes before we start
3014
+ self.check_for_membership_changes()
3015
+
3016
+ # The pod_ips() method now handles waiting for quorum
3017
+ pod_ips = self.pod_ips()
3018
+
3019
+ # Handle case where no pods are found
3020
+ if not pod_ips:
3021
+ logger.error(
3022
+ f"No pods found for service {os.environ.get('KT_SERVICE')}. "
3023
+ "This may indicate the pods aren't ready yet. Consider increasing quorum_timeout in .distribute() call."
3024
+ )
3025
+ raise RuntimeError(
3026
+ "No pods found for Ray distributed setup. "
3027
+ "Consider increasing quorum_timeout parameter."
3028
+ )
3029
+
3030
+ logger.info(f"Found {len(pod_ips)} pod(s) for distributed setup: {pod_ips}")
3031
+
3032
+ # Update distributed env vars with current cluster IPs
3033
+ self.distributed_env_vars["POD_IPS"] = ",".join(pod_ips)
3034
+
3035
+ logger.debug("Sending call to Ray subprocess")
3036
+ # Ray uses only one process, so always call index 0
3037
+ result = self.process_pool.call(
3038
+ idx=0,
3039
+ method_name=method_name,
3040
+ params=params,
3041
+ deployed_as_of=deployed_as_of,
3042
+ request_id=request_id,
3043
+ distributed_env_vars=self.distributed_env_vars,
3044
+ debug_port=debug_port,
3045
+ serialization=serialization,
3046
+ )
3047
+
3048
+ # Handle exceptions from subprocess
3049
+ if isinstance(result, JSONResponse):
3050
+ return result
3051
+ if isinstance(result, Exception):
3052
+ raise result
3053
+
3054
+ return result
3055
+
3056
+ def _reload_image_on_other_pods(self, pod_ips, this_pod_ip, deployed_as_of):
3057
+ """Send /_reload_image requests to all other pods in parallel, with retries for pods that aren't ready."""
3058
+ other_pod_ips = [ip for ip in pod_ips if ip != this_pod_ip]
3059
+
3060
+ if not other_pod_ips:
3061
+ logger.debug("No other pods to reload")
3062
+ return
3063
+
3064
+ logger.info(
3065
+ f"Sending reload requests to {len(other_pod_ips)} other pods: {other_pod_ips}"
3066
+ )
3067
+
3068
+ server_port = os.environ.get("KT_SERVER_PORT", "32300")
3069
+ total_timeout = self.quorum_timeout # Use configurable quorum timeout
3070
+ retry_interval = 2 # Wait 2 seconds between retry attempts
3071
+ start_time = time.time()
3072
+
3073
+ successful_pods = set()
3074
+ remaining_pods = set(other_pod_ips)
3075
+
3076
+ while remaining_pods and (time.time() - start_time) < total_timeout:
3077
+ logger.debug(
3078
+ f"Attempting to reload {len(remaining_pods)} remaining pods: {list(remaining_pods)}"
3079
+ )
3080
+
3081
+ def reload_pod(pod_ip):
3082
+ """Send reload request to a single pod."""
3083
+ try:
3084
+ # Use a proper HTTP client session to avoid Content-Length issues
3085
+ with httpx.Client(timeout=None) as client:
3086
+ url = f"http://{pod_ip}:{server_port}/_reload_image"
3087
+ # First try a quick health check to see if pod is ready
3088
+ health_url = f"http://{pod_ip}:{server_port}/health"
3089
+ health_response = client.get(health_url, timeout=5)
3090
+
3091
+ if health_response.status_code != 200:
3092
+ logger.debug(
3093
+ f"Pod {pod_ip} health check failed, will retry later"
3094
+ )
3095
+ return False
3096
+
3097
+ # Pod is healthy, send reload request (no timeout, installs can be long-running)
3098
+ response = client.post(
3099
+ url, headers={"X-Deployed-As-Of": deployed_as_of}
3100
+ )
3101
+ if response.status_code == 200:
3102
+ logger.debug(f"Successfully reloaded image on pod {pod_ip}")
3103
+ return True
3104
+ else:
3105
+ logger.warning(
3106
+ f"Pod {pod_ip} reload returned status {response.status_code}"
3107
+ )
3108
+ return False
3109
+
3110
+ except Exception as e:
3111
+ logger.debug(f"Failed to reload image on pod {pod_ip}: {e}")
3112
+ raise
3113
+
3114
+ # Try to reload all remaining pods in parallel
3115
+ current_attempt_pods = list(remaining_pods)
3116
+
3117
+ with ThreadPoolExecutor(
3118
+ max_workers=min(len(current_attempt_pods), 10)
3119
+ ) as executor:
3120
+ # Submit reload tasks for remaining pods
3121
+ future_to_pod = {
3122
+ executor.submit(reload_pod, pod_ip): pod_ip
3123
+ for pod_ip in current_attempt_pods
3124
+ }
3125
+
3126
+ # Process completed futures
3127
+ for future in as_completed(future_to_pod, timeout=None):
3128
+ pod_ip = future_to_pod[future]
3129
+ try:
3130
+ success = future.result()
3131
+ if success:
3132
+ successful_pods.add(pod_ip)
3133
+ remaining_pods.discard(pod_ip)
3134
+ except Exception as e:
3135
+ logger.debug(f"Reload task for pod {pod_ip} failed: {e}")
3136
+
3137
+ if remaining_pods:
3138
+ elapsed = time.time() - start_time
3139
+ remaining_time = total_timeout - elapsed
3140
+ if remaining_time > retry_interval:
3141
+ logger.info(
3142
+ f"Waiting {retry_interval}s before retrying {len(remaining_pods)} pods..."
3143
+ )
3144
+ time.sleep(retry_interval)
3145
+ else:
3146
+ logger.warning("Timeout approaching, stopping retry attempts")
3147
+ break
3148
+
3149
+ # Log final results
3150
+ if successful_pods:
3151
+ logger.info(
3152
+ f"Successfully reloaded {len(successful_pods)} pod images: {list(successful_pods)}"
3153
+ )
3154
+
3155
+ if remaining_pods:
3156
+ logger.warning(
3157
+ f"Failed to reload {len(remaining_pods)} pod images after timeout: {list(remaining_pods)}"
3158
+ )
3159
+
3160
+ def _is_ray_running(self):
3161
+ """Check if Ray is actually running by trying to connect to the Ray GCS port."""
3162
+ try:
3163
+ import socket
3164
+
3165
+ # Ray GCS runs on port 6379 by default
3166
+ ray_port = 6379
3167
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
3168
+ sock.settimeout(1) # 1 second timeout
3169
+ result = sock.connect_ex(("127.0.0.1", ray_port))
3170
+ sock.close()
3171
+
3172
+ if result == 0:
3173
+ logger.debug(
3174
+ "Ray GCS port 6379 is accessible, Ray appears to be running"
3175
+ )
3176
+ return True
3177
+ else:
3178
+ logger.debug("Ray GCS port 6379 is not accessible, Ray is not running")
3179
+ return False
3180
+
3181
+ except Exception as e:
3182
+ logger.debug(f"Error checking if Ray is running: {e}")
3183
+ return False
3184
+
3185
+
3186
+ def distributed_supervisor_factory(distribution_type, *args, **kwargs):
3187
+ """
3188
+ Factory function to create a distributed supervisor based on the specified type.
3189
+
3190
+ Args:
3191
+ distribution_type (str): The type of distributed supervisor to create.
3192
+ Options include 'ray', 'monarch', 'pytorch', 'jax', 'tensorflow', or None for generic SPMD.
3193
+ *args: Positional arguments to pass to the supervisor constructor.
3194
+ **kwargs: Keyword arguments to pass to the supervisor constructor.
3195
+ Common kwargs include:
3196
+ - quorum_timeout: Timeout in seconds for workers to become ready (default 30 for SPMD, 300 for Ray/Monarch)
3197
+
3198
+ Returns:
3199
+ DistributedSupervisor: An instance of the specified distributed supervisor.
3200
+ """
3201
+ if distribution_type == "ray":
3202
+ # Ray uses its own supervisor, not SPMD
3203
+ return RayDistributed(*args, **kwargs)
3204
+ elif distribution_type == "monarch":
3205
+ # Monarch is similar to Ray - single controller framework
3206
+ return MonarchDistributed(*args, **kwargs)
3207
+
3208
+ # All other types use SPMDDistributedSupervisor with different process classes
3209
+ if distribution_type == "pytorch":
3210
+ return SPMDDistributedSupervisor(process_class=PyTorchProcess, *args, **kwargs)
3211
+ elif distribution_type == "jax":
3212
+ return SPMDDistributedSupervisor(process_class=JaxProcess, *args, **kwargs)
3213
+ elif distribution_type == "tensorflow" or distribution_type == "tf":
3214
+ return SPMDDistributedSupervisor(
3215
+ process_class=TensorflowProcess, *args, **kwargs
3216
+ )
3217
+ elif distribution_type is None or distribution_type == "spmd":
3218
+ # Default to base DistributedProcess - no framework-specific dependencies
3219
+ return SPMDDistributedSupervisor(
3220
+ process_class=DistributedProcess, *args, **kwargs
3221
+ )
3222
+ else:
3223
+ raise ValueError(f"Unsupported distributed type: {distribution_type}")