gen-worker 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gen_worker/__init__.py +19 -0
  2. gen_worker/decorators.py +66 -0
  3. gen_worker/default_model_manager/__init__.py +5 -0
  4. gen_worker/downloader.py +84 -0
  5. gen_worker/entrypoint.py +135 -0
  6. gen_worker/errors.py +10 -0
  7. gen_worker/model_interface.py +48 -0
  8. gen_worker/pb/__init__.py +27 -0
  9. gen_worker/pb/frontend_pb2.py +53 -0
  10. gen_worker/pb/frontend_pb2_grpc.py +189 -0
  11. gen_worker/pb/worker_scheduler_pb2.py +69 -0
  12. gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
  13. gen_worker/py.typed +0 -0
  14. gen_worker/testing/__init__.py +1 -0
  15. gen_worker/testing/stub_manager.py +69 -0
  16. gen_worker/torch_manager/__init__.py +4 -0
  17. gen_worker/torch_manager/manager.py +2059 -0
  18. gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
  19. gen_worker/torch_manager/utils/base_types/common.py +52 -0
  20. gen_worker/torch_manager/utils/base_types/config.py +46 -0
  21. gen_worker/torch_manager/utils/config.py +321 -0
  22. gen_worker/torch_manager/utils/db/database.py +46 -0
  23. gen_worker/torch_manager/utils/device.py +26 -0
  24. gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
  25. gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
  26. gen_worker/torch_manager/utils/globals.py +59 -0
  27. gen_worker/torch_manager/utils/load_models.py +238 -0
  28. gen_worker/torch_manager/utils/local_cache.py +340 -0
  29. gen_worker/torch_manager/utils/model_downloader.py +763 -0
  30. gen_worker/torch_manager/utils/parse_cli.py +98 -0
  31. gen_worker/torch_manager/utils/paths.py +22 -0
  32. gen_worker/torch_manager/utils/repository.py +141 -0
  33. gen_worker/torch_manager/utils/utils.py +43 -0
  34. gen_worker/types.py +47 -0
  35. gen_worker/worker.py +1720 -0
  36. gen_worker-0.1.4.dist-info/METADATA +113 -0
  37. gen_worker-0.1.4.dist-info/RECORD +38 -0
  38. gen_worker-0.1.4.dist-info/WHEEL +4 -0
