kubetorch 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kubetorch/__init__.py +59 -0
- kubetorch/cli.py +1939 -0
- kubetorch/cli_utils.py +967 -0
- kubetorch/config.py +453 -0
- kubetorch/constants.py +18 -0
- kubetorch/docs/Makefile +18 -0
- kubetorch/docs/__init__.py +0 -0
- kubetorch/docs/_ext/json_globaltoc.py +42 -0
- kubetorch/docs/api/cli.rst +10 -0
- kubetorch/docs/api/python/app.rst +21 -0
- kubetorch/docs/api/python/cls.rst +19 -0
- kubetorch/docs/api/python/compute.rst +25 -0
- kubetorch/docs/api/python/config.rst +11 -0
- kubetorch/docs/api/python/fn.rst +19 -0
- kubetorch/docs/api/python/image.rst +14 -0
- kubetorch/docs/api/python/secret.rst +18 -0
- kubetorch/docs/api/python/volumes.rst +13 -0
- kubetorch/docs/api/python.rst +101 -0
- kubetorch/docs/conf.py +69 -0
- kubetorch/docs/index.rst +20 -0
- kubetorch/docs/requirements.txt +5 -0
- kubetorch/globals.py +269 -0
- kubetorch/logger.py +59 -0
- kubetorch/resources/__init__.py +0 -0
- kubetorch/resources/callables/__init__.py +0 -0
- kubetorch/resources/callables/cls/__init__.py +0 -0
- kubetorch/resources/callables/cls/cls.py +159 -0
- kubetorch/resources/callables/fn/__init__.py +0 -0
- kubetorch/resources/callables/fn/fn.py +140 -0
- kubetorch/resources/callables/module.py +1315 -0
- kubetorch/resources/callables/utils.py +203 -0
- kubetorch/resources/compute/__init__.py +0 -0
- kubetorch/resources/compute/app.py +253 -0
- kubetorch/resources/compute/compute.py +2414 -0
- kubetorch/resources/compute/decorators.py +137 -0
- kubetorch/resources/compute/utils.py +1026 -0
- kubetorch/resources/compute/websocket.py +135 -0
- kubetorch/resources/images/__init__.py +1 -0
- kubetorch/resources/images/image.py +412 -0
- kubetorch/resources/images/images.py +64 -0
- kubetorch/resources/secrets/__init__.py +2 -0
- kubetorch/resources/secrets/kubernetes_secrets_client.py +377 -0
- kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
- kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
- kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
- kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
- kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
- kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
- kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
- kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
- kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/providers.py +92 -0
- kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
- kubetorch/resources/secrets/secret.py +224 -0
- kubetorch/resources/secrets/secret_factory.py +64 -0
- kubetorch/resources/secrets/utils.py +222 -0
- kubetorch/resources/volumes/__init__.py +0 -0
- kubetorch/resources/volumes/volume.py +340 -0
- kubetorch/servers/__init__.py +0 -0
- kubetorch/servers/http/__init__.py +0 -0
- kubetorch/servers/http/distributed_utils.py +2968 -0
- kubetorch/servers/http/http_client.py +802 -0
- kubetorch/servers/http/http_server.py +1622 -0
- kubetorch/servers/http/server_metrics.py +255 -0
- kubetorch/servers/http/utils.py +722 -0
- kubetorch/serving/__init__.py +0 -0
- kubetorch/serving/autoscaling.py +153 -0
- kubetorch/serving/base_service_manager.py +344 -0
- kubetorch/serving/constants.py +77 -0
- kubetorch/serving/deployment_service_manager.py +431 -0
- kubetorch/serving/knative_service_manager.py +487 -0
- kubetorch/serving/raycluster_service_manager.py +526 -0
- kubetorch/serving/service_manager.py +18 -0
- kubetorch/serving/templates/deployment_template.yaml +17 -0
- kubetorch/serving/templates/knative_service_template.yaml +19 -0
- kubetorch/serving/templates/kt_setup_template.sh.j2 +91 -0
- kubetorch/serving/templates/pod_template.yaml +198 -0
- kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
- kubetorch/serving/templates/raycluster_template.yaml +35 -0
- kubetorch/serving/templates/service_template.yaml +21 -0
- kubetorch/serving/templates/workerset_template.yaml +36 -0
- kubetorch/serving/utils.py +344 -0
- kubetorch/utils.py +263 -0
- kubetorch-0.2.5.dist-info/METADATA +75 -0
- kubetorch-0.2.5.dist-info/RECORD +92 -0
- kubetorch-0.2.5.dist-info/WHEEL +4 -0
- kubetorch-0.2.5.dist-info/entry_points.txt +5 -0
|
@@ -0,0 +1,2968 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import multiprocessing
|
|
3
|
+
import os
|
|
4
|
+
import queue
|
|
5
|
+
import subprocess
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
import uuid
|
|
9
|
+
from bdb import BdbQuit
|
|
10
|
+
from concurrent.futures import as_completed, ThreadPoolExecutor
|
|
11
|
+
from typing import Dict, Optional
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
from starlette.responses import JSONResponse
|
|
15
|
+
|
|
16
|
+
from kubetorch.servers.http.http_server import (
|
|
17
|
+
load_callable,
|
|
18
|
+
logger,
|
|
19
|
+
package_exception,
|
|
20
|
+
patch_sys_path,
|
|
21
|
+
request_id_ctx_var,
|
|
22
|
+
run_callable_internal_sync,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from .utils import clear_debugging_sessions, is_running_in_kubernetes
|
|
26
|
+
|
|
27
|
+
# Try to import Monarch components at module level if available
|
|
28
|
+
# This helps avoid threading issues with Monarch's Rust bindings
|
|
29
|
+
try:
|
|
30
|
+
from monarch._src.actor.allocator import RemoteAllocator, StaticRemoteAllocInitializer
|
|
31
|
+
|
|
32
|
+
MONARCH_AVAILABLE = True
|
|
33
|
+
except ImportError:
|
|
34
|
+
MONARCH_AVAILABLE = False
|
|
35
|
+
RemoteAllocator = None
|
|
36
|
+
StaticRemoteAllocInitializer = None
|
|
37
|
+
except Exception:
|
|
38
|
+
# Catch any other exceptions during import
|
|
39
|
+
MONARCH_AVAILABLE = False
|
|
40
|
+
RemoteAllocator = None
|
|
41
|
+
StaticRemoteAllocInitializer = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class RemoteWorkerPool:
|
|
45
|
+
"""Manages async HTTP calls to remote workers in a separate process."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, quorum_timeout=300):
|
|
48
|
+
self.quorum_timeout = quorum_timeout
|
|
49
|
+
self.request_queue = multiprocessing.Queue()
|
|
50
|
+
self.response_queue = multiprocessing.Queue()
|
|
51
|
+
self.process = None
|
|
52
|
+
self._running = False
|
|
53
|
+
# Thread-safe response routing for concurrent calls
|
|
54
|
+
self._response_events = {}
|
|
55
|
+
self._response_lock = threading.Lock()
|
|
56
|
+
self._router_thread = None
|
|
57
|
+
|
|
58
|
+
def start(self, max_workers=2000):
|
|
59
|
+
"""Start the worker process and router thread."""
|
|
60
|
+
if self.process:
|
|
61
|
+
raise RuntimeError("WorkerPool already started")
|
|
62
|
+
|
|
63
|
+
self._running = True
|
|
64
|
+
# Pass necessary data as arguments to avoid pickling issues
|
|
65
|
+
self.process = multiprocessing.Process(
|
|
66
|
+
target=self._run_async_worker,
|
|
67
|
+
args=(
|
|
68
|
+
self.request_queue,
|
|
69
|
+
self.response_queue,
|
|
70
|
+
self.quorum_timeout,
|
|
71
|
+
max_workers,
|
|
72
|
+
),
|
|
73
|
+
daemon=True,
|
|
74
|
+
)
|
|
75
|
+
self.process.start()
|
|
76
|
+
|
|
77
|
+
# Start router thread for handling responses
|
|
78
|
+
self._router_thread = threading.Thread(target=self._response_router, daemon=True)
|
|
79
|
+
self._router_thread.start()
|
|
80
|
+
|
|
81
|
+
logger.debug("Started RemoteWorkerPool process and router thread")
|
|
82
|
+
|
|
83
|
+
def stop(self):
|
|
84
|
+
"""Stop the worker process and router thread."""
|
|
85
|
+
self._running = False
|
|
86
|
+
|
|
87
|
+
# Stop router thread
|
|
88
|
+
if self._router_thread:
|
|
89
|
+
self.response_queue.put({"request_id": "STOP_ROUTER"})
|
|
90
|
+
self._router_thread.join(timeout=1)
|
|
91
|
+
|
|
92
|
+
# Stop worker process
|
|
93
|
+
if self.process:
|
|
94
|
+
self.request_queue.put(("SHUTDOWN", None))
|
|
95
|
+
self.process.join(timeout=5)
|
|
96
|
+
if self.process and self.process.is_alive():
|
|
97
|
+
self.process.terminate()
|
|
98
|
+
self.process.join(timeout=1)
|
|
99
|
+
if self.process and self.process.is_alive():
|
|
100
|
+
self.process.kill()
|
|
101
|
+
self.process = None
|
|
102
|
+
|
|
103
|
+
# Clear response events
|
|
104
|
+
with self._response_lock:
|
|
105
|
+
self._response_events.clear()
|
|
106
|
+
|
|
107
|
+
logger.debug("Stopped RemoteWorkerPool process and router thread")
|
|
108
|
+
|
|
109
|
+
def call_workers(
|
|
110
|
+
self,
|
|
111
|
+
worker_ips,
|
|
112
|
+
cls_or_fn_name,
|
|
113
|
+
method_name,
|
|
114
|
+
params,
|
|
115
|
+
request_headers,
|
|
116
|
+
workers_arg="all",
|
|
117
|
+
):
|
|
118
|
+
"""Call remote workers and return responses."""
|
|
119
|
+
if not self.process or not self.process.is_alive():
|
|
120
|
+
raise RuntimeError("RemoteWorkerPool not running")
|
|
121
|
+
|
|
122
|
+
# Generate unique request ID
|
|
123
|
+
request_id = str(uuid.uuid4())
|
|
124
|
+
|
|
125
|
+
# Submit request
|
|
126
|
+
request_data = {
|
|
127
|
+
"request_id": request_id,
|
|
128
|
+
"worker_ips": worker_ips,
|
|
129
|
+
"cls_or_fn_name": cls_or_fn_name,
|
|
130
|
+
"method_name": method_name,
|
|
131
|
+
"params": params,
|
|
132
|
+
"request_headers": request_headers,
|
|
133
|
+
"workers_arg": workers_arg,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
# Register event for this request
|
|
137
|
+
event = threading.Event()
|
|
138
|
+
with self._response_lock:
|
|
139
|
+
self._response_events[request_id] = (event, None)
|
|
140
|
+
|
|
141
|
+
logger.debug(f"RemoteWorkerPool: Submitting request {request_id} to queue for {len(worker_ips)} workers")
|
|
142
|
+
self.request_queue.put(("CALL", request_data))
|
|
143
|
+
|
|
144
|
+
# Wait for response with a timeout to prevent hanging forever
|
|
145
|
+
logger.debug(f"RemoteWorkerPool: Waiting for response for request {request_id} from {len(worker_ips)} workers")
|
|
146
|
+
# Wait indefinitely for the response (no timeout for long-running jobs)
|
|
147
|
+
event.wait()
|
|
148
|
+
|
|
149
|
+
# Get and cleanup response
|
|
150
|
+
with self._response_lock:
|
|
151
|
+
_, result = self._response_events.pop(request_id)
|
|
152
|
+
|
|
153
|
+
logger.debug(f"RemoteWorkerPool: Got response for request {request_id}, type: {type(result).__name__}")
|
|
154
|
+
|
|
155
|
+
if isinstance(result, Exception):
|
|
156
|
+
raise result
|
|
157
|
+
return result
|
|
158
|
+
|
|
159
|
+
def _response_router(self):
|
|
160
|
+
"""Router thread that distributes responses to waiting threads."""
|
|
161
|
+
logger.debug("RemoteWorkerPool response router thread started")
|
|
162
|
+
while self._running:
|
|
163
|
+
try:
|
|
164
|
+
response = self.response_queue.get(timeout=0.1)
|
|
165
|
+
if response.get("request_id") == "STOP_ROUTER":
|
|
166
|
+
break
|
|
167
|
+
|
|
168
|
+
request_id = response.get("request_id")
|
|
169
|
+
logger.debug(f"Response router received response for request {request_id}")
|
|
170
|
+
with self._response_lock:
|
|
171
|
+
if request_id in self._response_events:
|
|
172
|
+
event, _ = self._response_events[request_id]
|
|
173
|
+
# Store the result (either results list or exception)
|
|
174
|
+
if "error" in response:
|
|
175
|
+
logger.debug(f"Response router: Setting error for request {request_id}")
|
|
176
|
+
self._response_events[request_id] = (
|
|
177
|
+
event,
|
|
178
|
+
response["error"],
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
logger.debug(
|
|
182
|
+
f"Response router: Setting {len(response.get('results', []))} results for request {request_id}"
|
|
183
|
+
)
|
|
184
|
+
self._response_events[request_id] = (
|
|
185
|
+
event,
|
|
186
|
+
response["results"],
|
|
187
|
+
)
|
|
188
|
+
event.set()
|
|
189
|
+
logger.debug(f"Response router: Event set for request {request_id}")
|
|
190
|
+
else:
|
|
191
|
+
logger.warning(
|
|
192
|
+
f"Response router: No event found for request {request_id}, registered events: {list(self._response_events.keys())}"
|
|
193
|
+
)
|
|
194
|
+
except queue.Empty:
|
|
195
|
+
continue # Queue timeout, continue checking
|
|
196
|
+
except Exception as e:
|
|
197
|
+
logger.error(f"Error in response router: {e}")
|
|
198
|
+
continue
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def _run_async_worker(request_queue, response_queue, quorum_timeout, max_workers=2000):
|
|
202
|
+
"""Worker process that handles async HTTP calls.
|
|
203
|
+
|
|
204
|
+
Architecture:
|
|
205
|
+
- Runs in a separate process with its own event loop
|
|
206
|
+
- Maintains a single shared httpx.AsyncClient for connection pooling
|
|
207
|
+
- Processes requests from queue concurrently (multiple requests in flight)
|
|
208
|
+
- Each request involves parallel async HTTP calls to multiple workers
|
|
209
|
+
|
|
210
|
+
Nested functions (share queue/timeout state):
|
|
211
|
+
- wait_for_worker_health: Health check with retries
|
|
212
|
+
- call_single_worker: Make HTTP call to one worker
|
|
213
|
+
- call_workers_async: Orchestrate health checks + calls for all workers
|
|
214
|
+
- process_requests: Main loop pulling from queue and dispatching tasks
|
|
215
|
+
"""
|
|
216
|
+
import asyncio
|
|
217
|
+
import queue
|
|
218
|
+
|
|
219
|
+
import httpx
|
|
220
|
+
|
|
221
|
+
async def wait_for_worker_health(client, worker_ip, workers_arg, quorum_timeout):
|
|
222
|
+
"""Wait for a worker to become healthy within timeout."""
|
|
223
|
+
port = os.environ["KT_SERVER_PORT"]
|
|
224
|
+
worker_url = f"http://{worker_ip}:{port}"
|
|
225
|
+
|
|
226
|
+
start_time = asyncio.get_event_loop().time()
|
|
227
|
+
|
|
228
|
+
# Keep trying until timeout
|
|
229
|
+
while (asyncio.get_event_loop().time() - start_time) < quorum_timeout:
|
|
230
|
+
try:
|
|
231
|
+
resp = await client.get(f"{worker_url}/health", timeout=10.0)
|
|
232
|
+
if resp.status_code == 200:
|
|
233
|
+
return (worker_ip, True) # Return IP and success
|
|
234
|
+
except (httpx.RequestError, httpx.TimeoutException):
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
# No waiting for quorum if user just wants to call ready workers
|
|
238
|
+
if workers_arg == "ready":
|
|
239
|
+
return worker_ip, False
|
|
240
|
+
|
|
241
|
+
# Wait before retry (1 second, same as original)
|
|
242
|
+
await asyncio.sleep(1.0)
|
|
243
|
+
|
|
244
|
+
# Timeout reached
|
|
245
|
+
return worker_ip, False
|
|
246
|
+
|
|
247
|
+
async def call_single_worker(
|
|
248
|
+
client,
|
|
249
|
+
worker_ip,
|
|
250
|
+
cls_or_fn_name,
|
|
251
|
+
method_name,
|
|
252
|
+
params,
|
|
253
|
+
request_headers,
|
|
254
|
+
workers_arg,
|
|
255
|
+
):
|
|
256
|
+
"""Call a single worker (assumes health already checked)."""
|
|
257
|
+
port = os.environ["KT_SERVER_PORT"]
|
|
258
|
+
worker_url = f"http://{worker_ip}:{port}"
|
|
259
|
+
|
|
260
|
+
# Make the actual call
|
|
261
|
+
call_url = f"{worker_url}/{cls_or_fn_name}"
|
|
262
|
+
if method_name:
|
|
263
|
+
call_url += f"/{method_name}"
|
|
264
|
+
call_url += "?distributed_subcall=true"
|
|
265
|
+
|
|
266
|
+
logger.debug(f"Async worker: Making POST to {call_url}")
|
|
267
|
+
|
|
268
|
+
# Retry logic for transient failures
|
|
269
|
+
max_retries = 3
|
|
270
|
+
for attempt in range(max_retries):
|
|
271
|
+
try:
|
|
272
|
+
resp = await client.post(call_url, json=params, headers=request_headers)
|
|
273
|
+
result = resp.json()
|
|
274
|
+
break # Success, exit retry loop
|
|
275
|
+
except httpx.ReadError as e:
|
|
276
|
+
# Check if this is due to server shutdown (connection reset)
|
|
277
|
+
if "Connection reset" in str(e) or "Connection closed" in str(e):
|
|
278
|
+
logger.warning(f"Worker {worker_ip} appears to be shutting down: {e}")
|
|
279
|
+
raise # Don't retry on shutdown
|
|
280
|
+
if attempt < max_retries - 1:
|
|
281
|
+
wait_time = (attempt + 1) * 2 # Exponential backoff: 2s, 4s
|
|
282
|
+
logger.warning(
|
|
283
|
+
f"ReadError calling {worker_ip} (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
|
|
284
|
+
)
|
|
285
|
+
await asyncio.sleep(wait_time)
|
|
286
|
+
else:
|
|
287
|
+
logger.error(
|
|
288
|
+
f"ReadError calling {worker_ip} after {max_retries} attempts: {e}. Worker may be crashed/overloaded."
|
|
289
|
+
)
|
|
290
|
+
raise
|
|
291
|
+
except httpx.TimeoutException as e:
|
|
292
|
+
if attempt < max_retries - 1:
|
|
293
|
+
logger.warning(
|
|
294
|
+
f"Timeout calling {worker_ip} (attempt {attempt + 1}/{max_retries}): {e}. Retrying..."
|
|
295
|
+
)
|
|
296
|
+
await asyncio.sleep(1)
|
|
297
|
+
else:
|
|
298
|
+
logger.error(f"Timeout calling {worker_ip} after {max_retries} attempts: {e}")
|
|
299
|
+
raise
|
|
300
|
+
except Exception as e:
|
|
301
|
+
logger.error(f"Unexpected error calling {worker_ip}: {e}")
|
|
302
|
+
raise
|
|
303
|
+
logger.debug(
|
|
304
|
+
f"Async worker: Got response from {worker_ip}, type: {type(result).__name__}, "
|
|
305
|
+
f"length: {len(result) if hasattr(result, '__len__') else 'N/A'}"
|
|
306
|
+
)
|
|
307
|
+
# In tree topology, intermediate nodes return aggregated results from their subtree
|
|
308
|
+
# We should preserve the flat list structure
|
|
309
|
+
return result
|
|
310
|
+
|
|
311
|
+
async def call_workers_async(client, data):
|
|
312
|
+
"""Make async calls to all workers using shared client."""
|
|
313
|
+
worker_ips = data["worker_ips"]
|
|
314
|
+
cls_or_fn_name = data["cls_or_fn_name"]
|
|
315
|
+
method_name = data["method_name"]
|
|
316
|
+
params = data["params"]
|
|
317
|
+
request_headers = data["request_headers"]
|
|
318
|
+
workers_arg = data["workers_arg"]
|
|
319
|
+
|
|
320
|
+
# With tree topology limiting fanout, we don't need to batch health checks
|
|
321
|
+
# Each node only calls its direct children in the tree
|
|
322
|
+
logger.info(f"Waiting for {len(worker_ips)} workers to become ready (timeout={quorum_timeout}s)")
|
|
323
|
+
health_tasks = [wait_for_worker_health(client, ip, workers_arg, quorum_timeout) for ip in worker_ips]
|
|
324
|
+
health_results = await asyncio.gather(*health_tasks)
|
|
325
|
+
|
|
326
|
+
# Process results
|
|
327
|
+
healthy_workers = []
|
|
328
|
+
unhealthy_workers = []
|
|
329
|
+
for worker_ip, is_healthy in health_results:
|
|
330
|
+
if is_healthy:
|
|
331
|
+
healthy_workers.append(worker_ip)
|
|
332
|
+
else:
|
|
333
|
+
unhealthy_workers.append(worker_ip)
|
|
334
|
+
|
|
335
|
+
if unhealthy_workers:
|
|
336
|
+
if workers_arg == "ready":
|
|
337
|
+
# For "ready" mode, just skip unhealthy workers
|
|
338
|
+
logger.info(f"Skipping {len(unhealthy_workers)} workers that didn't respond (ready mode)")
|
|
339
|
+
else:
|
|
340
|
+
# For normal mode, fail if any worker didn't become ready
|
|
341
|
+
logger.error(
|
|
342
|
+
f"{len(unhealthy_workers)} workers failed to become ready after {quorum_timeout}s: {unhealthy_workers[:5]}..."
|
|
343
|
+
)
|
|
344
|
+
raise TimeoutError(
|
|
345
|
+
f"{len(unhealthy_workers)} of {len(worker_ips)} workers did not become ready within {quorum_timeout} seconds. "
|
|
346
|
+
f"This may indicate the pods are still starting or there's a resource constraint. "
|
|
347
|
+
f"Consider increasing quorum_timeout in .distribute() call."
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
logger.info(f"All {len(healthy_workers)} workers are ready, making distributed calls")
|
|
351
|
+
|
|
352
|
+
# Now make the actual calls to ready workers
|
|
353
|
+
# Create tasks (not just coroutines)
|
|
354
|
+
tasks = []
|
|
355
|
+
for worker_ip in healthy_workers:
|
|
356
|
+
coro = call_single_worker(
|
|
357
|
+
client,
|
|
358
|
+
worker_ip,
|
|
359
|
+
cls_or_fn_name,
|
|
360
|
+
method_name,
|
|
361
|
+
params,
|
|
362
|
+
request_headers,
|
|
363
|
+
workers_arg,
|
|
364
|
+
)
|
|
365
|
+
# Create actual task from coroutine
|
|
366
|
+
task = asyncio.create_task(coro)
|
|
367
|
+
tasks.append(task)
|
|
368
|
+
|
|
369
|
+
# Use as_completed for fast failure propagation
|
|
370
|
+
responses = []
|
|
371
|
+
pending_tasks = set(tasks)
|
|
372
|
+
|
|
373
|
+
for future in asyncio.as_completed(tasks):
|
|
374
|
+
try:
|
|
375
|
+
result = await future
|
|
376
|
+
if result is not None: # Skip None results
|
|
377
|
+
responses.append(result)
|
|
378
|
+
# Remove completed task from pending set
|
|
379
|
+
pending_tasks.discard(future)
|
|
380
|
+
except Exception as e:
|
|
381
|
+
# Fast failure - immediately propagate the exception
|
|
382
|
+
# Cancel remaining tasks to avoid unnecessary work
|
|
383
|
+
for task in pending_tasks:
|
|
384
|
+
if not task.done():
|
|
385
|
+
task.cancel()
|
|
386
|
+
raise e
|
|
387
|
+
|
|
388
|
+
return responses
|
|
389
|
+
|
|
390
|
+
# Create and run event loop with shared AsyncClient
|
|
391
|
+
async def main():
|
|
392
|
+
# Set up signal handler for graceful shutdown
|
|
393
|
+
import signal
|
|
394
|
+
|
|
395
|
+
shutdown_event = asyncio.Event()
|
|
396
|
+
|
|
397
|
+
# Save existing handlers to chain to them
|
|
398
|
+
original_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
|
399
|
+
original_sigint_handler = signal.getsignal(signal.SIGINT)
|
|
400
|
+
|
|
401
|
+
def signal_handler(signum, frame):
|
|
402
|
+
logger.info(f"Received signal {signum}, initiating graceful shutdown")
|
|
403
|
+
shutdown_event.set()
|
|
404
|
+
|
|
405
|
+
# Chain to original handler if it exists
|
|
406
|
+
if signum == signal.SIGTERM:
|
|
407
|
+
if original_sigterm_handler and original_sigterm_handler not in (
|
|
408
|
+
signal.SIG_DFL,
|
|
409
|
+
signal.SIG_IGN,
|
|
410
|
+
):
|
|
411
|
+
original_sigterm_handler(signum, frame)
|
|
412
|
+
elif signum == signal.SIGINT:
|
|
413
|
+
if original_sigint_handler and original_sigint_handler not in (
|
|
414
|
+
signal.SIG_DFL,
|
|
415
|
+
signal.SIG_IGN,
|
|
416
|
+
):
|
|
417
|
+
original_sigint_handler(signum, frame)
|
|
418
|
+
|
|
419
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
420
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
421
|
+
|
|
422
|
+
# Create a single AsyncClient to be shared across all requests
|
|
423
|
+
# Set limits based on max expected workers (passed from parent)
|
|
424
|
+
# We need exactly max_workers connections since health checks and work calls
|
|
425
|
+
# should reuse the same connection per worker (HTTP keep-alive)
|
|
426
|
+
# Add small buffer for edge cases
|
|
427
|
+
buffer = max(100, int(max_workers * 0.1)) # 10% buffer, min 100
|
|
428
|
+
max_conn = max_workers + buffer
|
|
429
|
+
|
|
430
|
+
limits = httpx.Limits(
|
|
431
|
+
max_keepalive_connections=max_conn, # Keep all connections alive
|
|
432
|
+
max_connections=max_conn, # One connection per worker + buffer
|
|
433
|
+
keepalive_expiry=300.0, # Keep connections alive for 5 minutes
|
|
434
|
+
)
|
|
435
|
+
timeout = httpx.Timeout(
|
|
436
|
+
connect=10.0,
|
|
437
|
+
read=None, # No read timeout for long-running jobs
|
|
438
|
+
write=10.0,
|
|
439
|
+
pool=60.0, # Time to wait for a connection from pool
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
logger.debug(
|
|
443
|
+
f"AsyncClient configured with max_connections={max_conn} "
|
|
444
|
+
f"(workers={max_workers} + buffer={buffer}) with 5min keepalive"
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Note: http2=True would enable HTTP/2 if server supports it
|
|
448
|
+
async with httpx.AsyncClient(timeout=timeout, limits=limits) as client:
|
|
449
|
+
logger.debug("Async worker started with shared httpx.AsyncClient")
|
|
450
|
+
|
|
451
|
+
active_tasks = set()
|
|
452
|
+
|
|
453
|
+
async def process_call(request_data):
|
|
454
|
+
"""Process a CALL request asynchronously."""
|
|
455
|
+
request_id = request_data["request_id"]
|
|
456
|
+
worker_ips = request_data["worker_ips"]
|
|
457
|
+
logger.debug(
|
|
458
|
+
f"Async worker: Processing request {request_id} with {len(worker_ips)} workers: {worker_ips}"
|
|
459
|
+
)
|
|
460
|
+
try:
|
|
461
|
+
results = await call_workers_async(client, request_data)
|
|
462
|
+
logger.debug(
|
|
463
|
+
f"Async worker: Successfully got {len(results)} results for request {request_id}, sending response"
|
|
464
|
+
)
|
|
465
|
+
response_queue.put({"request_id": request_id, "results": results})
|
|
466
|
+
except Exception as e:
|
|
467
|
+
logger.error(f"Error processing request {request_id}: {e}")
|
|
468
|
+
response_queue.put({"request_id": request_id, "error": e})
|
|
469
|
+
|
|
470
|
+
# Main request processing loop
|
|
471
|
+
while not shutdown_event.is_set():
|
|
472
|
+
try:
|
|
473
|
+
# Check queue without blocking event loop
|
|
474
|
+
cmd, request_data = await asyncio.to_thread(request_queue.get, block=True, timeout=0.1)
|
|
475
|
+
|
|
476
|
+
if cmd == "SHUTDOWN" or shutdown_event.is_set():
|
|
477
|
+
# Cancel all active tasks immediately for quick cleanup
|
|
478
|
+
if active_tasks:
|
|
479
|
+
logger.debug(f"Cancelling {len(active_tasks)} active tasks for shutdown")
|
|
480
|
+
for task in active_tasks:
|
|
481
|
+
task.cancel()
|
|
482
|
+
break
|
|
483
|
+
|
|
484
|
+
elif cmd == "CALL":
|
|
485
|
+
# Create a task to handle this request concurrently
|
|
486
|
+
task = asyncio.create_task(process_call(request_data))
|
|
487
|
+
active_tasks.add(task)
|
|
488
|
+
|
|
489
|
+
# Clean up completed tasks periodically
|
|
490
|
+
active_tasks = {t for t in active_tasks if not t.done()}
|
|
491
|
+
|
|
492
|
+
except queue.Empty:
|
|
493
|
+
# Clean up completed tasks while waiting
|
|
494
|
+
active_tasks = {t for t in active_tasks if not t.done()}
|
|
495
|
+
except Exception as e:
|
|
496
|
+
logger.error(f"Error in async worker loop: {e}")
|
|
497
|
+
|
|
498
|
+
loop = asyncio.new_event_loop()
|
|
499
|
+
asyncio.set_event_loop(loop)
|
|
500
|
+
try:
|
|
501
|
+
loop.run_until_complete(main())
|
|
502
|
+
finally:
|
|
503
|
+
loop.close()
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
class DistributedProcessPool:
|
|
507
|
+
"""Unified pool managing distributed processes with single router thread."""
|
|
508
|
+
|
|
509
|
+
def __init__(self, process_class, num_processes, max_threads_per_proc=10, **process_kwargs):
|
|
510
|
+
self.process_class = process_class
|
|
511
|
+
self.num_processes = num_processes
|
|
512
|
+
self.max_threads_per_proc = max_threads_per_proc
|
|
513
|
+
self.process_kwargs = process_kwargs # Additional kwargs to pass to process constructor
|
|
514
|
+
|
|
515
|
+
# Processes and queues
|
|
516
|
+
self.processes = []
|
|
517
|
+
self.request_queues = [] # One request queue per process
|
|
518
|
+
self.response_queue = multiprocessing.Queue() # Single shared response queue
|
|
519
|
+
|
|
520
|
+
# Response routing with single router thread
|
|
521
|
+
self._router_thread = None
|
|
522
|
+
self._response_events = {} # Maps request_id to (threading.Event, response)
|
|
523
|
+
self._response_lock = threading.Lock()
|
|
524
|
+
self._running = False
|
|
525
|
+
|
|
526
|
+
def start(self):
|
|
527
|
+
"""Start all processes in the pool and the single router thread."""
|
|
528
|
+
if self.processes:
|
|
529
|
+
raise RuntimeError("Pool already started")
|
|
530
|
+
|
|
531
|
+
# Create and start all processes in parallel
|
|
532
|
+
with ThreadPoolExecutor(max_workers=self.num_processes) as executor:
|
|
533
|
+
futures = []
|
|
534
|
+
for i in range(self.num_processes):
|
|
535
|
+
future = executor.submit(self._create_and_start_process, i)
|
|
536
|
+
futures.append(future)
|
|
537
|
+
|
|
538
|
+
# Wait for all processes to start
|
|
539
|
+
for future in futures:
|
|
540
|
+
future.result()
|
|
541
|
+
|
|
542
|
+
# Start single response router thread for entire pool
|
|
543
|
+
self._running = True
|
|
544
|
+
self._router_thread = threading.Thread(target=self._response_router, daemon=True, name="PoolResponseRouter")
|
|
545
|
+
self._router_thread.start()
|
|
546
|
+
logger.debug(f"Started {self.num_processes} processes with single router thread")
|
|
547
|
+
|
|
548
|
+
def _create_and_start_process(self, local_rank):
|
|
549
|
+
"""Helper to create and start a single process."""
|
|
550
|
+
request_queue = multiprocessing.Queue()
|
|
551
|
+
self.request_queues.append(request_queue)
|
|
552
|
+
|
|
553
|
+
process = self.process_class(
|
|
554
|
+
local_rank=local_rank,
|
|
555
|
+
request_queue=request_queue,
|
|
556
|
+
response_queue=self.response_queue, # Shared response queue
|
|
557
|
+
max_threads=self.max_threads_per_proc,
|
|
558
|
+
**self.process_kwargs, # Pass additional framework-specific settings
|
|
559
|
+
)
|
|
560
|
+
process.start()
|
|
561
|
+
self.processes.append(process)
|
|
562
|
+
|
|
563
|
+
def stop(self):
|
|
564
|
+
"""Stop all processes and the router thread."""
|
|
565
|
+
self._running = False
|
|
566
|
+
|
|
567
|
+
# Send shutdown signal to all processes (use put_nowait to avoid blocking)
|
|
568
|
+
for q in self.request_queues:
|
|
569
|
+
try:
|
|
570
|
+
q.put_nowait("SHUTDOWN")
|
|
571
|
+
except Exception:
|
|
572
|
+
pass
|
|
573
|
+
|
|
574
|
+
# Stop router thread
|
|
575
|
+
try:
|
|
576
|
+
self.response_queue.put_nowait("STOP_ROUTER")
|
|
577
|
+
except Exception:
|
|
578
|
+
pass
|
|
579
|
+
|
|
580
|
+
# Wait briefly for router to stop
|
|
581
|
+
if self._router_thread:
|
|
582
|
+
self._router_thread.join(timeout=0.5)
|
|
583
|
+
|
|
584
|
+
# Terminate all processes immediately without waiting
|
|
585
|
+
for process in self.processes:
|
|
586
|
+
if process.is_alive():
|
|
587
|
+
process.terminate()
|
|
588
|
+
|
|
589
|
+
# Give processes a brief chance to terminate gracefully (reduced timeout)
|
|
590
|
+
for process in self.processes:
|
|
591
|
+
process.join(timeout=0.5)
|
|
592
|
+
|
|
593
|
+
# Force kill any remaining processes
|
|
594
|
+
for process in self.processes:
|
|
595
|
+
if process.is_alive():
|
|
596
|
+
logger.warning(f"Force killing process {process.pid}")
|
|
597
|
+
process.kill()
|
|
598
|
+
process.join(timeout=0.1) # Brief wait to confirm kill
|
|
599
|
+
|
|
600
|
+
# Clear all queues
|
|
601
|
+
self._clear_queues()
|
|
602
|
+
|
|
603
|
+
# Reset state
|
|
604
|
+
self.processes.clear()
|
|
605
|
+
self.request_queues.clear()
|
|
606
|
+
with self._response_lock:
|
|
607
|
+
self._response_events.clear()
|
|
608
|
+
|
|
609
|
+
def call(
|
|
610
|
+
self,
|
|
611
|
+
idx,
|
|
612
|
+
method_name,
|
|
613
|
+
params,
|
|
614
|
+
deployed_as_of,
|
|
615
|
+
request_id,
|
|
616
|
+
distributed_env_vars,
|
|
617
|
+
debug_port,
|
|
618
|
+
serialization,
|
|
619
|
+
):
|
|
620
|
+
"""Call a specific process by index."""
|
|
621
|
+
if idx >= len(self.processes):
|
|
622
|
+
raise ValueError(f"Process index {idx} out of range (have {len(self.processes)} processes)")
|
|
623
|
+
|
|
624
|
+
request_unique_id = str(uuid.uuid4())
|
|
625
|
+
|
|
626
|
+
# Register this request for response routing
|
|
627
|
+
event = threading.Event()
|
|
628
|
+
with self._response_lock:
|
|
629
|
+
self._response_events[request_unique_id] = (event, None)
|
|
630
|
+
|
|
631
|
+
try:
|
|
632
|
+
# Send request to specific process
|
|
633
|
+
self.request_queues[idx].put(
|
|
634
|
+
{
|
|
635
|
+
"request_unique_id": request_unique_id,
|
|
636
|
+
"method_name": method_name,
|
|
637
|
+
"params": params,
|
|
638
|
+
"deployed_as_of": deployed_as_of,
|
|
639
|
+
"request_id": request_id,
|
|
640
|
+
"distributed_env_vars": distributed_env_vars,
|
|
641
|
+
"debug_port": debug_port,
|
|
642
|
+
"serialization": serialization,
|
|
643
|
+
"process_idx": idx, # Include process index for debugging
|
|
644
|
+
}
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
# Wait for response
|
|
648
|
+
event.wait()
|
|
649
|
+
|
|
650
|
+
# Get and return the response
|
|
651
|
+
with self._response_lock:
|
|
652
|
+
_, result = self._response_events.pop(request_unique_id)
|
|
653
|
+
|
|
654
|
+
return result
|
|
655
|
+
|
|
656
|
+
except Exception:
|
|
657
|
+
# Clean up on error
|
|
658
|
+
with self._response_lock:
|
|
659
|
+
self._response_events.pop(request_unique_id, None)
|
|
660
|
+
raise
|
|
661
|
+
|
|
662
|
+
def call_all(
|
|
663
|
+
self,
|
|
664
|
+
method_name,
|
|
665
|
+
params_list,
|
|
666
|
+
deployed_as_of,
|
|
667
|
+
request_id,
|
|
668
|
+
distributed_env_vars_list,
|
|
669
|
+
debug_ports,
|
|
670
|
+
serialization,
|
|
671
|
+
):
|
|
672
|
+
"""Call all processes in parallel and return results."""
|
|
673
|
+
if len(params_list) != self.num_processes:
|
|
674
|
+
raise ValueError(f"Expected {self.num_processes} param sets, got {len(params_list)}")
|
|
675
|
+
|
|
676
|
+
with ThreadPoolExecutor(max_workers=self.num_processes) as executor:
|
|
677
|
+
futures = []
|
|
678
|
+
for idx in range(self.num_processes):
|
|
679
|
+
future = executor.submit(
|
|
680
|
+
self.call,
|
|
681
|
+
idx=idx,
|
|
682
|
+
method_name=method_name,
|
|
683
|
+
params=params_list[idx],
|
|
684
|
+
deployed_as_of=deployed_as_of,
|
|
685
|
+
request_id=request_id,
|
|
686
|
+
distributed_env_vars=distributed_env_vars_list[idx],
|
|
687
|
+
debug_port=debug_ports[idx] if debug_ports else None,
|
|
688
|
+
serialization=serialization,
|
|
689
|
+
)
|
|
690
|
+
futures.append(future)
|
|
691
|
+
|
|
692
|
+
results = []
|
|
693
|
+
for future in futures:
|
|
694
|
+
results.append(future.result())
|
|
695
|
+
|
|
696
|
+
return results
|
|
697
|
+
|
|
698
|
+
def _response_router(self):
|
|
699
|
+
"""Single router thread handling responses from all processes."""
|
|
700
|
+
while self._running:
|
|
701
|
+
try:
|
|
702
|
+
response = self.response_queue.get(timeout=1)
|
|
703
|
+
if response == "STOP_ROUTER":
|
|
704
|
+
break
|
|
705
|
+
|
|
706
|
+
request_id = response.get("request_unique_id")
|
|
707
|
+
with self._response_lock:
|
|
708
|
+
if request_id in self._response_events:
|
|
709
|
+
event, _ = self._response_events[request_id]
|
|
710
|
+
self._response_events[request_id] = (event, response["result"])
|
|
711
|
+
event.set()
|
|
712
|
+
else:
|
|
713
|
+
logger.warning(f"Received response for unknown request: {request_id}")
|
|
714
|
+
|
|
715
|
+
except Exception as e:
|
|
716
|
+
if "Empty" not in str(e.__class__.__name__):
|
|
717
|
+
logger.debug(f"Response router error: {e}")
|
|
718
|
+
continue
|
|
719
|
+
|
|
720
|
+
def _clear_queues(self):
|
|
721
|
+
"""Clear all pending items from queues."""
|
|
722
|
+
for q in self.request_queues:
|
|
723
|
+
try:
|
|
724
|
+
while not q.empty():
|
|
725
|
+
q.get_nowait()
|
|
726
|
+
except Exception:
|
|
727
|
+
pass
|
|
728
|
+
|
|
729
|
+
try:
|
|
730
|
+
while not self.response_queue.empty():
|
|
731
|
+
self.response_queue.get_nowait()
|
|
732
|
+
except Exception:
|
|
733
|
+
pass
|
|
734
|
+
|
|
735
|
+
def __len__(self):
|
|
736
|
+
return len(self.processes)
|
|
737
|
+
|
|
738
|
+
def __getitem__(self, idx):
|
|
739
|
+
return self.processes[idx]
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
class DistributedSupervisor:
|
|
743
|
+
def __init__(self, quorum_workers=None, quorum_timeout=300, monitor_members=True):
|
|
744
|
+
"""
|
|
745
|
+
Base class for distributed supervisors. This class should be subclassed for specific distributed
|
|
746
|
+
environments like PyTorch or Ray.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
config: Optional configuration object for the distributed environment.
|
|
750
|
+
"""
|
|
751
|
+
# Set after creation by the factory function
|
|
752
|
+
self.quorum_workers = quorum_workers
|
|
753
|
+
self.quorum_timeout = quorum_timeout
|
|
754
|
+
self.monitor_members = monitor_members
|
|
755
|
+
|
|
756
|
+
self.config_hash = None
|
|
757
|
+
|
|
758
|
+
# DNS monitoring state
|
|
759
|
+
self._dns_monitor_thread = None
|
|
760
|
+
self._dns_monitor_running = False
|
|
761
|
+
self._current_workers = set()
|
|
762
|
+
self._workers_lock = threading.Lock()
|
|
763
|
+
self._membership_changes = queue.Queue()
|
|
764
|
+
self._change_subscribers = []
|
|
765
|
+
self._last_dns_check = 0
|
|
766
|
+
self._dns_check_interval = 5 # seconds
|
|
767
|
+
|
|
768
|
+
def pod_ips(self):
|
|
769
|
+
"""Get pod IPs from DNS, waiting for quorum if specified.
|
|
770
|
+
|
|
771
|
+
Will wait up to quorum_timeout seconds for quorum_workers to appear in DNS.
|
|
772
|
+
If quorum_workers is not specified, returns immediately after first DNS query.
|
|
773
|
+
"""
|
|
774
|
+
# Primarily for testing
|
|
775
|
+
if not is_running_in_kubernetes():
|
|
776
|
+
return os.environ["LOCAL_IPS"].split(",")
|
|
777
|
+
|
|
778
|
+
# Use DNS-based service discovery instead of Kubernetes API
|
|
779
|
+
# Check if pre-computed DNS name is available (should point to headless service for distributed)
|
|
780
|
+
service_dns = os.environ.get("KT_SERVICE_DNS")
|
|
781
|
+
|
|
782
|
+
if not service_dns:
|
|
783
|
+
# Fall back to computing DNS name from service and namespace
|
|
784
|
+
service_name = os.environ.get("KT_SERVICE_NAME")
|
|
785
|
+
namespace = os.environ.get("POD_NAMESPACE")
|
|
786
|
+
|
|
787
|
+
if not service_name:
|
|
788
|
+
raise RuntimeError("KT_SERVICE environment variable not found")
|
|
789
|
+
if not namespace:
|
|
790
|
+
raise RuntimeError("POD_NAMESPACE environment variable not found")
|
|
791
|
+
|
|
792
|
+
# Kubernetes headless service DNS name for distributed pod discovery
|
|
793
|
+
# Format: <service-name>-headless.<namespace>.svc.cluster.local
|
|
794
|
+
service_dns = f"{service_name}-headless.{namespace}.svc.cluster.local"
|
|
795
|
+
|
|
796
|
+
import socket
|
|
797
|
+
import time
|
|
798
|
+
|
|
799
|
+
start_time = time.time()
|
|
800
|
+
max_wait = self.quorum_timeout if self.quorum_timeout else 0
|
|
801
|
+
expected_workers = self.quorum_workers
|
|
802
|
+
|
|
803
|
+
pod_ips = []
|
|
804
|
+
last_count = 0
|
|
805
|
+
|
|
806
|
+
while True:
|
|
807
|
+
try:
|
|
808
|
+
# DNS lookup returns all pod IPs for the headless service
|
|
809
|
+
# getaddrinfo returns list of (family, type, proto, canonname, sockaddr)
|
|
810
|
+
addr_info = socket.getaddrinfo(service_dns, None, socket.AF_INET)
|
|
811
|
+
|
|
812
|
+
# Extract unique IP addresses from the results
|
|
813
|
+
pod_ips = sorted(list(set([addr[4][0] for addr in addr_info])))
|
|
814
|
+
|
|
815
|
+
if not pod_ips:
|
|
816
|
+
logger.debug(f"No pod IPs found for service {service_dns}")
|
|
817
|
+
else:
|
|
818
|
+
logger.debug(f"Found {len(pod_ips)} pod IPs via DNS for {service_dns}: {pod_ips}")
|
|
819
|
+
|
|
820
|
+
except socket.gaierror as e:
|
|
821
|
+
logger.debug(f"DNS lookup failed for {service_dns}: {e}")
|
|
822
|
+
pod_ips = []
|
|
823
|
+
|
|
824
|
+
# Check if we should wait for more workers
|
|
825
|
+
elapsed = time.time() - start_time
|
|
826
|
+
|
|
827
|
+
# If we have the expected count, we're done
|
|
828
|
+
if expected_workers and len(pod_ips) >= expected_workers:
|
|
829
|
+
logger.info(f"Found {len(pod_ips)}/{expected_workers} workers after {elapsed:.1f}s")
|
|
830
|
+
return pod_ips
|
|
831
|
+
|
|
832
|
+
# If we don't have expected count or timeout is reached, decide what to do
|
|
833
|
+
if elapsed >= max_wait:
|
|
834
|
+
if expected_workers:
|
|
835
|
+
logger.warning(f"Only found {len(pod_ips)}/{expected_workers} workers after {elapsed:.1f}s timeout")
|
|
836
|
+
else:
|
|
837
|
+
logger.info(f"Found {len(pod_ips)} workers after {elapsed:.1f}s")
|
|
838
|
+
return pod_ips
|
|
839
|
+
|
|
840
|
+
# Log progress if count changed
|
|
841
|
+
if len(pod_ips) != last_count:
|
|
842
|
+
if expected_workers:
|
|
843
|
+
logger.info(f"{len(pod_ips)}/{expected_workers} workers found, waiting for quorum...")
|
|
844
|
+
else:
|
|
845
|
+
logger.debug(f"{len(pod_ips)} workers found, no quorum set")
|
|
846
|
+
last_count = len(pod_ips)
|
|
847
|
+
|
|
848
|
+
# Wait before retrying
|
|
849
|
+
time.sleep(2)
|
|
850
|
+
|
|
851
|
+
def _get_pod_ips_fast(self):
|
|
852
|
+
"""Get pod IPs from DNS without waiting for quorum - for monitoring only."""
|
|
853
|
+
# Primarily for testing
|
|
854
|
+
if not is_running_in_kubernetes():
|
|
855
|
+
return os.environ["LOCAL_IPS"].split(",")
|
|
856
|
+
|
|
857
|
+
# Use DNS-based service discovery
|
|
858
|
+
service_dns = os.environ.get("KT_SERVICE_DNS")
|
|
859
|
+
|
|
860
|
+
if not service_dns:
|
|
861
|
+
service_name = os.environ.get("KT_SERVICE")
|
|
862
|
+
namespace = os.environ.get("POD_NAMESPACE")
|
|
863
|
+
|
|
864
|
+
if not service_name or not namespace:
|
|
865
|
+
return []
|
|
866
|
+
|
|
867
|
+
service_dns = f"{service_name}-headless.{namespace}.svc.cluster.local"
|
|
868
|
+
|
|
869
|
+
import socket
|
|
870
|
+
|
|
871
|
+
try:
|
|
872
|
+
# Single DNS lookup, no retries, no waiting
|
|
873
|
+
addr_info = socket.getaddrinfo(service_dns, None, socket.AF_INET)
|
|
874
|
+
# Extract unique IP addresses
|
|
875
|
+
pod_ips = sorted(list(set([addr[4][0] for addr in addr_info])))
|
|
876
|
+
return pod_ips
|
|
877
|
+
except socket.gaierror:
|
|
878
|
+
# DNS lookup failed, return current known workers
|
|
879
|
+
with self._workers_lock:
|
|
880
|
+
return list(self._current_workers)
|
|
881
|
+
|
|
882
|
+
def start_dns_monitoring(self):
|
|
883
|
+
"""Start DNS monitoring if not already running.
|
|
884
|
+
Should be called by coordinator nodes only."""
|
|
885
|
+
# Skip if monitoring is disabled (e.g., for Ray)
|
|
886
|
+
if not self.monitor_members:
|
|
887
|
+
logger.debug("DNS monitoring disabled for this supervisor")
|
|
888
|
+
return
|
|
889
|
+
|
|
890
|
+
with self._workers_lock:
|
|
891
|
+
if self._dns_monitor_thread and self._dns_monitor_thread.is_alive():
|
|
892
|
+
return # Already running
|
|
893
|
+
|
|
894
|
+
# Initialize with current workers
|
|
895
|
+
self._current_workers = set(self.pod_ips())
|
|
896
|
+
logger.debug(f"Starting DNS monitor with {len(self._current_workers)} workers")
|
|
897
|
+
|
|
898
|
+
self._dns_monitor_running = True
|
|
899
|
+
self._dns_monitor_thread = threading.Thread(
|
|
900
|
+
target=self._monitor_worker_membership, daemon=True, name="DNSMonitor"
|
|
901
|
+
)
|
|
902
|
+
self._dns_monitor_thread.start()
|
|
903
|
+
|
|
904
|
+
def stop_dns_monitoring(self):
|
|
905
|
+
"""Stop DNS monitoring thread."""
|
|
906
|
+
self._dns_monitor_running = False
|
|
907
|
+
if self._dns_monitor_thread:
|
|
908
|
+
self._dns_monitor_thread.join(timeout=2)
|
|
909
|
+
self._dns_monitor_thread = None
|
|
910
|
+
|
|
911
|
+
def _monitor_worker_membership(self):
|
|
912
|
+
"""Monitor DNS for worker membership changes."""
|
|
913
|
+
check_interval = 3 # Start with 3 second checks (faster initial detection)
|
|
914
|
+
|
|
915
|
+
while self._dns_monitor_running:
|
|
916
|
+
try:
|
|
917
|
+
# Note that we start this after the delay, because we're doing a DNS check at
|
|
918
|
+
# the start of call_distributed anyway. This thread is only for the recurring checks
|
|
919
|
+
# as the call runs.
|
|
920
|
+
time.sleep(check_interval)
|
|
921
|
+
|
|
922
|
+
# Query DNS for current workers - use a faster version
|
|
923
|
+
current_ips = set(self._get_pod_ips_fast())
|
|
924
|
+
|
|
925
|
+
with self._workers_lock:
|
|
926
|
+
if current_ips != self._current_workers:
|
|
927
|
+
added = current_ips - self._current_workers
|
|
928
|
+
removed = self._current_workers - current_ips
|
|
929
|
+
|
|
930
|
+
change = {
|
|
931
|
+
"timestamp": time.time(),
|
|
932
|
+
"added": added,
|
|
933
|
+
"removed": removed,
|
|
934
|
+
"previous": self._current_workers.copy(),
|
|
935
|
+
"current": current_ips.copy(),
|
|
936
|
+
}
|
|
937
|
+
|
|
938
|
+
if removed:
|
|
939
|
+
logger.error(f"Workers REMOVED from cluster: {removed}")
|
|
940
|
+
if added:
|
|
941
|
+
logger.warning(f"Workers ADDED to cluster: {added}")
|
|
942
|
+
|
|
943
|
+
# Queue change and notify subscribers
|
|
944
|
+
self._membership_changes.put(change)
|
|
945
|
+
for event in self._change_subscribers:
|
|
946
|
+
event.set()
|
|
947
|
+
|
|
948
|
+
self._current_workers = current_ips
|
|
949
|
+
|
|
950
|
+
time.sleep(check_interval)
|
|
951
|
+
|
|
952
|
+
except Exception as e:
|
|
953
|
+
logger.error(f"DNS monitor error: {e}")
|
|
954
|
+
time.sleep(3)
|
|
955
|
+
|
|
956
|
+
def subscribe_to_membership_changes(self):
|
|
957
|
+
"""Subscribe to worker membership changes.
|
|
958
|
+
Returns an event that will be set when changes occur."""
|
|
959
|
+
event = threading.Event()
|
|
960
|
+
with self._workers_lock:
|
|
961
|
+
self._change_subscribers.append(event)
|
|
962
|
+
return event
|
|
963
|
+
|
|
964
|
+
def unsubscribe_from_membership_changes(self, event):
|
|
965
|
+
"""Unsubscribe from worker membership changes."""
|
|
966
|
+
with self._workers_lock:
|
|
967
|
+
if event in self._change_subscribers:
|
|
968
|
+
self._change_subscribers.remove(event)
|
|
969
|
+
|
|
970
|
+
def check_for_membership_changes(self, force_dns_check=False):
|
|
971
|
+
"""Check for membership changes and raise exception if any occurred.
|
|
972
|
+
|
|
973
|
+
Args:
|
|
974
|
+
force_dns_check: If True, immediately query DNS to check for changes
|
|
975
|
+
instead of relying on the monitoring thread
|
|
976
|
+
"""
|
|
977
|
+
# Skip if monitoring is disabled (e.g., for Ray)
|
|
978
|
+
if not self.monitor_members:
|
|
979
|
+
return
|
|
980
|
+
# Force an immediate DNS check if requested
|
|
981
|
+
if force_dns_check:
|
|
982
|
+
# Use fast DNS query for immediate check
|
|
983
|
+
current_ips = set(self._get_pod_ips_fast())
|
|
984
|
+
with self._workers_lock:
|
|
985
|
+
if current_ips != self._current_workers:
|
|
986
|
+
added = current_ips - self._current_workers
|
|
987
|
+
removed = self._current_workers - current_ips
|
|
988
|
+
|
|
989
|
+
# Import here to avoid circular dependency
|
|
990
|
+
from kubetorch.servers.http.utils import WorkerMembershipChanged
|
|
991
|
+
|
|
992
|
+
# Update current workers
|
|
993
|
+
self._current_workers = current_ips
|
|
994
|
+
|
|
995
|
+
# Log the change
|
|
996
|
+
if removed:
|
|
997
|
+
logger.error(f"Workers REMOVED from cluster (forced check): {removed}")
|
|
998
|
+
if added:
|
|
999
|
+
logger.warning(f"Workers ADDED to cluster (forced check): {added}")
|
|
1000
|
+
|
|
1001
|
+
raise WorkerMembershipChanged(
|
|
1002
|
+
added_ips=added,
|
|
1003
|
+
removed_ips=removed,
|
|
1004
|
+
previous_ips=self._current_workers.copy(),
|
|
1005
|
+
current_ips=current_ips,
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
# Check queued changes from monitoring thread
|
|
1009
|
+
try:
|
|
1010
|
+
change = self._membership_changes.get_nowait()
|
|
1011
|
+
|
|
1012
|
+
# Import here to avoid circular dependency
|
|
1013
|
+
from kubetorch.servers.http.utils import WorkerMembershipChanged
|
|
1014
|
+
|
|
1015
|
+
raise WorkerMembershipChanged(
|
|
1016
|
+
added_ips=change["added"],
|
|
1017
|
+
removed_ips=change["removed"],
|
|
1018
|
+
previous_ips=change["previous"],
|
|
1019
|
+
current_ips=change["current"],
|
|
1020
|
+
)
|
|
1021
|
+
except queue.Empty:
|
|
1022
|
+
pass # No changes
|
|
1023
|
+
|
|
1024
|
+
def setup(self, deployed_as_of: Optional[str] = None):
|
|
1025
|
+
# This method should be overridden by subclasses to set up the distributed environment
|
|
1026
|
+
raise NotImplementedError("setup() must be implemented by subclasses")
|
|
1027
|
+
|
|
1028
|
+
def cleanup(self):
|
|
1029
|
+
"""Base cleanup - stop DNS monitoring. Subclasses should call super().cleanup()"""
|
|
1030
|
+
self.stop_dns_monitoring()
|
|
1031
|
+
# Subclasses should override and call super().cleanup() to add their own cleanup
|
|
1032
|
+
|
|
1033
|
+
def intercept_call(self):
|
|
1034
|
+
# This method should be overridden by subclasses to indicate whether to intercept calls
|
|
1035
|
+
raise NotImplementedError("intercept_call() must be implemented by subclasses")
|
|
1036
|
+
|
|
1037
|
+
def call_distributed(
|
|
1038
|
+
self,
|
|
1039
|
+
request,
|
|
1040
|
+
cls_or_fn_name: str,
|
|
1041
|
+
method_name: Optional[str] = None,
|
|
1042
|
+
params: Optional[Dict] = None,
|
|
1043
|
+
distributed_subcall: bool = False,
|
|
1044
|
+
debug_port: int = False,
|
|
1045
|
+
deployed_as_of: Optional[str] = None,
|
|
1046
|
+
):
|
|
1047
|
+
# if intercept_call is True, this method should be overridden by subclasses to handle distributing and/or
|
|
1048
|
+
# supervising the distributed execution
|
|
1049
|
+
raise NotImplementedError("call_distributed() must be implemented by subclasses")
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
class DistributedProcess(multiprocessing.Process):
|
|
1053
|
+
"""Base class for distributed processes that run callables in subprocesses."""
|
|
1054
|
+
|
|
1055
|
+
def __init__(self, local_rank, request_queue, response_queue, max_threads=4, **kwargs):
|
|
1056
|
+
super().__init__()
|
|
1057
|
+
# We don't need the cache miss / reload here because these processes are destroyed and recreated
|
|
1058
|
+
# with each .to call.
|
|
1059
|
+
os.environ["LOCAL_RANK"] = str(local_rank)
|
|
1060
|
+
self._request_queue = request_queue
|
|
1061
|
+
self._response_queue = response_queue
|
|
1062
|
+
self._max_threads = max_threads
|
|
1063
|
+
self._executor = None
|
|
1064
|
+
# Store any additional framework-specific settings
|
|
1065
|
+
self._settings = kwargs
|
|
1066
|
+
|
|
1067
|
+
def proc_cleanup(self):
|
|
1068
|
+
"""Override this method to provide framework-specific cleanup."""
|
|
1069
|
+
logger.info("Cleaning up debugging sessions...")
|
|
1070
|
+
clear_debugging_sessions()
|
|
1071
|
+
logger.info("Debugging sessions cleaned up.")
|
|
1072
|
+
|
|
1073
|
+
# Cleanup thread pool
|
|
1074
|
+
if self._executor:
|
|
1075
|
+
self._executor.shutdown(wait=False)
|
|
1076
|
+
self._executor = None
|
|
1077
|
+
|
|
1078
|
+
@classmethod
|
|
1079
|
+
def get_distributed_env_vars(cls, worker_ips, node_rank, local_rank, num_local_procs, **settings):
|
|
1080
|
+
"""Get framework-specific distributed environment variables.
|
|
1081
|
+
|
|
1082
|
+
Args:
|
|
1083
|
+
worker_ips: List of all worker IPs
|
|
1084
|
+
node_rank: Rank of this node (0-indexed)
|
|
1085
|
+
local_rank: Local rank on this node (0-indexed)
|
|
1086
|
+
num_local_procs: Number of processes on this node
|
|
1087
|
+
**settings: Additional framework-specific settings (e.g., port)
|
|
1088
|
+
|
|
1089
|
+
Returns:
|
|
1090
|
+
Dict of environment variables to set
|
|
1091
|
+
"""
|
|
1092
|
+
# Base implementation - no special env vars needed
|
|
1093
|
+
return {
|
|
1094
|
+
"WORLD_SIZE": str(len(worker_ips) * num_local_procs),
|
|
1095
|
+
"RANK": str(node_rank * num_local_procs + local_rank),
|
|
1096
|
+
"LOCAL_RANK": str(local_rank),
|
|
1097
|
+
"NODE_RANK": str(node_rank),
|
|
1098
|
+
"POD_IPS": ",".join(worker_ips),
|
|
1099
|
+
}
|
|
1100
|
+
|
|
1101
|
+
@classmethod
|
|
1102
|
+
def get_auto_num_processes(cls):
|
|
1103
|
+
"""Auto-detect the number of processes to use."""
|
|
1104
|
+
return 1
|
|
1105
|
+
|
|
1106
|
+
def handle_request(self, request):
|
|
1107
|
+
"""Handle a single request in a thread."""
|
|
1108
|
+
try:
|
|
1109
|
+
request_unique_id = request["request_unique_id"]
|
|
1110
|
+
method_name = request["method_name"]
|
|
1111
|
+
params = request["params"]
|
|
1112
|
+
deployed_as_of = request["deployed_as_of"]
|
|
1113
|
+
request_id = request["request_id"]
|
|
1114
|
+
distributed_env_vars = request["distributed_env_vars"]
|
|
1115
|
+
debug_port = request["debug_port"]
|
|
1116
|
+
serialization = request["serialization"]
|
|
1117
|
+
|
|
1118
|
+
# Set the request ID in the context for this thread
|
|
1119
|
+
token = request_id_ctx_var.set(request_id)
|
|
1120
|
+
|
|
1121
|
+
# Set the environment variables for this thread (note: os.environ is process-wide, might need thread-local storage)
|
|
1122
|
+
# For distributed PyTorch calls, these should already be set at process level
|
|
1123
|
+
for key, value in distributed_env_vars.items():
|
|
1124
|
+
os.environ[key] = value
|
|
1125
|
+
|
|
1126
|
+
try:
|
|
1127
|
+
# Load callable if not already loaded or if deployed_as_of changed
|
|
1128
|
+
callable_obj = load_callable(
|
|
1129
|
+
deployed_as_of=deployed_as_of,
|
|
1130
|
+
distributed_subprocess=True,
|
|
1131
|
+
reload_cleanup_fn=self.proc_cleanup,
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
result = run_callable_internal_sync(
|
|
1135
|
+
callable_obj=callable_obj,
|
|
1136
|
+
cls_or_fn_name=os.environ["KT_CLS_OR_FN_NAME"],
|
|
1137
|
+
method_name=method_name,
|
|
1138
|
+
params=params,
|
|
1139
|
+
serialization=serialization,
|
|
1140
|
+
debug_port=debug_port,
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1143
|
+
# Reset the request ID after the call is complete
|
|
1144
|
+
request_id_ctx_var.reset(token)
|
|
1145
|
+
|
|
1146
|
+
# Send response back with the unique ID
|
|
1147
|
+
self._response_queue.put({"request_unique_id": request_unique_id, "result": result})
|
|
1148
|
+
|
|
1149
|
+
except Exception as e:
|
|
1150
|
+
# Reset the request ID even if there was an error
|
|
1151
|
+
request_id_ctx_var.reset(token)
|
|
1152
|
+
|
|
1153
|
+
# Package the exception
|
|
1154
|
+
try:
|
|
1155
|
+
packaged_exception = package_exception(e)
|
|
1156
|
+
except Exception as f:
|
|
1157
|
+
packaged_exception = f
|
|
1158
|
+
|
|
1159
|
+
self._response_queue.put(
|
|
1160
|
+
{
|
|
1161
|
+
"request_unique_id": request_unique_id,
|
|
1162
|
+
"result": packaged_exception,
|
|
1163
|
+
}
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
except Exception as e:
|
|
1167
|
+
# Last resort error handling
|
|
1168
|
+
logger.error(f"Fatal error handling request: {e}")
|
|
1169
|
+
self._response_queue.put(
|
|
1170
|
+
{
|
|
1171
|
+
"request_unique_id": request.get("request_unique_id", "unknown"),
|
|
1172
|
+
"result": Exception(f"Fatal error in thread: {e}"),
|
|
1173
|
+
}
|
|
1174
|
+
)
|
|
1175
|
+
|
|
1176
|
+
def run(self):
|
|
1177
|
+
"""Main process loop with thread pool for concurrent request handling."""
|
|
1178
|
+
# Create thread pool for handling requests
|
|
1179
|
+
self._executor = ThreadPoolExecutor(max_workers=self._max_threads)
|
|
1180
|
+
|
|
1181
|
+
try:
|
|
1182
|
+
while True:
|
|
1183
|
+
try:
|
|
1184
|
+
# Block waiting for next request
|
|
1185
|
+
request = self._request_queue.get(timeout=1)
|
|
1186
|
+
|
|
1187
|
+
# Special sentinel value to signal shutdown
|
|
1188
|
+
if request == "SHUTDOWN":
|
|
1189
|
+
break
|
|
1190
|
+
|
|
1191
|
+
# Submit request to thread pool for concurrent handling
|
|
1192
|
+
# Check executor exists in case we're shutting down
|
|
1193
|
+
if self._executor:
|
|
1194
|
+
self._executor.submit(self.handle_request, request)
|
|
1195
|
+
else:
|
|
1196
|
+
logger.warning("Executor is None, skipping request (likely shutting down)")
|
|
1197
|
+
|
|
1198
|
+
except Exception as e:
|
|
1199
|
+
# Timeout is normal, continue loop
|
|
1200
|
+
if "Empty" not in str(e.__class__.__name__):
|
|
1201
|
+
logger.error(f"Error getting request from queue: {e}")
|
|
1202
|
+
continue
|
|
1203
|
+
|
|
1204
|
+
except (KeyboardInterrupt, BdbQuit):
|
|
1205
|
+
logger.info("Process interrupted, shutting down...")
|
|
1206
|
+
|
|
1207
|
+
finally:
|
|
1208
|
+
# Cleanup
|
|
1209
|
+
logger.info("Received shutdown signal, cleaning up distributed environment...")
|
|
1210
|
+
self.proc_cleanup()
|
|
1211
|
+
logger.info("Exiting gracefully.")
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
class PyTorchProcess(DistributedProcess):
|
|
1215
|
+
"""PyTorch-specific distributed process."""
|
|
1216
|
+
|
|
1217
|
+
def proc_cleanup(self):
|
|
1218
|
+
import torch.distributed as dist
|
|
1219
|
+
|
|
1220
|
+
try:
|
|
1221
|
+
dist.destroy_process_group()
|
|
1222
|
+
logger.info("Destroyed PyTorch process group.")
|
|
1223
|
+
except Exception:
|
|
1224
|
+
logger.info("Failed to destroy PyTorch process group, it may not have been initialized: {e}")
|
|
1225
|
+
pass
|
|
1226
|
+
# Call parent cleanup for debugging sessions
|
|
1227
|
+
super().proc_cleanup()
|
|
1228
|
+
|
|
1229
|
+
@classmethod
|
|
1230
|
+
def get_distributed_env_vars(cls, worker_ips, node_rank, local_rank, num_local_procs, **settings):
|
|
1231
|
+
"""Get PyTorch-specific distributed environment variables."""
|
|
1232
|
+
port = settings.get("port") or 12345
|
|
1233
|
+
env_vars = super().get_distributed_env_vars(worker_ips, node_rank, local_rank, num_local_procs, **settings)
|
|
1234
|
+
env_vars.update(
|
|
1235
|
+
{
|
|
1236
|
+
"MASTER_ADDR": worker_ips[0],
|
|
1237
|
+
"MASTER_PORT": str(port),
|
|
1238
|
+
}
|
|
1239
|
+
)
|
|
1240
|
+
return env_vars
|
|
1241
|
+
|
|
1242
|
+
@classmethod
|
|
1243
|
+
def get_auto_num_processes(cls):
|
|
1244
|
+
"""Auto-detect based on GPU availability for PyTorch."""
|
|
1245
|
+
try:
|
|
1246
|
+
import torch
|
|
1247
|
+
|
|
1248
|
+
if torch.cuda.is_available():
|
|
1249
|
+
return torch.cuda.device_count()
|
|
1250
|
+
except ImportError:
|
|
1251
|
+
pass
|
|
1252
|
+
return 1 # Could use os.cpu_count() for CPU-only training
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
class RayProcess(DistributedProcess):
|
|
1256
|
+
"""Ray-specific distributed process."""
|
|
1257
|
+
|
|
1258
|
+
def proc_cleanup(self):
|
|
1259
|
+
try:
|
|
1260
|
+
import ray
|
|
1261
|
+
|
|
1262
|
+
if ray.is_initialized():
|
|
1263
|
+
ray.shutdown()
|
|
1264
|
+
logger.info("Ray shutdown completed.")
|
|
1265
|
+
except ImportError:
|
|
1266
|
+
logger.info("Ray not available for cleanup")
|
|
1267
|
+
except Exception as e:
|
|
1268
|
+
logger.info(f"Failed to shutdown Ray: {e}")
|
|
1269
|
+
# Call parent cleanup for debugging sessions
|
|
1270
|
+
super().proc_cleanup()
|
|
1271
|
+
|
|
1272
|
+
|
|
1273
|
+
class SPMDDistributedSupervisor(DistributedSupervisor):
|
|
1274
|
+
"""Base class for SPMD (Single Program Multiple Data) distributed supervisors.
|
|
1275
|
+
|
|
1276
|
+
This class provides common functionality for frameworks that follow the SPMD pattern
|
|
1277
|
+
where the same program runs on multiple processes with different data partitions.
|
|
1278
|
+
"""
|
|
1279
|
+
|
|
1280
|
+
def __init__(
|
|
1281
|
+
self,
|
|
1282
|
+
process_class=None,
|
|
1283
|
+
num_proc=None,
|
|
1284
|
+
port=None,
|
|
1285
|
+
restart_procs=True,
|
|
1286
|
+
max_threads_per_proc=10,
|
|
1287
|
+
quorum_timeout=300,
|
|
1288
|
+
quorum_workers=None,
|
|
1289
|
+
monitor_members=True,
|
|
1290
|
+
tree_fanout=50,
|
|
1291
|
+
tree_minimum=100,
|
|
1292
|
+
**process_kwargs,
|
|
1293
|
+
):
|
|
1294
|
+
super().__init__(
|
|
1295
|
+
quorum_workers=quorum_workers,
|
|
1296
|
+
quorum_timeout=quorum_timeout,
|
|
1297
|
+
monitor_members=monitor_members,
|
|
1298
|
+
)
|
|
1299
|
+
self.process_class = process_class or DistributedProcess
|
|
1300
|
+
self.num_proc = num_proc or "auto"
|
|
1301
|
+
self.port = port
|
|
1302
|
+
self.restart_procs = restart_procs
|
|
1303
|
+
self.max_threads_per_proc = max_threads_per_proc
|
|
1304
|
+
self.process_pool = None
|
|
1305
|
+
self.remote_worker_pool = None # Pool for async HTTP calls to remote workers
|
|
1306
|
+
self.process_kwargs = process_kwargs # Additional settings to pass to process class
|
|
1307
|
+
self.tree_fanout = tree_fanout
|
|
1308
|
+
self.tree_minimum = tree_minimum
|
|
1309
|
+
|
|
1310
|
+
def setup(self, deployed_as_of: Optional[str] = None):
|
|
1311
|
+
# Set multiprocessing to spawn if not already
|
|
1312
|
+
if multiprocessing.get_start_method() != "spawn":
|
|
1313
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
1314
|
+
|
|
1315
|
+
# Get number of processes
|
|
1316
|
+
if self.num_proc == "auto":
|
|
1317
|
+
num_proc = self.process_class.get_auto_num_processes()
|
|
1318
|
+
else:
|
|
1319
|
+
num_proc = self.num_proc
|
|
1320
|
+
|
|
1321
|
+
if self.restart_procs:
|
|
1322
|
+
logger.debug("restart_procs is True, restarting distributed processes")
|
|
1323
|
+
self.cleanup()
|
|
1324
|
+
|
|
1325
|
+
# If the number of processes has changed, we need to clean up the old ones and recreate them
|
|
1326
|
+
if self.process_pool is None or len(self.process_pool) != num_proc:
|
|
1327
|
+
if self.process_pool:
|
|
1328
|
+
logger.debug(
|
|
1329
|
+
f"Number of processes changed from {len(self.process_pool)} to {num_proc}, restarting processes."
|
|
1330
|
+
)
|
|
1331
|
+
self.cleanup()
|
|
1332
|
+
|
|
1333
|
+
logger.debug("Setting up distributed environment")
|
|
1334
|
+
self.process_pool = DistributedProcessPool(
|
|
1335
|
+
process_class=self.process_class,
|
|
1336
|
+
num_processes=num_proc,
|
|
1337
|
+
max_threads_per_proc=self.max_threads_per_proc,
|
|
1338
|
+
**self.process_kwargs, # Pass any additional settings
|
|
1339
|
+
)
|
|
1340
|
+
|
|
1341
|
+
# Start all processes (now handled internally by the pool)
|
|
1342
|
+
self.process_pool.start()
|
|
1343
|
+
|
|
1344
|
+
self.remote_worker_pool = RemoteWorkerPool(quorum_timeout=self.quorum_timeout)
|
|
1345
|
+
self.remote_worker_pool.start()
|
|
1346
|
+
logger.debug("Finished setting up distributed processes")
|
|
1347
|
+
|
|
1348
|
+
def cleanup(self):
|
|
1349
|
+
# Cleanup the processes
|
|
1350
|
+
logger.debug(f"Cleaning up {self.__class__.__name__} distributed processes")
|
|
1351
|
+
|
|
1352
|
+
# Stop DNS monitoring first
|
|
1353
|
+
super().cleanup()
|
|
1354
|
+
|
|
1355
|
+
if self.process_pool:
|
|
1356
|
+
self.process_pool.stop()
|
|
1357
|
+
self.process_pool = None
|
|
1358
|
+
|
|
1359
|
+
if self.remote_worker_pool:
|
|
1360
|
+
self.remote_worker_pool.stop()
|
|
1361
|
+
self.remote_worker_pool = None
|
|
1362
|
+
|
|
1363
|
+
logger.debug(f"Finished cleaning up {self.__class__.__name__} distributed processes")
|
|
1364
|
+
|
|
1365
|
+
@staticmethod
|
|
1366
|
+
def intercept_call():
|
|
1367
|
+
return True
|
|
1368
|
+
|
|
1369
|
+
def get_tree_children(self, sorted_ips: list, my_ip: str, fanout: int = 100):
|
|
1370
|
+
"""Calculate children nodes in a self-organizing tree based on IP indexing.
|
|
1371
|
+
|
|
1372
|
+
Args:
|
|
1373
|
+
sorted_ips: List of all worker IPs sorted deterministically
|
|
1374
|
+
my_ip: This node's IP address
|
|
1375
|
+
fanout: Maximum number of children per node (default 100)
|
|
1376
|
+
|
|
1377
|
+
Returns:
|
|
1378
|
+
List of IP addresses that are children of this node
|
|
1379
|
+
"""
|
|
1380
|
+
try:
|
|
1381
|
+
my_index = sorted_ips.index(my_ip)
|
|
1382
|
+
except ValueError:
|
|
1383
|
+
# If not found in list, this node has no children
|
|
1384
|
+
return []
|
|
1385
|
+
|
|
1386
|
+
# Calculate the range of children indices
|
|
1387
|
+
# In a tree with fanout F, node at index i has children at indices:
|
|
1388
|
+
# [i*F + 1, i*F + 2, ..., i*F + F]
|
|
1389
|
+
first_child_idx = my_index * fanout + 1
|
|
1390
|
+
last_child_idx = min(first_child_idx + fanout, len(sorted_ips))
|
|
1391
|
+
|
|
1392
|
+
if first_child_idx >= len(sorted_ips):
|
|
1393
|
+
# No children (leaf node)
|
|
1394
|
+
return []
|
|
1395
|
+
|
|
1396
|
+
children = sorted_ips[first_child_idx:last_child_idx]
|
|
1397
|
+
if len(children) > 0:
|
|
1398
|
+
logger.debug(
|
|
1399
|
+
f"Tree topology: Node {my_ip} (index {my_index}) has {len(children)} children "
|
|
1400
|
+
f"(indices {first_child_idx}-{last_child_idx-1})"
|
|
1401
|
+
)
|
|
1402
|
+
return children
|
|
1403
|
+
|
|
1404
|
+
def call_distributed(
|
|
1405
|
+
self,
|
|
1406
|
+
request,
|
|
1407
|
+
cls_or_fn_name: str,
|
|
1408
|
+
method_name: Optional[str] = None,
|
|
1409
|
+
params: Optional[Dict] = None,
|
|
1410
|
+
distributed_subcall: bool = False,
|
|
1411
|
+
debug_port: int = False,
|
|
1412
|
+
deployed_as_of: Optional[str] = None,
|
|
1413
|
+
):
|
|
1414
|
+
# Get the request ID from the headers
|
|
1415
|
+
request_id = request.headers.get("X-Request-ID", "-")
|
|
1416
|
+
serialization = request.headers.get("X-Serialization", "json")
|
|
1417
|
+
params = params or {}
|
|
1418
|
+
|
|
1419
|
+
# If deployed_as_of is None and we're the coordinator, generate a consistent timestamp
|
|
1420
|
+
# to use across all workers to prevent reload inconsistencies
|
|
1421
|
+
if not distributed_subcall and deployed_as_of is None:
|
|
1422
|
+
from datetime import datetime, timezone
|
|
1423
|
+
|
|
1424
|
+
deployed_as_of = datetime.now(timezone.utc).isoformat()
|
|
1425
|
+
|
|
1426
|
+
# Get all the pods in the service, and use the first one as the master.
|
|
1427
|
+
# Set the env vars based on whether this is a master or worker
|
|
1428
|
+
logger.debug(f"Configuring distributed environment, distributed_subcall={distributed_subcall}")
|
|
1429
|
+
this_pod_ip = os.environ["POD_IP"]
|
|
1430
|
+
logger.debug(f"This pod IP: {this_pod_ip}")
|
|
1431
|
+
|
|
1432
|
+
# Start DNS monitoring for coordinator nodes
|
|
1433
|
+
change_event = None
|
|
1434
|
+
if not distributed_subcall:
|
|
1435
|
+
# First wait for quorum before starting monitoring
|
|
1436
|
+
worker_ips = self.pod_ips()
|
|
1437
|
+
# sort the worker IPs to ensure generally consistent tree ordering
|
|
1438
|
+
# (avoids thrashing connections due to reordering)
|
|
1439
|
+
worker_ips.sort()
|
|
1440
|
+
# For tree topology, coordinator is always root (index 0)
|
|
1441
|
+
# Move coordinator to front if not already there
|
|
1442
|
+
if this_pod_ip in worker_ips:
|
|
1443
|
+
worker_ips.remove(this_pod_ip)
|
|
1444
|
+
worker_ips.insert(0, this_pod_ip) # Move coordinator to root of tree
|
|
1445
|
+
logger.debug(f"Acting as COORDINATOR - discovered worker IPs: {worker_ips}")
|
|
1446
|
+
logger.debug(f"Pod IPs: {worker_ips}")
|
|
1447
|
+
|
|
1448
|
+
# Check if this call uses flexible worker selection (workers="ready")
|
|
1449
|
+
# If so, don't start DNS monitoring as the worker set is expected to be flexible
|
|
1450
|
+
workers_arg = params.get("workers") if params else None
|
|
1451
|
+
should_monitor = workers_arg not in ["ready", "any"]
|
|
1452
|
+
|
|
1453
|
+
if should_monitor:
|
|
1454
|
+
# Now that we have quorum, start DNS monitoring
|
|
1455
|
+
# Start monitoring (idempotent - won't start if already running)
|
|
1456
|
+
self.start_dns_monitoring()
|
|
1457
|
+
|
|
1458
|
+
# Subscribe to membership changes
|
|
1459
|
+
change_event = self.subscribe_to_membership_changes()
|
|
1460
|
+
|
|
1461
|
+
# Check for any pending changes after starting monitor
|
|
1462
|
+
self.check_for_membership_changes(force_dns_check=True)
|
|
1463
|
+
else:
|
|
1464
|
+
logger.debug("Skipping DNS monitoring for workers='ready' call")
|
|
1465
|
+
|
|
1466
|
+
# Update distributed env vars to use the tree-ordered IPs
|
|
1467
|
+
distributed_env_vars = {
|
|
1468
|
+
"POD_IPS": ",".join(worker_ips),
|
|
1469
|
+
}
|
|
1470
|
+
else:
|
|
1471
|
+
logger.debug(f"Acting as WORKER (distributed_subcall=True) at {this_pod_ip}")
|
|
1472
|
+
logger.debug(f"Worker received params keys: {list(params.keys()) if params else 'None'}")
|
|
1473
|
+
distributed_env_vars = params.pop("distributed_env_vars", None) if params else None
|
|
1474
|
+
logger.debug(f"Using distributed_env_vars: {distributed_env_vars}")
|
|
1475
|
+
if not distributed_env_vars:
|
|
1476
|
+
logger.error(f"No distributed_env_vars found in params: {params}")
|
|
1477
|
+
raise RuntimeError("distributed_env_vars must be provided for distributed subcalls")
|
|
1478
|
+
worker_ips = distributed_env_vars["POD_IPS"].split(",")
|
|
1479
|
+
|
|
1480
|
+
# Don't debug for subcalls, we only want to debug one process
|
|
1481
|
+
debug_port = None
|
|
1482
|
+
|
|
1483
|
+
# Decide topology based on cluster size
|
|
1484
|
+
subcall_ips = []
|
|
1485
|
+
num_workers = len(worker_ips)
|
|
1486
|
+
tree_mode = num_workers >= self.tree_minimum
|
|
1487
|
+
if tree_mode:
|
|
1488
|
+
# Use tree topology for large clusters
|
|
1489
|
+
if distributed_subcall:
|
|
1490
|
+
logger.debug(
|
|
1491
|
+
f"Using tree topology for {num_workers} workers (> {self.tree_minimum} threshold) "
|
|
1492
|
+
f"with fanout {self.tree_fanout}"
|
|
1493
|
+
)
|
|
1494
|
+
|
|
1495
|
+
# Calculate direct children in the tree
|
|
1496
|
+
subcall_ips = self.get_tree_children(
|
|
1497
|
+
sorted_ips=worker_ips,
|
|
1498
|
+
my_ip=this_pod_ip,
|
|
1499
|
+
fanout=self.tree_fanout, # Each node can have up to 50 children
|
|
1500
|
+
)
|
|
1501
|
+
logger.debug(f"Tree node {this_pod_ip} will call {len(subcall_ips)} direct children")
|
|
1502
|
+
elif not distributed_subcall:
|
|
1503
|
+
# Use worker ip list as is for coordiantor node in flat topology
|
|
1504
|
+
# Leave subcall_ips = [] for workers in flat topology
|
|
1505
|
+
subcall_ips = copy.deepcopy(worker_ips)
|
|
1506
|
+
if this_pod_ip in subcall_ips:
|
|
1507
|
+
subcall_ips.remove(this_pod_ip)
|
|
1508
|
+
logger.debug(f"Removed self ({this_pod_ip}) from subcall list, will call: {subcall_ips}")
|
|
1509
|
+
else:
|
|
1510
|
+
# This can happen with headless services where POD_IP might not exactly match DNS results
|
|
1511
|
+
# Try to match by partial IP or hostname
|
|
1512
|
+
logger.warning(
|
|
1513
|
+
f"This pod IP {this_pod_ip} not found in DNS-discovered IPs {worker_ips}. "
|
|
1514
|
+
f"This may indicate DNS propagation delay or hostname/IP mismatch."
|
|
1515
|
+
)
|
|
1516
|
+
# Still proceed as coordinator but will call all discovered pods
|
|
1517
|
+
logger.debug(f"Will call all discovered workers: {subcall_ips}")
|
|
1518
|
+
|
|
1519
|
+
# Always pass the distributed environment variables to workers
|
|
1520
|
+
params["distributed_env_vars"] = distributed_env_vars
|
|
1521
|
+
|
|
1522
|
+
# "workers" is passed through as a regular kwarg, not a special extracted one like pdb or serialization
|
|
1523
|
+
# because it only applies to distributed calls, not regular ones.
|
|
1524
|
+
call_local_procs = True
|
|
1525
|
+
workers_arg = params.get("workers", None)
|
|
1526
|
+
if workers_arg:
|
|
1527
|
+
logger.debug(f"Filtering workers by argument: {workers_arg}")
|
|
1528
|
+
if isinstance(workers_arg, list) and workers_arg:
|
|
1529
|
+
# Build a set of IPs to include based on the list items
|
|
1530
|
+
target_ips = set()
|
|
1531
|
+
|
|
1532
|
+
for item in workers_arg:
|
|
1533
|
+
if isinstance(item, str) and "." in item:
|
|
1534
|
+
# It's an IP address
|
|
1535
|
+
if item not in worker_ips:
|
|
1536
|
+
raise ValueError(f"Worker IP '{item}' not found in available workers: {worker_ips}")
|
|
1537
|
+
target_ips.add(item)
|
|
1538
|
+
elif isinstance(item, int) or (isinstance(item, str) and item.isdigit()):
|
|
1539
|
+
# It's an index
|
|
1540
|
+
idx = int(item) if isinstance(item, str) else item
|
|
1541
|
+
if idx < 0 or idx >= len(worker_ips):
|
|
1542
|
+
raise ValueError(f"Worker index {idx} out of range. Valid range: 0-{len(worker_ips)-1}")
|
|
1543
|
+
target_ips.add(worker_ips[idx])
|
|
1544
|
+
else:
|
|
1545
|
+
raise ValueError(
|
|
1546
|
+
f"Invalid worker specification: {item}. Must be an IP address, "
|
|
1547
|
+
f"integer index, or numeric string."
|
|
1548
|
+
)
|
|
1549
|
+
|
|
1550
|
+
# Filter subcall_ips to only those in the target set
|
|
1551
|
+
subcall_ips = [ip for ip in subcall_ips if ip in target_ips]
|
|
1552
|
+
|
|
1553
|
+
# Check if current pod should participate in local processing
|
|
1554
|
+
if this_pod_ip not in target_ips:
|
|
1555
|
+
call_local_procs = False
|
|
1556
|
+
|
|
1557
|
+
elif workers_arg == "any":
|
|
1558
|
+
# Only call one worker (this one)
|
|
1559
|
+
subcall_ips = []
|
|
1560
|
+
elif workers_arg == "ready":
|
|
1561
|
+
# Filter below in call_worker to only those that respond to /health
|
|
1562
|
+
pass
|
|
1563
|
+
elif isinstance(workers_arg, str) and workers_arg:
|
|
1564
|
+
# Filter the subcall_ips to only those matching the workers_arg string
|
|
1565
|
+
subcall_ips = [ip for ip in subcall_ips if workers_arg in ip]
|
|
1566
|
+
logger.debug(f"Subcall IPs after filtering: {subcall_ips}")
|
|
1567
|
+
|
|
1568
|
+
# "restart_procs" is passed through as a regular kwarg, not a special extracted one like pdb or serialization
|
|
1569
|
+
# because it only applies to distributed calls, not regular ones.
|
|
1570
|
+
if params.get("restart_procs", False):
|
|
1571
|
+
logger.info("restart_procs parameter is True, restarting processes")
|
|
1572
|
+
self.cleanup()
|
|
1573
|
+
self.setup(deployed_as_of)
|
|
1574
|
+
|
|
1575
|
+
try:
|
|
1576
|
+
node_rank = worker_ips.index(this_pod_ip)
|
|
1577
|
+
except ValueError:
|
|
1578
|
+
# This pod IP not found in DNS results - may be external service IP
|
|
1579
|
+
# Fall back to using POD_IP directly and assume node_rank based on position
|
|
1580
|
+
logger.warning(f"Pod IP {this_pod_ip} not found in DNS results {worker_ips}. Using fallback logic.")
|
|
1581
|
+
# For now, assume we're the first worker if not found
|
|
1582
|
+
node_rank = 0
|
|
1583
|
+
|
|
1584
|
+
# Call the workers using RemoteWorkerPool for async operations
|
|
1585
|
+
def call_worker(worker_ip):
|
|
1586
|
+
# Keep this function for backward compatibility but it won't be used
|
|
1587
|
+
# when RemoteWorkerPool is available
|
|
1588
|
+
with httpx.Client(timeout=None) as client:
|
|
1589
|
+
port = os.environ["KT_SERVER_PORT"]
|
|
1590
|
+
worker_url = f"http://{worker_ip}:{port}"
|
|
1591
|
+
# First check that the worker is alive, replicas don't finish setup at exactly the same moment
|
|
1592
|
+
# Use quorum_timeout to control how long to wait for workers
|
|
1593
|
+
for i in range(int(self.quorum_timeout)):
|
|
1594
|
+
try:
|
|
1595
|
+
resp = client.get(f"{worker_url}/health")
|
|
1596
|
+
if resp.status_code == 200:
|
|
1597
|
+
break
|
|
1598
|
+
except httpx.RequestError:
|
|
1599
|
+
if workers_arg == "ready":
|
|
1600
|
+
logger.debug(f"Worker {worker_ip} not ready, skipping as per 'ready' workers argument")
|
|
1601
|
+
return None
|
|
1602
|
+
time.sleep(1)
|
|
1603
|
+
else:
|
|
1604
|
+
# Timeout reached without successful health check
|
|
1605
|
+
logger.warning(f"Worker {worker_ip} failed to respond after {self.quorum_timeout}s timeout")
|
|
1606
|
+
if workers_arg != "ready":
|
|
1607
|
+
raise TimeoutError(
|
|
1608
|
+
f"Worker {worker_ip} did not become ready within {self.quorum_timeout} seconds. "
|
|
1609
|
+
"This may indicate the pod is still starting or there's a resource constraint. "
|
|
1610
|
+
"Consider increasing quorum_timeout in .distribute() call."
|
|
1611
|
+
)
|
|
1612
|
+
|
|
1613
|
+
call_url = (
|
|
1614
|
+
f"{worker_url}/{cls_or_fn_name}/{method_name}?distributed_subcall=true"
|
|
1615
|
+
if method_name is not None
|
|
1616
|
+
else f"{worker_url}/{cls_or_fn_name}?distributed_subcall=true"
|
|
1617
|
+
)
|
|
1618
|
+
|
|
1619
|
+
# Clean headers to avoid potential Content-Length issues
|
|
1620
|
+
clean_headers = {}
|
|
1621
|
+
if request.headers:
|
|
1622
|
+
for key, value in request.headers.items():
|
|
1623
|
+
# Skip headers that could interfere with httpx's automatic handling
|
|
1624
|
+
if key.lower() not in [
|
|
1625
|
+
"content-length",
|
|
1626
|
+
"transfer-encoding",
|
|
1627
|
+
"connection",
|
|
1628
|
+
]:
|
|
1629
|
+
clean_headers[key] = value
|
|
1630
|
+
|
|
1631
|
+
try:
|
|
1632
|
+
logger.debug(f"Making distributed call to {worker_url}")
|
|
1633
|
+
resp = client.post(
|
|
1634
|
+
url=call_url,
|
|
1635
|
+
json=params,
|
|
1636
|
+
headers=clean_headers, # Includes deployed_as_of and request_id
|
|
1637
|
+
)
|
|
1638
|
+
return resp
|
|
1639
|
+
except (httpx.RequestError, httpx.HTTPError) as e:
|
|
1640
|
+
logger.error(f"Failed to call worker {worker_url}: {e}")
|
|
1641
|
+
raise
|
|
1642
|
+
|
|
1643
|
+
# Prepare per-process parameters
|
|
1644
|
+
num_procs = len(self.process_pool)
|
|
1645
|
+
params_list = [params] * num_procs
|
|
1646
|
+
distributed_env_vars_list = []
|
|
1647
|
+
debug_ports = []
|
|
1648
|
+
|
|
1649
|
+
for idx in range(num_procs):
|
|
1650
|
+
# Get framework-specific env vars from the process class
|
|
1651
|
+
env_vars = self.process_class.get_distributed_env_vars(
|
|
1652
|
+
worker_ips=worker_ips
|
|
1653
|
+
if "worker_ips" in locals()
|
|
1654
|
+
else distributed_env_vars.get("POD_IPS", "").split(","),
|
|
1655
|
+
node_rank=node_rank,
|
|
1656
|
+
local_rank=idx,
|
|
1657
|
+
num_local_procs=num_procs,
|
|
1658
|
+
port=self.port,
|
|
1659
|
+
)
|
|
1660
|
+
# Add any base env vars
|
|
1661
|
+
env_vars.update(distributed_env_vars)
|
|
1662
|
+
distributed_env_vars_list.append(env_vars)
|
|
1663
|
+
|
|
1664
|
+
# Only debug one process and if debug_port is set
|
|
1665
|
+
debug = debug_port and idx == num_procs - 1
|
|
1666
|
+
if debug:
|
|
1667
|
+
time.sleep(0.25)
|
|
1668
|
+
debug_ports.append(debug_port if debug else None)
|
|
1669
|
+
|
|
1670
|
+
# Execute distributed calls in parallel with local processes
|
|
1671
|
+
worker_responses = []
|
|
1672
|
+
worker_exception = None
|
|
1673
|
+
local_exception = None
|
|
1674
|
+
|
|
1675
|
+
# Start both remote and local calls in parallel
|
|
1676
|
+
executor = ThreadPoolExecutor(max_workers=2)
|
|
1677
|
+
|
|
1678
|
+
# Submit remote worker calls if needed
|
|
1679
|
+
worker_future = None
|
|
1680
|
+
if subcall_ips:
|
|
1681
|
+
logger.debug(f"Have {len(subcall_ips)} remote workers to call")
|
|
1682
|
+
if not self.remote_worker_pool:
|
|
1683
|
+
raise RuntimeError("RemoteWorkerPool not initialized. This is required for distributed execution.")
|
|
1684
|
+
logger.debug(f"Using existing RemoteWorkerPool to call {len(subcall_ips)} workers")
|
|
1685
|
+
|
|
1686
|
+
def call_remote_workers():
|
|
1687
|
+
nonlocal worker_exception
|
|
1688
|
+
try:
|
|
1689
|
+
# Prepare headers for remote workers
|
|
1690
|
+
clean_headers = {}
|
|
1691
|
+
if request.headers:
|
|
1692
|
+
for key, value in request.headers.items():
|
|
1693
|
+
if key.lower() not in [
|
|
1694
|
+
"content-length",
|
|
1695
|
+
"transfer-encoding",
|
|
1696
|
+
"connection",
|
|
1697
|
+
]:
|
|
1698
|
+
clean_headers[key] = value
|
|
1699
|
+
# Always include deployed_as_of in headers for consistency
|
|
1700
|
+
if deployed_as_of:
|
|
1701
|
+
clean_headers["X-Deployed-As-Of"] = deployed_as_of
|
|
1702
|
+
|
|
1703
|
+
# Call remote workers asynchronously through the pool
|
|
1704
|
+
logger.debug(f"Calling {len(subcall_ips)} remote workers via RemoteWorkerPool: {subcall_ips}")
|
|
1705
|
+
results = self.remote_worker_pool.call_workers(
|
|
1706
|
+
worker_ips=subcall_ips,
|
|
1707
|
+
cls_or_fn_name=cls_or_fn_name,
|
|
1708
|
+
method_name=method_name,
|
|
1709
|
+
params=params,
|
|
1710
|
+
request_headers=clean_headers,
|
|
1711
|
+
workers_arg=workers_arg,
|
|
1712
|
+
)
|
|
1713
|
+
logger.warning(
|
|
1714
|
+
f"RemoteWorkerPool returned {len(results) if results else 0} results from {len(subcall_ips)} workers"
|
|
1715
|
+
)
|
|
1716
|
+
return results
|
|
1717
|
+
except Exception as e:
|
|
1718
|
+
# Check if this is a connection error - might indicate worker removal
|
|
1719
|
+
if any(
|
|
1720
|
+
err_type in str(e)
|
|
1721
|
+
for err_type in [
|
|
1722
|
+
"ReadError",
|
|
1723
|
+
"TimeoutException",
|
|
1724
|
+
"RequestError",
|
|
1725
|
+
"HTTPError",
|
|
1726
|
+
"ConnectionError",
|
|
1727
|
+
"Connection reset",
|
|
1728
|
+
"Connection closed",
|
|
1729
|
+
]
|
|
1730
|
+
):
|
|
1731
|
+
logger.debug(f"Connection error detected: {e}, checking for membership changes")
|
|
1732
|
+
# Force DNS check to see if workers were removed
|
|
1733
|
+
self.check_for_membership_changes(force_dns_check=True)
|
|
1734
|
+
worker_exception = e
|
|
1735
|
+
raise
|
|
1736
|
+
|
|
1737
|
+
worker_future = executor.submit(call_remote_workers)
|
|
1738
|
+
logger.debug(f"Submitted worker_future for {len(subcall_ips)} remote workers")
|
|
1739
|
+
|
|
1740
|
+
else:
|
|
1741
|
+
logger.debug(f"No remote workers to call (subcall_ips is empty or None: {subcall_ips})")
|
|
1742
|
+
|
|
1743
|
+
# Check if we need to initialize RemoteWorkerPool for tree topology workers
|
|
1744
|
+
if subcall_ips:
|
|
1745
|
+
if not self.remote_worker_pool:
|
|
1746
|
+
# Initialize RemoteWorkerPool if not already done (needed for tree topology workers)
|
|
1747
|
+
logger.warning(
|
|
1748
|
+
f"INITIALIZING RemoteWorkerPool for tree worker at {this_pod_ip} to call {len(subcall_ips)} children"
|
|
1749
|
+
)
|
|
1750
|
+
self.remote_worker_pool = RemoteWorkerPool(quorum_timeout=self.quorum_timeout)
|
|
1751
|
+
self.remote_worker_pool.start(
|
|
1752
|
+
max_workers=min(len(subcall_ips) + 50, 200)
|
|
1753
|
+
) # Size for expected children plus buffer
|
|
1754
|
+
logger.warning(f"RemoteWorkerPool initialized successfully for {this_pod_ip}")
|
|
1755
|
+
elif (
|
|
1756
|
+
not hasattr(self.remote_worker_pool, "process")
|
|
1757
|
+
or not self.remote_worker_pool.process
|
|
1758
|
+
or not self.remote_worker_pool.process.is_alive()
|
|
1759
|
+
):
|
|
1760
|
+
# Pool exists but not started/alive
|
|
1761
|
+
logger.warning(f"RemoteWorkerPool exists but not running for {this_pod_ip}, starting it now")
|
|
1762
|
+
self.remote_worker_pool.start(max_workers=min(len(subcall_ips) + 50, 200))
|
|
1763
|
+
|
|
1764
|
+
if subcall_ips and not self.remote_worker_pool:
|
|
1765
|
+
# RemoteWorkerPool should always be initialized at this point
|
|
1766
|
+
raise RuntimeError(
|
|
1767
|
+
f"RemoteWorkerPool not available for worker at {this_pod_ip}. "
|
|
1768
|
+
"This is required for distributed execution with subcall_ips."
|
|
1769
|
+
)
|
|
1770
|
+
|
|
1771
|
+
# Submit local process calls
|
|
1772
|
+
def call_local_processes():
|
|
1773
|
+
logger.debug(f"Processing {num_procs} local process responses")
|
|
1774
|
+
return self.process_pool.call_all(
|
|
1775
|
+
method_name=method_name,
|
|
1776
|
+
params_list=params_list,
|
|
1777
|
+
deployed_as_of=deployed_as_of,
|
|
1778
|
+
request_id=request_id,
|
|
1779
|
+
distributed_env_vars_list=distributed_env_vars_list,
|
|
1780
|
+
debug_ports=debug_ports,
|
|
1781
|
+
serialization=serialization,
|
|
1782
|
+
)
|
|
1783
|
+
|
|
1784
|
+
# We may not be calling the locally processes if the user specified workers and didn't include this node
|
|
1785
|
+
if call_local_procs:
|
|
1786
|
+
local_future = executor.submit(call_local_processes)
|
|
1787
|
+
else:
|
|
1788
|
+
local_future = None
|
|
1789
|
+
|
|
1790
|
+
# Wait for both to complete with fast failure propagation
|
|
1791
|
+
try:
|
|
1792
|
+
# Use as_completed to get results as they arrive
|
|
1793
|
+
futures = [f for f in [worker_future, local_future] if f is not None]
|
|
1794
|
+
local_responses = []
|
|
1795
|
+
|
|
1796
|
+
# If we only have local processes (no remote workers), handle that case
|
|
1797
|
+
if not futures:
|
|
1798
|
+
logger.error("No futures to wait for - this shouldn't happen")
|
|
1799
|
+
raise RuntimeError("No distributed work to execute")
|
|
1800
|
+
|
|
1801
|
+
logger.debug(
|
|
1802
|
+
f"Waiting for {len(futures)} futures to complete (worker_future={worker_future is not None}, local_future={local_future is not None})"
|
|
1803
|
+
)
|
|
1804
|
+
|
|
1805
|
+
# Process futures with periodic membership checks
|
|
1806
|
+
from concurrent.futures import FIRST_COMPLETED, wait
|
|
1807
|
+
|
|
1808
|
+
pending_futures = set(futures)
|
|
1809
|
+
while pending_futures:
|
|
1810
|
+
# Check for membership changes even while waiting
|
|
1811
|
+
if change_event and change_event.is_set():
|
|
1812
|
+
logger.debug("Membership change detected, checking...")
|
|
1813
|
+
try:
|
|
1814
|
+
self.check_for_membership_changes()
|
|
1815
|
+
except Exception as e:
|
|
1816
|
+
# Cancel all pending futures immediately
|
|
1817
|
+
logger.error(f"Membership change detected, cancelling futures: {e}")
|
|
1818
|
+
for f in pending_futures:
|
|
1819
|
+
if not f.done():
|
|
1820
|
+
f.cancel()
|
|
1821
|
+
raise e
|
|
1822
|
+
finally:
|
|
1823
|
+
change_event.clear() # Reset for next change
|
|
1824
|
+
|
|
1825
|
+
# Wait for next future with a short timeout to allow membership checks
|
|
1826
|
+
done, pending_futures = wait(pending_futures, timeout=1.0, return_when=FIRST_COMPLETED)
|
|
1827
|
+
|
|
1828
|
+
for future in done:
|
|
1829
|
+
logger.debug(
|
|
1830
|
+
f"Future completed: is_worker={worker_future and future == worker_future}, is_local={local_future and future == local_future}"
|
|
1831
|
+
)
|
|
1832
|
+
try:
|
|
1833
|
+
if worker_future and future == worker_future:
|
|
1834
|
+
logger.debug("Getting results from remote workers future")
|
|
1835
|
+
results = future.result()
|
|
1836
|
+
logger.debug(
|
|
1837
|
+
f"Remote worker future returned: {type(results).__name__} with {len(results) if hasattr(results, '__len__') else 'N/A'} items"
|
|
1838
|
+
)
|
|
1839
|
+
# Process results - they're already JSON-decoded
|
|
1840
|
+
logger.debug(f"Processing {len(results)} results from RemoteWorkerPool")
|
|
1841
|
+
for i, result in enumerate(results):
|
|
1842
|
+
logger.debug(
|
|
1843
|
+
f"Result {i}: type={type(result).__name__}, "
|
|
1844
|
+
f"length={len(result) if hasattr(result, '__len__') else 'N/A'}"
|
|
1845
|
+
)
|
|
1846
|
+
if isinstance(result, dict) and "error_type" in result:
|
|
1847
|
+
# Fast failure - return error immediately
|
|
1848
|
+
executor.shutdown(wait=False)
|
|
1849
|
+
return JSONResponse(status_code=500, content=result)
|
|
1850
|
+
# Results from RemoteWorkerPool are already lists (aggregated from subtree in tree topology)
|
|
1851
|
+
if isinstance(result, list):
|
|
1852
|
+
worker_responses.extend(result)
|
|
1853
|
+
else:
|
|
1854
|
+
worker_responses.append(result)
|
|
1855
|
+
logger.debug(f"Got {len(worker_responses)} total responses from remote workers")
|
|
1856
|
+
else: # local_future
|
|
1857
|
+
logger.debug("Getting results from local processes future")
|
|
1858
|
+
local_responses = future.result()
|
|
1859
|
+
logger.debug(f"Got {len(local_responses)} responses from local processes")
|
|
1860
|
+
# Check for errors in local responses
|
|
1861
|
+
for response in local_responses:
|
|
1862
|
+
if isinstance(response, JSONResponse):
|
|
1863
|
+
# Fast failure - return error immediately
|
|
1864
|
+
executor.shutdown(wait=False)
|
|
1865
|
+
return response
|
|
1866
|
+
except Exception as e:
|
|
1867
|
+
# Fast failure - propagate exception immediately
|
|
1868
|
+
executor.shutdown(wait=False)
|
|
1869
|
+
logger.error(f"Error in distributed execution: {e}")
|
|
1870
|
+
raise
|
|
1871
|
+
|
|
1872
|
+
finally:
|
|
1873
|
+
# Unsubscribe from membership changes
|
|
1874
|
+
if change_event:
|
|
1875
|
+
self.unsubscribe_from_membership_changes(change_event)
|
|
1876
|
+
|
|
1877
|
+
logger.debug("Shutting down executor")
|
|
1878
|
+
executor.shutdown(wait=False)
|
|
1879
|
+
|
|
1880
|
+
total = len(local_responses) + len(worker_responses)
|
|
1881
|
+
if tree_mode and total > 10: # Only log for tree mode with significant results
|
|
1882
|
+
logger.debug(
|
|
1883
|
+
f"TREE RESULT AGGREGATION at {this_pod_ip}: {len(local_responses)} local + {len(worker_responses)} remote = {total} total"
|
|
1884
|
+
)
|
|
1885
|
+
else:
|
|
1886
|
+
logger.debug(
|
|
1887
|
+
f"Combining {len(local_responses)} local + {len(worker_responses)} remote = {total} total responses"
|
|
1888
|
+
)
|
|
1889
|
+
logger.debug(f"Distributed_subcall={distributed_subcall}, tree topology={tree_mode}")
|
|
1890
|
+
# Log sample of what we're returning for debugging
|
|
1891
|
+
if worker_responses:
|
|
1892
|
+
logger.debug(f"Sample worker response type: {type(worker_responses[0]).__name__}")
|
|
1893
|
+
responses = local_responses + worker_responses
|
|
1894
|
+
for response in responses:
|
|
1895
|
+
# If the response is a JSONResponse, we need to check if it contains an exception,
|
|
1896
|
+
# and "raise" it if so - essentially just returning it immediately rather than the full result list.
|
|
1897
|
+
if isinstance(response, JSONResponse):
|
|
1898
|
+
return response
|
|
1899
|
+
# This is primarily to handle exceptions while packaging an exception, which will cause the server to hang.
|
|
1900
|
+
if isinstance(response, Exception):
|
|
1901
|
+
raise response
|
|
1902
|
+
logger.debug(f"Returning {len(responses)} responses from execute_call")
|
|
1903
|
+
return responses
|
|
1904
|
+
|
|
1905
|
+
|
|
1906
|
+
class JaxProcess(DistributedProcess):
|
|
1907
|
+
"""JAX-specific distributed process."""
|
|
1908
|
+
|
|
1909
|
+
@classmethod
|
|
1910
|
+
def get_distributed_env_vars(cls, worker_ips, node_rank, local_rank, num_local_procs, **settings):
|
|
1911
|
+
"""Get JAX-specific distributed environment variables.
|
|
1912
|
+
|
|
1913
|
+
JAX uses a coordinator address and process ID for distributed setup.
|
|
1914
|
+
"""
|
|
1915
|
+
port = settings.get("port") or 1234 # JAX default coordinator port
|
|
1916
|
+
env_vars = super().get_distributed_env_vars(worker_ips, node_rank, local_rank, num_local_procs, **settings)
|
|
1917
|
+
|
|
1918
|
+
# JAX distributed environment variables
|
|
1919
|
+
env_vars.update(
|
|
1920
|
+
{
|
|
1921
|
+
# Coordinator is the first worker
|
|
1922
|
+
"JAX_COORDINATOR_ADDRESS": f"{worker_ips[0]}:{port}",
|
|
1923
|
+
# Process ID is global rank
|
|
1924
|
+
"JAX_PROCESS_ID": str(node_rank * num_local_procs + local_rank),
|
|
1925
|
+
# Total number of processes
|
|
1926
|
+
"JAX_NUM_PROCESSES": str(len(worker_ips) * num_local_procs),
|
|
1927
|
+
# Local device IDs (for GPU/TPU)
|
|
1928
|
+
"JAX_LOCAL_DEVICE_IDS": str(local_rank),
|
|
1929
|
+
}
|
|
1930
|
+
)
|
|
1931
|
+
return env_vars
|
|
1932
|
+
|
|
1933
|
+
@classmethod
|
|
1934
|
+
def get_auto_num_processes(cls):
|
|
1935
|
+
"""Auto-detect based on available accelerators for JAX."""
|
|
1936
|
+
try:
|
|
1937
|
+
import jax
|
|
1938
|
+
|
|
1939
|
+
# JAX can use TPUs, GPUs, or CPUs
|
|
1940
|
+
devices = jax.devices()
|
|
1941
|
+
return len(devices)
|
|
1942
|
+
except Exception:
|
|
1943
|
+
return 1
|
|
1944
|
+
|
|
1945
|
+
# JAX doesn't have a global process group to destroy like PyTorch
|
|
1946
|
+
# Cleanup is mostly handled automatically
|
|
1947
|
+
# def proc_cleanup(self):
|
|
1948
|
+
|
|
1949
|
+
|
|
1950
|
+
class TensorflowProcess(DistributedProcess):
|
|
1951
|
+
"""TensorFlow-specific distributed process."""
|
|
1952
|
+
|
|
1953
|
+
def proc_cleanup(self):
|
|
1954
|
+
"""TensorFlow-specific cleanup."""
|
|
1955
|
+
try:
|
|
1956
|
+
import tensorflow as tf
|
|
1957
|
+
|
|
1958
|
+
# Clear the default graph and reset the session
|
|
1959
|
+
tf.keras.backend.clear_session()
|
|
1960
|
+
logger.info("TensorFlow process cleanup completed.")
|
|
1961
|
+
except ImportError:
|
|
1962
|
+
logger.info("TensorFlow not available for cleanup")
|
|
1963
|
+
except Exception as e:
|
|
1964
|
+
logger.info(f"Failed during TensorFlow cleanup: {e}")
|
|
1965
|
+
# Call parent cleanup for debugging sessions
|
|
1966
|
+
super().proc_cleanup()
|
|
1967
|
+
|
|
1968
|
+
@classmethod
|
|
1969
|
+
def get_distributed_env_vars(cls, worker_ips, node_rank, local_rank, num_local_procs, **settings):
|
|
1970
|
+
"""Get TensorFlow-specific distributed environment variables.
|
|
1971
|
+
|
|
1972
|
+
TensorFlow uses TF_CONFIG for distributed training configuration.
|
|
1973
|
+
"""
|
|
1974
|
+
import json
|
|
1975
|
+
|
|
1976
|
+
port = settings.get("port") or 2222 # TensorFlow default port
|
|
1977
|
+
env_vars = super().get_distributed_env_vars(worker_ips, node_rank, local_rank, num_local_procs, **settings)
|
|
1978
|
+
|
|
1979
|
+
# Build TF_CONFIG for MultiWorkerMirroredStrategy
|
|
1980
|
+
worker_addresses = [f"{ip}:{port}" for ip in worker_ips]
|
|
1981
|
+
|
|
1982
|
+
tf_config = {
|
|
1983
|
+
"cluster": {"worker": worker_addresses},
|
|
1984
|
+
"task": {"type": "worker", "index": node_rank},
|
|
1985
|
+
}
|
|
1986
|
+
|
|
1987
|
+
env_vars.update(
|
|
1988
|
+
{
|
|
1989
|
+
"TF_CONFIG": json.dumps(tf_config),
|
|
1990
|
+
# Additional TF env vars for performance
|
|
1991
|
+
"TF_FORCE_GPU_ALLOW_GROWTH": "true",
|
|
1992
|
+
"TF_GPU_THREAD_MODE": "gpu_private",
|
|
1993
|
+
}
|
|
1994
|
+
)
|
|
1995
|
+
return env_vars
|
|
1996
|
+
|
|
1997
|
+
@classmethod
|
|
1998
|
+
def get_auto_num_processes(cls):
|
|
1999
|
+
"""Auto-detect based on available GPUs for TensorFlow."""
|
|
2000
|
+
try:
|
|
2001
|
+
import tensorflow as tf
|
|
2002
|
+
|
|
2003
|
+
gpus = tf.config.list_physical_devices("GPU")
|
|
2004
|
+
if gpus:
|
|
2005
|
+
return len(gpus)
|
|
2006
|
+
except Exception:
|
|
2007
|
+
pass
|
|
2008
|
+
return 1
|
|
2009
|
+
|
|
2010
|
+
|
|
2011
|
+
class MonarchProcess(DistributedProcess):
|
|
2012
|
+
"""Monarch-specific distributed process for single-controller actor framework.
|
|
2013
|
+
|
|
2014
|
+
Similar to Ray, Monarch uses a single controller (rank 0) that manages distributed
|
|
2015
|
+
actors across worker nodes. Each node runs a process_allocator service.
|
|
2016
|
+
"""
|
|
2017
|
+
|
|
2018
|
+
def __init__(self, local_rank, request_queue, response_queue, max_threads=4, **kwargs):
|
|
2019
|
+
super().__init__(local_rank, request_queue, response_queue, max_threads, **kwargs)
|
|
2020
|
+
self.allocator = None
|
|
2021
|
+
|
|
2022
|
+
# Monarch imports will be done in run() on the main thread
|
|
2023
|
+
self.RemoteAllocator = None
|
|
2024
|
+
self.StaticRemoteAllocInitializer = None
|
|
2025
|
+
|
|
2026
|
+
def _create_allocator_for_controller(self):
|
|
2027
|
+
"""Create a RemoteAllocator for the controller (rank 0)."""
|
|
2028
|
+
|
|
2029
|
+
try:
|
|
2030
|
+
# Try to import if not already available
|
|
2031
|
+
if self.RemoteAllocator is None or self.StaticRemoteAllocInitializer is None:
|
|
2032
|
+
try:
|
|
2033
|
+
from monarch._src.actor.allocator import RemoteAllocator, StaticRemoteAllocInitializer
|
|
2034
|
+
|
|
2035
|
+
self.RemoteAllocator = RemoteAllocator
|
|
2036
|
+
self.StaticRemoteAllocInitializer = StaticRemoteAllocInitializer
|
|
2037
|
+
logger.debug("Monarch components imported")
|
|
2038
|
+
except ImportError as e:
|
|
2039
|
+
logger.error(f"Failed to import Monarch: {e}")
|
|
2040
|
+
logger.error("Make sure torchmonarch is installed: pip install torchmonarch")
|
|
2041
|
+
import traceback
|
|
2042
|
+
|
|
2043
|
+
logger.error(traceback.format_exc())
|
|
2044
|
+
return None
|
|
2045
|
+
except Exception as e:
|
|
2046
|
+
logger.error(f"Unexpected error importing Monarch: {e}")
|
|
2047
|
+
import traceback
|
|
2048
|
+
|
|
2049
|
+
logger.error(traceback.format_exc())
|
|
2050
|
+
return None
|
|
2051
|
+
|
|
2052
|
+
if self.RemoteAllocator is None or self.StaticRemoteAllocInitializer is None:
|
|
2053
|
+
logger.error("Monarch components not available. Cannot create allocator.")
|
|
2054
|
+
return None
|
|
2055
|
+
|
|
2056
|
+
# Get worker addresses from POD_IPS
|
|
2057
|
+
pod_ips = os.environ.get("POD_IPS", "").split(",")
|
|
2058
|
+
if not pod_ips or pod_ips == [""]:
|
|
2059
|
+
logger.warning("No POD_IPS found, using localhost")
|
|
2060
|
+
pod_ips = ["127.0.0.1"]
|
|
2061
|
+
|
|
2062
|
+
# Use tcp! format for channel addresses (Monarch's format)
|
|
2063
|
+
# Format: tcp!{ip}:{port} not tcp://{ip}:{port}
|
|
2064
|
+
worker_addresses = [f"tcp!{ip}:26600" for ip in pod_ips]
|
|
2065
|
+
logger.info(f"Creating Monarch allocator with {len(worker_addresses)} workers")
|
|
2066
|
+
logger.debug(f"Worker addresses type: {type(worker_addresses)}")
|
|
2067
|
+
logger.debug(f"First address: {worker_addresses[0] if worker_addresses else 'none'}")
|
|
2068
|
+
logger.debug(f"First address type: {type(worker_addresses[0]) if worker_addresses else 'none'}")
|
|
2069
|
+
|
|
2070
|
+
# Simple check - don't add complex waiting logic
|
|
2071
|
+
|
|
2072
|
+
# Create initializer with all workers using pre-imported classes
|
|
2073
|
+
# StaticRemoteAllocInitializer takes addresses as positional args
|
|
2074
|
+
logger.debug(f"About to create StaticRemoteAllocInitializer with args: {worker_addresses}")
|
|
2075
|
+
try:
|
|
2076
|
+
initializer = self.StaticRemoteAllocInitializer(*worker_addresses)
|
|
2077
|
+
except Exception as e:
|
|
2078
|
+
logger.error(f"Failed to create StaticRemoteAllocInitializer: {e}")
|
|
2079
|
+
import traceback
|
|
2080
|
+
|
|
2081
|
+
logger.debug(f"Traceback: {traceback.format_exc()}")
|
|
2082
|
+
raise
|
|
2083
|
+
|
|
2084
|
+
# Return configured allocator using pre-imported class
|
|
2085
|
+
# RemoteAllocator takes world_id and initializer
|
|
2086
|
+
# Use stable world_id based on service name
|
|
2087
|
+
# This allows coordinator failover and process restarts to work correctly
|
|
2088
|
+
service_name = os.environ.get("KT_SERVICE_NAME", "monarch-default")
|
|
2089
|
+
world_id = service_name
|
|
2090
|
+
try:
|
|
2091
|
+
allocator = self.RemoteAllocator(world_id=world_id, initializer=initializer)
|
|
2092
|
+
logger.info(f"Created allocator with world_id={world_id}")
|
|
2093
|
+
return allocator
|
|
2094
|
+
except Exception as e:
|
|
2095
|
+
logger.error(f"Failed to create RemoteAllocator: {e}")
|
|
2096
|
+
import traceback
|
|
2097
|
+
|
|
2098
|
+
logger.debug(f"Traceback: {traceback.format_exc()}")
|
|
2099
|
+
raise
|
|
2100
|
+
|
|
2101
|
+
except ImportError as e:
|
|
2102
|
+
logger.error(f"Could not import Monarch for allocator creation: {e}")
|
|
2103
|
+
import traceback
|
|
2104
|
+
|
|
2105
|
+
logger.error(traceback.format_exc())
|
|
2106
|
+
return None
|
|
2107
|
+
except Exception as e:
|
|
2108
|
+
logger.error(f"Failed to create Monarch allocator: {e}")
|
|
2109
|
+
import traceback
|
|
2110
|
+
|
|
2111
|
+
logger.error(traceback.format_exc())
|
|
2112
|
+
return None
|
|
2113
|
+
|
|
2114
|
+
def run_user_function(self, callable_obj, method_name, params):
|
|
2115
|
+
"""Run user function with Monarch-specific setup."""
|
|
2116
|
+
import asyncio
|
|
2117
|
+
import inspect
|
|
2118
|
+
|
|
2119
|
+
# Get the rank from environment
|
|
2120
|
+
rank = int(os.environ.get("NODE_RANK", "0"))
|
|
2121
|
+
logger.debug(f"Running user function on rank {rank}")
|
|
2122
|
+
|
|
2123
|
+
# Get the method to call
|
|
2124
|
+
if method_name and hasattr(callable_obj, method_name):
|
|
2125
|
+
user_method = getattr(callable_obj, method_name)
|
|
2126
|
+
else:
|
|
2127
|
+
user_method = callable_obj
|
|
2128
|
+
|
|
2129
|
+
logger.info(f"User method: {user_method}")
|
|
2130
|
+
|
|
2131
|
+
# Prepare arguments
|
|
2132
|
+
args = params.get("args", []) if params else []
|
|
2133
|
+
kwargs = params.get("kwargs", {}) if params else {}
|
|
2134
|
+
|
|
2135
|
+
# Only create and inject allocator for controller (rank 0)
|
|
2136
|
+
# Workers will run the user function with allocator=None
|
|
2137
|
+
if rank == 0:
|
|
2138
|
+
logger.info("Rank 0 (controller) - will create allocator")
|
|
2139
|
+
# Controller (rank 0) - create allocator if needed
|
|
2140
|
+
if self.allocator is None:
|
|
2141
|
+
logger.debug("Creating allocator...")
|
|
2142
|
+
self.allocator = self._create_allocator_for_controller()
|
|
2143
|
+
if self.allocator is None:
|
|
2144
|
+
logger.error("Failed to create allocator - returned None!")
|
|
2145
|
+
else:
|
|
2146
|
+
logger.debug("Allocator created successfully")
|
|
2147
|
+
|
|
2148
|
+
# Inject allocator if function accepts it
|
|
2149
|
+
try:
|
|
2150
|
+
sig = inspect.signature(user_method)
|
|
2151
|
+
if "allocator" in sig.parameters:
|
|
2152
|
+
logger.debug("Injecting allocator into controller function")
|
|
2153
|
+
kwargs["allocator"] = self.allocator
|
|
2154
|
+
except Exception as e:
|
|
2155
|
+
logger.warning(f"Could not inspect function signature: {e}")
|
|
2156
|
+
else:
|
|
2157
|
+
# Workers get None for allocator parameter if the function expects it
|
|
2158
|
+
try:
|
|
2159
|
+
sig = inspect.signature(user_method)
|
|
2160
|
+
if "allocator" in sig.parameters:
|
|
2161
|
+
logger.info(f"Worker {rank}: Setting allocator=None")
|
|
2162
|
+
kwargs["allocator"] = None
|
|
2163
|
+
except Exception:
|
|
2164
|
+
pass
|
|
2165
|
+
|
|
2166
|
+
# Run the function (we're already on the main thread of this process)
|
|
2167
|
+
logger.info(f"Rank {rank}: Running user function with args={args}, kwargs keys={list(kwargs.keys())}")
|
|
2168
|
+
|
|
2169
|
+
try:
|
|
2170
|
+
if asyncio.iscoroutinefunction(user_method):
|
|
2171
|
+
result = asyncio.run(user_method(*args, **kwargs))
|
|
2172
|
+
else:
|
|
2173
|
+
result = user_method(*args, **kwargs)
|
|
2174
|
+
if inspect.isawaitable(result):
|
|
2175
|
+
result = asyncio.run(result)
|
|
2176
|
+
|
|
2177
|
+
# If the result is an exception dict (from the user's try/except),
|
|
2178
|
+
# convert it back to a proper exception and raise it
|
|
2179
|
+
if isinstance(result, dict) and "status" in result and result["status"] == "error":
|
|
2180
|
+
error_msg = result.get("error", "Unknown error")
|
|
2181
|
+
traceback_str = result.get("traceback", "")
|
|
2182
|
+
logger.error(f"User function returned error dict: {error_msg}")
|
|
2183
|
+
logger.error(f"Traceback from user function: {traceback_str}")
|
|
2184
|
+
# Raise a RuntimeError with the original error message
|
|
2185
|
+
raise RuntimeError(f"Monarch execution failed: {error_msg}\n{traceback_str}")
|
|
2186
|
+
|
|
2187
|
+
return result
|
|
2188
|
+
except Exception as e:
|
|
2189
|
+
logger.error(f"Exception in run_user_function: {e}")
|
|
2190
|
+
import traceback
|
|
2191
|
+
|
|
2192
|
+
logger.error(f"Full traceback: {traceback.format_exc()}")
|
|
2193
|
+
# Re-raise the exception to be handled by the caller
|
|
2194
|
+
raise
|
|
2195
|
+
|
|
2196
|
+
def run(self):
|
|
2197
|
+
"""Override run to handle requests on main thread for Monarch."""
|
|
2198
|
+
logger.debug("MonarchProcess starting on main thread")
|
|
2199
|
+
|
|
2200
|
+
# Import Monarch on the main thread of this subprocess
|
|
2201
|
+
# This is the right place since run() executes on the main thread
|
|
2202
|
+
if self.RemoteAllocator is None or self.StaticRemoteAllocInitializer is None:
|
|
2203
|
+
try:
|
|
2204
|
+
from monarch._src.actor.allocator import RemoteAllocator, StaticRemoteAllocInitializer
|
|
2205
|
+
|
|
2206
|
+
self.RemoteAllocator = RemoteAllocator
|
|
2207
|
+
self.StaticRemoteAllocInitializer = StaticRemoteAllocInitializer
|
|
2208
|
+
except Exception as e:
|
|
2209
|
+
logger.error(f"Failed to import Monarch in run(): {e}")
|
|
2210
|
+
import traceback
|
|
2211
|
+
|
|
2212
|
+
logger.error(traceback.format_exc())
|
|
2213
|
+
|
|
2214
|
+
# Monarch requires main thread execution, so we don't use ThreadPoolExecutor
|
|
2215
|
+
try:
|
|
2216
|
+
while True:
|
|
2217
|
+
try:
|
|
2218
|
+
# Block waiting for next request
|
|
2219
|
+
request = self._request_queue.get(timeout=1)
|
|
2220
|
+
|
|
2221
|
+
# Special sentinel value to signal shutdown
|
|
2222
|
+
if request == "SHUTDOWN":
|
|
2223
|
+
break
|
|
2224
|
+
|
|
2225
|
+
# Handle request directly on main thread (not in thread pool)
|
|
2226
|
+
self.handle_request(request)
|
|
2227
|
+
|
|
2228
|
+
except queue.Empty:
|
|
2229
|
+
continue
|
|
2230
|
+
except Exception as e:
|
|
2231
|
+
if "Empty" not in str(e.__class__.__name__):
|
|
2232
|
+
logger.error(f"Error getting request from queue: {e}")
|
|
2233
|
+
continue
|
|
2234
|
+
|
|
2235
|
+
except (KeyboardInterrupt, BdbQuit):
|
|
2236
|
+
logger.debug("MonarchProcess interrupted")
|
|
2237
|
+
finally:
|
|
2238
|
+
logger.debug("MonarchProcess shutting down")
|
|
2239
|
+
self.proc_cleanup()
|
|
2240
|
+
|
|
2241
|
+
def handle_request(self, request):
|
|
2242
|
+
"""Handle request using Monarch-specific logic."""
|
|
2243
|
+
try:
|
|
2244
|
+
# Use parent's handle_request but override the actual function execution
|
|
2245
|
+
request_unique_id = request["request_unique_id"]
|
|
2246
|
+
method_name = request["method_name"]
|
|
2247
|
+
params = request["params"]
|
|
2248
|
+
deployed_as_of = request["deployed_as_of"]
|
|
2249
|
+
request_id = request["request_id"]
|
|
2250
|
+
distributed_env_vars = request["distributed_env_vars"]
|
|
2251
|
+
|
|
2252
|
+
# Set environment variables
|
|
2253
|
+
for key, value in distributed_env_vars.items():
|
|
2254
|
+
os.environ[key] = value
|
|
2255
|
+
|
|
2256
|
+
# Set request context
|
|
2257
|
+
token = request_id_ctx_var.set(request_id)
|
|
2258
|
+
|
|
2259
|
+
try:
|
|
2260
|
+
# Load callable
|
|
2261
|
+
callable_obj = load_callable(
|
|
2262
|
+
deployed_as_of=deployed_as_of,
|
|
2263
|
+
distributed_subprocess=True,
|
|
2264
|
+
reload_cleanup_fn=self.proc_cleanup,
|
|
2265
|
+
)
|
|
2266
|
+
|
|
2267
|
+
# Run with our simplified Monarch logic
|
|
2268
|
+
result = self.run_user_function(callable_obj, method_name, params)
|
|
2269
|
+
|
|
2270
|
+
# Send response
|
|
2271
|
+
self._response_queue.put({"request_unique_id": request_unique_id, "result": result})
|
|
2272
|
+
|
|
2273
|
+
except Exception as e:
|
|
2274
|
+
# Package and send error
|
|
2275
|
+
packaged_exception = package_exception(e)
|
|
2276
|
+
self._response_queue.put(
|
|
2277
|
+
{
|
|
2278
|
+
"request_unique_id": request_unique_id,
|
|
2279
|
+
"result": packaged_exception,
|
|
2280
|
+
}
|
|
2281
|
+
)
|
|
2282
|
+
finally:
|
|
2283
|
+
request_id_ctx_var.reset(token)
|
|
2284
|
+
|
|
2285
|
+
except Exception as e:
|
|
2286
|
+
logger.error(f"Error in Monarch request handling: {e}")
|
|
2287
|
+
self._response_queue.put(
|
|
2288
|
+
{
|
|
2289
|
+
"request_unique_id": request.get("request_unique_id", "unknown"),
|
|
2290
|
+
"result": Exception(f"Fatal error: {e}"),
|
|
2291
|
+
}
|
|
2292
|
+
)
|
|
2293
|
+
|
|
2294
|
+
def proc_cleanup(self):
|
|
2295
|
+
"""Monarch-specific cleanup."""
|
|
2296
|
+
try:
|
|
2297
|
+
# Stop allocator service
|
|
2298
|
+
if self.allocator_proc:
|
|
2299
|
+
self.allocator_proc.terminate()
|
|
2300
|
+
try:
|
|
2301
|
+
self.allocator_proc.wait(timeout=5)
|
|
2302
|
+
except subprocess.TimeoutExpired:
|
|
2303
|
+
self.allocator_proc.kill()
|
|
2304
|
+
self.allocator_proc = None
|
|
2305
|
+
logger.info("Stopped process_allocator service")
|
|
2306
|
+
|
|
2307
|
+
# Cleanup any Monarch resources
|
|
2308
|
+
# Monarch doesn't have a global shutdown like Ray
|
|
2309
|
+
logger.debug("Monarch process cleanup completed")
|
|
2310
|
+
|
|
2311
|
+
except Exception as e:
|
|
2312
|
+
logger.error(f"Error during Monarch cleanup: {e}")
|
|
2313
|
+
|
|
2314
|
+
# Call parent cleanup
|
|
2315
|
+
super().proc_cleanup()
|
|
2316
|
+
|
|
2317
|
+
@classmethod
|
|
2318
|
+
def get_distributed_env_vars(cls, worker_ips, node_rank, local_rank, num_local_procs, **settings):
|
|
2319
|
+
"""Get Monarch-specific environment variables."""
|
|
2320
|
+
env_vars = super().get_distributed_env_vars(worker_ips, node_rank, local_rank, num_local_procs, **settings)
|
|
2321
|
+
|
|
2322
|
+
# Monarch uses these for discovery
|
|
2323
|
+
env_vars.update(
|
|
2324
|
+
{
|
|
2325
|
+
"HYPERACTOR_MESH_BOOTSTRAP_ADDR": "tcp://localhost:26600",
|
|
2326
|
+
"HYPERACTOR_MESH_INDEX": str(node_rank),
|
|
2327
|
+
# Keep POD_IPS for allocator creation
|
|
2328
|
+
"POD_IPS": ",".join(worker_ips),
|
|
2329
|
+
}
|
|
2330
|
+
)
|
|
2331
|
+
|
|
2332
|
+
return env_vars
|
|
2333
|
+
|
|
2334
|
+
@classmethod
|
|
2335
|
+
def get_auto_num_processes(cls):
|
|
2336
|
+
"""Monarch uses one process per node (like Ray)."""
|
|
2337
|
+
return 1
|
|
2338
|
+
|
|
2339
|
+
|
|
2340
|
+
# Similar to Ray, Monarch needs special handling as a single-controller framework
|
|
2341
|
+
class MonarchDistributed(DistributedSupervisor):
|
|
2342
|
+
"""Monarch distributed supervisor for single-controller actor framework."""
|
|
2343
|
+
|
|
2344
|
+
def __init__(
|
|
2345
|
+
self,
|
|
2346
|
+
restart_procs=True,
|
|
2347
|
+
max_threads=4,
|
|
2348
|
+
quorum_timeout=300,
|
|
2349
|
+
quorum_workers=None,
|
|
2350
|
+
**kwargs,
|
|
2351
|
+
):
|
|
2352
|
+
# Monarch doesn't use DNS monitoring like SPMD frameworks
|
|
2353
|
+
super().__init__(
|
|
2354
|
+
quorum_workers=quorum_workers,
|
|
2355
|
+
quorum_timeout=quorum_timeout,
|
|
2356
|
+
monitor_members=False, # Disable DNS monitoring like Ray
|
|
2357
|
+
)
|
|
2358
|
+
self.restart_procs = restart_procs
|
|
2359
|
+
self.max_threads = max_threads
|
|
2360
|
+
self.process_pool = None
|
|
2361
|
+
self.remote_worker_pool = None
|
|
2362
|
+
|
|
2363
|
+
def setup(self, deployed_as_of: Optional[str] = None):
|
|
2364
|
+
"""Setup Monarch distributed environment."""
|
|
2365
|
+
# Set multiprocessing to spawn
|
|
2366
|
+
if multiprocessing.get_start_method() != "spawn":
|
|
2367
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
2368
|
+
|
|
2369
|
+
# Start process_allocator service (like Ray starts its server)
|
|
2370
|
+
self._start_allocator_service()
|
|
2371
|
+
|
|
2372
|
+
if self.restart_procs:
|
|
2373
|
+
logger.debug("restart_procs is True, restarting Monarch processes")
|
|
2374
|
+
self.cleanup()
|
|
2375
|
+
|
|
2376
|
+
if self.process_pool is None:
|
|
2377
|
+
logger.debug("Setting up Monarch distributed environment")
|
|
2378
|
+
|
|
2379
|
+
# Create process pool with MonarchProcess
|
|
2380
|
+
logger.info("Creating DistributedProcessPool with MonarchProcess class")
|
|
2381
|
+
self.process_pool = DistributedProcessPool(
|
|
2382
|
+
process_class=MonarchProcess,
|
|
2383
|
+
num_processes=1, # One process per node for Monarch
|
|
2384
|
+
max_threads_per_proc=self.max_threads,
|
|
2385
|
+
)
|
|
2386
|
+
|
|
2387
|
+
# Start the process
|
|
2388
|
+
logger.info("Starting MonarchProcess pool...")
|
|
2389
|
+
self.process_pool.start()
|
|
2390
|
+
logger.info(f"Started MonarchProcess pool: {self.process_pool}")
|
|
2391
|
+
|
|
2392
|
+
# Create remote worker pool for coordination
|
|
2393
|
+
self.remote_worker_pool = RemoteWorkerPool(quorum_timeout=self.quorum_timeout)
|
|
2394
|
+
self.remote_worker_pool.start()
|
|
2395
|
+
|
|
2396
|
+
logger.debug("Finished setting up Monarch distributed processes")
|
|
2397
|
+
|
|
2398
|
+
def _start_allocator_service(self):
|
|
2399
|
+
"""Start the process_allocator service if available."""
|
|
2400
|
+
try:
|
|
2401
|
+
# Check if process_allocator is already running
|
|
2402
|
+
import subprocess
|
|
2403
|
+
|
|
2404
|
+
check_result = subprocess.run(["pgrep", "-f", "process_allocator"], capture_output=True)
|
|
2405
|
+
if check_result.returncode == 0:
|
|
2406
|
+
logger.info("process_allocator already running")
|
|
2407
|
+
return
|
|
2408
|
+
|
|
2409
|
+
# Try to find process_allocator binary
|
|
2410
|
+
import shutil
|
|
2411
|
+
|
|
2412
|
+
allocator_path = shutil.which("process_allocator")
|
|
2413
|
+
|
|
2414
|
+
if not allocator_path:
|
|
2415
|
+
# Check common installation paths
|
|
2416
|
+
import sys
|
|
2417
|
+
|
|
2418
|
+
possible_paths = [
|
|
2419
|
+
"/opt/conda/bin/process_allocator",
|
|
2420
|
+
os.path.join(sys.prefix, "bin", "process_allocator"),
|
|
2421
|
+
]
|
|
2422
|
+
for path in possible_paths:
|
|
2423
|
+
if os.path.exists(path) and os.access(path, os.X_OK):
|
|
2424
|
+
allocator_path = path
|
|
2425
|
+
break
|
|
2426
|
+
|
|
2427
|
+
if not allocator_path:
|
|
2428
|
+
logger.warning(
|
|
2429
|
+
"process_allocator binary not found. "
|
|
2430
|
+
"Please ensure torchmonarch is properly installed or "
|
|
2431
|
+
"start process_allocator manually in your Docker image."
|
|
2432
|
+
)
|
|
2433
|
+
return
|
|
2434
|
+
|
|
2435
|
+
# Start process_allocator with the correct arguments
|
|
2436
|
+
# Based on monarch/python/monarch/tools/components/hyperactor.py
|
|
2437
|
+
allocator_cmd = [
|
|
2438
|
+
allocator_path,
|
|
2439
|
+
"--port=26600",
|
|
2440
|
+
"--program=monarch_bootstrap",
|
|
2441
|
+
]
|
|
2442
|
+
|
|
2443
|
+
logger.info(f"Starting process_allocator: {' '.join(allocator_cmd)}")
|
|
2444
|
+
|
|
2445
|
+
# Start in background
|
|
2446
|
+
self.allocator_proc = subprocess.Popen(
|
|
2447
|
+
allocator_cmd,
|
|
2448
|
+
stdout=subprocess.DEVNULL,
|
|
2449
|
+
stderr=subprocess.PIPE,
|
|
2450
|
+
start_new_session=True,
|
|
2451
|
+
)
|
|
2452
|
+
|
|
2453
|
+
# Give it a moment to start
|
|
2454
|
+
import time
|
|
2455
|
+
|
|
2456
|
+
time.sleep(2)
|
|
2457
|
+
|
|
2458
|
+
# Check if it's still running
|
|
2459
|
+
if self.allocator_proc.poll() is None:
|
|
2460
|
+
logger.info(f"process_allocator started successfully (PID: {self.allocator_proc.pid})")
|
|
2461
|
+
else:
|
|
2462
|
+
stderr = self.allocator_proc.stderr.read().decode() if self.allocator_proc.stderr else ""
|
|
2463
|
+
logger.error(f"process_allocator failed to start: {stderr}")
|
|
2464
|
+
self.allocator_proc = None
|
|
2465
|
+
|
|
2466
|
+
except Exception as e:
|
|
2467
|
+
logger.warning(f"Could not start process_allocator: {e}")
|
|
2468
|
+
# Continue anyway - user may have started it differently
|
|
2469
|
+
|
|
2470
|
+
def cleanup(self):
|
|
2471
|
+
"""Cleanup Monarch distributed environment."""
|
|
2472
|
+
logger.debug("Cleaning up Monarch distributed processes")
|
|
2473
|
+
|
|
2474
|
+
# Stop DNS monitoring (though it's disabled for Monarch)
|
|
2475
|
+
super().cleanup()
|
|
2476
|
+
|
|
2477
|
+
# Stop process_allocator if we started it
|
|
2478
|
+
if hasattr(self, "allocator_proc") and self.allocator_proc:
|
|
2479
|
+
try:
|
|
2480
|
+
self.allocator_proc.terminate()
|
|
2481
|
+
self.allocator_proc.wait(timeout=5)
|
|
2482
|
+
logger.info("Stopped process_allocator")
|
|
2483
|
+
except Exception as e:
|
|
2484
|
+
logger.debug(f"Error stopping process_allocator: {e}")
|
|
2485
|
+
|
|
2486
|
+
if self.process_pool:
|
|
2487
|
+
self.process_pool.stop()
|
|
2488
|
+
self.process_pool = None
|
|
2489
|
+
|
|
2490
|
+
if self.remote_worker_pool:
|
|
2491
|
+
self.remote_worker_pool.stop()
|
|
2492
|
+
self.remote_worker_pool = None
|
|
2493
|
+
|
|
2494
|
+
logger.debug("Finished cleaning up Monarch distributed processes")
|
|
2495
|
+
|
|
2496
|
+
@staticmethod
|
|
2497
|
+
def intercept_call():
|
|
2498
|
+
"""Monarch intercepts calls like Ray."""
|
|
2499
|
+
return True
|
|
2500
|
+
|
|
2501
|
+
def call_distributed(
|
|
2502
|
+
self,
|
|
2503
|
+
request,
|
|
2504
|
+
cls_or_fn_name: str,
|
|
2505
|
+
method_name: Optional[str] = None,
|
|
2506
|
+
params: Optional[Dict] = None,
|
|
2507
|
+
distributed_subcall: bool = False,
|
|
2508
|
+
debug_port: int = False,
|
|
2509
|
+
deployed_as_of: Optional[str] = None,
|
|
2510
|
+
):
|
|
2511
|
+
"""Monarch distributed call - executes on controller node (rank 0)."""
|
|
2512
|
+
logger.info("MonarchDistributed.call_distributed called")
|
|
2513
|
+
|
|
2514
|
+
# Ensure setup has been called
|
|
2515
|
+
if self.process_pool is None:
|
|
2516
|
+
logger.info("Process pool not initialized, calling setup()")
|
|
2517
|
+
self.setup(deployed_as_of=deployed_as_of)
|
|
2518
|
+
|
|
2519
|
+
request_id = request.headers.get("X-Request-ID", "-")
|
|
2520
|
+
serialization = request.headers.get("X-Serialization", "json")
|
|
2521
|
+
|
|
2522
|
+
# If deployed_as_of is None, generate a consistent timestamp
|
|
2523
|
+
if deployed_as_of is None:
|
|
2524
|
+
from datetime import datetime, timezone
|
|
2525
|
+
|
|
2526
|
+
deployed_as_of = datetime.now(timezone.utc).isoformat()
|
|
2527
|
+
|
|
2528
|
+
# Start DNS monitoring for worker discovery
|
|
2529
|
+
self.start_dns_monitoring()
|
|
2530
|
+
|
|
2531
|
+
# Check for any pending changes before we start
|
|
2532
|
+
self.check_for_membership_changes()
|
|
2533
|
+
|
|
2534
|
+
# Get pod IPs with quorum handling
|
|
2535
|
+
pod_ips = self.pod_ips()
|
|
2536
|
+
|
|
2537
|
+
# Handle case where no pods are found
|
|
2538
|
+
if not pod_ips:
|
|
2539
|
+
logger.error(
|
|
2540
|
+
f"No pods found for service {os.environ.get('KT_SERVICE')}. "
|
|
2541
|
+
"This may indicate the pods aren't ready yet. Consider increasing quorum_timeout in .distribute() call."
|
|
2542
|
+
)
|
|
2543
|
+
raise RuntimeError(
|
|
2544
|
+
"No pods found for Monarch distributed setup. " "Consider increasing quorum_timeout parameter."
|
|
2545
|
+
)
|
|
2546
|
+
|
|
2547
|
+
logger.info(f"Found {len(pod_ips)} pod(s) for Monarch distributed setup: {pod_ips}")
|
|
2548
|
+
|
|
2549
|
+
# Store critical environment variables
|
|
2550
|
+
self.distributed_env_vars = {}
|
|
2551
|
+
critical_env_vars = [
|
|
2552
|
+
"KT_SERVICE",
|
|
2553
|
+
"KT_SERVICE_NAME",
|
|
2554
|
+
"KT_FILE_PATH",
|
|
2555
|
+
"KT_MODULE_NAME",
|
|
2556
|
+
"KT_CLS_OR_FN_NAME",
|
|
2557
|
+
]
|
|
2558
|
+
for env_var in critical_env_vars:
|
|
2559
|
+
if env_var in os.environ:
|
|
2560
|
+
self.distributed_env_vars[env_var] = os.environ[env_var]
|
|
2561
|
+
|
|
2562
|
+
# Update distributed env vars with current cluster IPs
|
|
2563
|
+
self.distributed_env_vars["POD_IPS"] = ",".join(pod_ips)
|
|
2564
|
+
self.distributed_env_vars["WORLD_SIZE"] = str(len(pod_ips))
|
|
2565
|
+
self.distributed_env_vars["NODE_RANK"] = "0" # Controller is always rank 0
|
|
2566
|
+
|
|
2567
|
+
logger.debug("Sending call to Monarch subprocess (controller)")
|
|
2568
|
+
|
|
2569
|
+
# Monarch uses only one process per node, call index 0
|
|
2570
|
+
result = self.process_pool.call(
|
|
2571
|
+
idx=0,
|
|
2572
|
+
method_name=method_name,
|
|
2573
|
+
params=params,
|
|
2574
|
+
deployed_as_of=deployed_as_of,
|
|
2575
|
+
request_id=request_id,
|
|
2576
|
+
distributed_env_vars=self.distributed_env_vars,
|
|
2577
|
+
debug_port=debug_port,
|
|
2578
|
+
serialization=serialization,
|
|
2579
|
+
)
|
|
2580
|
+
|
|
2581
|
+
# Handle exceptions from subprocess
|
|
2582
|
+
if isinstance(result, JSONResponse):
|
|
2583
|
+
return result
|
|
2584
|
+
if isinstance(result, Exception):
|
|
2585
|
+
raise result
|
|
2586
|
+
|
|
2587
|
+
return result
|
|
2588
|
+
|
|
2589
|
+
|
|
2590
|
+
RAY_START_PROC = None
|
|
2591
|
+
|
|
2592
|
+
|
|
2593
|
+
class RayDistributed(DistributedSupervisor):
|
|
2594
|
+
def __init__(
|
|
2595
|
+
self,
|
|
2596
|
+
restart_procs=True,
|
|
2597
|
+
max_threads=4,
|
|
2598
|
+
quorum_timeout=300,
|
|
2599
|
+
quorum_workers=None,
|
|
2600
|
+
monitor_members=False,
|
|
2601
|
+
):
|
|
2602
|
+
"""Ray distributed supervisor - only runs on head node (single controller).
|
|
2603
|
+
|
|
2604
|
+
Args:
|
|
2605
|
+
restart_procs: Whether to restart processes on each call
|
|
2606
|
+
max_threads: Maximum threads per process
|
|
2607
|
+
quorum_timeout: Timeout in seconds for Ray cluster nodes to become ready (default 300s/5min)
|
|
2608
|
+
"""
|
|
2609
|
+
# Ray manages its own membership, so we don't monitor DNS changes
|
|
2610
|
+
super().__init__(
|
|
2611
|
+
quorum_timeout=quorum_timeout,
|
|
2612
|
+
quorum_workers=quorum_workers,
|
|
2613
|
+
monitor_members=monitor_members,
|
|
2614
|
+
)
|
|
2615
|
+
self.restart_procs = restart_procs
|
|
2616
|
+
self.distributed_env_vars = None
|
|
2617
|
+
self.process_pool = None # Using pool even for single process for consistency
|
|
2618
|
+
self.remote_worker_pool = None # Pool for async HTTP calls to remote workers
|
|
2619
|
+
self.max_threads = max_threads
|
|
2620
|
+
self.quorum_timeout = quorum_timeout
|
|
2621
|
+
|
|
2622
|
+
def setup(self, deployed_as_of: Optional[str] = None):
|
|
2623
|
+
# Set multiprocessing to spawn if not already
|
|
2624
|
+
if multiprocessing.get_start_method() != "spawn":
|
|
2625
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
2626
|
+
|
|
2627
|
+
# Start the Ray server here, if we allow KubeRay to start it in the pod template
|
|
2628
|
+
# it's hard to wait for it start properly and we lose the ability to restart if needed.
|
|
2629
|
+
global RAY_START_PROC
|
|
2630
|
+
|
|
2631
|
+
# Check if Ray is actually running, not just if our global variable is None
|
|
2632
|
+
# (the global variable gets reset when HTTP server restarts)
|
|
2633
|
+
ray_running = self._is_ray_running()
|
|
2634
|
+
|
|
2635
|
+
if not ray_running:
|
|
2636
|
+
patch_sys_path()
|
|
2637
|
+
|
|
2638
|
+
kuberay_start_cmd = os.environ.get("KUBERAY_GEN_RAY_START_CMD")
|
|
2639
|
+
if kuberay_start_cmd:
|
|
2640
|
+
full_cmd = f"ulimit -n 65536; {kuberay_start_cmd}"
|
|
2641
|
+
logger.info(f"Starting Ray server with command: {full_cmd}")
|
|
2642
|
+
|
|
2643
|
+
try:
|
|
2644
|
+
# Start Ray as a non-blocking subprocess
|
|
2645
|
+
RAY_START_PROC = subprocess.Popen(
|
|
2646
|
+
full_cmd,
|
|
2647
|
+
shell=True,
|
|
2648
|
+
stdout=subprocess.PIPE,
|
|
2649
|
+
stderr=subprocess.STDOUT,
|
|
2650
|
+
universal_newlines=True,
|
|
2651
|
+
bufsize=1,
|
|
2652
|
+
env=os.environ.copy(),
|
|
2653
|
+
)
|
|
2654
|
+
|
|
2655
|
+
# Start a thread to stream Ray logs
|
|
2656
|
+
def stream_ray_logs():
|
|
2657
|
+
try:
|
|
2658
|
+
for line in RAY_START_PROC.stdout:
|
|
2659
|
+
logger.info(f"[Ray] {line.strip()}")
|
|
2660
|
+
except Exception as e:
|
|
2661
|
+
logger.error(f"Error streaming Ray logs: {e}")
|
|
2662
|
+
|
|
2663
|
+
import threading
|
|
2664
|
+
|
|
2665
|
+
log_thread = threading.Thread(target=stream_ray_logs, daemon=True)
|
|
2666
|
+
log_thread.start()
|
|
2667
|
+
|
|
2668
|
+
logger.info(f"Ray server started with PID: {RAY_START_PROC.pid}")
|
|
2669
|
+
|
|
2670
|
+
# Give Ray a moment to start
|
|
2671
|
+
time.sleep(2)
|
|
2672
|
+
|
|
2673
|
+
except Exception as e:
|
|
2674
|
+
logger.error(f"Failed to start Ray server: {e}")
|
|
2675
|
+
RAY_START_PROC = None
|
|
2676
|
+
raise
|
|
2677
|
+
else:
|
|
2678
|
+
logger.warning("KUBERAY_GEN_RAY_START_CMD environment variable not found")
|
|
2679
|
+
|
|
2680
|
+
logger.debug("Ray distributed supervisor setup completed (pod discovery will be done lazily)")
|
|
2681
|
+
|
|
2682
|
+
# Only the head node runs the subprocess
|
|
2683
|
+
this_pod_ip = os.environ["POD_IP"]
|
|
2684
|
+
if not os.environ["POD_NAME"].endswith("-head"):
|
|
2685
|
+
logger.info(f"Ray worker node {this_pod_ip}, skipping subprocess setup")
|
|
2686
|
+
return
|
|
2687
|
+
|
|
2688
|
+
logger.info(f"Ray head node {this_pod_ip}, setting up subprocess")
|
|
2689
|
+
|
|
2690
|
+
# Set Ray environment variables
|
|
2691
|
+
self.distributed_env_vars = {"RAY_HEAD_NODE_IP": this_pod_ip}
|
|
2692
|
+
|
|
2693
|
+
# Include critical environment variables so Ray workers can find and load the callable
|
|
2694
|
+
critical_env_vars = [
|
|
2695
|
+
"PYTHONPATH",
|
|
2696
|
+
"KT_FILE_PATH",
|
|
2697
|
+
"KT_MODULE_NAME",
|
|
2698
|
+
"KT_CLS_OR_FN_NAME",
|
|
2699
|
+
]
|
|
2700
|
+
for env_var in critical_env_vars:
|
|
2701
|
+
if env_var in os.environ:
|
|
2702
|
+
self.distributed_env_vars[env_var] = os.environ[env_var]
|
|
2703
|
+
|
|
2704
|
+
# Cleanup will remove the process pool if found, so we need to check if it was previously initialized
|
|
2705
|
+
previously_initialized = self.remote_worker_pool is not None
|
|
2706
|
+
|
|
2707
|
+
if self.restart_procs:
|
|
2708
|
+
logger.debug("restart_procs is True, restarting Ray distributed process")
|
|
2709
|
+
self.cleanup()
|
|
2710
|
+
|
|
2711
|
+
if previously_initialized:
|
|
2712
|
+
pod_ips = self.pod_ips()
|
|
2713
|
+
this_pod_ip = os.environ["POD_IP"]
|
|
2714
|
+
|
|
2715
|
+
# Send reload requests to other pods if needed
|
|
2716
|
+
self._reload_image_on_other_pods(pod_ips, this_pod_ip, deployed_as_of)
|
|
2717
|
+
|
|
2718
|
+
if self.process_pool is None:
|
|
2719
|
+
logger.debug("Setting up Ray distributed process")
|
|
2720
|
+
self.process_pool = DistributedProcessPool(
|
|
2721
|
+
process_class=RayProcess,
|
|
2722
|
+
num_processes=1, # Ray only needs one process
|
|
2723
|
+
max_threads_per_proc=self.max_threads,
|
|
2724
|
+
)
|
|
2725
|
+
self.process_pool.start()
|
|
2726
|
+
|
|
2727
|
+
# # Start remote worker pool for async HTTP calls if needed
|
|
2728
|
+
# Use a reasonable default max_workers since we don't know cluster size yet
|
|
2729
|
+
self.remote_worker_pool = RemoteWorkerPool(quorum_timeout=self.quorum_timeout)
|
|
2730
|
+
self.remote_worker_pool.start(max_workers=100) # Default size
|
|
2731
|
+
|
|
2732
|
+
logger.debug("Finished setting up Ray distributed process and remote worker pool")
|
|
2733
|
+
|
|
2734
|
+
def cleanup(self):
|
|
2735
|
+
"""Clean up Ray distributed process."""
|
|
2736
|
+
logger.debug("Cleaning up Ray distributed process")
|
|
2737
|
+
|
|
2738
|
+
# Stop DNS monitoring first
|
|
2739
|
+
super().cleanup()
|
|
2740
|
+
|
|
2741
|
+
if self.process_pool:
|
|
2742
|
+
self.process_pool.stop()
|
|
2743
|
+
self.process_pool = None
|
|
2744
|
+
|
|
2745
|
+
if self.remote_worker_pool:
|
|
2746
|
+
self.remote_worker_pool.stop()
|
|
2747
|
+
self.remote_worker_pool = None
|
|
2748
|
+
|
|
2749
|
+
logger.debug("Finished cleaning up Ray distributed process")
|
|
2750
|
+
|
|
2751
|
+
@staticmethod
|
|
2752
|
+
def intercept_call():
|
|
2753
|
+
return True
|
|
2754
|
+
|
|
2755
|
+
def call_distributed(
|
|
2756
|
+
self,
|
|
2757
|
+
request,
|
|
2758
|
+
cls_or_fn_name: str,
|
|
2759
|
+
method_name: Optional[str] = None,
|
|
2760
|
+
params: Optional[Dict] = None,
|
|
2761
|
+
distributed_subcall: bool = False,
|
|
2762
|
+
debug_port: int = False,
|
|
2763
|
+
deployed_as_of: Optional[str] = None,
|
|
2764
|
+
):
|
|
2765
|
+
"""Ray distributed call - only executes on head node."""
|
|
2766
|
+
request_id = request.headers.get("X-Request-ID", "-")
|
|
2767
|
+
serialization = request.headers.get("X-Serialization", "json")
|
|
2768
|
+
|
|
2769
|
+
# If deployed_as_of is None, generate a consistent timestamp
|
|
2770
|
+
# to use across all workers to prevent reload inconsistencies
|
|
2771
|
+
if deployed_as_of is None:
|
|
2772
|
+
from datetime import datetime, timezone
|
|
2773
|
+
|
|
2774
|
+
deployed_as_of = datetime.now(timezone.utc).isoformat()
|
|
2775
|
+
|
|
2776
|
+
if not os.environ["POD_NAME"].endswith("-head"):
|
|
2777
|
+
# This should never happen, because the service only points to the head node, Raise an error if it does.
|
|
2778
|
+
raise RuntimeError(
|
|
2779
|
+
f"Ray distributed call attempted on non-head node {os.environ['POD_NAME']}. "
|
|
2780
|
+
"This should only be called on the head node."
|
|
2781
|
+
)
|
|
2782
|
+
|
|
2783
|
+
# Start DNS monitoring for the head node
|
|
2784
|
+
self.start_dns_monitoring()
|
|
2785
|
+
|
|
2786
|
+
# Check for any pending changes before we start
|
|
2787
|
+
self.check_for_membership_changes()
|
|
2788
|
+
|
|
2789
|
+
# The pod_ips() method now handles waiting for quorum
|
|
2790
|
+
pod_ips = self.pod_ips()
|
|
2791
|
+
|
|
2792
|
+
# Handle case where no pods are found
|
|
2793
|
+
if not pod_ips:
|
|
2794
|
+
logger.error(
|
|
2795
|
+
f"No pods found for service {os.environ.get('KT_SERVICE')}. "
|
|
2796
|
+
"This may indicate the pods aren't ready yet. Consider increasing quorum_timeout in .distribute() call."
|
|
2797
|
+
)
|
|
2798
|
+
raise RuntimeError(
|
|
2799
|
+
"No pods found for Ray distributed setup. " "Consider increasing quorum_timeout parameter."
|
|
2800
|
+
)
|
|
2801
|
+
|
|
2802
|
+
logger.info(f"Found {len(pod_ips)} pod(s) for distributed setup: {pod_ips}")
|
|
2803
|
+
|
|
2804
|
+
# Update distributed env vars with current cluster IPs
|
|
2805
|
+
self.distributed_env_vars["POD_IPS"] = ",".join(pod_ips)
|
|
2806
|
+
|
|
2807
|
+
logger.debug("Sending call to Ray subprocess")
|
|
2808
|
+
# Ray uses only one process, so always call index 0
|
|
2809
|
+
result = self.process_pool.call(
|
|
2810
|
+
idx=0,
|
|
2811
|
+
method_name=method_name,
|
|
2812
|
+
params=params,
|
|
2813
|
+
deployed_as_of=deployed_as_of,
|
|
2814
|
+
request_id=request_id,
|
|
2815
|
+
distributed_env_vars=self.distributed_env_vars,
|
|
2816
|
+
debug_port=debug_port,
|
|
2817
|
+
serialization=serialization,
|
|
2818
|
+
)
|
|
2819
|
+
|
|
2820
|
+
# Handle exceptions from subprocess
|
|
2821
|
+
if isinstance(result, JSONResponse):
|
|
2822
|
+
return result
|
|
2823
|
+
if isinstance(result, Exception):
|
|
2824
|
+
raise result
|
|
2825
|
+
|
|
2826
|
+
return result
|
|
2827
|
+
|
|
2828
|
+
def _reload_image_on_other_pods(self, pod_ips, this_pod_ip, deployed_as_of):
|
|
2829
|
+
"""Send /_reload_image requests to all other pods in parallel, with retries for pods that aren't ready."""
|
|
2830
|
+
other_pod_ips = [ip for ip in pod_ips if ip != this_pod_ip]
|
|
2831
|
+
|
|
2832
|
+
if not other_pod_ips:
|
|
2833
|
+
logger.debug("No other pods to reload")
|
|
2834
|
+
return
|
|
2835
|
+
|
|
2836
|
+
logger.info(f"Sending reload requests to {len(other_pod_ips)} other pods: {other_pod_ips}")
|
|
2837
|
+
|
|
2838
|
+
server_port = os.environ.get("KT_SERVER_PORT", "32300")
|
|
2839
|
+
total_timeout = self.quorum_timeout # Use configurable quorum timeout
|
|
2840
|
+
retry_interval = 2 # Wait 2 seconds between retry attempts
|
|
2841
|
+
start_time = time.time()
|
|
2842
|
+
|
|
2843
|
+
successful_pods = set()
|
|
2844
|
+
remaining_pods = set(other_pod_ips)
|
|
2845
|
+
|
|
2846
|
+
while remaining_pods and (time.time() - start_time) < total_timeout:
|
|
2847
|
+
logger.debug(f"Attempting to reload {len(remaining_pods)} remaining pods: {list(remaining_pods)}")
|
|
2848
|
+
|
|
2849
|
+
def reload_pod(pod_ip):
|
|
2850
|
+
"""Send reload request to a single pod."""
|
|
2851
|
+
try:
|
|
2852
|
+
# Use a proper HTTP client session to avoid Content-Length issues
|
|
2853
|
+
with httpx.Client(timeout=None) as client:
|
|
2854
|
+
url = f"http://{pod_ip}:{server_port}/_reload_image"
|
|
2855
|
+
# First try a quick health check to see if pod is ready
|
|
2856
|
+
health_url = f"http://{pod_ip}:{server_port}/health"
|
|
2857
|
+
health_response = client.get(health_url, timeout=5)
|
|
2858
|
+
|
|
2859
|
+
if health_response.status_code != 200:
|
|
2860
|
+
logger.debug(f"Pod {pod_ip} health check failed, will retry later")
|
|
2861
|
+
return False
|
|
2862
|
+
|
|
2863
|
+
# Pod is healthy, send reload request (no timeout, installs can be long-running)
|
|
2864
|
+
response = client.post(url, headers={"X-Deployed-As-Of": deployed_as_of})
|
|
2865
|
+
if response.status_code == 200:
|
|
2866
|
+
logger.debug(f"Successfully reloaded image on pod {pod_ip}")
|
|
2867
|
+
return True
|
|
2868
|
+
else:
|
|
2869
|
+
logger.warning(f"Pod {pod_ip} reload returned status {response.status_code}")
|
|
2870
|
+
return False
|
|
2871
|
+
|
|
2872
|
+
except Exception as e:
|
|
2873
|
+
logger.debug(f"Failed to reload image on pod {pod_ip}: {e}")
|
|
2874
|
+
raise
|
|
2875
|
+
|
|
2876
|
+
# Try to reload all remaining pods in parallel
|
|
2877
|
+
current_attempt_pods = list(remaining_pods)
|
|
2878
|
+
|
|
2879
|
+
with ThreadPoolExecutor(max_workers=min(len(current_attempt_pods), 10)) as executor:
|
|
2880
|
+
# Submit reload tasks for remaining pods
|
|
2881
|
+
future_to_pod = {executor.submit(reload_pod, pod_ip): pod_ip for pod_ip in current_attempt_pods}
|
|
2882
|
+
|
|
2883
|
+
# Process completed futures
|
|
2884
|
+
for future in as_completed(future_to_pod, timeout=None):
|
|
2885
|
+
pod_ip = future_to_pod[future]
|
|
2886
|
+
try:
|
|
2887
|
+
success = future.result()
|
|
2888
|
+
if success:
|
|
2889
|
+
successful_pods.add(pod_ip)
|
|
2890
|
+
remaining_pods.discard(pod_ip)
|
|
2891
|
+
except Exception as e:
|
|
2892
|
+
logger.debug(f"Reload task for pod {pod_ip} failed: {e}")
|
|
2893
|
+
|
|
2894
|
+
if remaining_pods:
|
|
2895
|
+
elapsed = time.time() - start_time
|
|
2896
|
+
remaining_time = total_timeout - elapsed
|
|
2897
|
+
if remaining_time > retry_interval:
|
|
2898
|
+
logger.info(f"Waiting {retry_interval}s before retrying {len(remaining_pods)} pods...")
|
|
2899
|
+
time.sleep(retry_interval)
|
|
2900
|
+
else:
|
|
2901
|
+
logger.warning("Timeout approaching, stopping retry attempts")
|
|
2902
|
+
break
|
|
2903
|
+
|
|
2904
|
+
# Log final results
|
|
2905
|
+
if successful_pods:
|
|
2906
|
+
logger.info(f"Successfully reloaded {len(successful_pods)} pod images: {list(successful_pods)}")
|
|
2907
|
+
|
|
2908
|
+
if remaining_pods:
|
|
2909
|
+
logger.warning(f"Failed to reload {len(remaining_pods)} pod images after timeout: {list(remaining_pods)}")
|
|
2910
|
+
|
|
2911
|
+
def _is_ray_running(self):
|
|
2912
|
+
"""Check if Ray is actually running by trying to connect to the Ray GCS port."""
|
|
2913
|
+
try:
|
|
2914
|
+
import socket
|
|
2915
|
+
|
|
2916
|
+
# Ray GCS runs on port 6379 by default
|
|
2917
|
+
ray_port = 6379
|
|
2918
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
2919
|
+
sock.settimeout(1) # 1 second timeout
|
|
2920
|
+
result = sock.connect_ex(("127.0.0.1", ray_port))
|
|
2921
|
+
sock.close()
|
|
2922
|
+
|
|
2923
|
+
if result == 0:
|
|
2924
|
+
logger.debug("Ray GCS port 6379 is accessible, Ray appears to be running")
|
|
2925
|
+
return True
|
|
2926
|
+
else:
|
|
2927
|
+
logger.debug("Ray GCS port 6379 is not accessible, Ray is not running")
|
|
2928
|
+
return False
|
|
2929
|
+
|
|
2930
|
+
except Exception as e:
|
|
2931
|
+
logger.debug(f"Error checking if Ray is running: {e}")
|
|
2932
|
+
return False
|
|
2933
|
+
|
|
2934
|
+
|
|
2935
|
+
def distributed_supervisor_factory(distribution_type, *args, **kwargs):
|
|
2936
|
+
"""
|
|
2937
|
+
Factory function to create a distributed supervisor based on the specified type.
|
|
2938
|
+
|
|
2939
|
+
Args:
|
|
2940
|
+
distribution_type (str): The type of distributed supervisor to create.
|
|
2941
|
+
Options include 'ray', 'monarch', 'pytorch', 'jax', 'tensorflow', or None for generic SPMD.
|
|
2942
|
+
*args: Positional arguments to pass to the supervisor constructor.
|
|
2943
|
+
**kwargs: Keyword arguments to pass to the supervisor constructor.
|
|
2944
|
+
Common kwargs include:
|
|
2945
|
+
- quorum_timeout: Timeout in seconds for workers to become ready (default 30 for SPMD, 300 for Ray/Monarch)
|
|
2946
|
+
|
|
2947
|
+
Returns:
|
|
2948
|
+
DistributedSupervisor: An instance of the specified distributed supervisor.
|
|
2949
|
+
"""
|
|
2950
|
+
if distribution_type == "ray":
|
|
2951
|
+
# Ray uses its own supervisor, not SPMD
|
|
2952
|
+
return RayDistributed(*args, **kwargs)
|
|
2953
|
+
elif distribution_type == "monarch":
|
|
2954
|
+
# Monarch is similar to Ray - single controller framework
|
|
2955
|
+
return MonarchDistributed(*args, **kwargs)
|
|
2956
|
+
|
|
2957
|
+
# All other types use SPMDDistributedSupervisor with different process classes
|
|
2958
|
+
if distribution_type == "pytorch":
|
|
2959
|
+
return SPMDDistributedSupervisor(process_class=PyTorchProcess, *args, **kwargs)
|
|
2960
|
+
elif distribution_type == "jax":
|
|
2961
|
+
return SPMDDistributedSupervisor(process_class=JaxProcess, *args, **kwargs)
|
|
2962
|
+
elif distribution_type == "tensorflow" or distribution_type == "tf":
|
|
2963
|
+
return SPMDDistributedSupervisor(process_class=TensorflowProcess, *args, **kwargs)
|
|
2964
|
+
elif distribution_type is None or distribution_type == "spmd":
|
|
2965
|
+
# Default to base DistributedProcess - no framework-specific dependencies
|
|
2966
|
+
return SPMDDistributedSupervisor(process_class=DistributedProcess, *args, **kwargs)
|
|
2967
|
+
else:
|
|
2968
|
+
raise ValueError(f"Unsupported distributed type: {distribution_type}")
|