gen-worker 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gen_worker/__init__.py +19 -0
- gen_worker/decorators.py +66 -0
- gen_worker/default_model_manager/__init__.py +5 -0
- gen_worker/downloader.py +84 -0
- gen_worker/entrypoint.py +135 -0
- gen_worker/errors.py +10 -0
- gen_worker/model_interface.py +48 -0
- gen_worker/pb/__init__.py +27 -0
- gen_worker/pb/frontend_pb2.py +53 -0
- gen_worker/pb/frontend_pb2_grpc.py +189 -0
- gen_worker/pb/worker_scheduler_pb2.py +69 -0
- gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
- gen_worker/py.typed +0 -0
- gen_worker/testing/__init__.py +1 -0
- gen_worker/testing/stub_manager.py +69 -0
- gen_worker/torch_manager/__init__.py +4 -0
- gen_worker/torch_manager/manager.py +2059 -0
- gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
- gen_worker/torch_manager/utils/base_types/common.py +52 -0
- gen_worker/torch_manager/utils/base_types/config.py +46 -0
- gen_worker/torch_manager/utils/config.py +321 -0
- gen_worker/torch_manager/utils/db/database.py +46 -0
- gen_worker/torch_manager/utils/device.py +26 -0
- gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
- gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
- gen_worker/torch_manager/utils/globals.py +59 -0
- gen_worker/torch_manager/utils/load_models.py +238 -0
- gen_worker/torch_manager/utils/local_cache.py +340 -0
- gen_worker/torch_manager/utils/model_downloader.py +763 -0
- gen_worker/torch_manager/utils/parse_cli.py +98 -0
- gen_worker/torch_manager/utils/paths.py +22 -0
- gen_worker/torch_manager/utils/repository.py +141 -0
- gen_worker/torch_manager/utils/utils.py +43 -0
- gen_worker/types.py +47 -0
- gen_worker/worker.py +1720 -0
- gen_worker-0.1.4.dist-info/METADATA +113 -0
- gen_worker-0.1.4.dist-info/RECORD +38 -0
- gen_worker-0.1.4.dist-info/WHEEL +4 -0
gen_worker/worker.py
ADDED
|
@@ -0,0 +1,1720 @@
|
|
|
1
|
+
import grpc
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
import json
|
|
5
|
+
import urllib.request
|
|
6
|
+
import urllib.parse
|
|
7
|
+
import urllib.error
|
|
8
|
+
import random
|
|
9
|
+
import threading
|
|
10
|
+
import os
|
|
11
|
+
import signal
|
|
12
|
+
import queue
|
|
13
|
+
import psutil
|
|
14
|
+
import importlib
|
|
15
|
+
import inspect
|
|
16
|
+
import functools
|
|
17
|
+
import typing
|
|
18
|
+
import socket
|
|
19
|
+
import ipaddress
|
|
20
|
+
from typing import Any, Callable, Dict, Optional, TypeVar, Iterator, List, Tuple
|
|
21
|
+
from types import ModuleType
|
|
22
|
+
import hashlib
|
|
23
|
+
import msgspec
|
|
24
|
+
try:
|
|
25
|
+
import torch
|
|
26
|
+
except Exception: # pragma: no cover - optional dependency
|
|
27
|
+
torch = None
|
|
28
|
+
import asyncio
|
|
29
|
+
|
|
30
|
+
# JWT verification for worker auth (scheduler-issued)
|
|
31
|
+
import jwt
|
|
32
|
+
_jwt_algorithms: Optional[ModuleType]
|
|
33
|
+
try:
|
|
34
|
+
import jwt.algorithms as _jwt_algorithms
|
|
35
|
+
except Exception: # pragma: no cover - optional crypto backend
|
|
36
|
+
_jwt_algorithms = None
|
|
37
|
+
RSAAlgorithm: Optional[Any] = getattr(_jwt_algorithms, "RSAAlgorithm", None) if _jwt_algorithms else None
|
|
38
|
+
# Use relative imports within the package
|
|
39
|
+
from .pb import worker_scheduler_pb2 as _pb
|
|
40
|
+
from .pb import worker_scheduler_pb2_grpc as _pb_grpc
|
|
41
|
+
|
|
42
|
+
pb: Any = _pb
|
|
43
|
+
pb_grpc: Any = _pb_grpc
|
|
44
|
+
|
|
45
|
+
WorkerSchedulerMessage = Any
|
|
46
|
+
WorkerEvent = Any
|
|
47
|
+
WorkerResources = Any
|
|
48
|
+
WorkerRegistration = Any
|
|
49
|
+
LoadModelCommand = Any
|
|
50
|
+
LoadModelResult = Any
|
|
51
|
+
UnloadModelResult = Any
|
|
52
|
+
TaskExecutionRequest = Any
|
|
53
|
+
TaskExecutionResult = Any
|
|
54
|
+
from .decorators import ResourceRequirements # Import ResourceRequirements for type hints if needed
|
|
55
|
+
from .errors import RetryableError, FatalError
|
|
56
|
+
|
|
57
|
+
from .model_interface import ModelManagementInterface
|
|
58
|
+
from .downloader import CozyHubDownloader, ModelDownloader
|
|
59
|
+
from .types import Asset
|
|
60
|
+
|
|
61
|
+
# Configure logging
|
|
62
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
63
|
+
logger = logging.getLogger(__name__) # Use __name__ for logger
|
|
64
|
+
|
|
65
|
+
# Type variables for generic function signatures
|
|
66
|
+
I = TypeVar('I') # Input type
|
|
67
|
+
O = TypeVar('O') # Output type
|
|
68
|
+
|
|
69
|
+
# Generic type for action functions
|
|
70
|
+
ActionFunc = Callable[[Any, I], O]
|
|
71
|
+
|
|
72
|
+
HEARTBEAT_INTERVAL = 10 # seconds
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _encode_ref_for_url(ref: str) -> str:
|
|
76
|
+
ref = ref.strip().lstrip("/")
|
|
77
|
+
parts = [urllib.parse.quote(p, safe="") for p in ref.split("/") if p]
|
|
78
|
+
return "/".join(parts)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _infer_mime_type(ref: str, head: bytes) -> str:
|
|
82
|
+
# Prefer magic bytes when available.
|
|
83
|
+
if head.startswith(b"\x89PNG\r\n\x1a\n"):
|
|
84
|
+
return "image/png"
|
|
85
|
+
if head.startswith(b"\xff\xd8\xff"):
|
|
86
|
+
return "image/jpeg"
|
|
87
|
+
if head.startswith(b"GIF87a") or head.startswith(b"GIF89a"):
|
|
88
|
+
return "image/gif"
|
|
89
|
+
if len(head) >= 12 and head[0:4] == b"RIFF" and head[8:12] == b"WEBP":
|
|
90
|
+
return "image/webp"
|
|
91
|
+
|
|
92
|
+
# Fall back to extension.
|
|
93
|
+
import mimetypes
|
|
94
|
+
|
|
95
|
+
guessed, _ = mimetypes.guess_type(ref)
|
|
96
|
+
return guessed or "application/octet-stream"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _default_output_prefix(run_id: str) -> str:
|
|
100
|
+
return f"runs/{run_id}/outputs/"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _require_file_api_base_url() -> str:
|
|
104
|
+
base = os.getenv("FILE_API_BASE_URL", "").strip()
|
|
105
|
+
if not base:
|
|
106
|
+
base = os.getenv("ORCHESTRATOR_HTTP_URL", "").strip()
|
|
107
|
+
if not base:
|
|
108
|
+
base = os.getenv("COZY_HUB_URL", "").strip()
|
|
109
|
+
if not base:
|
|
110
|
+
raise RuntimeError("FILE_API_BASE_URL is required for file operations")
|
|
111
|
+
return base.rstrip("/")
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _require_file_api_token() -> str:
|
|
115
|
+
token = os.getenv("FILE_API_TOKEN", "").strip()
|
|
116
|
+
if not token:
|
|
117
|
+
token = os.getenv("COZY_HUB_TOKEN", "").strip()
|
|
118
|
+
if not token:
|
|
119
|
+
raise RuntimeError("FILE_API_TOKEN is required for file operations")
|
|
120
|
+
return token
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _http_request(
|
|
124
|
+
method: str,
|
|
125
|
+
url: str,
|
|
126
|
+
token: str,
|
|
127
|
+
body: Optional[bytes] = None,
|
|
128
|
+
content_type: Optional[str] = None,
|
|
129
|
+
) -> urllib.request.Request:
|
|
130
|
+
req = urllib.request.Request(url, data=body, method=method)
|
|
131
|
+
req.add_header("Authorization", f"Bearer {token}")
|
|
132
|
+
tenant_id = os.getenv("TENANT_ID", "").strip()
|
|
133
|
+
if tenant_id:
|
|
134
|
+
req.add_header("X-Cozy-Tenant-Id", tenant_id)
|
|
135
|
+
if content_type:
|
|
136
|
+
req.add_header("Content-Type", content_type)
|
|
137
|
+
return req
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _is_private_ip_str(ip_str: str) -> bool:
|
|
141
|
+
try:
|
|
142
|
+
ip = ipaddress.ip_address(ip_str)
|
|
143
|
+
except Exception:
|
|
144
|
+
return True
|
|
145
|
+
return bool(
|
|
146
|
+
ip.is_private
|
|
147
|
+
or ip.is_loopback
|
|
148
|
+
or ip.is_link_local
|
|
149
|
+
or ip.is_multicast
|
|
150
|
+
or ip.is_reserved
|
|
151
|
+
or ip.is_unspecified
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _url_is_blocked(url_str: str) -> bool:
|
|
156
|
+
try:
|
|
157
|
+
u = urllib.parse.urlparse(url_str)
|
|
158
|
+
except Exception:
|
|
159
|
+
return True
|
|
160
|
+
if u.scheme not in ("http", "https"):
|
|
161
|
+
return True
|
|
162
|
+
host = (u.hostname or "").strip()
|
|
163
|
+
if not host:
|
|
164
|
+
return True
|
|
165
|
+
try:
|
|
166
|
+
infos = socket.getaddrinfo(host, None)
|
|
167
|
+
except Exception:
|
|
168
|
+
return True
|
|
169
|
+
for info in infos:
|
|
170
|
+
sockaddr = info[4]
|
|
171
|
+
if not sockaddr:
|
|
172
|
+
continue
|
|
173
|
+
ip_str = str(sockaddr[0])
|
|
174
|
+
if _is_private_ip_str(ip_str):
|
|
175
|
+
return True
|
|
176
|
+
return False
|
|
177
|
+
|
|
178
|
+
class _JWKSCache:
|
|
179
|
+
def __init__(self, url: str, ttl_seconds: int = 300) -> None:
|
|
180
|
+
self._url = url
|
|
181
|
+
self._ttl_seconds = max(ttl_seconds, 0)
|
|
182
|
+
self._lock = threading.Lock()
|
|
183
|
+
self._fetched_at = 0.0
|
|
184
|
+
self._keys: Dict[str, Any] = {}
|
|
185
|
+
|
|
186
|
+
def _fetch(self) -> None:
|
|
187
|
+
if RSAAlgorithm is None:
|
|
188
|
+
raise RuntimeError(
|
|
189
|
+
"PyJWT RSA support is unavailable (missing cryptography). "
|
|
190
|
+
"Install gen-worker with a JWT/RSA-capable build of PyJWT."
|
|
191
|
+
)
|
|
192
|
+
with urllib.request.urlopen(self._url, timeout=5) as resp:
|
|
193
|
+
body = resp.read()
|
|
194
|
+
payload = json.loads(body.decode("utf-8"))
|
|
195
|
+
keys: Dict[str, Any] = {}
|
|
196
|
+
for jwk in payload.get("keys", []):
|
|
197
|
+
kid = jwk.get("kid")
|
|
198
|
+
if not kid:
|
|
199
|
+
continue
|
|
200
|
+
try:
|
|
201
|
+
keys[kid] = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
|
202
|
+
except Exception:
|
|
203
|
+
continue
|
|
204
|
+
self._keys = keys
|
|
205
|
+
self._fetched_at = time.time()
|
|
206
|
+
|
|
207
|
+
def _needs_refresh(self) -> bool:
|
|
208
|
+
if not self._keys:
|
|
209
|
+
return True
|
|
210
|
+
if self._ttl_seconds <= 0:
|
|
211
|
+
return False
|
|
212
|
+
return (time.time() - self._fetched_at) > self._ttl_seconds
|
|
213
|
+
|
|
214
|
+
def get_key(self, kid: Optional[str]) -> Optional[Any]:
|
|
215
|
+
with self._lock:
|
|
216
|
+
if self._needs_refresh():
|
|
217
|
+
self._fetch()
|
|
218
|
+
if kid and kid in self._keys:
|
|
219
|
+
return self._keys[kid]
|
|
220
|
+
# refresh on miss (rotation)
|
|
221
|
+
self._fetch()
|
|
222
|
+
if kid and kid in self._keys:
|
|
223
|
+
return self._keys[kid]
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
class ActionContext:
|
|
227
|
+
"""Context object passed to action functions, allowing cancellation."""
|
|
228
|
+
def __init__(
|
|
229
|
+
self,
|
|
230
|
+
run_id: str,
|
|
231
|
+
emitter: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
232
|
+
tenant_id: Optional[str] = None,
|
|
233
|
+
user_id: Optional[str] = None,
|
|
234
|
+
timeout_ms: Optional[int] = None,
|
|
235
|
+
) -> None:
|
|
236
|
+
self._run_id = run_id
|
|
237
|
+
self._tenant_id = tenant_id
|
|
238
|
+
self._user_id = user_id
|
|
239
|
+
self._timeout_ms = timeout_ms
|
|
240
|
+
self._started_at = time.time()
|
|
241
|
+
self._deadline: Optional[float] = None
|
|
242
|
+
if timeout_ms is not None and timeout_ms > 0:
|
|
243
|
+
self._deadline = self._started_at + (timeout_ms / 1000.0)
|
|
244
|
+
self._canceled = False
|
|
245
|
+
self._cancel_event = threading.Event()
|
|
246
|
+
self._emitter = emitter
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def run_id(self) -> str:
|
|
250
|
+
return self._run_id
|
|
251
|
+
|
|
252
|
+
@property
|
|
253
|
+
def tenant_id(self) -> Optional[str]:
|
|
254
|
+
return self._tenant_id
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def user_id(self) -> Optional[str]:
|
|
258
|
+
return self._user_id
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def timeout_ms(self) -> Optional[int]:
|
|
262
|
+
return self._timeout_ms
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def deadline(self) -> Optional[float]:
|
|
266
|
+
return self._deadline
|
|
267
|
+
|
|
268
|
+
def time_remaining_s(self) -> Optional[float]:
|
|
269
|
+
if self._deadline is None:
|
|
270
|
+
return None
|
|
271
|
+
return max(0.0, self._deadline - time.time())
|
|
272
|
+
|
|
273
|
+
def is_canceled(self) -> bool:
|
|
274
|
+
"""Check if the action was canceled."""
|
|
275
|
+
return self._canceled
|
|
276
|
+
|
|
277
|
+
def cancel(self) -> None:
|
|
278
|
+
"""Mark the action as canceled."""
|
|
279
|
+
if not self._canceled:
|
|
280
|
+
self._canceled = True
|
|
281
|
+
self._cancel_event.set()
|
|
282
|
+
logger.info(f"Action {self.run_id} marked for cancellation.")
|
|
283
|
+
|
|
284
|
+
def done(self) -> threading.Event:
|
|
285
|
+
"""Returns an event that is set when the action is cancelled."""
|
|
286
|
+
return self._cancel_event
|
|
287
|
+
|
|
288
|
+
def emit(self, event_type: str, payload: Optional[Dict[str, Any]] = None) -> None:
|
|
289
|
+
"""Emit a progress/event payload (best-effort)."""
|
|
290
|
+
if not self._emitter:
|
|
291
|
+
logger.debug(f"emit({event_type}) dropped: no emitter configured")
|
|
292
|
+
return
|
|
293
|
+
event = {
|
|
294
|
+
"run_id": self._run_id,
|
|
295
|
+
"type": event_type,
|
|
296
|
+
"payload": payload or {},
|
|
297
|
+
"timestamp": time.time(),
|
|
298
|
+
}
|
|
299
|
+
self._emitter(event)
|
|
300
|
+
|
|
301
|
+
def progress(self, progress: float, stage: Optional[str] = None) -> None:
|
|
302
|
+
payload: Dict[str, Any] = {"progress": progress}
|
|
303
|
+
if stage is not None:
|
|
304
|
+
payload["stage"] = stage
|
|
305
|
+
self.emit("job.progress", payload)
|
|
306
|
+
|
|
307
|
+
def log(self, message: str, level: str = "info") -> None:
|
|
308
|
+
self.emit("job.log", {"message": message, "level": level})
|
|
309
|
+
|
|
310
|
+
def save_bytes(self, ref: str, data: bytes) -> Asset:
|
|
311
|
+
if not isinstance(data, (bytes, bytearray)):
|
|
312
|
+
raise TypeError("save_bytes expects bytes")
|
|
313
|
+
data = bytes(data)
|
|
314
|
+
max_bytes = int(os.getenv("WORKER_MAX_OUTPUT_FILE_BYTES", str(200 * 1024 * 1024)))
|
|
315
|
+
if max_bytes > 0 and len(data) > max_bytes:
|
|
316
|
+
raise ValueError("output file too large")
|
|
317
|
+
ref = ref.strip().lstrip("/")
|
|
318
|
+
if not ref.startswith(_default_output_prefix(self.run_id)):
|
|
319
|
+
raise ValueError(f"ref must start with '{_default_output_prefix(self.run_id)}'")
|
|
320
|
+
|
|
321
|
+
base = _require_file_api_base_url()
|
|
322
|
+
token = _require_file_api_token()
|
|
323
|
+
url = f"{base}/api/v1/file/{_encode_ref_for_url(ref)}"
|
|
324
|
+
# Default behavior is upsert: PUT to the tenant file store.
|
|
325
|
+
req = _http_request("PUT", url, token, body=data, content_type="application/octet-stream")
|
|
326
|
+
try:
|
|
327
|
+
with urllib.request.urlopen(req, timeout=30) as resp:
|
|
328
|
+
body = resp.read()
|
|
329
|
+
if resp.status < 200 or resp.status >= 300:
|
|
330
|
+
raise RuntimeError(f"file save failed ({resp.status})")
|
|
331
|
+
try:
|
|
332
|
+
meta = json.loads(body.decode("utf-8"))
|
|
333
|
+
except Exception:
|
|
334
|
+
meta = {}
|
|
335
|
+
except urllib.error.HTTPError as e:
|
|
336
|
+
raise RuntimeError(f"file save failed ({getattr(e, 'code', 'unknown')})") from e
|
|
337
|
+
|
|
338
|
+
return Asset(
|
|
339
|
+
ref=ref,
|
|
340
|
+
tenant_id=self.tenant_id,
|
|
341
|
+
local_path=None,
|
|
342
|
+
mime_type=str(meta.get("mime_type") or "") or None,
|
|
343
|
+
size_bytes=int(meta.get("size_bytes") or 0) or len(data),
|
|
344
|
+
sha256=str(meta.get("sha256") or "") or None,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
def save_file(self, ref: str, local_path: str) -> Asset:
|
|
348
|
+
with open(local_path, "rb") as f:
|
|
349
|
+
data = f.read()
|
|
350
|
+
return self.save_bytes(ref, data)
|
|
351
|
+
|
|
352
|
+
def save_bytes_create(self, ref: str, data: bytes) -> Asset:
|
|
353
|
+
if not isinstance(data, (bytes, bytearray)):
|
|
354
|
+
raise TypeError("save_bytes_create expects bytes")
|
|
355
|
+
data = bytes(data)
|
|
356
|
+
max_bytes = int(os.getenv("WORKER_MAX_OUTPUT_FILE_BYTES", str(200 * 1024 * 1024)))
|
|
357
|
+
if max_bytes > 0 and len(data) > max_bytes:
|
|
358
|
+
raise ValueError("output file too large")
|
|
359
|
+
ref = ref.strip().lstrip("/")
|
|
360
|
+
if not ref.startswith(_default_output_prefix(self.run_id)):
|
|
361
|
+
raise ValueError(f"ref must start with '{_default_output_prefix(self.run_id)}'")
|
|
362
|
+
|
|
363
|
+
base = _require_file_api_base_url()
|
|
364
|
+
token = _require_file_api_token()
|
|
365
|
+
url = f"{base}/api/v1/file/{_encode_ref_for_url(ref)}"
|
|
366
|
+
req = _http_request("POST", url, token, body=data, content_type="application/octet-stream")
|
|
367
|
+
try:
|
|
368
|
+
with urllib.request.urlopen(req, timeout=30) as resp:
|
|
369
|
+
body = resp.read()
|
|
370
|
+
if resp.status < 200 or resp.status >= 300:
|
|
371
|
+
raise RuntimeError(f"file save failed ({resp.status})")
|
|
372
|
+
try:
|
|
373
|
+
meta = json.loads(body.decode("utf-8"))
|
|
374
|
+
except Exception:
|
|
375
|
+
meta = {}
|
|
376
|
+
except urllib.error.HTTPError as e:
|
|
377
|
+
if getattr(e, "code", None) == 409:
|
|
378
|
+
raise RuntimeError("output path already exists") from e
|
|
379
|
+
raise RuntimeError(f"file save failed ({getattr(e, 'code', 'unknown')})") from e
|
|
380
|
+
|
|
381
|
+
return Asset(
|
|
382
|
+
ref=ref,
|
|
383
|
+
tenant_id=self.tenant_id,
|
|
384
|
+
local_path=None,
|
|
385
|
+
mime_type=str(meta.get("mime_type") or "") or None,
|
|
386
|
+
size_bytes=int(meta.get("size_bytes") or 0) or len(data),
|
|
387
|
+
sha256=str(meta.get("sha256") or "") or None,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
def save_file_create(self, ref: str, local_path: str) -> Asset:
|
|
391
|
+
with open(local_path, "rb") as f:
|
|
392
|
+
data = f.read()
|
|
393
|
+
return self.save_bytes_create(ref, data)
|
|
394
|
+
|
|
395
|
+
def save_bytes_overwrite(self, ref: str, data: bytes) -> Asset:
|
|
396
|
+
# Back-compat alias: overwrite is the default save_bytes behavior.
|
|
397
|
+
return self.save_bytes(ref, data)
|
|
398
|
+
|
|
399
|
+
def save_file_overwrite(self, ref: str, local_path: str) -> Asset:
|
|
400
|
+
return self.save_file(ref, local_path)
|
|
401
|
+
|
|
402
|
+
# Define the interceptor class correctly
|
|
403
|
+
class _AuthInterceptor(grpc.StreamStreamClientInterceptor):
|
|
404
|
+
def __init__(self, token: str) -> None:
|
|
405
|
+
self._token = token
|
|
406
|
+
|
|
407
|
+
def intercept_stream_stream(self, continuation: Any, client_call_details: Any, request_iterator: Any) -> Any:
|
|
408
|
+
metadata = list(client_call_details.metadata or [])
|
|
409
|
+
metadata.append(('authorization', f'Bearer {self._token}'))
|
|
410
|
+
new_details = client_call_details._replace(metadata=metadata)
|
|
411
|
+
return continuation(new_details, request_iterator)
|
|
412
|
+
|
|
413
|
+
class Worker:
|
|
414
|
+
"""Worker implementation that connects to the scheduler via gRPC."""
|
|
415
|
+
|
|
416
|
+
def __init__(
|
|
417
|
+
self,
|
|
418
|
+
scheduler_addr: str = "localhost:8080",
|
|
419
|
+
scheduler_addrs: Optional[List[str]] = None,
|
|
420
|
+
user_module_names: List[str] = ["functions"], # Add new parameter for user modules
|
|
421
|
+
worker_id: Optional[str] = None,
|
|
422
|
+
auth_token: Optional[str] = None,
|
|
423
|
+
use_tls: bool = False,
|
|
424
|
+
reconnect_delay: int = 5,
|
|
425
|
+
max_reconnect_attempts: int = 0, # 0 means infinite retries
|
|
426
|
+
model_manager: Optional[ModelManagementInterface] = None, # Optional model manager
|
|
427
|
+
downloader: Optional[ModelDownloader] = None, # Optional model downloader
|
|
428
|
+
) -> None:
|
|
429
|
+
"""Initialize a new worker.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
scheduler_addr: Address of the scheduler service.
|
|
433
|
+
scheduler_addrs: Optional list of seed scheduler addresses.
|
|
434
|
+
user_module_names: List of Python module names containing user-defined @worker_function functions.
|
|
435
|
+
worker_id: Unique ID for this worker (generated if not provided).
|
|
436
|
+
auth_token: Optional authentication token.
|
|
437
|
+
use_tls: Whether to use TLS for the connection.
|
|
438
|
+
reconnect_delay: Seconds to wait between reconnection attempts.
|
|
439
|
+
max_reconnect_attempts: Max reconnect attempts (0 = infinite).
|
|
440
|
+
model_manager: Optional model manager.
|
|
441
|
+
downloader: Optional model downloader.
|
|
442
|
+
"""
|
|
443
|
+
self.scheduler_addr = scheduler_addr
|
|
444
|
+
self.scheduler_addrs = self._normalize_scheduler_addrs(scheduler_addr, scheduler_addrs)
|
|
445
|
+
self.user_module_names = user_module_names # Store module names
|
|
446
|
+
self.worker_id = worker_id or f"py-worker-{os.getpid()}"
|
|
447
|
+
self.auth_token = auth_token
|
|
448
|
+
self.use_tls = use_tls
|
|
449
|
+
self.reconnect_delay = reconnect_delay
|
|
450
|
+
self.max_reconnect_attempts = max_reconnect_attempts
|
|
451
|
+
self.max_input_bytes = int(os.getenv("WORKER_MAX_INPUT_BYTES", "0"))
|
|
452
|
+
self.max_output_bytes = int(os.getenv("WORKER_MAX_OUTPUT_BYTES", "0"))
|
|
453
|
+
|
|
454
|
+
self._jwks_url = os.getenv("SCHEDULER_JWKS_URL", "").strip()
|
|
455
|
+
self._jwks_ttl_seconds = int(os.getenv("SCHEDULER_JWKS_TTL_SECONDS", "300"))
|
|
456
|
+
self._jwt_issuer = os.getenv("SCHEDULER_JWT_ISSUER", "").strip()
|
|
457
|
+
self._jwt_audience = os.getenv("SCHEDULER_JWT_AUDIENCE", "").strip()
|
|
458
|
+
self._jwks_cache: Optional[_JWKSCache] = _JWKSCache(self._jwks_url, self._jwks_ttl_seconds) if self._jwks_url else None
|
|
459
|
+
|
|
460
|
+
self.deployment_id = os.getenv("DEPLOYMENT_ID", "") # Read DEPLOYMENT_ID env var
|
|
461
|
+
if not self.deployment_id:
|
|
462
|
+
logger.warning("DEPLOYMENT_ID environment variable not set for this worker!")
|
|
463
|
+
|
|
464
|
+
self.tenant_id = os.getenv("TENANT_ID", "")
|
|
465
|
+
self.runpod_pod_id = os.getenv("RUNPOD_POD_ID", "") # Read injected pod ID
|
|
466
|
+
if not self.runpod_pod_id:
|
|
467
|
+
logger.warning("RUNPOD_POD_ID environment variable not set for this worker!")
|
|
468
|
+
|
|
469
|
+
logger.info(f"RUNPOD_POD_ID: {self.runpod_pod_id}")
|
|
470
|
+
|
|
471
|
+
self._actions: Dict[str, Callable[[ActionContext, Optional[Any], bytes], bytes]] = {}
|
|
472
|
+
self._active_tasks: Dict[str, ActionContext] = {}
|
|
473
|
+
self._active_tasks_lock = threading.Lock()
|
|
474
|
+
self._active_function_counts: Dict[str, int] = {}
|
|
475
|
+
self.max_concurrency = int(os.getenv("WORKER_MAX_CONCURRENCY", "0"))
|
|
476
|
+
self._drain_timeout_seconds = int(os.getenv("WORKER_DRAIN_TIMEOUT_SECONDS", "0"))
|
|
477
|
+
self._draining = False
|
|
478
|
+
self._discovered_resources: Dict[str, ResourceRequirements] = {} # Store resources per function
|
|
479
|
+
self._function_schemas: Dict[str, Tuple[bytes, bytes]] = {} # func_name -> (input_schema_json, output_schema_json)
|
|
480
|
+
|
|
481
|
+
self._gpu_busy_lock = threading.Lock()
|
|
482
|
+
self._is_gpu_busy = False
|
|
483
|
+
|
|
484
|
+
self._channel: Optional[Any] = None
|
|
485
|
+
self._stub: Optional[Any] = None
|
|
486
|
+
self._stream: Optional[Any] = None
|
|
487
|
+
self._running = False
|
|
488
|
+
self._stop_event = threading.Event()
|
|
489
|
+
self._reconnect_count = 0
|
|
490
|
+
self._outgoing_queue: queue.Queue[Any] = queue.Queue()
|
|
491
|
+
self._leader_hint: Optional[str] = None
|
|
492
|
+
|
|
493
|
+
self._receive_thread: Optional[threading.Thread] = None
|
|
494
|
+
self._heartbeat_thread: Optional[threading.Thread] = None
|
|
495
|
+
|
|
496
|
+
self._reconnect_delay_base = max(0, reconnect_delay)
|
|
497
|
+
self._reconnect_delay_max = int(os.getenv("RECONNECT_MAX_DELAY", "60"))
|
|
498
|
+
self._reconnect_jitter = float(os.getenv("RECONNECT_JITTER_SECONDS", "1.0"))
|
|
499
|
+
|
|
500
|
+
resolved_model_manager = model_manager
|
|
501
|
+
if resolved_model_manager is None:
|
|
502
|
+
model_manager_path = os.getenv("MODEL_MANAGER_CLASS", "").strip()
|
|
503
|
+
if model_manager_path:
|
|
504
|
+
try:
|
|
505
|
+
module_path, _, class_name = model_manager_path.partition(":")
|
|
506
|
+
if not module_path or not class_name:
|
|
507
|
+
raise ValueError("MODEL_MANAGER_CLASS must be in module:Class format")
|
|
508
|
+
module = importlib.import_module(module_path)
|
|
509
|
+
manager_cls = getattr(module, class_name)
|
|
510
|
+
resolved_model_manager = manager_cls()
|
|
511
|
+
logger.info(f"Loaded ModelManager from MODEL_MANAGER_CLASS={model_manager_path}")
|
|
512
|
+
except Exception as e:
|
|
513
|
+
logger.exception(f"Failed to load MODEL_MANAGER_CLASS '{model_manager_path}': {e}")
|
|
514
|
+
self._model_manager = resolved_model_manager
|
|
515
|
+
self._downloader = downloader
|
|
516
|
+
if self._downloader is None:
|
|
517
|
+
base_url = os.getenv("COZY_HUB_URL", "").strip()
|
|
518
|
+
token = os.getenv("COZY_HUB_TOKEN", "").strip() or None
|
|
519
|
+
if base_url:
|
|
520
|
+
self._downloader = CozyHubDownloader(base_url, token=token)
|
|
521
|
+
self._supported_model_ids_from_scheduler: Optional[List[str]] = None # To store IDs from scheduler
|
|
522
|
+
self._model_init_done_event = threading.Event() # To signal model init is complete
|
|
523
|
+
|
|
524
|
+
if self._model_manager:
|
|
525
|
+
logger.info(f"ModelManager of type '{type(self._model_manager).__name__}' provided.")
|
|
526
|
+
else:
|
|
527
|
+
logger.info("No ModelManager provided. Worker operating in simple mode regarding models.")
|
|
528
|
+
self._model_init_done_event.set() # No model init to wait for if no manager
|
|
529
|
+
if self._downloader:
|
|
530
|
+
logger.info(f"ModelDownloader of type '{type(self._downloader).__name__}' configured.")
|
|
531
|
+
|
|
532
|
+
logger.info(f"Created worker: ID={self.worker_id}, DeploymentID={self.deployment_id or 'N/A'}, Scheduler={scheduler_addr}")
|
|
533
|
+
|
|
534
|
+
# Discover functions before setting signals? Maybe after. Let's do it here.
|
|
535
|
+
self._discover_and_register_functions()
|
|
536
|
+
|
|
537
|
+
self._verify_auth_token()
|
|
538
|
+
|
|
539
|
+
signal.signal(signal.SIGINT, self._handle_interrupt)
|
|
540
|
+
signal.signal(signal.SIGTERM, self._handle_interrupt)
|
|
541
|
+
|
|
542
|
+
@staticmethod
|
|
543
|
+
def _normalize_scheduler_addrs(primary: str, addrs: Optional[List[str]]) -> List[str]:
|
|
544
|
+
unique: List[str] = []
|
|
545
|
+
for addr in [primary] + (addrs or []):
|
|
546
|
+
addr = (addr or "").strip()
|
|
547
|
+
if addr and addr not in unique:
|
|
548
|
+
unique.append(addr)
|
|
549
|
+
return unique
|
|
550
|
+
|
|
551
|
+
@staticmethod
|
|
552
|
+
def _extract_leader_addr(details: Optional[str]) -> Optional[str]:
|
|
553
|
+
if not details:
|
|
554
|
+
return None
|
|
555
|
+
if details.startswith("not_leader:"):
|
|
556
|
+
leader = details.split("not_leader:", 1)[1].strip()
|
|
557
|
+
return leader or None
|
|
558
|
+
return None
|
|
559
|
+
|
|
560
|
+
def _set_scheduler_addr(self, addr: str) -> None:
|
|
561
|
+
addr = addr.strip()
|
|
562
|
+
if not addr:
|
|
563
|
+
return
|
|
564
|
+
self.scheduler_addr = addr
|
|
565
|
+
if addr not in self.scheduler_addrs:
|
|
566
|
+
self.scheduler_addrs.insert(0, addr)
|
|
567
|
+
|
|
568
|
+
def _iter_scheduler_addrs(self) -> Iterator[str]:
|
|
569
|
+
seen = set()
|
|
570
|
+
for addr in self.scheduler_addrs:
|
|
571
|
+
addr = addr.strip()
|
|
572
|
+
if not addr or addr in seen:
|
|
573
|
+
continue
|
|
574
|
+
seen.add(addr)
|
|
575
|
+
yield addr
|
|
576
|
+
|
|
577
|
+
def _verify_auth_token(self) -> None:
|
|
578
|
+
if not self.auth_token or not self._jwks_cache:
|
|
579
|
+
return
|
|
580
|
+
try:
|
|
581
|
+
header = jwt.get_unverified_header(self.auth_token)
|
|
582
|
+
kid = header.get("kid")
|
|
583
|
+
key = self._jwks_cache.get_key(kid)
|
|
584
|
+
if not key:
|
|
585
|
+
raise ValueError("JWKS key not found for token")
|
|
586
|
+
options = {"verify_aud": bool(self._jwt_audience)}
|
|
587
|
+
jwt.decode(
|
|
588
|
+
self.auth_token,
|
|
589
|
+
key=key,
|
|
590
|
+
algorithms=["RS256"],
|
|
591
|
+
audience=self._jwt_audience or None,
|
|
592
|
+
issuer=self._jwt_issuer or None,
|
|
593
|
+
options=options,
|
|
594
|
+
)
|
|
595
|
+
logger.info("Worker auth token verified against scheduler JWKS.")
|
|
596
|
+
except Exception as e:
|
|
597
|
+
logger.error(f"Worker auth token verification failed: {e}")
|
|
598
|
+
raise
|
|
599
|
+
|
|
600
|
+
def _format_error(self, message: str, retryable: bool) -> str:
|
|
601
|
+
return json.dumps({
|
|
602
|
+
"message": message,
|
|
603
|
+
"retryable": retryable,
|
|
604
|
+
})
|
|
605
|
+
|
|
606
|
+
def _emit_progress_event(self, event: Dict[str, Any]) -> None:
|
|
607
|
+
try:
|
|
608
|
+
run_id = event.get("run_id") or ""
|
|
609
|
+
event_type = event.get("type") or ""
|
|
610
|
+
payload = event.get("payload") or {}
|
|
611
|
+
if "timestamp" not in payload:
|
|
612
|
+
payload = dict(payload)
|
|
613
|
+
payload["timestamp"] = event.get("timestamp", time.time())
|
|
614
|
+
payload_json = json.dumps(payload).encode("utf-8")
|
|
615
|
+
msg = pb.WorkerSchedulerMessage(
|
|
616
|
+
worker_event=pb.WorkerEvent(
|
|
617
|
+
run_id=run_id,
|
|
618
|
+
event_type=event_type,
|
|
619
|
+
payload_json=payload_json,
|
|
620
|
+
)
|
|
621
|
+
)
|
|
622
|
+
self._send_message(msg)
|
|
623
|
+
except Exception:
|
|
624
|
+
logger.exception("Failed to emit progress event")
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def _set_gpu_busy_status(self, busy: bool, func_name_for_log: str = "") -> None:
|
|
628
|
+
with self._gpu_busy_lock:
|
|
629
|
+
if self._is_gpu_busy == busy:
|
|
630
|
+
return
|
|
631
|
+
self._is_gpu_busy = busy
|
|
632
|
+
if func_name_for_log:
|
|
633
|
+
logger.info(f"GPU status changed to {busy} due to function '{func_name_for_log}'.")
|
|
634
|
+
else:
|
|
635
|
+
logger.info(f"GPU status changed to {busy}.")
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def _get_gpu_busy_status(self) -> bool:
|
|
639
|
+
with self._gpu_busy_lock:
|
|
640
|
+
return self._is_gpu_busy
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
def _discover_and_register_functions(self) -> None:
|
|
644
|
+
"""Discover and register functions marked with @worker_function."""
|
|
645
|
+
logger.info(f"Discovering worker functions in modules: {self.user_module_names}...")
|
|
646
|
+
discovered_count = 0
|
|
647
|
+
for module_name in self.user_module_names:
|
|
648
|
+
try:
|
|
649
|
+
module = importlib.import_module(module_name)
|
|
650
|
+
logger.debug(f"Inspecting module: {module_name}")
|
|
651
|
+
for name, obj in inspect.getmembers(module):
|
|
652
|
+
if inspect.isfunction(obj) and hasattr(obj, '_is_worker_function'):
|
|
653
|
+
if getattr(obj, '_is_worker_function') is True:
|
|
654
|
+
# Found a decorated function
|
|
655
|
+
original_func = obj # Keep reference to the actual decorated function
|
|
656
|
+
func_name = original_func.__name__ # Use the real function name
|
|
657
|
+
|
|
658
|
+
if func_name in self._actions:
|
|
659
|
+
logger.warning(f"Function '{func_name}' from module '{module_name}' conflicts with an already registered function. Skipping.")
|
|
660
|
+
continue
|
|
661
|
+
|
|
662
|
+
resources: ResourceRequirements = getattr(original_func, '_worker_resources', ResourceRequirements())
|
|
663
|
+
self._discovered_resources[func_name] = resources
|
|
664
|
+
|
|
665
|
+
expects_pipeline_flag = resources.expects_pipeline_arg
|
|
666
|
+
payload_type = self._infer_payload_type(original_func, expects_pipeline_flag)
|
|
667
|
+
return_type = self._infer_return_type(original_func)
|
|
668
|
+
if payload_type is None or return_type is None:
|
|
669
|
+
logger.error(
|
|
670
|
+
"Skipping function '%s' due to invalid or missing payload type annotation.",
|
|
671
|
+
func_name,
|
|
672
|
+
)
|
|
673
|
+
continue
|
|
674
|
+
|
|
675
|
+
try:
|
|
676
|
+
input_schema = msgspec.json.schema(payload_type)
|
|
677
|
+
output_schema = msgspec.json.schema(return_type)
|
|
678
|
+
self._function_schemas[func_name] = (
|
|
679
|
+
json.dumps(input_schema, separators=(",", ":"), sort_keys=True).encode("utf-8"),
|
|
680
|
+
json.dumps(output_schema, separators=(",", ":"), sort_keys=True).encode("utf-8"),
|
|
681
|
+
)
|
|
682
|
+
except Exception as exc:
|
|
683
|
+
logger.error("Failed to generate msgspec JSON schema for '%s': %s", func_name, exc)
|
|
684
|
+
continue
|
|
685
|
+
|
|
686
|
+
# Create the wrapper for gRPC/msgpack interaction
|
|
687
|
+
def create_wrapper(
|
|
688
|
+
captured_func: Callable[..., Any],
|
|
689
|
+
captured_name: str,
|
|
690
|
+
captured_payload_type: type[msgspec.Struct],
|
|
691
|
+
captured_return_type: type[msgspec.Struct],
|
|
692
|
+
func_expects_pipeline: bool = False,
|
|
693
|
+
) -> Callable[[ActionContext, Optional[Any], bytes], bytes]:
|
|
694
|
+
@functools.wraps(captured_func) # Preserve metadata of original user func
|
|
695
|
+
def wrapper(ctx: ActionContext, pipeline_instance: Optional[Any], input_bytes: bytes) -> bytes:
|
|
696
|
+
try:
|
|
697
|
+
input_obj = msgspec.msgpack.decode(input_bytes, type=captured_payload_type)
|
|
698
|
+
self._materialize_assets(ctx.run_id, input_obj)
|
|
699
|
+
# Pass the context and deserialized input to the *original* user function
|
|
700
|
+
if func_expects_pipeline: # Only pass pipeline if function expects it
|
|
701
|
+
if pipeline_instance is None:
|
|
702
|
+
err_msg = f"Function '{captured_name}' expected a pipeline argument, but None was provided by the Worker core."
|
|
703
|
+
logger.error(err_msg)
|
|
704
|
+
raise ValueError(err_msg)
|
|
705
|
+
result = captured_func(ctx, pipeline_instance, input_obj)
|
|
706
|
+
else:
|
|
707
|
+
result = captured_func(ctx, input_obj) # For functions not needing a model
|
|
708
|
+
|
|
709
|
+
if ctx.is_canceled():
|
|
710
|
+
raise InterruptedError("Task was canceled during execution")
|
|
711
|
+
if not isinstance(result, captured_return_type):
|
|
712
|
+
raise TypeError(
|
|
713
|
+
f"Function {captured_name} returned {type(result)!r}, "
|
|
714
|
+
f"expected {captured_return_type!r}"
|
|
715
|
+
)
|
|
716
|
+
# Ensure result is bytes after msgspec msgpack serialization
|
|
717
|
+
packed_result = msgspec.msgpack.encode(result)
|
|
718
|
+
if not isinstance(packed_result, bytes):
|
|
719
|
+
raise TypeError(
|
|
720
|
+
f"Function {captured_name} did not return msgspec-serializable data resulting in bytes"
|
|
721
|
+
)
|
|
722
|
+
return packed_result
|
|
723
|
+
except InterruptedError as ie: # Catch cancellation specifically
|
|
724
|
+
logger.warning(f"Function {captured_name} run {ctx.run_id} was interrupted.")
|
|
725
|
+
raise # Re-raise to be handled in _execute_function
|
|
726
|
+
except Exception as e:
|
|
727
|
+
logger.exception(f"Error during execution of function {captured_name} (run_id: {ctx.run_id})")
|
|
728
|
+
raise # Re-raise to be caught by _execute_function
|
|
729
|
+
return wrapper
|
|
730
|
+
|
|
731
|
+
self._actions[func_name] = create_wrapper(
|
|
732
|
+
original_func,
|
|
733
|
+
func_name,
|
|
734
|
+
payload_type,
|
|
735
|
+
return_type,
|
|
736
|
+
func_expects_pipeline=expects_pipeline_flag,
|
|
737
|
+
)
|
|
738
|
+
logger.info(f"Registered function: '{func_name}' from module '{module_name}' with resources: {resources}")
|
|
739
|
+
discovered_count += 1
|
|
740
|
+
|
|
741
|
+
except ImportError:
|
|
742
|
+
logger.error(f"Could not import user module: {module_name}")
|
|
743
|
+
except Exception as e:
|
|
744
|
+
logger.exception(f"Error during discovery in module {module_name}: {e}")
|
|
745
|
+
|
|
746
|
+
if discovered_count == 0:
|
|
747
|
+
logger.warning(f"No functions decorated with @worker_function found in specified modules: {self.user_module_names}")
|
|
748
|
+
else:
|
|
749
|
+
logger.info(f"Discovery complete. Found {discovered_count} worker functions.")
|
|
750
|
+
|
|
751
|
+
def _infer_payload_type(
|
|
752
|
+
self,
|
|
753
|
+
func: Callable[..., Any],
|
|
754
|
+
expects_pipeline: bool,
|
|
755
|
+
) -> Optional[type[msgspec.Struct]]:
|
|
756
|
+
signature = inspect.signature(func)
|
|
757
|
+
params = list(signature.parameters.values())
|
|
758
|
+
expected_params = 3 if expects_pipeline else 2
|
|
759
|
+
if len(params) != expected_params:
|
|
760
|
+
logger.error(
|
|
761
|
+
"Function '%s' has %d parameters but expected %d.",
|
|
762
|
+
func.__name__,
|
|
763
|
+
len(params),
|
|
764
|
+
expected_params,
|
|
765
|
+
)
|
|
766
|
+
return None
|
|
767
|
+
|
|
768
|
+
payload_param = params[-1]
|
|
769
|
+
try:
|
|
770
|
+
type_hints = typing.get_type_hints(func, globalns=func.__globals__)
|
|
771
|
+
except Exception as exc:
|
|
772
|
+
logger.error("Failed to resolve type hints for '%s': %s", func.__name__, exc)
|
|
773
|
+
return None
|
|
774
|
+
|
|
775
|
+
payload_type = type_hints.get(payload_param.name)
|
|
776
|
+
if payload_type is None:
|
|
777
|
+
logger.error("Function '%s' is missing a payload type annotation.", func.__name__)
|
|
778
|
+
return None
|
|
779
|
+
|
|
780
|
+
if not isinstance(payload_type, type) or not issubclass(payload_type, msgspec.Struct):
|
|
781
|
+
logger.error(
|
|
782
|
+
"Function '%s' payload type must be a msgspec.Struct, got %r.",
|
|
783
|
+
func.__name__,
|
|
784
|
+
payload_type,
|
|
785
|
+
)
|
|
786
|
+
return None
|
|
787
|
+
|
|
788
|
+
return payload_type
|
|
789
|
+
|
|
790
|
+
def _infer_return_type(self, func: Callable[..., Any]) -> Optional[type[msgspec.Struct]]:
|
|
791
|
+
try:
|
|
792
|
+
type_hints = typing.get_type_hints(func, globalns=func.__globals__)
|
|
793
|
+
except Exception as exc:
|
|
794
|
+
logger.error("Failed to resolve return type hints for '%s': %s", func.__name__, exc)
|
|
795
|
+
return None
|
|
796
|
+
|
|
797
|
+
return_type = type_hints.get("return")
|
|
798
|
+
if return_type is None:
|
|
799
|
+
logger.error("Function '%s' is missing a return type annotation.", func.__name__)
|
|
800
|
+
return None
|
|
801
|
+
|
|
802
|
+
if not isinstance(return_type, type) or not issubclass(return_type, msgspec.Struct):
|
|
803
|
+
logger.error(
|
|
804
|
+
"Function '%s' return type must be a msgspec.Struct, got %r.",
|
|
805
|
+
func.__name__,
|
|
806
|
+
return_type,
|
|
807
|
+
)
|
|
808
|
+
return None
|
|
809
|
+
|
|
810
|
+
return return_type
|
|
811
|
+
|
|
812
|
+
def _send_message(self, message: WorkerSchedulerMessage) -> None:
|
|
813
|
+
"""Add a message to the outgoing queue."""
|
|
814
|
+
if self._running and not self._stop_event.is_set():
|
|
815
|
+
try:
|
|
816
|
+
self._outgoing_queue.put_nowait(message)
|
|
817
|
+
except queue.Full:
|
|
818
|
+
logger.error("Outgoing message queue is full. Message dropped!")
|
|
819
|
+
else:
|
|
820
|
+
logger.warning("Attempted to send message while worker is stopping or stopped.")
|
|
821
|
+
|
|
822
|
+
def _materialize_assets(self, run_id: str, obj: Any) -> None:
|
|
823
|
+
if isinstance(obj, Asset):
|
|
824
|
+
self._materialize_asset(run_id, obj)
|
|
825
|
+
return
|
|
826
|
+
if isinstance(obj, list):
|
|
827
|
+
for it in obj:
|
|
828
|
+
self._materialize_assets(run_id, it)
|
|
829
|
+
return
|
|
830
|
+
if isinstance(obj, dict):
|
|
831
|
+
for it in obj.values():
|
|
832
|
+
self._materialize_assets(run_id, it)
|
|
833
|
+
return
|
|
834
|
+
fields = getattr(obj, "__struct_fields__", None)
|
|
835
|
+
if fields and isinstance(fields, (tuple, list)):
|
|
836
|
+
for name in fields:
|
|
837
|
+
try:
|
|
838
|
+
self._materialize_assets(run_id, getattr(obj, name))
|
|
839
|
+
except Exception:
|
|
840
|
+
continue
|
|
841
|
+
|
|
842
|
+
def _materialize_asset(self, run_id: str, asset: Asset) -> None:
|
|
843
|
+
if asset.local_path:
|
|
844
|
+
return
|
|
845
|
+
ref = (asset.ref or "").strip()
|
|
846
|
+
if not ref:
|
|
847
|
+
return
|
|
848
|
+
|
|
849
|
+
base_dir = os.getenv("WORKER_RUN_DIR", "/tmp/cozy").rstrip("/")
|
|
850
|
+
local_inputs_dir = os.path.join(base_dir, run_id, "inputs")
|
|
851
|
+
os.makedirs(local_inputs_dir, exist_ok=True)
|
|
852
|
+
cache_dir = os.getenv("WORKER_CACHE_DIR", os.path.join(base_dir, "cache")).rstrip("/")
|
|
853
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
854
|
+
|
|
855
|
+
max_bytes = int(os.getenv("WORKER_MAX_INPUT_FILE_BYTES", str(200 * 1024 * 1024)))
|
|
856
|
+
|
|
857
|
+
# External URL inputs (download directly into the run folder).
|
|
858
|
+
if ref.startswith("http://") or ref.startswith("https://"):
|
|
859
|
+
if _url_is_blocked(ref):
|
|
860
|
+
raise RuntimeError("input url blocked")
|
|
861
|
+
ext = os.path.splitext(urllib.parse.urlparse(ref).path)[1] or os.path.splitext(ref)[1]
|
|
862
|
+
name_hash = hashlib.sha256(ref.encode("utf-8")).hexdigest()[:32]
|
|
863
|
+
local_path = os.path.join(local_inputs_dir, f"{name_hash}{ext}")
|
|
864
|
+
size, sha256_hex, mime = self._download_url_to_file(ref, local_path, max_bytes)
|
|
865
|
+
asset.local_path = local_path
|
|
866
|
+
if not asset.tenant_id:
|
|
867
|
+
asset.tenant_id = self.tenant_id
|
|
868
|
+
asset.mime_type = mime
|
|
869
|
+
asset.size_bytes = size
|
|
870
|
+
asset.sha256 = sha256_hex
|
|
871
|
+
return
|
|
872
|
+
|
|
873
|
+
# Cozy Hub file ref (tenant scoped) - use orchestrator file API with HEAD+cache.
|
|
874
|
+
base = _require_file_api_base_url()
|
|
875
|
+
token = _require_file_api_token()
|
|
876
|
+
url = f"{base}/api/v1/file/{_encode_ref_for_url(ref)}"
|
|
877
|
+
|
|
878
|
+
head_req = _http_request("HEAD", url, token)
|
|
879
|
+
with urllib.request.urlopen(head_req, timeout=10) as resp:
|
|
880
|
+
if resp.status < 200 or resp.status >= 300:
|
|
881
|
+
raise RuntimeError(f"failed to stat asset ({resp.status})")
|
|
882
|
+
sha256_hex = (resp.headers.get("X-Cozy-SHA256") or "").strip()
|
|
883
|
+
size_hdr = (resp.headers.get("X-Cozy-Size-Bytes") or "").strip()
|
|
884
|
+
mime = (resp.headers.get("X-Cozy-Mime-Type") or "").strip()
|
|
885
|
+
size = int(size_hdr) if size_hdr.isdigit() else 0
|
|
886
|
+
if max_bytes > 0 and size > max_bytes:
|
|
887
|
+
raise RuntimeError("input file too large")
|
|
888
|
+
|
|
889
|
+
ext = os.path.splitext(ref)[1]
|
|
890
|
+
if not ext and mime:
|
|
891
|
+
guessed = {
|
|
892
|
+
"image/png": ".png",
|
|
893
|
+
"image/jpeg": ".jpg",
|
|
894
|
+
"image/webp": ".webp",
|
|
895
|
+
"image/gif": ".gif",
|
|
896
|
+
}.get(mime)
|
|
897
|
+
ext = guessed or ""
|
|
898
|
+
|
|
899
|
+
if not sha256_hex:
|
|
900
|
+
sha256_hex = hashlib.sha256(ref.encode("utf-8")).hexdigest()
|
|
901
|
+
cache_name = f"{sha256_hex[:32]}{ext}"
|
|
902
|
+
cache_path = os.path.join(cache_dir, cache_name)
|
|
903
|
+
|
|
904
|
+
if not os.path.exists(cache_path):
|
|
905
|
+
get_req = _http_request("GET", url, token)
|
|
906
|
+
with urllib.request.urlopen(get_req, timeout=30) as resp:
|
|
907
|
+
if resp.status < 200 or resp.status >= 300:
|
|
908
|
+
raise RuntimeError(f"failed to download asset ({resp.status})")
|
|
909
|
+
_size, _sha = self._stream_to_file(resp, cache_path, max_bytes)
|
|
910
|
+
if not size:
|
|
911
|
+
size = _size
|
|
912
|
+
if not sha256_hex:
|
|
913
|
+
sha256_hex = _sha
|
|
914
|
+
|
|
915
|
+
local_path = os.path.join(local_inputs_dir, cache_name)
|
|
916
|
+
if not os.path.exists(local_path):
|
|
917
|
+
try:
|
|
918
|
+
os.link(cache_path, local_path)
|
|
919
|
+
except Exception:
|
|
920
|
+
try:
|
|
921
|
+
import shutil
|
|
922
|
+
|
|
923
|
+
shutil.copyfile(cache_path, local_path)
|
|
924
|
+
except Exception:
|
|
925
|
+
local_path = cache_path
|
|
926
|
+
|
|
927
|
+
if not mime:
|
|
928
|
+
with open(local_path, "rb") as f:
|
|
929
|
+
head = f.read(512)
|
|
930
|
+
mime = _infer_mime_type(ref, head)
|
|
931
|
+
|
|
932
|
+
asset.local_path = local_path
|
|
933
|
+
if not asset.tenant_id:
|
|
934
|
+
asset.tenant_id = self.tenant_id
|
|
935
|
+
asset.mime_type = mime or None
|
|
936
|
+
asset.size_bytes = size or None
|
|
937
|
+
asset.sha256 = sha256_hex or None
|
|
938
|
+
|
|
939
|
+
def _download_url_to_file(self, src: str, dst: str, max_bytes: int) -> Tuple[int, str, Optional[str]]:
|
|
940
|
+
attempts = int(os.getenv("WORKER_DOWNLOAD_RETRIES", "3"))
|
|
941
|
+
attempt = 0
|
|
942
|
+
last_err: Optional[Exception] = None
|
|
943
|
+
while attempt < max(1, attempts):
|
|
944
|
+
attempt += 1
|
|
945
|
+
try:
|
|
946
|
+
client = urllib.request.build_opener()
|
|
947
|
+
req = urllib.request.Request(src, method="GET")
|
|
948
|
+
with client.open(req, timeout=30) as resp:
|
|
949
|
+
size, sha = self._stream_to_file(resp, dst, max_bytes)
|
|
950
|
+
with open(dst, "rb") as f:
|
|
951
|
+
head = f.read(512)
|
|
952
|
+
mime = _infer_mime_type(src, head)
|
|
953
|
+
return size, sha, mime
|
|
954
|
+
except Exception as e:
|
|
955
|
+
last_err = e
|
|
956
|
+
if attempt >= max(1, attempts):
|
|
957
|
+
break
|
|
958
|
+
sleep_s = min(10.0, 0.5 * (2 ** (attempt - 1))) + random.random() * 0.2
|
|
959
|
+
time.sleep(sleep_s)
|
|
960
|
+
raise RuntimeError(f"failed to download url: {last_err}")
|
|
961
|
+
|
|
962
|
+
def _stream_to_file(self, src: Any, dst: str, max_bytes: int) -> Tuple[int, str]:
|
|
963
|
+
tmp = f"{dst}.tmp-{os.getpid()}-{threading.get_ident()}-{random.randint(0, 1_000_000)}"
|
|
964
|
+
total = 0
|
|
965
|
+
h = hashlib.sha256()
|
|
966
|
+
try:
|
|
967
|
+
with open(tmp, "wb") as out:
|
|
968
|
+
while True:
|
|
969
|
+
chunk = src.read(1024 * 1024)
|
|
970
|
+
if not chunk:
|
|
971
|
+
break
|
|
972
|
+
total += len(chunk)
|
|
973
|
+
if total > max_bytes:
|
|
974
|
+
raise RuntimeError("input file too large")
|
|
975
|
+
h.update(chunk)
|
|
976
|
+
out.write(chunk)
|
|
977
|
+
os.replace(tmp, dst)
|
|
978
|
+
finally:
|
|
979
|
+
try:
|
|
980
|
+
if os.path.exists(tmp):
|
|
981
|
+
os.remove(tmp)
|
|
982
|
+
except Exception:
|
|
983
|
+
pass
|
|
984
|
+
return total, h.hexdigest()
|
|
985
|
+
|
|
986
|
+
def connect(self) -> bool:
|
|
987
|
+
"""Connect to the scheduler.
|
|
988
|
+
|
|
989
|
+
Returns:
|
|
990
|
+
bool: True if connection was successful, False otherwise.
|
|
991
|
+
"""
|
|
992
|
+
attempted: set[str] = set()
|
|
993
|
+
while True:
|
|
994
|
+
addr = None
|
|
995
|
+
if self._leader_hint and self._leader_hint not in attempted:
|
|
996
|
+
addr = self._leader_hint
|
|
997
|
+
self._leader_hint = None
|
|
998
|
+
else:
|
|
999
|
+
for candidate in self._iter_scheduler_addrs():
|
|
1000
|
+
if candidate not in attempted:
|
|
1001
|
+
addr = candidate
|
|
1002
|
+
break
|
|
1003
|
+
if not addr:
|
|
1004
|
+
break
|
|
1005
|
+
attempted.add(addr)
|
|
1006
|
+
self._set_scheduler_addr(addr)
|
|
1007
|
+
if self._connect_once():
|
|
1008
|
+
return True
|
|
1009
|
+
return False
|
|
1010
|
+
|
|
1011
|
+
def _connect_once(self) -> bool:
|
|
1012
|
+
try:
|
|
1013
|
+
if self.use_tls:
|
|
1014
|
+
# TODO: Add proper credential loading if needed
|
|
1015
|
+
creds = grpc.ssl_channel_credentials()
|
|
1016
|
+
self._channel = grpc.secure_channel(self.scheduler_addr, creds)
|
|
1017
|
+
else:
|
|
1018
|
+
self._channel = grpc.insecure_channel(self.scheduler_addr)
|
|
1019
|
+
|
|
1020
|
+
interceptors = []
|
|
1021
|
+
if self.auth_token:
|
|
1022
|
+
interceptors.append(_AuthInterceptor(self.auth_token))
|
|
1023
|
+
|
|
1024
|
+
if interceptors:
|
|
1025
|
+
self._channel = grpc.intercept_channel(self._channel, *interceptors)
|
|
1026
|
+
|
|
1027
|
+
self._stub = pb_grpc.SchedulerWorkerServiceStub(self._channel)
|
|
1028
|
+
|
|
1029
|
+
# Start the bidirectional stream
|
|
1030
|
+
request_iterator = self._outgoing_message_iterator()
|
|
1031
|
+
self._stream = self._stub.ConnectWorker(request_iterator)
|
|
1032
|
+
|
|
1033
|
+
logger.info(f"Attempting to connect to scheduler at {self.scheduler_addr}...")
|
|
1034
|
+
|
|
1035
|
+
# Send initial registration immediately
|
|
1036
|
+
self._register_worker(is_heartbeat=False)
|
|
1037
|
+
|
|
1038
|
+
# Start the receive loop in a separate thread *after* stream is initiated
|
|
1039
|
+
self._receive_thread = threading.Thread(target=self._receive_loop, daemon=True)
|
|
1040
|
+
self._receive_thread.start()
|
|
1041
|
+
|
|
1042
|
+
# Start heartbeat thread
|
|
1043
|
+
self._heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True)
|
|
1044
|
+
self._heartbeat_thread.start()
|
|
1045
|
+
|
|
1046
|
+
logger.info(f"Successfully connected to scheduler at {self.scheduler_addr}")
|
|
1047
|
+
self._reconnect_count = 0
|
|
1048
|
+
return True
|
|
1049
|
+
|
|
1050
|
+
except grpc.RpcError as e:
|
|
1051
|
+
# Access code() and details() methods for RpcError
|
|
1052
|
+
code = e.code() if hasattr(e, 'code') and callable(e.code) else grpc.StatusCode.UNKNOWN
|
|
1053
|
+
details = e.details() if hasattr(e, 'details') and callable(e.details) else str(e)
|
|
1054
|
+
leader = self._extract_leader_addr(details)
|
|
1055
|
+
if code == grpc.StatusCode.FAILED_PRECONDITION and leader:
|
|
1056
|
+
logger.warning(f"Scheduler returned not_leader for {self.scheduler_addr}; redirecting to {leader}")
|
|
1057
|
+
self._leader_hint = leader
|
|
1058
|
+
self._set_scheduler_addr(leader)
|
|
1059
|
+
else:
|
|
1060
|
+
logger.error(f"Failed to connect to scheduler: {code} - {details}")
|
|
1061
|
+
self._close_connection()
|
|
1062
|
+
return False
|
|
1063
|
+
except Exception as e:
|
|
1064
|
+
logger.exception(f"Unexpected error connecting to scheduler: {e}")
|
|
1065
|
+
self._close_connection()
|
|
1066
|
+
return False
|
|
1067
|
+
|
|
1068
|
+
def _outgoing_message_iterator(self) -> Iterator[WorkerSchedulerMessage]:
|
|
1069
|
+
"""Yields messages from the outgoing queue to send to the scheduler."""
|
|
1070
|
+
while not self._stop_event.is_set():
|
|
1071
|
+
try:
|
|
1072
|
+
# Block for a short time to allow stopping gracefully
|
|
1073
|
+
message = self._outgoing_queue.get(timeout=0.1)
|
|
1074
|
+
yield message
|
|
1075
|
+
# self._outgoing_queue.task_done() # Not needed if not joining queue
|
|
1076
|
+
except queue.Empty:
|
|
1077
|
+
continue
|
|
1078
|
+
except Exception as e:
|
|
1079
|
+
if not self._stop_event.is_set():
|
|
1080
|
+
logger.exception(f"Error in outgoing message iterator: {e}")
|
|
1081
|
+
self._handle_connection_error()
|
|
1082
|
+
break # Exit iterator on error
|
|
1083
|
+
|
|
1084
|
+
def _heartbeat_loop(self) -> None:
|
|
1085
|
+
"""Periodically sends heartbeat messages."""
|
|
1086
|
+
while not self._stop_event.wait(HEARTBEAT_INTERVAL):
|
|
1087
|
+
try:
|
|
1088
|
+
self._register_worker(is_heartbeat=True)
|
|
1089
|
+
logger.debug("Sent heartbeat to scheduler")
|
|
1090
|
+
except Exception as e:
|
|
1091
|
+
if not self._stop_event.is_set():
|
|
1092
|
+
logger.error(f"Error sending heartbeat: {e}")
|
|
1093
|
+
self._handle_connection_error()
|
|
1094
|
+
break # Stop heartbeating on error
|
|
1095
|
+
|
|
1096
|
+
def _register_worker(self, is_heartbeat: bool = False) -> None:
|
|
1097
|
+
"""Create and send a registration/heartbeat message."""
|
|
1098
|
+
try:
|
|
1099
|
+
mem = psutil.virtual_memory()
|
|
1100
|
+
cpu_cores = os.cpu_count() or 0
|
|
1101
|
+
|
|
1102
|
+
gpu_count = 0
|
|
1103
|
+
gpu_total_mem = 0
|
|
1104
|
+
vram_models = []
|
|
1105
|
+
gpu_used_mem = 0
|
|
1106
|
+
gpu_free_mem = 0
|
|
1107
|
+
gpu_name = ""
|
|
1108
|
+
gpu_driver = ""
|
|
1109
|
+
|
|
1110
|
+
if torch and torch.cuda.is_available():
|
|
1111
|
+
gpu_count = torch.cuda.device_count()
|
|
1112
|
+
if gpu_count > 0:
|
|
1113
|
+
try:
|
|
1114
|
+
props = torch.cuda.get_device_properties(0)
|
|
1115
|
+
gpu_total_mem = props.total_memory
|
|
1116
|
+
gpu_used_mem = torch.cuda.memory_allocated(0)
|
|
1117
|
+
gpu_name = props.name
|
|
1118
|
+
gpu_driver = torch.version.cuda or ""
|
|
1119
|
+
try:
|
|
1120
|
+
free_mem, total_mem = torch.cuda.mem_get_info(0)
|
|
1121
|
+
gpu_total_mem = total_mem
|
|
1122
|
+
gpu_used_mem = total_mem - free_mem
|
|
1123
|
+
gpu_free_mem = free_mem
|
|
1124
|
+
except Exception:
|
|
1125
|
+
pass
|
|
1126
|
+
logger.debug(f"GPU: {props.name}, VRAM total={gpu_total_mem}, used={gpu_used_mem}, cuda={torch.version.cuda}")
|
|
1127
|
+
except Exception as gpu_err:
|
|
1128
|
+
logger.warning(f"Could not get GPU properties: {gpu_err}")
|
|
1129
|
+
|
|
1130
|
+
fake_gpu_count = os.getenv("WORKER_FAKE_GPU_COUNT")
|
|
1131
|
+
if fake_gpu_count:
|
|
1132
|
+
try:
|
|
1133
|
+
gpu_count = int(fake_gpu_count)
|
|
1134
|
+
if gpu_count > 0:
|
|
1135
|
+
fake_mem = int(os.getenv("WORKER_FAKE_GPU_MEMORY_BYTES", str(24 * 1024 * 1024 * 1024)))
|
|
1136
|
+
gpu_total_mem = fake_mem
|
|
1137
|
+
gpu_used_mem = 0
|
|
1138
|
+
gpu_free_mem = fake_mem
|
|
1139
|
+
gpu_name = os.getenv("WORKER_FAKE_GPU_NAME", "FakeGPU")
|
|
1140
|
+
gpu_driver = os.getenv("WORKER_FAKE_GPU_DRIVER", "fake")
|
|
1141
|
+
except ValueError:
|
|
1142
|
+
logger.warning("Invalid WORKER_FAKE_GPU_COUNT; ignoring fake GPU override.")
|
|
1143
|
+
|
|
1144
|
+
supports_model_loading_flag = False
|
|
1145
|
+
# current_models = []
|
|
1146
|
+
if self._model_manager:
|
|
1147
|
+
vram_models = self._model_manager.get_vram_loaded_models()
|
|
1148
|
+
supports_model_loading_flag = True
|
|
1149
|
+
|
|
1150
|
+
function_concurrency = {}
|
|
1151
|
+
for func_name, req in self._discovered_resources.items():
|
|
1152
|
+
if req and req.max_concurrency:
|
|
1153
|
+
function_concurrency[func_name] = int(req.max_concurrency)
|
|
1154
|
+
|
|
1155
|
+
cuda_version = os.getenv("WORKER_CUDA_VERSION", "").strip()
|
|
1156
|
+
torch_version = os.getenv("WORKER_TORCH_VERSION", "").strip()
|
|
1157
|
+
if torch is not None:
|
|
1158
|
+
if not torch_version:
|
|
1159
|
+
torch_version = getattr(torch, "__version__", "") or ""
|
|
1160
|
+
if not cuda_version:
|
|
1161
|
+
cuda_version = getattr(torch.version, "cuda", "") or ""
|
|
1162
|
+
if not cuda_version:
|
|
1163
|
+
cuda_version = os.getenv("CUDA_VERSION", "").strip() or os.getenv("NVIDIA_CUDA_VERSION", "").strip()
|
|
1164
|
+
|
|
1165
|
+
function_schemas = []
|
|
1166
|
+
for fname, (in_schema, out_schema) in self._function_schemas.items():
|
|
1167
|
+
try:
|
|
1168
|
+
function_schemas.append(
|
|
1169
|
+
pb.FunctionSchema(
|
|
1170
|
+
name=fname,
|
|
1171
|
+
input_schema_json=in_schema,
|
|
1172
|
+
output_schema_json=out_schema,
|
|
1173
|
+
)
|
|
1174
|
+
)
|
|
1175
|
+
except Exception:
|
|
1176
|
+
continue
|
|
1177
|
+
|
|
1178
|
+
resources = pb.WorkerResources(
|
|
1179
|
+
worker_id=self.worker_id,
|
|
1180
|
+
deployment_id=self.deployment_id,
|
|
1181
|
+
# tenant_id=self.tenant_id,
|
|
1182
|
+
runpod_pod_id=self.runpod_pod_id,
|
|
1183
|
+
gpu_is_busy=self._get_gpu_busy_status(),
|
|
1184
|
+
cpu_cores=cpu_cores,
|
|
1185
|
+
memory_bytes=mem.total,
|
|
1186
|
+
gpu_count=gpu_count,
|
|
1187
|
+
gpu_memory_bytes=gpu_total_mem,
|
|
1188
|
+
gpu_memory_used_bytes=gpu_used_mem,
|
|
1189
|
+
gpu_memory_free_bytes=gpu_free_mem,
|
|
1190
|
+
gpu_name=gpu_name,
|
|
1191
|
+
gpu_driver=gpu_driver,
|
|
1192
|
+
max_concurrency=self.max_concurrency,
|
|
1193
|
+
function_concurrency=function_concurrency,
|
|
1194
|
+
cuda_version=cuda_version,
|
|
1195
|
+
torch_version=torch_version,
|
|
1196
|
+
available_functions=list(self._actions.keys()),
|
|
1197
|
+
available_models=vram_models,
|
|
1198
|
+
supports_model_loading=supports_model_loading_flag,
|
|
1199
|
+
function_schemas=function_schemas,
|
|
1200
|
+
)
|
|
1201
|
+
registration = pb.WorkerRegistration(
|
|
1202
|
+
resources=resources,
|
|
1203
|
+
is_heartbeat=is_heartbeat
|
|
1204
|
+
)
|
|
1205
|
+
message = pb.WorkerSchedulerMessage(worker_registration=registration)
|
|
1206
|
+
# logger.info(f"DEBUG: Preparing to send registration. Resource object: {resources}")
|
|
1207
|
+
# logger.info(f"DEBUG: Value being sent for runpod_pod_id: '{resources.runpod_pod_id}'")
|
|
1208
|
+
self._send_message(message)
|
|
1209
|
+
except Exception as e:
|
|
1210
|
+
logger.error(f"Failed to create or send registration/heartbeat: {e}")
|
|
1211
|
+
|
|
1212
|
+
def run(self) -> None:
|
|
1213
|
+
"""Run the worker, connecting to the scheduler and processing tasks."""
|
|
1214
|
+
if self._running:
|
|
1215
|
+
logger.warning("Worker is already running")
|
|
1216
|
+
return
|
|
1217
|
+
|
|
1218
|
+
self._running = True
|
|
1219
|
+
self._stop_event.clear()
|
|
1220
|
+
self._reconnect_count = 0 # Reset reconnect count on new run
|
|
1221
|
+
self._draining = False
|
|
1222
|
+
|
|
1223
|
+
while self._running and not self._stop_event.is_set():
|
|
1224
|
+
self._reconnect_count += 1
|
|
1225
|
+
logger.info(f"Connection attempt {self._reconnect_count}...")
|
|
1226
|
+
if self.connect():
|
|
1227
|
+
# Successfully connected, wait for stop signal or disconnection
|
|
1228
|
+
logger.info("Connection successful. Worker running.")
|
|
1229
|
+
self._stop_event.wait() # Wait here until stopped or disconnected
|
|
1230
|
+
logger.info("Worker run loop received stop/disconnect signal.")
|
|
1231
|
+
# If stopped normally (self.stop() called), _running will be False
|
|
1232
|
+
# If disconnected, connect() failed, threads stopped, _handle_connection_error called _stop_event.set()
|
|
1233
|
+
else:
|
|
1234
|
+
# Connection failed
|
|
1235
|
+
if self.max_reconnect_attempts > 0 and self._reconnect_count >= self.max_reconnect_attempts:
|
|
1236
|
+
logger.error("Failed to connect after maximum attempts. Stopping worker.")
|
|
1237
|
+
self._running = False # Ensure loop terminates
|
|
1238
|
+
break
|
|
1239
|
+
|
|
1240
|
+
if self._running and not self._stop_event.is_set():
|
|
1241
|
+
backoff = self._reconnect_delay_base * (2 ** max(self._reconnect_count - 1, 0))
|
|
1242
|
+
if self._reconnect_delay_max > 0:
|
|
1243
|
+
backoff = min(backoff, self._reconnect_delay_max)
|
|
1244
|
+
jitter = random.uniform(0, self._reconnect_jitter) if self._reconnect_jitter > 0 else 0
|
|
1245
|
+
delay = backoff + jitter
|
|
1246
|
+
logger.info(f"Connection attempt {self._reconnect_count} failed. Retrying in {delay:.2f} seconds...")
|
|
1247
|
+
# Wait for delay, but break if stop event is set during wait
|
|
1248
|
+
if self._stop_event.wait(delay):
|
|
1249
|
+
logger.info("Stop requested during reconnect delay.")
|
|
1250
|
+
break # Exit if stopped while waiting
|
|
1251
|
+
# After a failed attempt or disconnect, clear stop event for next retry
|
|
1252
|
+
if self._running:
|
|
1253
|
+
self._stop_event.clear()
|
|
1254
|
+
|
|
1255
|
+
# Cleanup after loop exits (either max attempts reached or manual stop)
|
|
1256
|
+
self.stop()
|
|
1257
|
+
|
|
1258
|
+
def _handle_interrupt(self, sig: int, frame: Optional[Any]) -> None:
|
|
1259
|
+
"""Handle interrupt signal (Ctrl+C)."""
|
|
1260
|
+
logger.info(f"Received signal {sig}, shutting down gracefully.")
|
|
1261
|
+
self.stop()
|
|
1262
|
+
|
|
1263
|
+
def stop(self) -> None:
|
|
1264
|
+
"""Stop the worker and clean up resources."""
|
|
1265
|
+
if not self._running and not self._stop_event.is_set(): # Check if already stopped or stopping
|
|
1266
|
+
# Avoid multiple stop calls piling up
|
|
1267
|
+
# logger.debug("Stop called but worker already stopped or stopping.")
|
|
1268
|
+
return
|
|
1269
|
+
|
|
1270
|
+
logger.info("Stopping worker...")
|
|
1271
|
+
self._draining = True
|
|
1272
|
+
self._running = False # Signal loops to stop
|
|
1273
|
+
self._stop_event.set() # Wake up any waiting threads
|
|
1274
|
+
|
|
1275
|
+
# Cancel any active tasks
|
|
1276
|
+
active_task_ids = []
|
|
1277
|
+
if self._drain_timeout_seconds > 0:
|
|
1278
|
+
deadline = time.time() + self._drain_timeout_seconds
|
|
1279
|
+
while time.time() < deadline:
|
|
1280
|
+
with self._active_tasks_lock:
|
|
1281
|
+
remaining = len(self._active_tasks)
|
|
1282
|
+
if remaining == 0:
|
|
1283
|
+
break
|
|
1284
|
+
time.sleep(0.2)
|
|
1285
|
+
|
|
1286
|
+
with self._active_tasks_lock:
|
|
1287
|
+
active_task_ids = list(self._active_tasks.keys())
|
|
1288
|
+
for run_id in active_task_ids:
|
|
1289
|
+
ctx = self._active_tasks.get(run_id)
|
|
1290
|
+
if ctx:
|
|
1291
|
+
logger.debug(f"Cancelling active task {run_id} during stop.")
|
|
1292
|
+
ctx.cancel()
|
|
1293
|
+
# Don't clear here, allow _execute_function to finish and remove
|
|
1294
|
+
|
|
1295
|
+
# Wait for threads (give them a chance to finish)
|
|
1296
|
+
# Stop heartbeat first
|
|
1297
|
+
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
|
|
1298
|
+
logger.debug("Joining heartbeat thread...")
|
|
1299
|
+
self._heartbeat_thread.join(timeout=1.0)
|
|
1300
|
+
|
|
1301
|
+
# The outgoing iterator might be blocked on queue.get, stop_event wakes it
|
|
1302
|
+
|
|
1303
|
+
# Close the gRPC connection (this might interrupt the receive loop)
|
|
1304
|
+
self._close_connection()
|
|
1305
|
+
|
|
1306
|
+
# Wait for receive thread
|
|
1307
|
+
if self._receive_thread and self._receive_thread.is_alive():
|
|
1308
|
+
logger.debug("Joining receive thread...")
|
|
1309
|
+
self._receive_thread.join(timeout=2.0)
|
|
1310
|
+
|
|
1311
|
+
# Clear outgoing queue after threads are stopped
|
|
1312
|
+
logger.debug("Clearing outgoing message queue...")
|
|
1313
|
+
while not self._outgoing_queue.empty():
|
|
1314
|
+
try:
|
|
1315
|
+
self._outgoing_queue.get_nowait()
|
|
1316
|
+
except queue.Empty:
|
|
1317
|
+
break
|
|
1318
|
+
|
|
1319
|
+
logger.info("Worker stopped.")
|
|
1320
|
+
# Reset stop event in case run() is called again
|
|
1321
|
+
self._stop_event.clear()
|
|
1322
|
+
|
|
1323
|
+
def _close_connection(self) -> None:
|
|
1324
|
+
"""Close the gRPC channel and reset state."""
|
|
1325
|
+
if self._stream:
|
|
1326
|
+
try:
|
|
1327
|
+
# Attempt to cancel the stream from the client side
|
|
1328
|
+
# This might help the server side release resources quicker
|
|
1329
|
+
# Note: Behavior might vary depending on server implementation
|
|
1330
|
+
if hasattr(self._stream, 'cancel') and callable(self._stream.cancel):
|
|
1331
|
+
self._stream.cancel()
|
|
1332
|
+
logger.debug("gRPC stream cancelled.")
|
|
1333
|
+
except Exception as e:
|
|
1334
|
+
logger.warning(f"Error cancelling gRPC stream: {e}")
|
|
1335
|
+
self._stream = None
|
|
1336
|
+
|
|
1337
|
+
if self._channel:
|
|
1338
|
+
try:
|
|
1339
|
+
self._channel.close()
|
|
1340
|
+
logger.debug("gRPC channel closed.")
|
|
1341
|
+
except Exception as e:
|
|
1342
|
+
logger.error(f"Error closing gRPC channel: {e}")
|
|
1343
|
+
self._channel = None
|
|
1344
|
+
self._stub = None
|
|
1345
|
+
|
|
1346
|
+
|
|
1347
|
+
def _receive_loop(self) -> None:
|
|
1348
|
+
"""Loop to receive messages from the scheduler via the stream."""
|
|
1349
|
+
logger.info("Receive loop started.")
|
|
1350
|
+
try:
|
|
1351
|
+
if not self._stream:
|
|
1352
|
+
logger.error("Receive loop started without a valid stream.")
|
|
1353
|
+
# Don't call _handle_connection_error here, connect should have failed
|
|
1354
|
+
return
|
|
1355
|
+
|
|
1356
|
+
for message in self._stream:
|
|
1357
|
+
# Check stop event *before* processing
|
|
1358
|
+
if self._stop_event.is_set():
|
|
1359
|
+
logger.debug("Stop event set during iteration, exiting receive loop.")
|
|
1360
|
+
break
|
|
1361
|
+
try:
|
|
1362
|
+
self._process_message(message)
|
|
1363
|
+
except Exception as e:
|
|
1364
|
+
# Log errors processing individual messages but continue loop
|
|
1365
|
+
logger.exception(f"Error processing message: {e}")
|
|
1366
|
+
|
|
1367
|
+
except grpc.RpcError as e:
|
|
1368
|
+
# RpcError indicates a problem with the gRPC connection itself
|
|
1369
|
+
code = e.code() if hasattr(e, 'code') and callable(e.code) else grpc.StatusCode.UNKNOWN
|
|
1370
|
+
details = e.details() if hasattr(e, 'details') and callable(e.details) else str(e)
|
|
1371
|
+
|
|
1372
|
+
if self._stop_event.is_set():
|
|
1373
|
+
# If stopping, cancellation is expected
|
|
1374
|
+
if code == grpc.StatusCode.CANCELLED:
|
|
1375
|
+
logger.info("gRPC stream cancelled gracefully during shutdown.")
|
|
1376
|
+
else:
|
|
1377
|
+
logger.warning(f"gRPC error during shutdown: {code} - {details}")
|
|
1378
|
+
elif code == grpc.StatusCode.FAILED_PRECONDITION:
|
|
1379
|
+
leader = self._extract_leader_addr(details)
|
|
1380
|
+
if leader:
|
|
1381
|
+
logger.warning(f"Scheduler redirect received; reconnecting to leader at {leader}")
|
|
1382
|
+
self._leader_hint = leader
|
|
1383
|
+
self._set_scheduler_addr(leader)
|
|
1384
|
+
self._handle_connection_error()
|
|
1385
|
+
elif code == grpc.StatusCode.CANCELLED:
|
|
1386
|
+
logger.warning("gRPC stream unexpectedly cancelled by server or network.")
|
|
1387
|
+
self._handle_connection_error()
|
|
1388
|
+
elif code in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.DEADLINE_EXCEEDED, grpc.StatusCode.INTERNAL):
|
|
1389
|
+
logger.warning(f"gRPC connection lost ({code}). Attempting reconnect.")
|
|
1390
|
+
self._handle_connection_error()
|
|
1391
|
+
else:
|
|
1392
|
+
logger.error(f"Unhandled gRPC error in receive loop: {code} - {details}")
|
|
1393
|
+
self._handle_connection_error() # Attempt reconnect on unknown errors too
|
|
1394
|
+
except Exception as e:
|
|
1395
|
+
# Catch-all for non-gRPC errors in the loop
|
|
1396
|
+
if not self._stop_event.is_set():
|
|
1397
|
+
logger.exception(f"Unexpected error in receive loop: {e}")
|
|
1398
|
+
self._handle_connection_error() # Attempt reconnect
|
|
1399
|
+
finally:
|
|
1400
|
+
logger.info("Receive loop finished.")
|
|
1401
|
+
|
|
1402
|
+
def _handle_connection_error(self) -> None:
|
|
1403
|
+
"""Handles actions needed when a connection error occurs during run."""
|
|
1404
|
+
if self._running and not self._stop_event.is_set():
|
|
1405
|
+
logger.warning("Connection error detected. Signaling main loop to reconnect...")
|
|
1406
|
+
self._close_connection() # Ensure resources are closed before reconnect attempt
|
|
1407
|
+
self._stop_event.set() # Signal run loop to attempt reconnection
|
|
1408
|
+
# else: # Already stopping or stopped
|
|
1409
|
+
# logger.debug("Connection error detected but worker is already stopping.")
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
def _process_message(self, message: WorkerSchedulerMessage) -> None:
|
|
1413
|
+
"""Process a single message received from the scheduler."""
|
|
1414
|
+
msg_type = message.WhichOneof('msg')
|
|
1415
|
+
# logger.debug(f"Received message of type: {msg_type}")
|
|
1416
|
+
|
|
1417
|
+
if msg_type == 'run_request':
|
|
1418
|
+
self._handle_run_request(message.run_request)
|
|
1419
|
+
elif msg_type == 'load_model_cmd':
|
|
1420
|
+
# TODO: Implement model loading logic
|
|
1421
|
+
# model_id = message.load_model_cmd.model_id
|
|
1422
|
+
# logger.warning(f"Received load_model_cmd for {model_id}, but not yet implemented.")
|
|
1423
|
+
# # Send result back (failure for now)
|
|
1424
|
+
# result = pb.LoadModelResult(model_id=model_id, success=False, error_message="Model loading not implemented")
|
|
1425
|
+
# self._send_message(pb.WorkerSchedulerMessage(load_model_result=result))
|
|
1426
|
+
self._handle_load_model_cmd(message.load_model_cmd)
|
|
1427
|
+
elif msg_type == 'unload_model_cmd':
|
|
1428
|
+
# TODO: Implement model unloading logic
|
|
1429
|
+
model_id = message.unload_model_cmd.model_id
|
|
1430
|
+
logger.warning(f"Received unload_model_cmd for {model_id}, but not yet implemented.")
|
|
1431
|
+
result = pb.UnloadModelResult(model_id=model_id, success=False, error_message="Model unloading not implemented")
|
|
1432
|
+
self._send_message(pb.WorkerSchedulerMessage(unload_model_result=result))
|
|
1433
|
+
elif msg_type == 'interrupt_run_cmd':
|
|
1434
|
+
run_id = message.interrupt_run_cmd.run_id
|
|
1435
|
+
self._handle_interrupt_request(run_id)
|
|
1436
|
+
# Add handling for other message types if needed (e.g., config updates)
|
|
1437
|
+
elif msg_type == 'deployment_model_config':
|
|
1438
|
+
if self._model_manager:
|
|
1439
|
+
logger.info(f"Received DeploymentModelConfig: {message.deployment_model_config.supported_model_ids}")
|
|
1440
|
+
self._supported_model_ids_from_scheduler = list(message.deployment_model_config.supported_model_ids)
|
|
1441
|
+
self._model_init_done_event.clear() # Clear before starting new init
|
|
1442
|
+
model_init_thread = threading.Thread(target=self._process_deployment_config_async_wrapper, daemon=True)
|
|
1443
|
+
model_init_thread.start()
|
|
1444
|
+
else:
|
|
1445
|
+
logger.info("Received DeploymentModelConfig, but no model manager configured. Ignoring.")
|
|
1446
|
+
self._model_init_done_event.set() # Signal completion as there's nothing to do
|
|
1447
|
+
elif msg_type is None:
|
|
1448
|
+
logger.warning("Received empty message from scheduler.")
|
|
1449
|
+
else:
|
|
1450
|
+
logger.warning(f"Received unhandled message type: {msg_type}")
|
|
1451
|
+
|
|
1452
|
+
def _process_deployment_config_async_wrapper(self) -> None:
|
|
1453
|
+
if not self._model_manager or self._supported_model_ids_from_scheduler is None:
|
|
1454
|
+
self._model_init_done_event.set()
|
|
1455
|
+
return
|
|
1456
|
+
|
|
1457
|
+
loop = None
|
|
1458
|
+
try:
|
|
1459
|
+
# Get or create an event loop for this thread
|
|
1460
|
+
try:
|
|
1461
|
+
loop = asyncio.get_running_loop()
|
|
1462
|
+
except RuntimeError:
|
|
1463
|
+
loop = asyncio.new_event_loop()
|
|
1464
|
+
asyncio.set_event_loop(loop)
|
|
1465
|
+
|
|
1466
|
+
loop.run_until_complete(
|
|
1467
|
+
self._model_manager.process_supported_models_config(
|
|
1468
|
+
self._supported_model_ids_from_scheduler,
|
|
1469
|
+
self._downloader
|
|
1470
|
+
)
|
|
1471
|
+
)
|
|
1472
|
+
logger.info("Model configuration and downloads (if any) processed.")
|
|
1473
|
+
except Exception as e:
|
|
1474
|
+
logger.exception(f"Error during model_manager.process_supported_models_config: {e}")
|
|
1475
|
+
finally:
|
|
1476
|
+
if loop and not loop.is_running() and not loop.is_closed(): # Clean up loop if we created it
|
|
1477
|
+
loop.close()
|
|
1478
|
+
self._model_init_done_event.set() # Signal completion or failure
|
|
1479
|
+
|
|
1480
|
+
def _handle_load_model_cmd(self, cmd: LoadModelCommand) -> None:
|
|
1481
|
+
model_id = cmd.model_id
|
|
1482
|
+
logger.info(f"Received LoadModelCommand for: {model_id}")
|
|
1483
|
+
success = False; error_msg = ""
|
|
1484
|
+
if not self._model_manager:
|
|
1485
|
+
error_msg = "LoadModelCommand: No model manager configured on worker."
|
|
1486
|
+
logger.error(error_msg)
|
|
1487
|
+
else:
|
|
1488
|
+
try:
|
|
1489
|
+
# Wait for initial model downloads if they haven't finished
|
|
1490
|
+
if not self._model_init_done_event.is_set():
|
|
1491
|
+
logger.info(f"LoadModelCmd ({model_id}): Waiting for initial model setup...")
|
|
1492
|
+
# Timeout for this wait, can be adjusted
|
|
1493
|
+
if not self._model_init_done_event.wait(timeout=300.0): # 5 minutes
|
|
1494
|
+
raise TimeoutError("Timeout waiting for model initialization before VRAM load.")
|
|
1495
|
+
|
|
1496
|
+
logger.info(f"Model Memory Manager attempting to load '{model_id}' into VRAM...")
|
|
1497
|
+
# load_model_into_vram is async
|
|
1498
|
+
success = asyncio.run(self._model_manager.load_model_into_vram(model_id))
|
|
1499
|
+
if success: logger.info(f"Model '{model_id}' loaded to VRAM by Model Memory Manager.")
|
|
1500
|
+
else: error_msg = f"MMM.load_model_into_vram failed for '{model_id}'."; logger.error(error_msg)
|
|
1501
|
+
except Exception as e: error_msg = f"Exception in mmm.load_model_into_vram for '{model_id}': {e}"; logger.exception(error_msg)
|
|
1502
|
+
|
|
1503
|
+
result = pb.LoadModelResult(model_id=model_id, success=success, error_message=error_msg)
|
|
1504
|
+
self._send_message(pb.WorkerSchedulerMessage(load_model_result=result))
|
|
1505
|
+
|
|
1506
|
+
|
|
1507
|
+
def _handle_run_request(self, request: TaskExecutionRequest) -> None:
|
|
1508
|
+
"""Handle a task execution request from the scheduler."""
|
|
1509
|
+
run_id = request.run_id
|
|
1510
|
+
function_name = request.function_name
|
|
1511
|
+
input_payload = request.input_payload
|
|
1512
|
+
required_model_id_for_exec = ""
|
|
1513
|
+
timeout_ms = int(getattr(request, "timeout_ms", 0) or 0)
|
|
1514
|
+
tenant_id = str(getattr(request, "tenant_id", "") or "") or (self.tenant_id or "")
|
|
1515
|
+
user_id = str(getattr(request, "user_id", "") or "")
|
|
1516
|
+
|
|
1517
|
+
if request.required_models and len(request.required_models) > 0:
|
|
1518
|
+
required_model_id_for_exec = request.required_models[0]
|
|
1519
|
+
|
|
1520
|
+
logger.info(f"Received Task request: run_id={run_id}, function={function_name}, model='{required_model_id_for_exec or 'None'}'")
|
|
1521
|
+
|
|
1522
|
+
func_wrapper = self._actions.get(function_name)
|
|
1523
|
+
if not func_wrapper:
|
|
1524
|
+
error_msg = f"Unknown function requested: {function_name}"
|
|
1525
|
+
logger.error(error_msg)
|
|
1526
|
+
self._send_task_result(run_id, False, None, error_msg)
|
|
1527
|
+
return
|
|
1528
|
+
if self.max_input_bytes > 0 and len(input_payload) > self.max_input_bytes:
|
|
1529
|
+
error_msg = f"Input payload too large: {len(input_payload)} bytes (max {self.max_input_bytes})"
|
|
1530
|
+
logger.error(error_msg)
|
|
1531
|
+
self._send_task_result(run_id, False, None, error_msg)
|
|
1532
|
+
return
|
|
1533
|
+
if self._draining:
|
|
1534
|
+
error_msg = "Worker is draining; refusing new tasks"
|
|
1535
|
+
logger.warning(error_msg)
|
|
1536
|
+
self._send_task_result(run_id, False, None, error_msg)
|
|
1537
|
+
return
|
|
1538
|
+
|
|
1539
|
+
ctx = ActionContext(
|
|
1540
|
+
run_id,
|
|
1541
|
+
emitter=self._emit_progress_event,
|
|
1542
|
+
tenant_id=tenant_id or None,
|
|
1543
|
+
user_id=user_id or None,
|
|
1544
|
+
timeout_ms=timeout_ms if timeout_ms > 0 else None,
|
|
1545
|
+
)
|
|
1546
|
+
# Add to active tasks *before* starting thread
|
|
1547
|
+
with self._active_tasks_lock:
|
|
1548
|
+
# Double-check if task is already active (race condition mitigation)
|
|
1549
|
+
if run_id in self._active_tasks:
|
|
1550
|
+
error_msg = f"Task with run_id {run_id} is already active (race condition?)."
|
|
1551
|
+
logger.error(error_msg)
|
|
1552
|
+
return # Avoid starting duplicate thread
|
|
1553
|
+
if self.max_concurrency > 0 and len(self._active_tasks) >= self.max_concurrency:
|
|
1554
|
+
error_msg = f"Worker concurrency limit reached ({self.max_concurrency})."
|
|
1555
|
+
logger.error(error_msg)
|
|
1556
|
+
self._send_task_result(run_id, False, None, error_msg)
|
|
1557
|
+
return
|
|
1558
|
+
resource_req = self._discovered_resources.get(function_name)
|
|
1559
|
+
func_limit = resource_req.max_concurrency if resource_req and resource_req.max_concurrency else 0
|
|
1560
|
+
if func_limit > 0 and self._active_function_counts.get(function_name, 0) >= func_limit:
|
|
1561
|
+
error_msg = f"Function concurrency limit reached for {function_name} ({func_limit})."
|
|
1562
|
+
logger.error(error_msg)
|
|
1563
|
+
self._send_task_result(run_id, False, None, error_msg)
|
|
1564
|
+
return
|
|
1565
|
+
self._active_tasks[run_id] = ctx
|
|
1566
|
+
if func_limit > 0:
|
|
1567
|
+
self._active_function_counts[function_name] = self._active_function_counts.get(function_name, 0) + 1
|
|
1568
|
+
|
|
1569
|
+
# Execute function in a separate thread to avoid blocking the receive loop
|
|
1570
|
+
thread = threading.Thread(
|
|
1571
|
+
target=self._execute_function,
|
|
1572
|
+
args=(ctx, function_name, func_wrapper, input_payload, required_model_id_for_exec),
|
|
1573
|
+
daemon=True,
|
|
1574
|
+
)
|
|
1575
|
+
thread.start()
|
|
1576
|
+
|
|
1577
|
+
def _handle_interrupt_request(self, run_id: str) -> None:
|
|
1578
|
+
"""Handle a request to interrupt/cancel a running task."""
|
|
1579
|
+
logger.info(f"Received interrupt request for run_id={run_id}")
|
|
1580
|
+
with self._active_tasks_lock:
|
|
1581
|
+
ctx = self._active_tasks.get(run_id)
|
|
1582
|
+
if ctx:
|
|
1583
|
+
ctx.cancel() # Set internal flag and event
|
|
1584
|
+
else:
|
|
1585
|
+
logger.warning(f"Could not interrupt task {run_id}: Not found in active tasks.")
|
|
1586
|
+
|
|
1587
|
+
def _execute_function(
|
|
1588
|
+
self,
|
|
1589
|
+
ctx: ActionContext,
|
|
1590
|
+
function_name: str,
|
|
1591
|
+
func_to_execute: Callable[[ActionContext, Optional[Any], bytes], bytes],
|
|
1592
|
+
input_payload: bytes,
|
|
1593
|
+
required_model_id: str,
|
|
1594
|
+
) -> None:
|
|
1595
|
+
"""Execute the registered function and send the result/error back."""
|
|
1596
|
+
run_id = ctx.run_id
|
|
1597
|
+
output_payload: Optional[bytes] = None
|
|
1598
|
+
error_message: str = ""
|
|
1599
|
+
success = False
|
|
1600
|
+
|
|
1601
|
+
# Determine if this function requires GPU and manage worker's GPU state
|
|
1602
|
+
func_requires_gpu = False
|
|
1603
|
+
resource_req = self._discovered_resources.get(function_name)
|
|
1604
|
+
if resource_req:
|
|
1605
|
+
func_requires_gpu = resource_req.requires_gpu
|
|
1606
|
+
func_expects_pipeline = resource_req.expects_pipeline_arg
|
|
1607
|
+
|
|
1608
|
+
# Variable to track if this specific thread execution set the GPU busy
|
|
1609
|
+
this_thread_set_gpu_busy = False
|
|
1610
|
+
if func_requires_gpu:
|
|
1611
|
+
with self._gpu_busy_lock: # Lock to check and set self._is_gpu_busy atomically
|
|
1612
|
+
if not self._is_gpu_busy:
|
|
1613
|
+
self._is_gpu_busy = True
|
|
1614
|
+
this_thread_set_gpu_busy = True
|
|
1615
|
+
logger.info(f"Worker GPU marked as BUSY by task {run_id} ({function_name}).")
|
|
1616
|
+
else:
|
|
1617
|
+
logger.warning(f"Task {run_id} ({function_name}) requires GPU, but worker GPU was already marked busy. Proceeding...")
|
|
1618
|
+
|
|
1619
|
+
active_pipeline_instance = None # To hold the pipeline for the user function
|
|
1620
|
+
try:
|
|
1621
|
+
if ctx.is_canceled():
|
|
1622
|
+
raise InterruptedError("Task cancelled before execution")
|
|
1623
|
+
|
|
1624
|
+
if func_expects_pipeline:
|
|
1625
|
+
if not required_model_id and resource_req and resource_req.model_name:
|
|
1626
|
+
required_model_id = str(resource_req.model_name)
|
|
1627
|
+
if not required_model_id:
|
|
1628
|
+
raise ValueError(f"Function '{function_name}' expects a pipeline argument, but no model ID was provided.")
|
|
1629
|
+
|
|
1630
|
+
if not self._model_manager:
|
|
1631
|
+
raise RuntimeError(f"Function '{function_name}' expects a pipeline argument, but no model manager configured on worker.")
|
|
1632
|
+
|
|
1633
|
+
if not self._model_init_done_event.is_set():
|
|
1634
|
+
logger.info(f"Task {run_id} ({function_name}) waiting for initial model setup...")
|
|
1635
|
+
if not self._model_init_done_event.wait(timeout=300.0): # 5 min timeout
|
|
1636
|
+
raise TimeoutError(f"Timeout waiting for model initialization for task {run_id}")
|
|
1637
|
+
logger.info(f"Initial model setup complete. Proceeding for task {run_id}.")
|
|
1638
|
+
|
|
1639
|
+
logger.info(f"Task {run_id} ({function_name}) getting active pipeline for model '{required_model_id}'...")
|
|
1640
|
+
# get_active_pipeline is async
|
|
1641
|
+
active_pipeline_instance = asyncio.run(self._model_manager.get_active_pipeline(required_model_id))
|
|
1642
|
+
if not active_pipeline_instance:
|
|
1643
|
+
raise RuntimeError(f"ModelManager failed to provide active pipeline for '{required_model_id}' for task {run_id}.")
|
|
1644
|
+
|
|
1645
|
+
logger.info(f"Task {run_id} ({function_name}) obtained pipeline for model '{required_model_id}'.")
|
|
1646
|
+
|
|
1647
|
+
# Execute the function wrapper (which handles deserialization/serialization)
|
|
1648
|
+
output_payload = func_to_execute(ctx, active_pipeline_instance, input_payload)
|
|
1649
|
+
# Check for cancellation *during* execution (func should check ctx.is_canceled)
|
|
1650
|
+
if ctx.is_canceled():
|
|
1651
|
+
raise InterruptedError("Task was cancelled during execution")
|
|
1652
|
+
|
|
1653
|
+
if output_payload is not None and self.max_output_bytes > 0:
|
|
1654
|
+
if len(output_payload) > self.max_output_bytes:
|
|
1655
|
+
raise ValueError(f"Output payload too large: {len(output_payload)} bytes (max {self.max_output_bytes})")
|
|
1656
|
+
|
|
1657
|
+
success = True
|
|
1658
|
+
logger.info(f"Task {run_id} completed successfully.")
|
|
1659
|
+
|
|
1660
|
+
except InterruptedError as e:
|
|
1661
|
+
error_message = self._format_error(str(e) or "Task was canceled", retryable=False)
|
|
1662
|
+
logger.warning(f"Task {run_id} was canceled: {error_message}")
|
|
1663
|
+
success = False # Explicitly set success to False on cancellation
|
|
1664
|
+
except RetryableError as e:
|
|
1665
|
+
error_message = self._format_error(f"{type(e).__name__}: {str(e)}", retryable=True)
|
|
1666
|
+
logger.error(f"Task {run_id} ({function_name}) retryable failure: {error_message}")
|
|
1667
|
+
success = False
|
|
1668
|
+
except FatalError as e:
|
|
1669
|
+
error_message = self._format_error(f"{type(e).__name__}: {str(e)}", retryable=False)
|
|
1670
|
+
logger.error(f"Task {run_id} ({function_name}) fatal failure: {error_message}")
|
|
1671
|
+
success = False
|
|
1672
|
+
except (ValueError, RuntimeError, TimeoutError) as ve_rte_to: # Catch specific errors we raise
|
|
1673
|
+
retryable = isinstance(ve_rte_to, TimeoutError)
|
|
1674
|
+
error_message = self._format_error(f"{type(ve_rte_to).__name__}: {str(ve_rte_to)}", retryable=retryable)
|
|
1675
|
+
logger.error(f"Task {run_id} ({function_name}) failed pre-execution or during model acquisition: {error_message}")
|
|
1676
|
+
success = False
|
|
1677
|
+
except Exception as e:
|
|
1678
|
+
error_message = self._format_error(f"{type(e).__name__}: {str(e)}", retryable=False)
|
|
1679
|
+
logger.exception(f"Error executing function for run_id={run_id}: {error_message}")
|
|
1680
|
+
success = False
|
|
1681
|
+
finally:
|
|
1682
|
+
# Release the GPU if this thread set it busy
|
|
1683
|
+
if this_thread_set_gpu_busy:
|
|
1684
|
+
with self._gpu_busy_lock: # Lock to set self._is_gpu_busy
|
|
1685
|
+
self._is_gpu_busy = False
|
|
1686
|
+
logger.info(f"Worker GPU marked as NOT BUSY by task {run_id} ({function_name}).")
|
|
1687
|
+
|
|
1688
|
+
# Always send a result back, regardless of success, failure, or cancellation
|
|
1689
|
+
self._send_task_result(run_id, success, output_payload, error_message)
|
|
1690
|
+
# Remove from active tasks *after* sending result
|
|
1691
|
+
with self._active_tasks_lock:
|
|
1692
|
+
if run_id in self._active_tasks:
|
|
1693
|
+
del self._active_tasks[run_id]
|
|
1694
|
+
resource_req = self._discovered_resources.get(function_name)
|
|
1695
|
+
func_limit = resource_req.max_concurrency if resource_req and resource_req.max_concurrency else 0
|
|
1696
|
+
if func_limit > 0:
|
|
1697
|
+
current = self._active_function_counts.get(function_name, 0) - 1
|
|
1698
|
+
if current <= 0:
|
|
1699
|
+
self._active_function_counts.pop(function_name, None)
|
|
1700
|
+
else:
|
|
1701
|
+
self._active_function_counts[function_name] = current
|
|
1702
|
+
# else: # Might have been removed by stop() already
|
|
1703
|
+
# logger.warning(f"Task {run_id} not found in active tasks during cleanup.")
|
|
1704
|
+
|
|
1705
|
+
|
|
1706
|
+
def _send_task_result(self, run_id: str, success: bool, output_payload: Optional[bytes], error_message: str) -> None:
|
|
1707
|
+
"""Send a task execution result back to the scheduler via the queue."""
|
|
1708
|
+
try:
|
|
1709
|
+
result = pb.TaskExecutionResult(
|
|
1710
|
+
run_id=run_id,
|
|
1711
|
+
success=success,
|
|
1712
|
+
output_payload=(output_payload or b'') if success else b'', # Default to b'' if None
|
|
1713
|
+
error_message=error_message if not success else ""
|
|
1714
|
+
)
|
|
1715
|
+
msg = pb.WorkerSchedulerMessage(run_result=result)
|
|
1716
|
+
self._send_message(msg)
|
|
1717
|
+
logger.debug(f"Queued task result for run_id={run_id}, success={success}")
|
|
1718
|
+
except Exception as e:
|
|
1719
|
+
# This shouldn't generally fail unless message creation has issues
|
|
1720
|
+
logger.error(f"Failed to create or queue task result for run_id={run_id}: {e}")
|