gen_worker/worker.py ADDED
@@ -0,0 +1,1720 @@
1
+ import grpc
2
+ import logging
3
+ import time
4
+ import json
5
+ import urllib.request
6
+ import urllib.parse
7
+ import urllib.error
8
+ import random
9
+ import threading
10
+ import os
11
+ import signal
12
+ import queue
13
+ import psutil
14
+ import importlib
15
+ import inspect
16
+ import functools
17
+ import typing
18
+ import socket
19
+ import ipaddress
20
+ from typing import Any, Callable, Dict, Optional, TypeVar, Iterator, List, Tuple
21
+ from types import ModuleType
22
+ import hashlib
23
+ import msgspec
24
+ try:
25
+ import torch
26
+ except Exception: # pragma: no cover - optional dependency
27
+ torch = None
28
+ import asyncio
29
+
30
+ # JWT verification for worker auth (scheduler-issued)
31
+ import jwt
32
+ _jwt_algorithms: Optional[ModuleType]
33
+ try:
34
+ import jwt.algorithms as _jwt_algorithms
35
+ except Exception: # pragma: no cover - optional crypto backend
36
+ _jwt_algorithms = None
37
+ RSAAlgorithm: Optional[Any] = getattr(_jwt_algorithms, "RSAAlgorithm", None) if _jwt_algorithms else None
38
+ # Use relative imports within the package
39
+ from .pb import worker_scheduler_pb2 as _pb
40
+ from .pb import worker_scheduler_pb2_grpc as _pb_grpc
41
+
42
+ pb: Any = _pb
43
+ pb_grpc: Any = _pb_grpc
44
+
45
+ WorkerSchedulerMessage = Any
46
+ WorkerEvent = Any
47
+ WorkerResources = Any
48
+ WorkerRegistration = Any
49
+ LoadModelCommand = Any
50
+ LoadModelResult = Any
51
+ UnloadModelResult = Any
52
+ TaskExecutionRequest = Any
53
+ TaskExecutionResult = Any
54
+ from .decorators import ResourceRequirements # Import ResourceRequirements for type hints if needed
55
+ from .errors import RetryableError, FatalError
56
+
57
+ from .model_interface import ModelManagementInterface
58
+ from .downloader import CozyHubDownloader, ModelDownloader
59
+ from .types import Asset
60
+
61
+ # Configure logging
62
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
63
+ logger = logging.getLogger(__name__) # Use __name__ for logger
64
+
65
+ # Type variables for generic function signatures
66
+ I = TypeVar('I') # Input type
67
+ O = TypeVar('O') # Output type
68
+
69
+ # Generic type for action functions
70
+ ActionFunc = Callable[[Any, I], O]
71
+
72
+ HEARTBEAT_INTERVAL = 10 # seconds
73
+
74
+
75
+ def _encode_ref_for_url(ref: str) -> str:
76
+ ref = ref.strip().lstrip("/")
77
+ parts = [urllib.parse.quote(p, safe="") for p in ref.split("/") if p]
78
+ return "/".join(parts)
79
+
80
+
81
+ def _infer_mime_type(ref: str, head: bytes) -> str:
82
+ # Prefer magic bytes when available.
83
+ if head.startswith(b"\x89PNG\r\n\x1a\n"):
84
+ return "image/png"
85
+ if head.startswith(b"\xff\xd8\xff"):
86
+ return "image/jpeg"
87
+ if head.startswith(b"GIF87a") or head.startswith(b"GIF89a"):
88
+ return "image/gif"
89
+ if len(head) >= 12 and head[0:4] == b"RIFF" and head[8:12] == b"WEBP":
90
+ return "image/webp"
91
+
92
+ # Fall back to extension.
93
+ import mimetypes
94
+
95
+ guessed, _ = mimetypes.guess_type(ref)
96
+ return guessed or "application/octet-stream"
97
+
98
+
99
+ def _default_output_prefix(run_id: str) -> str:
100
+ return f"runs/{run_id}/outputs/"
101
+
102
+
103
+ def _require_file_api_base_url() -> str:
104
+ base = os.getenv("FILE_API_BASE_URL", "").strip()
105
+ if not base:
106
+ base = os.getenv("ORCHESTRATOR_HTTP_URL", "").strip()
107
+ if not base:
108
+ base = os.getenv("COZY_HUB_URL", "").strip()
109
+ if not base:
110
+ raise RuntimeError("FILE_API_BASE_URL is required for file operations")
111
+ return base.rstrip("/")
112
+
113
+
114
+ def _require_file_api_token() -> str:
115
+ token = os.getenv("FILE_API_TOKEN", "").strip()
116
+ if not token:
117
+ token = os.getenv("COZY_HUB_TOKEN", "").strip()
118
+ if not token:
119
+ raise RuntimeError("FILE_API_TOKEN is required for file operations")
120
+ return token
121
+
122
+
123
+ def _http_request(
124
+ method: str,
125
+ url: str,
126
+ token: str,
127
+ body: Optional[bytes] = None,
128
+ content_type: Optional[str] = None,
129
+ ) -> urllib.request.Request:
130
+ req = urllib.request.Request(url, data=body, method=method)
131
+ req.add_header("Authorization", f"Bearer {token}")
132
+ tenant_id = os.getenv("TENANT_ID", "").strip()
133
+ if tenant_id:
134
+ req.add_header("X-Cozy-Tenant-Id", tenant_id)
135
+ if content_type:
136
+ req.add_header("Content-Type", content_type)
137
+ return req
138
+
139
+
140
+ def _is_private_ip_str(ip_str: str) -> bool:
141
+ try:
142
+ ip = ipaddress.ip_address(ip_str)
143
+ except Exception:
144
+ return True
145
+ return bool(
146
+ ip.is_private
147
+ or ip.is_loopback
148
+ or ip.is_link_local
149
+ or ip.is_multicast
150
+ or ip.is_reserved
151
+ or ip.is_unspecified
152
+ )
153
+
154
+
155
+ def _url_is_blocked(url_str: str) -> bool:
156
+ try:
157
+ u = urllib.parse.urlparse(url_str)
158
+ except Exception:
159
+ return True
160
+ if u.scheme not in ("http", "https"):
161
+ return True
162
+ host = (u.hostname or "").strip()
163
+ if not host:
164
+ return True
165
+ try:
166
+ infos = socket.getaddrinfo(host, None)
167
+ except Exception:
168
+ return True
169
+ for info in infos:
170
+ sockaddr = info[4]
171
+ if not sockaddr:
172
+ continue
173
+ ip_str = str(sockaddr[0])
174
+ if _is_private_ip_str(ip_str):
175
+ return True
176
+ return False
177
+
178
+ class _JWKSCache:
179
+ def __init__(self, url: str, ttl_seconds: int = 300) -> None:
180
+ self._url = url
181
+ self._ttl_seconds = max(ttl_seconds, 0)
182
+ self._lock = threading.Lock()
183
+ self._fetched_at = 0.0
184
+ self._keys: Dict[str, Any] = {}
185
+
186
+ def _fetch(self) -> None:
187
+ if RSAAlgorithm is None:
188
+ raise RuntimeError(
189
+ "PyJWT RSA support is unavailable (missing cryptography). "
190
+ "Install gen-worker with a JWT/RSA-capable build of PyJWT."
191
+ )
192
+ with urllib.request.urlopen(self._url, timeout=5) as resp:
193
+ body = resp.read()
194
+ payload = json.loads(body.decode("utf-8"))
195
+ keys: Dict[str, Any] = {}
196
+ for jwk in payload.get("keys", []):
197
+ kid = jwk.get("kid")
198
+ if not kid:
199
+ continue
200
+ try:
201
+ keys[kid] = RSAAlgorithm.from_jwk(json.dumps(jwk))
202
+ except Exception:
203
+ continue
204
+ self._keys = keys
205
+ self._fetched_at = time.time()
206
+
207
+ def _needs_refresh(self) -> bool:
208
+ if not self._keys:
209
+ return True
210
+ if self._ttl_seconds <= 0:
211
+ return False
212
+ return (time.time() - self._fetched_at) > self._ttl_seconds
213
+
214
+ def get_key(self, kid: Optional[str]) -> Optional[Any]:
215
+ with self._lock:
216
+ if self._needs_refresh():
217
+ self._fetch()
218
+ if kid and kid in self._keys:
219
+ return self._keys[kid]
220
+ # refresh on miss (rotation)
221
+ self._fetch()
222
+ if kid and kid in self._keys:
223
+ return self._keys[kid]
224
+ return None
225
+
226
+ class ActionContext:
227
+ """Context object passed to action functions, allowing cancellation."""
228
+ def __init__(
229
+ self,
230
+ run_id: str,
231
+ emitter: Optional[Callable[[Dict[str, Any]], None]] = None,
232
+ tenant_id: Optional[str] = None,
233
+ user_id: Optional[str] = None,
234
+ timeout_ms: Optional[int] = None,
235
+ ) -> None:
236
+ self._run_id = run_id
237
+ self._tenant_id = tenant_id
238
+ self._user_id = user_id
239
+ self._timeout_ms = timeout_ms
240
+ self._started_at = time.time()
241
+ self._deadline: Optional[float] = None
242
+ if timeout_ms is not None and timeout_ms > 0:
243
+ self._deadline = self._started_at + (timeout_ms / 1000.0)
244
+ self._canceled = False
245
+ self._cancel_event = threading.Event()
246
+ self._emitter = emitter
247
+
248
+ @property
249
+ def run_id(self) -> str:
250
+ return self._run_id
251
+
252
+ @property
253
+ def tenant_id(self) -> Optional[str]:
254
+ return self._tenant_id
255
+
256
+ @property
257
+ def user_id(self) -> Optional[str]:
258
+ return self._user_id
259
+
260
+ @property
261
+ def timeout_ms(self) -> Optional[int]:
262
+ return self._timeout_ms
263
+
264
+ @property
265
+ def deadline(self) -> Optional[float]:
266
+ return self._deadline
267
+
268
+ def time_remaining_s(self) -> Optional[float]:
269
+ if self._deadline is None:
270
+ return None
271
+ return max(0.0, self._deadline - time.time())
272
+
273
+ def is_canceled(self) -> bool:
274
+ """Check if the action was canceled."""
275
+ return self._canceled
276
+
277
+ def cancel(self) -> None:
278
+ """Mark the action as canceled."""
279
+ if not self._canceled:
280
+ self._canceled = True
281
+ self._cancel_event.set()
282
+ logger.info(f"Action {self.run_id} marked for cancellation.")
283
+
284
+ def done(self) -> threading.Event:
285
+ """Returns an event that is set when the action is cancelled."""
286
+ return self._cancel_event
287
+
288
+ def emit(self, event_type: str, payload: Optional[Dict[str, Any]] = None) -> None:
289
+ """Emit a progress/event payload (best-effort)."""
290
+ if not self._emitter:
291
+ logger.debug(f"emit({event_type}) dropped: no emitter configured")
292
+ return
293
+ event = {
294
+ "run_id": self._run_id,
295
+ "type": event_type,
296
+ "payload": payload or {},
297
+ "timestamp": time.time(),
298
+ }
299
+ self._emitter(event)
300
+
301
+ def progress(self, progress: float, stage: Optional[str] = None) -> None:
302
+ payload: Dict[str, Any] = {"progress": progress}
303
+ if stage is not None:
304
+ payload["stage"] = stage
305
+ self.emit("job.progress", payload)
306
+
307
+ def log(self, message: str, level: str = "info") -> None:
308
+ self.emit("job.log", {"message": message, "level": level})
309
+
310
+ def save_bytes(self, ref: str, data: bytes) -> Asset:
311
+ if not isinstance(data, (bytes, bytearray)):
312
+ raise TypeError("save_bytes expects bytes")
313
+ data = bytes(data)
314
+ max_bytes = int(os.getenv("WORKER_MAX_OUTPUT_FILE_BYTES", str(200 * 1024 * 1024)))
315
+ if max_bytes > 0 and len(data) > max_bytes:
316
+ raise ValueError("output file too large")
317
+ ref = ref.strip().lstrip("/")
318
+ if not ref.startswith(_default_output_prefix(self.run_id)):
319
+ raise ValueError(f"ref must start with '{_default_output_prefix(self.run_id)}'")
320
+
321
+ base = _require_file_api_base_url()
322
+ token = _require_file_api_token()
323
+ url = f"{base}/api/v1/file/{_encode_ref_for_url(ref)}"
324
+ # Default behavior is upsert: PUT to the tenant file store.
325
+ req = _http_request("PUT", url, token, body=data, content_type="application/octet-stream")
326
+ try:
327
+ with urllib.request.urlopen(req, timeout=30) as resp:
328
+ body = resp.read()
329
+ if resp.status < 200 or resp.status >= 300:
330
+ raise RuntimeError(f"file save failed ({resp.status})")
331
+ try:
332
+ meta = json.loads(body.decode("utf-8"))
333
+ except Exception:
334
+ meta = {}
335
+ except urllib.error.HTTPError as e:
336
+ raise RuntimeError(f"file save failed ({getattr(e, 'code', 'unknown')})") from e
337
+
338
+ return Asset(
339
+ ref=ref,
340
+ tenant_id=self.tenant_id,
341
+ local_path=None,
342
+ mime_type=str(meta.get("mime_type") or "") or None,
343
+ size_bytes=int(meta.get("size_bytes") or 0) or len(data),
344
+ sha256=str(meta.get("sha256") or "") or None,
345
+ )
346
+
347
+ def save_file(self, ref: str, local_path: str) -> Asset:
348
+ with open(local_path, "rb") as f:
349
+ data = f.read()
350
+ return self.save_bytes(ref, data)
351
+
352
+ def save_bytes_create(self, ref: str, data: bytes) -> Asset:
353
+ if not isinstance(data, (bytes, bytearray)):
354
+ raise TypeError("save_bytes_create expects bytes")
355
+ data = bytes(data)
356
+ max_bytes = int(os.getenv("WORKER_MAX_OUTPUT_FILE_BYTES", str(200 * 1024 * 1024)))
357
+ if max_bytes > 0 and len(data) > max_bytes:
358
+ raise ValueError("output file too large")
359
+ ref = ref.strip().lstrip("/")
360
+ if not ref.startswith(_default_output_prefix(self.run_id)):
361
+ raise ValueError(f"ref must start with '{_default_output_prefix(self.run_id)}'")
362
+
363
+ base = _require_file_api_base_url()
364
+ token = _require_file_api_token()
365
+ url = f"{base}/api/v1/file/{_encode_ref_for_url(ref)}"
366
+ req = _http_request("POST", url, token, body=data, content_type="application/octet-stream")
367
+ try:
368
+ with urllib.request.urlopen(req, timeout=30) as resp:
369
+ body = resp.read()
370
+ if resp.status < 200 or resp.status >= 300:
371
+ raise RuntimeError(f"file save failed ({resp.status})")
372
+ try:
373
+ meta = json.loads(body.decode("utf-8"))
374
+ except Exception:
375
+ meta = {}
376
+ except urllib.error.HTTPError as e:
377
+ if getattr(e, "code", None) == 409:
378
+ raise RuntimeError("output path already exists") from e
379
+ raise RuntimeError(f"file save failed ({getattr(e, 'code', 'unknown')})") from e
380
+
381
+ return Asset(
382
+ ref=ref,
383
+ tenant_id=self.tenant_id,
384
+ local_path=None,
385
+ mime_type=str(meta.get("mime_type") or "") or None,
386
+ size_bytes=int(meta.get("size_bytes") or 0) or len(data),
387
+ sha256=str(meta.get("sha256") or "") or None,
388
+ )
389
+
390
+ def save_file_create(self, ref: str, local_path: str) -> Asset:
391
+ with open(local_path, "rb") as f:
392
+ data = f.read()
393
+ return self.save_bytes_create(ref, data)
394
+
395
+ def save_bytes_overwrite(self, ref: str, data: bytes) -> Asset:
396
+ # Back-compat alias: overwrite is the default save_bytes behavior.
397
+ return self.save_bytes(ref, data)
398
+
399
+ def save_file_overwrite(self, ref: str, local_path: str) -> Asset:
400
+ return self.save_file(ref, local_path)
401
+
402
+ # Define the interceptor class correctly
403
+ class _AuthInterceptor(grpc.StreamStreamClientInterceptor):
404
+ def __init__(self, token: str) -> None:
405
+ self._token = token
406
+
407
+ def intercept_stream_stream(self, continuation: Any, client_call_details: Any, request_iterator: Any) -> Any:
408
+ metadata = list(client_call_details.metadata or [])
409
+ metadata.append(('authorization', f'Bearer {self._token}'))
410
+ new_details = client_call_details._replace(metadata=metadata)
411
+ return continuation(new_details, request_iterator)
412
+
413
+ class Worker:
414
+ """Worker implementation that connects to the scheduler via gRPC."""
415
+
416
+ def __init__(
417
+ self,
418
+ scheduler_addr: str = "localhost:8080",
419
+ scheduler_addrs: Optional[List[str]] = None,
420
+ user_module_names: List[str] = ["functions"], # Add new parameter for user modules
421
+ worker_id: Optional[str] = None,
422
+ auth_token: Optional[str] = None,
423
+ use_tls: bool = False,
424
+ reconnect_delay: int = 5,
425
+ max_reconnect_attempts: int = 0, # 0 means infinite retries
426
+ model_manager: Optional[ModelManagementInterface] = None, # Optional model manager
427
+ downloader: Optional[ModelDownloader] = None, # Optional model downloader
428
+ ) -> None:
429
+ """Initialize a new worker.
430
+
431
+ Args:
432
+ scheduler_addr: Address of the scheduler service.
433
+ scheduler_addrs: Optional list of seed scheduler addresses.
434
+ user_module_names: List of Python module names containing user-defined @worker_function functions.
435
+ worker_id: Unique ID for this worker (generated if not provided).
436
+ auth_token: Optional authentication token.
437
+ use_tls: Whether to use TLS for the connection.
438
+ reconnect_delay: Seconds to wait between reconnection attempts.
439
+ max_reconnect_attempts: Max reconnect attempts (0 = infinite).
440
+ model_manager: Optional model manager.
441
+ downloader: Optional model downloader.
442
+ """
443
+ self.scheduler_addr = scheduler_addr
444
+ self.scheduler_addrs = self._normalize_scheduler_addrs(scheduler_addr, scheduler_addrs)
445
+ self.user_module_names = user_module_names # Store module names
446
+ self.worker_id = worker_id or f"py-worker-{os.getpid()}"
447
+ self.auth_token = auth_token
448
+ self.use_tls = use_tls
449
+ self.reconnect_delay = reconnect_delay
450
+ self.max_reconnect_attempts = max_reconnect_attempts
451
+ self.max_input_bytes = int(os.getenv("WORKER_MAX_INPUT_BYTES", "0"))
452
+ self.max_output_bytes = int(os.getenv("WORKER_MAX_OUTPUT_BYTES", "0"))
453
+
454
+ self._jwks_url = os.getenv("SCHEDULER_JWKS_URL", "").strip()
455
+ self._jwks_ttl_seconds = int(os.getenv("SCHEDULER_JWKS_TTL_SECONDS", "300"))
456
+ self._jwt_issuer = os.getenv("SCHEDULER_JWT_ISSUER", "").strip()
457
+ self._jwt_audience = os.getenv("SCHEDULER_JWT_AUDIENCE", "").strip()
458
+ self._jwks_cache: Optional[_JWKSCache] = _JWKSCache(self._jwks_url, self._jwks_ttl_seconds) if self._jwks_url else None
459
+
460
+ self.deployment_id = os.getenv("DEPLOYMENT_ID", "") # Read DEPLOYMENT_ID env var
461
+ if not self.deployment_id:
462
+ logger.warning("DEPLOYMENT_ID environment variable not set for this worker!")
463
+
464
+ self.tenant_id = os.getenv("TENANT_ID", "")
465
+ self.runpod_pod_id = os.getenv("RUNPOD_POD_ID", "") # Read injected pod ID
466
+ if not self.runpod_pod_id:
467
+ logger.warning("RUNPOD_POD_ID environment variable not set for this worker!")
468
+
469
+ logger.info(f"RUNPOD_POD_ID: {self.runpod_pod_id}")
470
+
471
+ self._actions: Dict[str, Callable[[ActionContext, Optional[Any], bytes], bytes]] = {}
472
+ self._active_tasks: Dict[str, ActionContext] = {}
473
+ self._active_tasks_lock = threading.Lock()
474
+ self._active_function_counts: Dict[str, int] = {}
475
+ self.max_concurrency = int(os.getenv("WORKER_MAX_CONCURRENCY", "0"))
476
+ self._drain_timeout_seconds = int(os.getenv("WORKER_DRAIN_TIMEOUT_SECONDS", "0"))
477
+ self._draining = False
478
+ self._discovered_resources: Dict[str, ResourceRequirements] = {} # Store resources per function
479
+ self._function_schemas: Dict[str, Tuple[bytes, bytes]] = {} # func_name -> (input_schema_json, output_schema_json)
480
+
481
+ self._gpu_busy_lock = threading.Lock()
482
+ self._is_gpu_busy = False
483
+
484
+ self._channel: Optional[Any] = None
485
+ self._stub: Optional[Any] = None
486
+ self._stream: Optional[Any] = None
487
+ self._running = False
488
+ self._stop_event = threading.Event()
489
+ self._reconnect_count = 0
490
+ self._outgoing_queue: queue.Queue[Any] = queue.Queue()
491
+ self._leader_hint: Optional[str] = None
492
+
493
+ self._receive_thread: Optional[threading.Thread] = None
494
+ self._heartbeat_thread: Optional[threading.Thread] = None
495
+
496
+ self._reconnect_delay_base = max(0, reconnect_delay)
497
+ self._reconnect_delay_max = int(os.getenv("RECONNECT_MAX_DELAY", "60"))
498
+ self._reconnect_jitter = float(os.getenv("RECONNECT_JITTER_SECONDS", "1.0"))
499
+
500
+ resolved_model_manager = model_manager
501
+ if resolved_model_manager is None:
502
+ model_manager_path = os.getenv("MODEL_MANAGER_CLASS", "").strip()
503
+ if model_manager_path:
504
+ try:
505
+ module_path, _, class_name = model_manager_path.partition(":")
506
+ if not module_path or not class_name:
507
+ raise ValueError("MODEL_MANAGER_CLASS must be in module:Class format")
508
+ module = importlib.import_module(module_path)
509
+ manager_cls = getattr(module, class_name)
510
+ resolved_model_manager = manager_cls()
511
+ logger.info(f"Loaded ModelManager from MODEL_MANAGER_CLASS={model_manager_path}")
512
+ except Exception as e:
513
+ logger.exception(f"Failed to load MODEL_MANAGER_CLASS '{model_manager_path}': {e}")
514
+ self._model_manager = resolved_model_manager
515
+ self._downloader = downloader
516
+ if self._downloader is None:
517
+ base_url = os.getenv("COZY_HUB_URL", "").strip()
518
+ token = os.getenv("COZY_HUB_TOKEN", "").strip() or None
519
+ if base_url:
520
+ self._downloader = CozyHubDownloader(base_url, token=token)
521
+ self._supported_model_ids_from_scheduler: Optional[List[str]] = None # To store IDs from scheduler
522
+ self._model_init_done_event = threading.Event() # To signal model init is complete
523
+
524
+ if self._model_manager:
525
+ logger.info(f"ModelManager of type '{type(self._model_manager).__name__}' provided.")
526
+ else:
527
+ logger.info("No ModelManager provided. Worker operating in simple mode regarding models.")
528
+ self._model_init_done_event.set() # No model init to wait for if no manager
529
+ if self._downloader:
530
+ logger.info(f"ModelDownloader of type '{type(self._downloader).__name__}' configured.")
531
+
532
+ logger.info(f"Created worker: ID={self.worker_id}, DeploymentID={self.deployment_id or 'N/A'}, Scheduler={scheduler_addr}")
533
+
534
+ # Discover functions before setting signals? Maybe after. Let's do it here.
535
+ self._discover_and_register_functions()
536
+
537
+ self._verify_auth_token()
538
+
539
+ signal.signal(signal.SIGINT, self._handle_interrupt)
540
+ signal.signal(signal.SIGTERM, self._handle_interrupt)
541
+
542
+ @staticmethod
543
+ def _normalize_scheduler_addrs(primary: str, addrs: Optional[List[str]]) -> List[str]:
544
+ unique: List[str] = []
545
+ for addr in [primary] + (addrs or []):
546
+ addr = (addr or "").strip()
547
+ if addr and addr not in unique:
548
+ unique.append(addr)
549
+ return unique
550
+
551
+ @staticmethod
552
+ def _extract_leader_addr(details: Optional[str]) -> Optional[str]:
553
+ if not details:
554
+ return None
555
+ if details.startswith("not_leader:"):
556
+ leader = details.split("not_leader:", 1)[1].strip()
557
+ return leader or None
558
+ return None
559
+
560
+ def _set_scheduler_addr(self, addr: str) -> None:
561
+ addr = addr.strip()
562
+ if not addr:
563
+ return
564
+ self.scheduler_addr = addr
565
+ if addr not in self.scheduler_addrs:
566
+ self.scheduler_addrs.insert(0, addr)
567
+
568
+ def _iter_scheduler_addrs(self) -> Iterator[str]:
569
+ seen = set()
570
+ for addr in self.scheduler_addrs:
571
+ addr = addr.strip()
572
+ if not addr or addr in seen:
573
+ continue
574
+ seen.add(addr)
575
+ yield addr
576
+
577
+ def _verify_auth_token(self) -> None:
578
+ if not self.auth_token or not self._jwks_cache:
579
+ return
580
+ try:
581
+ header = jwt.get_unverified_header(self.auth_token)
582
+ kid = header.get("kid")
583
+ key = self._jwks_cache.get_key(kid)
584
+ if not key:
585
+ raise ValueError("JWKS key not found for token")
586
+ options = {"verify_aud": bool(self._jwt_audience)}
587
+ jwt.decode(
588
+ self.auth_token,
589
+ key=key,
590
+ algorithms=["RS256"],
591
+ audience=self._jwt_audience or None,
592
+ issuer=self._jwt_issuer or None,
593
+ options=options,
594
+ )
595
+ logger.info("Worker auth token verified against scheduler JWKS.")
596
+ except Exception as e:
597
+ logger.error(f"Worker auth token verification failed: {e}")
598
+ raise
599
+
600
+ def _format_error(self, message: str, retryable: bool) -> str:
601
+ return json.dumps({
602
+ "message": message,
603
+ "retryable": retryable,
604
+ })
605
+
606
+ def _emit_progress_event(self, event: Dict[str, Any]) -> None:
607
+ try:
608
+ run_id = event.get("run_id") or ""
609
+ event_type = event.get("type") or ""
610
+ payload = event.get("payload") or {}
611
+ if "timestamp" not in payload:
612
+ payload = dict(payload)
613
+ payload["timestamp"] = event.get("timestamp", time.time())
614
+ payload_json = json.dumps(payload).encode("utf-8")
615
+ msg = pb.WorkerSchedulerMessage(
616
+ worker_event=pb.WorkerEvent(
617
+ run_id=run_id,
618
+ event_type=event_type,
619
+ payload_json=payload_json,
620
+ )
621
+ )
622
+ self._send_message(msg)
623
+ except Exception:
624
+ logger.exception("Failed to emit progress event")
625
+
626
+
627
+ def _set_gpu_busy_status(self, busy: bool, func_name_for_log: str = "") -> None:
628
+ with self._gpu_busy_lock:
629
+ if self._is_gpu_busy == busy:
630
+ return
631
+ self._is_gpu_busy = busy
632
+ if func_name_for_log:
633
+ logger.info(f"GPU status changed to {busy} due to function '{func_name_for_log}'.")
634
+ else:
635
+ logger.info(f"GPU status changed to {busy}.")
636
+
637
+
638
+ def _get_gpu_busy_status(self) -> bool:
639
+ with self._gpu_busy_lock:
640
+ return self._is_gpu_busy
641
+
642
+
643
+ def _discover_and_register_functions(self) -> None:
644
+ """Discover and register functions marked with @worker_function."""
645
+ logger.info(f"Discovering worker functions in modules: {self.user_module_names}...")
646
+ discovered_count = 0
647
+ for module_name in self.user_module_names:
648
+ try:
649
+ module = importlib.import_module(module_name)
650
+ logger.debug(f"Inspecting module: {module_name}")
651
+ for name, obj in inspect.getmembers(module):
652
+ if inspect.isfunction(obj) and hasattr(obj, '_is_worker_function'):
653
+ if getattr(obj, '_is_worker_function') is True:
654
+ # Found a decorated function
655
+ original_func = obj # Keep reference to the actual decorated function
656
+ func_name = original_func.__name__ # Use the real function name
657
+
658
+ if func_name in self._actions:
659
+ logger.warning(f"Function '{func_name}' from module '{module_name}' conflicts with an already registered function. Skipping.")
660
+ continue
661
+
662
+ resources: ResourceRequirements = getattr(original_func, '_worker_resources', ResourceRequirements())
663
+ self._discovered_resources[func_name] = resources
664
+
665
+ expects_pipeline_flag = resources.expects_pipeline_arg
666
+ payload_type = self._infer_payload_type(original_func, expects_pipeline_flag)
667
+ return_type = self._infer_return_type(original_func)
668
+ if payload_type is None or return_type is None:
669
+ logger.error(
670
+ "Skipping function '%s' due to invalid or missing payload type annotation.",
671
+ func_name,
672
+ )
673
+ continue
674
+
675
+ try:
676
+ input_schema = msgspec.json.schema(payload_type)
677
+ output_schema = msgspec.json.schema(return_type)
678
+ self._function_schemas[func_name] = (
679
+ json.dumps(input_schema, separators=(",", ":"), sort_keys=True).encode("utf-8"),
680
+ json.dumps(output_schema, separators=(",", ":"), sort_keys=True).encode("utf-8"),
681
+ )
682
+ except Exception as exc:
683
+ logger.error("Failed to generate msgspec JSON schema for '%s': %s", func_name, exc)
684
+ continue
685
+
686
+ # Create the wrapper for gRPC/msgpack interaction
687
+ def create_wrapper(
688
+ captured_func: Callable[..., Any],
689
+ captured_name: str,
690
+ captured_payload_type: type[msgspec.Struct],
691
+ captured_return_type: type[msgspec.Struct],
692
+ func_expects_pipeline: bool = False,
693
+ ) -> Callable[[ActionContext, Optional[Any], bytes], bytes]:
694
+ @functools.wraps(captured_func) # Preserve metadata of original user func
695
+ def wrapper(ctx: ActionContext, pipeline_instance: Optional[Any], input_bytes: bytes) -> bytes:
696
+ try:
697
+ input_obj = msgspec.msgpack.decode(input_bytes, type=captured_payload_type)
698
+ self._materialize_assets(ctx.run_id, input_obj)
699
+ # Pass the context and deserialized input to the *original* user function
700
+ if func_expects_pipeline: # Only pass pipeline if function expects it
701
+ if pipeline_instance is None:
702
+ err_msg = f"Function '{captured_name}' expected a pipeline argument, but None was provided by the Worker core."
703
+ logger.error(err_msg)
704
+ raise ValueError(err_msg)
705
+ result = captured_func(ctx, pipeline_instance, input_obj)
706
+ else:
707
+ result = captured_func(ctx, input_obj) # For functions not needing a model
708
+
709
+ if ctx.is_canceled():
710
+ raise InterruptedError("Task was canceled during execution")
711
+ if not isinstance(result, captured_return_type):
712
+ raise TypeError(
713
+ f"Function {captured_name} returned {type(result)!r}, "
714
+ f"expected {captured_return_type!r}"
715
+ )
716
+ # Ensure result is bytes after msgspec msgpack serialization
717
+ packed_result = msgspec.msgpack.encode(result)
718
+ if not isinstance(packed_result, bytes):
719
+ raise TypeError(
720
+ f"Function {captured_name} did not return msgspec-serializable data resulting in bytes"
721
+ )
722
+ return packed_result
723
+ except InterruptedError as ie: # Catch cancellation specifically
724
+ logger.warning(f"Function {captured_name} run {ctx.run_id} was interrupted.")
725
+ raise # Re-raise to be handled in _execute_function
726
+ except Exception as e:
727
+ logger.exception(f"Error during execution of function {captured_name} (run_id: {ctx.run_id})")
728
+ raise # Re-raise to be caught by _execute_function
729
+ return wrapper
730
+
731
+ self._actions[func_name] = create_wrapper(
732
+ original_func,
733
+ func_name,
734
+ payload_type,
735
+ return_type,
736
+ func_expects_pipeline=expects_pipeline_flag,
737
+ )
738
+ logger.info(f"Registered function: '{func_name}' from module '{module_name}' with resources: {resources}")
739
+ discovered_count += 1
740
+
741
+ except ImportError:
742
+ logger.error(f"Could not import user module: {module_name}")
743
+ except Exception as e:
744
+ logger.exception(f"Error during discovery in module {module_name}: {e}")
745
+
746
+ if discovered_count == 0:
747
+ logger.warning(f"No functions decorated with @worker_function found in specified modules: {self.user_module_names}")
748
+ else:
749
+ logger.info(f"Discovery complete. Found {discovered_count} worker functions.")
750
+
751
+ def _infer_payload_type(
752
+ self,
753
+ func: Callable[..., Any],
754
+ expects_pipeline: bool,
755
+ ) -> Optional[type[msgspec.Struct]]:
756
+ signature = inspect.signature(func)
757
+ params = list(signature.parameters.values())
758
+ expected_params = 3 if expects_pipeline else 2
759
+ if len(params) != expected_params:
760
+ logger.error(
761
+ "Function '%s' has %d parameters but expected %d.",
762
+ func.__name__,
763
+ len(params),
764
+ expected_params,
765
+ )
766
+ return None
767
+
768
+ payload_param = params[-1]
769
+ try:
770
+ type_hints = typing.get_type_hints(func, globalns=func.__globals__)
771
+ except Exception as exc:
772
+ logger.error("Failed to resolve type hints for '%s': %s", func.__name__, exc)
773
+ return None
774
+
775
+ payload_type = type_hints.get(payload_param.name)
776
+ if payload_type is None:
777
+ logger.error("Function '%s' is missing a payload type annotation.", func.__name__)
778
+ return None
779
+
780
+ if not isinstance(payload_type, type) or not issubclass(payload_type, msgspec.Struct):
781
+ logger.error(
782
+ "Function '%s' payload type must be a msgspec.Struct, got %r.",
783
+ func.__name__,
784
+ payload_type,
785
+ )
786
+ return None
787
+
788
+ return payload_type
789
+
790
+ def _infer_return_type(self, func: Callable[..., Any]) -> Optional[type[msgspec.Struct]]:
791
+ try:
792
+ type_hints = typing.get_type_hints(func, globalns=func.__globals__)
793
+ except Exception as exc:
794
+ logger.error("Failed to resolve return type hints for '%s': %s", func.__name__, exc)
795
+ return None
796
+
797
+ return_type = type_hints.get("return")
798
+ if return_type is None:
799
+ logger.error("Function '%s' is missing a return type annotation.", func.__name__)
800
+ return None
801
+
802
+ if not isinstance(return_type, type) or not issubclass(return_type, msgspec.Struct):
803
+ logger.error(
804
+ "Function '%s' return type must be a msgspec.Struct, got %r.",
805
+ func.__name__,
806
+ return_type,
807
+ )
808
+ return None
809
+
810
+ return return_type
811
+
812
+ def _send_message(self, message: WorkerSchedulerMessage) -> None:
813
+ """Add a message to the outgoing queue."""
814
+ if self._running and not self._stop_event.is_set():
815
+ try:
816
+ self._outgoing_queue.put_nowait(message)
817
+ except queue.Full:
818
+ logger.error("Outgoing message queue is full. Message dropped!")
819
+ else:
820
+ logger.warning("Attempted to send message while worker is stopping or stopped.")
821
+
822
+ def _materialize_assets(self, run_id: str, obj: Any) -> None:
823
+ if isinstance(obj, Asset):
824
+ self._materialize_asset(run_id, obj)
825
+ return
826
+ if isinstance(obj, list):
827
+ for it in obj:
828
+ self._materialize_assets(run_id, it)
829
+ return
830
+ if isinstance(obj, dict):
831
+ for it in obj.values():
832
+ self._materialize_assets(run_id, it)
833
+ return
834
+ fields = getattr(obj, "__struct_fields__", None)
835
+ if fields and isinstance(fields, (tuple, list)):
836
+ for name in fields:
837
+ try:
838
+ self._materialize_assets(run_id, getattr(obj, name))
839
+ except Exception:
840
+ continue
841
+
842
+ def _materialize_asset(self, run_id: str, asset: Asset) -> None:
843
+ if asset.local_path:
844
+ return
845
+ ref = (asset.ref or "").strip()
846
+ if not ref:
847
+ return
848
+
849
+ base_dir = os.getenv("WORKER_RUN_DIR", "/tmp/cozy").rstrip("/")
850
+ local_inputs_dir = os.path.join(base_dir, run_id, "inputs")
851
+ os.makedirs(local_inputs_dir, exist_ok=True)
852
+ cache_dir = os.getenv("WORKER_CACHE_DIR", os.path.join(base_dir, "cache")).rstrip("/")
853
+ os.makedirs(cache_dir, exist_ok=True)
854
+
855
+ max_bytes = int(os.getenv("WORKER_MAX_INPUT_FILE_BYTES", str(200 * 1024 * 1024)))
856
+
857
+ # External URL inputs (download directly into the run folder).
858
+ if ref.startswith("http://") or ref.startswith("https://"):
859
+ if _url_is_blocked(ref):
860
+ raise RuntimeError("input url blocked")
861
+ ext = os.path.splitext(urllib.parse.urlparse(ref).path)[1] or os.path.splitext(ref)[1]
862
+ name_hash = hashlib.sha256(ref.encode("utf-8")).hexdigest()[:32]
863
+ local_path = os.path.join(local_inputs_dir, f"{name_hash}{ext}")
864
+ size, sha256_hex, mime = self._download_url_to_file(ref, local_path, max_bytes)
865
+ asset.local_path = local_path
866
+ if not asset.tenant_id:
867
+ asset.tenant_id = self.tenant_id
868
+ asset.mime_type = mime
869
+ asset.size_bytes = size
870
+ asset.sha256 = sha256_hex
871
+ return
872
+
873
+ # Cozy Hub file ref (tenant scoped) - use orchestrator file API with HEAD+cache.
874
+ base = _require_file_api_base_url()
875
+ token = _require_file_api_token()
876
+ url = f"{base}/api/v1/file/{_encode_ref_for_url(ref)}"
877
+
878
+ head_req = _http_request("HEAD", url, token)
879
+ with urllib.request.urlopen(head_req, timeout=10) as resp:
880
+ if resp.status < 200 or resp.status >= 300:
881
+ raise RuntimeError(f"failed to stat asset ({resp.status})")
882
+ sha256_hex = (resp.headers.get("X-Cozy-SHA256") or "").strip()
883
+ size_hdr = (resp.headers.get("X-Cozy-Size-Bytes") or "").strip()
884
+ mime = (resp.headers.get("X-Cozy-Mime-Type") or "").strip()
885
+ size = int(size_hdr) if size_hdr.isdigit() else 0
886
+ if max_bytes > 0 and size > max_bytes:
887
+ raise RuntimeError("input file too large")
888
+
889
+ ext = os.path.splitext(ref)[1]
890
+ if not ext and mime:
891
+ guessed = {
892
+ "image/png": ".png",
893
+ "image/jpeg": ".jpg",
894
+ "image/webp": ".webp",
895
+ "image/gif": ".gif",
896
+ }.get(mime)
897
+ ext = guessed or ""
898
+
899
+ if not sha256_hex:
900
+ sha256_hex = hashlib.sha256(ref.encode("utf-8")).hexdigest()
901
+ cache_name = f"{sha256_hex[:32]}{ext}"
902
+ cache_path = os.path.join(cache_dir, cache_name)
903
+
904
+ if not os.path.exists(cache_path):
905
+ get_req = _http_request("GET", url, token)
906
+ with urllib.request.urlopen(get_req, timeout=30) as resp:
907
+ if resp.status < 200 or resp.status >= 300:
908
+ raise RuntimeError(f"failed to download asset ({resp.status})")
909
+ _size, _sha = self._stream_to_file(resp, cache_path, max_bytes)
910
+ if not size:
911
+ size = _size
912
+ if not sha256_hex:
913
+ sha256_hex = _sha
914
+
915
+ local_path = os.path.join(local_inputs_dir, cache_name)
916
+ if not os.path.exists(local_path):
917
+ try:
918
+ os.link(cache_path, local_path)
919
+ except Exception:
920
+ try:
921
+ import shutil
922
+
923
+ shutil.copyfile(cache_path, local_path)
924
+ except Exception:
925
+ local_path = cache_path
926
+
927
+ if not mime:
928
+ with open(local_path, "rb") as f:
929
+ head = f.read(512)
930
+ mime = _infer_mime_type(ref, head)
931
+
932
+ asset.local_path = local_path
933
+ if not asset.tenant_id:
934
+ asset.tenant_id = self.tenant_id
935
+ asset.mime_type = mime or None
936
+ asset.size_bytes = size or None
937
+ asset.sha256 = sha256_hex or None
938
+
939
+ def _download_url_to_file(self, src: str, dst: str, max_bytes: int) -> Tuple[int, str, Optional[str]]:
940
+ attempts = int(os.getenv("WORKER_DOWNLOAD_RETRIES", "3"))
941
+ attempt = 0
942
+ last_err: Optional[Exception] = None
943
+ while attempt < max(1, attempts):
944
+ attempt += 1
945
+ try:
946
+ client = urllib.request.build_opener()
947
+ req = urllib.request.Request(src, method="GET")
948
+ with client.open(req, timeout=30) as resp:
949
+ size, sha = self._stream_to_file(resp, dst, max_bytes)
950
+ with open(dst, "rb") as f:
951
+ head = f.read(512)
952
+ mime = _infer_mime_type(src, head)
953
+ return size, sha, mime
954
+ except Exception as e:
955
+ last_err = e
956
+ if attempt >= max(1, attempts):
957
+ break
958
+ sleep_s = min(10.0, 0.5 * (2 ** (attempt - 1))) + random.random() * 0.2
959
+ time.sleep(sleep_s)
960
+ raise RuntimeError(f"failed to download url: {last_err}")
961
+
962
+ def _stream_to_file(self, src: Any, dst: str, max_bytes: int) -> Tuple[int, str]:
963
+ tmp = f"{dst}.tmp-{os.getpid()}-{threading.get_ident()}-{random.randint(0, 1_000_000)}"
964
+ total = 0
965
+ h = hashlib.sha256()
966
+ try:
967
+ with open(tmp, "wb") as out:
968
+ while True:
969
+ chunk = src.read(1024 * 1024)
970
+ if not chunk:
971
+ break
972
+ total += len(chunk)
973
+ if total > max_bytes:
974
+ raise RuntimeError("input file too large")
975
+ h.update(chunk)
976
+ out.write(chunk)
977
+ os.replace(tmp, dst)
978
+ finally:
979
+ try:
980
+ if os.path.exists(tmp):
981
+ os.remove(tmp)
982
+ except Exception:
983
+ pass
984
+ return total, h.hexdigest()
985
+
986
+ def connect(self) -> bool:
987
+ """Connect to the scheduler.
988
+
989
+ Returns:
990
+ bool: True if connection was successful, False otherwise.
991
+ """
992
+ attempted: set[str] = set()
993
+ while True:
994
+ addr = None
995
+ if self._leader_hint and self._leader_hint not in attempted:
996
+ addr = self._leader_hint
997
+ self._leader_hint = None
998
+ else:
999
+ for candidate in self._iter_scheduler_addrs():
1000
+ if candidate not in attempted:
1001
+ addr = candidate
1002
+ break
1003
+ if not addr:
1004
+ break
1005
+ attempted.add(addr)
1006
+ self._set_scheduler_addr(addr)
1007
+ if self._connect_once():
1008
+ return True
1009
+ return False
1010
+
1011
+ def _connect_once(self) -> bool:
1012
+ try:
1013
+ if self.use_tls:
1014
+ # TODO: Add proper credential loading if needed
1015
+ creds = grpc.ssl_channel_credentials()
1016
+ self._channel = grpc.secure_channel(self.scheduler_addr, creds)
1017
+ else:
1018
+ self._channel = grpc.insecure_channel(self.scheduler_addr)
1019
+
1020
+ interceptors = []
1021
+ if self.auth_token:
1022
+ interceptors.append(_AuthInterceptor(self.auth_token))
1023
+
1024
+ if interceptors:
1025
+ self._channel = grpc.intercept_channel(self._channel, *interceptors)
1026
+
1027
+ self._stub = pb_grpc.SchedulerWorkerServiceStub(self._channel)
1028
+
1029
+ # Start the bidirectional stream
1030
+ request_iterator = self._outgoing_message_iterator()
1031
+ self._stream = self._stub.ConnectWorker(request_iterator)
1032
+
1033
+ logger.info(f"Attempting to connect to scheduler at {self.scheduler_addr}...")
1034
+
1035
+ # Send initial registration immediately
1036
+ self._register_worker(is_heartbeat=False)
1037
+
1038
+ # Start the receive loop in a separate thread *after* stream is initiated
1039
+ self._receive_thread = threading.Thread(target=self._receive_loop, daemon=True)
1040
+ self._receive_thread.start()
1041
+
1042
+ # Start heartbeat thread
1043
+ self._heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True)
1044
+ self._heartbeat_thread.start()
1045
+
1046
+ logger.info(f"Successfully connected to scheduler at {self.scheduler_addr}")
1047
+ self._reconnect_count = 0
1048
+ return True
1049
+
1050
+ except grpc.RpcError as e:
1051
+ # Access code() and details() methods for RpcError
1052
+ code = e.code() if hasattr(e, 'code') and callable(e.code) else grpc.StatusCode.UNKNOWN
1053
+ details = e.details() if hasattr(e, 'details') and callable(e.details) else str(e)
1054
+ leader = self._extract_leader_addr(details)
1055
+ if code == grpc.StatusCode.FAILED_PRECONDITION and leader:
1056
+ logger.warning(f"Scheduler returned not_leader for {self.scheduler_addr}; redirecting to {leader}")
1057
+ self._leader_hint = leader
1058
+ self._set_scheduler_addr(leader)
1059
+ else:
1060
+ logger.error(f"Failed to connect to scheduler: {code} - {details}")
1061
+ self._close_connection()
1062
+ return False
1063
+ except Exception as e:
1064
+ logger.exception(f"Unexpected error connecting to scheduler: {e}")
1065
+ self._close_connection()
1066
+ return False
1067
+
1068
+ def _outgoing_message_iterator(self) -> Iterator[WorkerSchedulerMessage]:
1069
+ """Yields messages from the outgoing queue to send to the scheduler."""
1070
+ while not self._stop_event.is_set():
1071
+ try:
1072
+ # Block for a short time to allow stopping gracefully
1073
+ message = self._outgoing_queue.get(timeout=0.1)
1074
+ yield message
1075
+ # self._outgoing_queue.task_done() # Not needed if not joining queue
1076
+ except queue.Empty:
1077
+ continue
1078
+ except Exception as e:
1079
+ if not self._stop_event.is_set():
1080
+ logger.exception(f"Error in outgoing message iterator: {e}")
1081
+ self._handle_connection_error()
1082
+ break # Exit iterator on error
1083
+
1084
+ def _heartbeat_loop(self) -> None:
1085
+ """Periodically sends heartbeat messages."""
1086
+ while not self._stop_event.wait(HEARTBEAT_INTERVAL):
1087
+ try:
1088
+ self._register_worker(is_heartbeat=True)
1089
+ logger.debug("Sent heartbeat to scheduler")
1090
+ except Exception as e:
1091
+ if not self._stop_event.is_set():
1092
+ logger.error(f"Error sending heartbeat: {e}")
1093
+ self._handle_connection_error()
1094
+ break # Stop heartbeating on error
1095
+
1096
+ def _register_worker(self, is_heartbeat: bool = False) -> None:
1097
+ """Create and send a registration/heartbeat message."""
1098
+ try:
1099
+ mem = psutil.virtual_memory()
1100
+ cpu_cores = os.cpu_count() or 0
1101
+
1102
+ gpu_count = 0
1103
+ gpu_total_mem = 0
1104
+ vram_models = []
1105
+ gpu_used_mem = 0
1106
+ gpu_free_mem = 0
1107
+ gpu_name = ""
1108
+ gpu_driver = ""
1109
+
1110
+ if torch and torch.cuda.is_available():
1111
+ gpu_count = torch.cuda.device_count()
1112
+ if gpu_count > 0:
1113
+ try:
1114
+ props = torch.cuda.get_device_properties(0)
1115
+ gpu_total_mem = props.total_memory
1116
+ gpu_used_mem = torch.cuda.memory_allocated(0)
1117
+ gpu_name = props.name
1118
+ gpu_driver = torch.version.cuda or ""
1119
+ try:
1120
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
1121
+ gpu_total_mem = total_mem
1122
+ gpu_used_mem = total_mem - free_mem
1123
+ gpu_free_mem = free_mem
1124
+ except Exception:
1125
+ pass
1126
+ logger.debug(f"GPU: {props.name}, VRAM total={gpu_total_mem}, used={gpu_used_mem}, cuda={torch.version.cuda}")
1127
+ except Exception as gpu_err:
1128
+ logger.warning(f"Could not get GPU properties: {gpu_err}")
1129
+
1130
+ fake_gpu_count = os.getenv("WORKER_FAKE_GPU_COUNT")
1131
+ if fake_gpu_count:
1132
+ try:
1133
+ gpu_count = int(fake_gpu_count)
1134
+ if gpu_count > 0:
1135
+ fake_mem = int(os.getenv("WORKER_FAKE_GPU_MEMORY_BYTES", str(24 * 1024 * 1024 * 1024)))
1136
+ gpu_total_mem = fake_mem
1137
+ gpu_used_mem = 0
1138
+ gpu_free_mem = fake_mem
1139
+ gpu_name = os.getenv("WORKER_FAKE_GPU_NAME", "FakeGPU")
1140
+ gpu_driver = os.getenv("WORKER_FAKE_GPU_DRIVER", "fake")
1141
+ except ValueError:
1142
+ logger.warning("Invalid WORKER_FAKE_GPU_COUNT; ignoring fake GPU override.")
1143
+
1144
+ supports_model_loading_flag = False
1145
+ # current_models = []
1146
+ if self._model_manager:
1147
+ vram_models = self._model_manager.get_vram_loaded_models()
1148
+ supports_model_loading_flag = True
1149
+
1150
+ function_concurrency = {}
1151
+ for func_name, req in self._discovered_resources.items():
1152
+ if req and req.max_concurrency:
1153
+ function_concurrency[func_name] = int(req.max_concurrency)
1154
+
1155
+ cuda_version = os.getenv("WORKER_CUDA_VERSION", "").strip()
1156
+ torch_version = os.getenv("WORKER_TORCH_VERSION", "").strip()
1157
+ if torch is not None:
1158
+ if not torch_version:
1159
+ torch_version = getattr(torch, "__version__", "") or ""
1160
+ if not cuda_version:
1161
+ cuda_version = getattr(torch.version, "cuda", "") or ""
1162
+ if not cuda_version:
1163
+ cuda_version = os.getenv("CUDA_VERSION", "").strip() or os.getenv("NVIDIA_CUDA_VERSION", "").strip()
1164
+
1165
+ function_schemas = []
1166
+ for fname, (in_schema, out_schema) in self._function_schemas.items():
1167
+ try:
1168
+ function_schemas.append(
1169
+ pb.FunctionSchema(
1170
+ name=fname,
1171
+ input_schema_json=in_schema,
1172
+ output_schema_json=out_schema,
1173
+ )
1174
+ )
1175
+ except Exception:
1176
+ continue
1177
+
1178
+ resources = pb.WorkerResources(
1179
+ worker_id=self.worker_id,
1180
+ deployment_id=self.deployment_id,
1181
+ # tenant_id=self.tenant_id,
1182
+ runpod_pod_id=self.runpod_pod_id,
1183
+ gpu_is_busy=self._get_gpu_busy_status(),
1184
+ cpu_cores=cpu_cores,
1185
+ memory_bytes=mem.total,
1186
+ gpu_count=gpu_count,
1187
+ gpu_memory_bytes=gpu_total_mem,
1188
+ gpu_memory_used_bytes=gpu_used_mem,
1189
+ gpu_memory_free_bytes=gpu_free_mem,
1190
+ gpu_name=gpu_name,
1191
+ gpu_driver=gpu_driver,
1192
+ max_concurrency=self.max_concurrency,
1193
+ function_concurrency=function_concurrency,
1194
+ cuda_version=cuda_version,
1195
+ torch_version=torch_version,
1196
+ available_functions=list(self._actions.keys()),
1197
+ available_models=vram_models,
1198
+ supports_model_loading=supports_model_loading_flag,
1199
+ function_schemas=function_schemas,
1200
+ )
1201
+ registration = pb.WorkerRegistration(
1202
+ resources=resources,
1203
+ is_heartbeat=is_heartbeat
1204
+ )
1205
+ message = pb.WorkerSchedulerMessage(worker_registration=registration)
1206
+ # logger.info(f"DEBUG: Preparing to send registration. Resource object: {resources}")
1207
+ # logger.info(f"DEBUG: Value being sent for runpod_pod_id: '{resources.runpod_pod_id}'")
1208
+ self._send_message(message)
1209
+ except Exception as e:
1210
+ logger.error(f"Failed to create or send registration/heartbeat: {e}")
1211
+
1212
+ def run(self) -> None:
1213
+ """Run the worker, connecting to the scheduler and processing tasks."""
1214
+ if self._running:
1215
+ logger.warning("Worker is already running")
1216
+ return
1217
+
1218
+ self._running = True
1219
+ self._stop_event.clear()
1220
+ self._reconnect_count = 0 # Reset reconnect count on new run
1221
+ self._draining = False
1222
+
1223
+ while self._running and not self._stop_event.is_set():
1224
+ self._reconnect_count += 1
1225
+ logger.info(f"Connection attempt {self._reconnect_count}...")
1226
+ if self.connect():
1227
+ # Successfully connected, wait for stop signal or disconnection
1228
+ logger.info("Connection successful. Worker running.")
1229
+ self._stop_event.wait() # Wait here until stopped or disconnected
1230
+ logger.info("Worker run loop received stop/disconnect signal.")
1231
+ # If stopped normally (self.stop() called), _running will be False
1232
+ # If disconnected, connect() failed, threads stopped, _handle_connection_error called _stop_event.set()
1233
+ else:
1234
+ # Connection failed
1235
+ if self.max_reconnect_attempts > 0 and self._reconnect_count >= self.max_reconnect_attempts:
1236
+ logger.error("Failed to connect after maximum attempts. Stopping worker.")
1237
+ self._running = False # Ensure loop terminates
1238
+ break
1239
+
1240
+ if self._running and not self._stop_event.is_set():
1241
+ backoff = self._reconnect_delay_base * (2 ** max(self._reconnect_count - 1, 0))
1242
+ if self._reconnect_delay_max > 0:
1243
+ backoff = min(backoff, self._reconnect_delay_max)
1244
+ jitter = random.uniform(0, self._reconnect_jitter) if self._reconnect_jitter > 0 else 0
1245
+ delay = backoff + jitter
1246
+ logger.info(f"Connection attempt {self._reconnect_count} failed. Retrying in {delay:.2f} seconds...")
1247
+ # Wait for delay, but break if stop event is set during wait
1248
+ if self._stop_event.wait(delay):
1249
+ logger.info("Stop requested during reconnect delay.")
1250
+ break # Exit if stopped while waiting
1251
+ # After a failed attempt or disconnect, clear stop event for next retry
1252
+ if self._running:
1253
+ self._stop_event.clear()
1254
+
1255
+ # Cleanup after loop exits (either max attempts reached or manual stop)
1256
+ self.stop()
1257
+
1258
+ def _handle_interrupt(self, sig: int, frame: Optional[Any]) -> None:
1259
+ """Handle interrupt signal (Ctrl+C)."""
1260
+ logger.info(f"Received signal {sig}, shutting down gracefully.")
1261
+ self.stop()
1262
+
1263
+ def stop(self) -> None:
1264
+ """Stop the worker and clean up resources."""
1265
+ if not self._running and not self._stop_event.is_set(): # Check if already stopped or stopping
1266
+ # Avoid multiple stop calls piling up
1267
+ # logger.debug("Stop called but worker already stopped or stopping.")
1268
+ return
1269
+
1270
+ logger.info("Stopping worker...")
1271
+ self._draining = True
1272
+ self._running = False # Signal loops to stop
1273
+ self._stop_event.set() # Wake up any waiting threads
1274
+
1275
+ # Cancel any active tasks
1276
+ active_task_ids = []
1277
+ if self._drain_timeout_seconds > 0:
1278
+ deadline = time.time() + self._drain_timeout_seconds
1279
+ while time.time() < deadline:
1280
+ with self._active_tasks_lock:
1281
+ remaining = len(self._active_tasks)
1282
+ if remaining == 0:
1283
+ break
1284
+ time.sleep(0.2)
1285
+
1286
+ with self._active_tasks_lock:
1287
+ active_task_ids = list(self._active_tasks.keys())
1288
+ for run_id in active_task_ids:
1289
+ ctx = self._active_tasks.get(run_id)
1290
+ if ctx:
1291
+ logger.debug(f"Cancelling active task {run_id} during stop.")
1292
+ ctx.cancel()
1293
+ # Don't clear here, allow _execute_function to finish and remove
1294
+
1295
+ # Wait for threads (give them a chance to finish)
1296
+ # Stop heartbeat first
1297
+ if self._heartbeat_thread and self._heartbeat_thread.is_alive():
1298
+ logger.debug("Joining heartbeat thread...")
1299
+ self._heartbeat_thread.join(timeout=1.0)
1300
+
1301
+ # The outgoing iterator might be blocked on queue.get, stop_event wakes it
1302
+
1303
+ # Close the gRPC connection (this might interrupt the receive loop)
1304
+ self._close_connection()
1305
+
1306
+ # Wait for receive thread
1307
+ if self._receive_thread and self._receive_thread.is_alive():
1308
+ logger.debug("Joining receive thread...")
1309
+ self._receive_thread.join(timeout=2.0)
1310
+
1311
+ # Clear outgoing queue after threads are stopped
1312
+ logger.debug("Clearing outgoing message queue...")
1313
+ while not self._outgoing_queue.empty():
1314
+ try:
1315
+ self._outgoing_queue.get_nowait()
1316
+ except queue.Empty:
1317
+ break
1318
+
1319
+ logger.info("Worker stopped.")
1320
+ # Reset stop event in case run() is called again
1321
+ self._stop_event.clear()
1322
+
1323
+ def _close_connection(self) -> None:
1324
+ """Close the gRPC channel and reset state."""
1325
+ if self._stream:
1326
+ try:
1327
+ # Attempt to cancel the stream from the client side
1328
+ # This might help the server side release resources quicker
1329
+ # Note: Behavior might vary depending on server implementation
1330
+ if hasattr(self._stream, 'cancel') and callable(self._stream.cancel):
1331
+ self._stream.cancel()
1332
+ logger.debug("gRPC stream cancelled.")
1333
+ except Exception as e:
1334
+ logger.warning(f"Error cancelling gRPC stream: {e}")
1335
+ self._stream = None
1336
+
1337
+ if self._channel:
1338
+ try:
1339
+ self._channel.close()
1340
+ logger.debug("gRPC channel closed.")
1341
+ except Exception as e:
1342
+ logger.error(f"Error closing gRPC channel: {e}")
1343
+ self._channel = None
1344
+ self._stub = None
1345
+
1346
+
1347
+ def _receive_loop(self) -> None:
1348
+ """Loop to receive messages from the scheduler via the stream."""
1349
+ logger.info("Receive loop started.")
1350
+ try:
1351
+ if not self._stream:
1352
+ logger.error("Receive loop started without a valid stream.")
1353
+ # Don't call _handle_connection_error here, connect should have failed
1354
+ return
1355
+
1356
+ for message in self._stream:
1357
+ # Check stop event *before* processing
1358
+ if self._stop_event.is_set():
1359
+ logger.debug("Stop event set during iteration, exiting receive loop.")
1360
+ break
1361
+ try:
1362
+ self._process_message(message)
1363
+ except Exception as e:
1364
+ # Log errors processing individual messages but continue loop
1365
+ logger.exception(f"Error processing message: {e}")
1366
+
1367
+ except grpc.RpcError as e:
1368
+ # RpcError indicates a problem with the gRPC connection itself
1369
+ code = e.code() if hasattr(e, 'code') and callable(e.code) else grpc.StatusCode.UNKNOWN
1370
+ details = e.details() if hasattr(e, 'details') and callable(e.details) else str(e)
1371
+
1372
+ if self._stop_event.is_set():
1373
+ # If stopping, cancellation is expected
1374
+ if code == grpc.StatusCode.CANCELLED:
1375
+ logger.info("gRPC stream cancelled gracefully during shutdown.")
1376
+ else:
1377
+ logger.warning(f"gRPC error during shutdown: {code} - {details}")
1378
+ elif code == grpc.StatusCode.FAILED_PRECONDITION:
1379
+ leader = self._extract_leader_addr(details)
1380
+ if leader:
1381
+ logger.warning(f"Scheduler redirect received; reconnecting to leader at {leader}")
1382
+ self._leader_hint = leader
1383
+ self._set_scheduler_addr(leader)
1384
+ self._handle_connection_error()
1385
+ elif code == grpc.StatusCode.CANCELLED:
1386
+ logger.warning("gRPC stream unexpectedly cancelled by server or network.")
1387
+ self._handle_connection_error()
1388
+ elif code in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.DEADLINE_EXCEEDED, grpc.StatusCode.INTERNAL):
1389
+ logger.warning(f"gRPC connection lost ({code}). Attempting reconnect.")
1390
+ self._handle_connection_error()
1391
+ else:
1392
+ logger.error(f"Unhandled gRPC error in receive loop: {code} - {details}")
1393
+ self._handle_connection_error() # Attempt reconnect on unknown errors too
1394
+ except Exception as e:
1395
+ # Catch-all for non-gRPC errors in the loop
1396
+ if not self._stop_event.is_set():
1397
+ logger.exception(f"Unexpected error in receive loop: {e}")
1398
+ self._handle_connection_error() # Attempt reconnect
1399
+ finally:
1400
+ logger.info("Receive loop finished.")
1401
+
1402
+ def _handle_connection_error(self) -> None:
1403
+ """Handles actions needed when a connection error occurs during run."""
1404
+ if self._running and not self._stop_event.is_set():
1405
+ logger.warning("Connection error detected. Signaling main loop to reconnect...")
1406
+ self._close_connection() # Ensure resources are closed before reconnect attempt
1407
+ self._stop_event.set() # Signal run loop to attempt reconnection
1408
+ # else: # Already stopping or stopped
1409
+ # logger.debug("Connection error detected but worker is already stopping.")
1410
+
1411
+
1412
+ def _process_message(self, message: WorkerSchedulerMessage) -> None:
1413
+ """Process a single message received from the scheduler."""
1414
+ msg_type = message.WhichOneof('msg')
1415
+ # logger.debug(f"Received message of type: {msg_type}")
1416
+
1417
+ if msg_type == 'run_request':
1418
+ self._handle_run_request(message.run_request)
1419
+ elif msg_type == 'load_model_cmd':
1420
+ # TODO: Implement model loading logic
1421
+ # model_id = message.load_model_cmd.model_id
1422
+ # logger.warning(f"Received load_model_cmd for {model_id}, but not yet implemented.")
1423
+ # # Send result back (failure for now)
1424
+ # result = pb.LoadModelResult(model_id=model_id, success=False, error_message="Model loading not implemented")
1425
+ # self._send_message(pb.WorkerSchedulerMessage(load_model_result=result))
1426
+ self._handle_load_model_cmd(message.load_model_cmd)
1427
+ elif msg_type == 'unload_model_cmd':
1428
+ # TODO: Implement model unloading logic
1429
+ model_id = message.unload_model_cmd.model_id
1430
+ logger.warning(f"Received unload_model_cmd for {model_id}, but not yet implemented.")
1431
+ result = pb.UnloadModelResult(model_id=model_id, success=False, error_message="Model unloading not implemented")
1432
+ self._send_message(pb.WorkerSchedulerMessage(unload_model_result=result))
1433
+ elif msg_type == 'interrupt_run_cmd':
1434
+ run_id = message.interrupt_run_cmd.run_id
1435
+ self._handle_interrupt_request(run_id)
1436
+ # Add handling for other message types if needed (e.g., config updates)
1437
+ elif msg_type == 'deployment_model_config':
1438
+ if self._model_manager:
1439
+ logger.info(f"Received DeploymentModelConfig: {message.deployment_model_config.supported_model_ids}")
1440
+ self._supported_model_ids_from_scheduler = list(message.deployment_model_config.supported_model_ids)
1441
+ self._model_init_done_event.clear() # Clear before starting new init
1442
+ model_init_thread = threading.Thread(target=self._process_deployment_config_async_wrapper, daemon=True)
1443
+ model_init_thread.start()
1444
+ else:
1445
+ logger.info("Received DeploymentModelConfig, but no model manager configured. Ignoring.")
1446
+ self._model_init_done_event.set() # Signal completion as there's nothing to do
1447
+ elif msg_type is None:
1448
+ logger.warning("Received empty message from scheduler.")
1449
+ else:
1450
+ logger.warning(f"Received unhandled message type: {msg_type}")
1451
+
1452
+ def _process_deployment_config_async_wrapper(self) -> None:
1453
+ if not self._model_manager or self._supported_model_ids_from_scheduler is None:
1454
+ self._model_init_done_event.set()
1455
+ return
1456
+
1457
+ loop = None
1458
+ try:
1459
+ # Get or create an event loop for this thread
1460
+ try:
1461
+ loop = asyncio.get_running_loop()
1462
+ except RuntimeError:
1463
+ loop = asyncio.new_event_loop()
1464
+ asyncio.set_event_loop(loop)
1465
+
1466
+ loop.run_until_complete(
1467
+ self._model_manager.process_supported_models_config(
1468
+ self._supported_model_ids_from_scheduler,
1469
+ self._downloader
1470
+ )
1471
+ )
1472
+ logger.info("Model configuration and downloads (if any) processed.")
1473
+ except Exception as e:
1474
+ logger.exception(f"Error during model_manager.process_supported_models_config: {e}")
1475
+ finally:
1476
+ if loop and not loop.is_running() and not loop.is_closed(): # Clean up loop if we created it
1477
+ loop.close()
1478
+ self._model_init_done_event.set() # Signal completion or failure
1479
+
1480
+ def _handle_load_model_cmd(self, cmd: LoadModelCommand) -> None:
1481
+ model_id = cmd.model_id
1482
+ logger.info(f"Received LoadModelCommand for: {model_id}")
1483
+ success = False; error_msg = ""
1484
+ if not self._model_manager:
1485
+ error_msg = "LoadModelCommand: No model manager configured on worker."
1486
+ logger.error(error_msg)
1487
+ else:
1488
+ try:
1489
+ # Wait for initial model downloads if they haven't finished
1490
+ if not self._model_init_done_event.is_set():
1491
+ logger.info(f"LoadModelCmd ({model_id}): Waiting for initial model setup...")
1492
+ # Timeout for this wait, can be adjusted
1493
+ if not self._model_init_done_event.wait(timeout=300.0): # 5 minutes
1494
+ raise TimeoutError("Timeout waiting for model initialization before VRAM load.")
1495
+
1496
+ logger.info(f"Model Memory Manager attempting to load '{model_id}' into VRAM...")
1497
+ # load_model_into_vram is async
1498
+ success = asyncio.run(self._model_manager.load_model_into_vram(model_id))
1499
+ if success: logger.info(f"Model '{model_id}' loaded to VRAM by Model Memory Manager.")
1500
+ else: error_msg = f"MMM.load_model_into_vram failed for '{model_id}'."; logger.error(error_msg)
1501
+ except Exception as e: error_msg = f"Exception in mmm.load_model_into_vram for '{model_id}': {e}"; logger.exception(error_msg)
1502
+
1503
+ result = pb.LoadModelResult(model_id=model_id, success=success, error_message=error_msg)
1504
+ self._send_message(pb.WorkerSchedulerMessage(load_model_result=result))
1505
+
1506
+
1507
+ def _handle_run_request(self, request: TaskExecutionRequest) -> None:
1508
+ """Handle a task execution request from the scheduler."""
1509
+ run_id = request.run_id
1510
+ function_name = request.function_name
1511
+ input_payload = request.input_payload
1512
+ required_model_id_for_exec = ""
1513
+ timeout_ms = int(getattr(request, "timeout_ms", 0) or 0)
1514
+ tenant_id = str(getattr(request, "tenant_id", "") or "") or (self.tenant_id or "")
1515
+ user_id = str(getattr(request, "user_id", "") or "")
1516
+
1517
+ if request.required_models and len(request.required_models) > 0:
1518
+ required_model_id_for_exec = request.required_models[0]
1519
+
1520
+ logger.info(f"Received Task request: run_id={run_id}, function={function_name}, model='{required_model_id_for_exec or 'None'}'")
1521
+
1522
+ func_wrapper = self._actions.get(function_name)
1523
+ if not func_wrapper:
1524
+ error_msg = f"Unknown function requested: {function_name}"
1525
+ logger.error(error_msg)
1526
+ self._send_task_result(run_id, False, None, error_msg)
1527
+ return
1528
+ if self.max_input_bytes > 0 and len(input_payload) > self.max_input_bytes:
1529
+ error_msg = f"Input payload too large: {len(input_payload)} bytes (max {self.max_input_bytes})"
1530
+ logger.error(error_msg)
1531
+ self._send_task_result(run_id, False, None, error_msg)
1532
+ return
1533
+ if self._draining:
1534
+ error_msg = "Worker is draining; refusing new tasks"
1535
+ logger.warning(error_msg)
1536
+ self._send_task_result(run_id, False, None, error_msg)
1537
+ return
1538
+
1539
+ ctx = ActionContext(
1540
+ run_id,
1541
+ emitter=self._emit_progress_event,
1542
+ tenant_id=tenant_id or None,
1543
+ user_id=user_id or None,
1544
+ timeout_ms=timeout_ms if timeout_ms > 0 else None,
1545
+ )
1546
+ # Add to active tasks *before* starting thread
1547
+ with self._active_tasks_lock:
1548
+ # Double-check if task is already active (race condition mitigation)
1549
+ if run_id in self._active_tasks:
1550
+ error_msg = f"Task with run_id {run_id} is already active (race condition?)."
1551
+ logger.error(error_msg)
1552
+ return # Avoid starting duplicate thread
1553
+ if self.max_concurrency > 0 and len(self._active_tasks) >= self.max_concurrency:
1554
+ error_msg = f"Worker concurrency limit reached ({self.max_concurrency})."
1555
+ logger.error(error_msg)
1556
+ self._send_task_result(run_id, False, None, error_msg)
1557
+ return
1558
+ resource_req = self._discovered_resources.get(function_name)
1559
+ func_limit = resource_req.max_concurrency if resource_req and resource_req.max_concurrency else 0
1560
+ if func_limit > 0 and self._active_function_counts.get(function_name, 0) >= func_limit:
1561
+ error_msg = f"Function concurrency limit reached for {function_name} ({func_limit})."
1562
+ logger.error(error_msg)
1563
+ self._send_task_result(run_id, False, None, error_msg)
1564
+ return
1565
+ self._active_tasks[run_id] = ctx
1566
+ if func_limit > 0:
1567
+ self._active_function_counts[function_name] = self._active_function_counts.get(function_name, 0) + 1
1568
+
1569
+ # Execute function in a separate thread to avoid blocking the receive loop
1570
+ thread = threading.Thread(
1571
+ target=self._execute_function,
1572
+ args=(ctx, function_name, func_wrapper, input_payload, required_model_id_for_exec),
1573
+ daemon=True,
1574
+ )
1575
+ thread.start()
1576
+
1577
+ def _handle_interrupt_request(self, run_id: str) -> None:
1578
+ """Handle a request to interrupt/cancel a running task."""
1579
+ logger.info(f"Received interrupt request for run_id={run_id}")
1580
+ with self._active_tasks_lock:
1581
+ ctx = self._active_tasks.get(run_id)
1582
+ if ctx:
1583
+ ctx.cancel() # Set internal flag and event
1584
+ else:
1585
+ logger.warning(f"Could not interrupt task {run_id}: Not found in active tasks.")
1586
+
1587
+ def _execute_function(
1588
+ self,
1589
+ ctx: ActionContext,
1590
+ function_name: str,
1591
+ func_to_execute: Callable[[ActionContext, Optional[Any], bytes], bytes],
1592
+ input_payload: bytes,
1593
+ required_model_id: str,
1594
+ ) -> None:
1595
+ """Execute the registered function and send the result/error back."""
1596
+ run_id = ctx.run_id
1597
+ output_payload: Optional[bytes] = None
1598
+ error_message: str = ""
1599
+ success = False
1600
+
1601
+ # Determine if this function requires GPU and manage worker's GPU state
1602
+ func_requires_gpu = False
1603
+ resource_req = self._discovered_resources.get(function_name)
1604
+ if resource_req:
1605
+ func_requires_gpu = resource_req.requires_gpu
1606
+ func_expects_pipeline = resource_req.expects_pipeline_arg
1607
+
1608
+ # Variable to track if this specific thread execution set the GPU busy
1609
+ this_thread_set_gpu_busy = False
1610
+ if func_requires_gpu:
1611
+ with self._gpu_busy_lock: # Lock to check and set self._is_gpu_busy atomically
1612
+ if not self._is_gpu_busy:
1613
+ self._is_gpu_busy = True
1614
+ this_thread_set_gpu_busy = True
1615
+ logger.info(f"Worker GPU marked as BUSY by task {run_id} ({function_name}).")
1616
+ else:
1617
+ logger.warning(f"Task {run_id} ({function_name}) requires GPU, but worker GPU was already marked busy. Proceeding...")
1618
+
1619
+ active_pipeline_instance = None # To hold the pipeline for the user function
1620
+ try:
1621
+ if ctx.is_canceled():
1622
+ raise InterruptedError("Task cancelled before execution")
1623
+
1624
+ if func_expects_pipeline:
1625
+ if not required_model_id and resource_req and resource_req.model_name:
1626
+ required_model_id = str(resource_req.model_name)
1627
+ if not required_model_id:
1628
+ raise ValueError(f"Function '{function_name}' expects a pipeline argument, but no model ID was provided.")
1629
+
1630
+ if not self._model_manager:
1631
+ raise RuntimeError(f"Function '{function_name}' expects a pipeline argument, but no model manager configured on worker.")
1632
+
1633
+ if not self._model_init_done_event.is_set():
1634
+ logger.info(f"Task {run_id} ({function_name}) waiting for initial model setup...")
1635
+ if not self._model_init_done_event.wait(timeout=300.0): # 5 min timeout
1636
+ raise TimeoutError(f"Timeout waiting for model initialization for task {run_id}")
1637
+ logger.info(f"Initial model setup complete. Proceeding for task {run_id}.")
1638
+
1639
+ logger.info(f"Task {run_id} ({function_name}) getting active pipeline for model '{required_model_id}'...")
1640
+ # get_active_pipeline is async
1641
+ active_pipeline_instance = asyncio.run(self._model_manager.get_active_pipeline(required_model_id))
1642
+ if not active_pipeline_instance:
1643
+ raise RuntimeError(f"ModelManager failed to provide active pipeline for '{required_model_id}' for task {run_id}.")
1644
+
1645
+ logger.info(f"Task {run_id} ({function_name}) obtained pipeline for model '{required_model_id}'.")
1646
+
1647
+ # Execute the function wrapper (which handles deserialization/serialization)
1648
+ output_payload = func_to_execute(ctx, active_pipeline_instance, input_payload)
1649
+ # Check for cancellation *during* execution (func should check ctx.is_canceled)
1650
+ if ctx.is_canceled():
1651
+ raise InterruptedError("Task was cancelled during execution")
1652
+
1653
+ if output_payload is not None and self.max_output_bytes > 0:
1654
+ if len(output_payload) > self.max_output_bytes:
1655
+ raise ValueError(f"Output payload too large: {len(output_payload)} bytes (max {self.max_output_bytes})")
1656
+
1657
+ success = True
1658
+ logger.info(f"Task {run_id} completed successfully.")
1659
+
1660
+ except InterruptedError as e:
1661
+ error_message = self._format_error(str(e) or "Task was canceled", retryable=False)
1662
+ logger.warning(f"Task {run_id} was canceled: {error_message}")
1663
+ success = False # Explicitly set success to False on cancellation
1664
+ except RetryableError as e:
1665
+ error_message = self._format_error(f"{type(e).__name__}: {str(e)}", retryable=True)
1666
+ logger.error(f"Task {run_id} ({function_name}) retryable failure: {error_message}")
1667
+ success = False
1668
+ except FatalError as e:
1669
+ error_message = self._format_error(f"{type(e).__name__}: {str(e)}", retryable=False)
1670
+ logger.error(f"Task {run_id} ({function_name}) fatal failure: {error_message}")
1671
+ success = False
1672
+ except (ValueError, RuntimeError, TimeoutError) as ve_rte_to: # Catch specific errors we raise
1673
+ retryable = isinstance(ve_rte_to, TimeoutError)
1674
+ error_message = self._format_error(f"{type(ve_rte_to).__name__}: {str(ve_rte_to)}", retryable=retryable)
1675
+ logger.error(f"Task {run_id} ({function_name}) failed pre-execution or during model acquisition: {error_message}")
1676
+ success = False
1677
+ except Exception as e:
1678
+ error_message = self._format_error(f"{type(e).__name__}: {str(e)}", retryable=False)
1679
+ logger.exception(f"Error executing function for run_id={run_id}: {error_message}")
1680
+ success = False
1681
+ finally:
1682
+ # Release the GPU if this thread set it busy
1683
+ if this_thread_set_gpu_busy:
1684
+ with self._gpu_busy_lock: # Lock to set self._is_gpu_busy
1685
+ self._is_gpu_busy = False
1686
+ logger.info(f"Worker GPU marked as NOT BUSY by task {run_id} ({function_name}).")
1687
+
1688
+ # Always send a result back, regardless of success, failure, or cancellation
1689
+ self._send_task_result(run_id, success, output_payload, error_message)
1690
+ # Remove from active tasks *after* sending result
1691
+ with self._active_tasks_lock:
1692
+ if run_id in self._active_tasks:
1693
+ del self._active_tasks[run_id]
1694
+ resource_req = self._discovered_resources.get(function_name)
1695
+ func_limit = resource_req.max_concurrency if resource_req and resource_req.max_concurrency else 0
1696
+ if func_limit > 0:
1697
+ current = self._active_function_counts.get(function_name, 0) - 1
1698
+ if current <= 0:
1699
+ self._active_function_counts.pop(function_name, None)
1700
+ else:
1701
+ self._active_function_counts[function_name] = current
1702
+ # else: # Might have been removed by stop() already
1703
+ # logger.warning(f"Task {run_id} not found in active tasks during cleanup.")
1704
+
1705
+
1706
+ def _send_task_result(self, run_id: str, success: bool, output_payload: Optional[bytes], error_message: str) -> None:
1707
+ """Send a task execution result back to the scheduler via the queue."""
1708
+ try:
1709
+ result = pb.TaskExecutionResult(
1710
+ run_id=run_id,
1711
+ success=success,
1712
+ output_payload=(output_payload or b'') if success else b'', # Default to b'' if None
1713
+ error_message=error_message if not success else ""
1714
+ )
1715
+ msg = pb.WorkerSchedulerMessage(run_result=result)
1716
+ self._send_message(msg)
1717
+ logger.debug(f"Queued task result for run_id={run_id}, success={success}")
1718
+ except Exception as e:
1719
+ # This shouldn't generally fail unless message creation has issues
1720
+ logger.error(f"Failed to create or queue task result for run_id={run_id}: {e}")