avtomatika-worker 1.0b3__py3-none-any.whl → 1.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- avtomatika_worker/__init__.py +1 -1
- avtomatika_worker/config.py +6 -0
- avtomatika_worker/s3.py +76 -48
- avtomatika_worker/task_files.py +60 -2
- avtomatika_worker/types.py +9 -4
- avtomatika_worker/worker.py +333 -155
- {avtomatika_worker-1.0b3.dist-info → avtomatika_worker-1.0b4.dist-info}/METADATA +76 -9
- avtomatika_worker-1.0b4.dist-info/RECORD +12 -0
- {avtomatika_worker-1.0b3.dist-info → avtomatika_worker-1.0b4.dist-info}/WHEEL +1 -1
- {avtomatika_worker-1.0b3.dist-info → avtomatika_worker-1.0b4.dist-info}/licenses/LICENSE +1 -1
- avtomatika_worker/client.py +0 -93
- avtomatika_worker/constants.py +0 -22
- avtomatika_worker-1.0b3.dist-info/RECORD +0 -14
- {avtomatika_worker-1.0b3.dist-info → avtomatika_worker-1.0b4.dist-info}/top_level.txt +0 -0
avtomatika_worker/worker.py
CHANGED
|
@@ -1,26 +1,46 @@
|
|
|
1
|
-
from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
|
|
1
|
+
from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep, to_thread
|
|
2
2
|
from dataclasses import is_dataclass
|
|
3
|
-
from inspect import Parameter, signature
|
|
4
|
-
from json import JSONDecodeError
|
|
3
|
+
from inspect import Parameter, iscoroutinefunction, signature
|
|
5
4
|
from logging import getLogger
|
|
6
5
|
from os.path import join
|
|
7
6
|
from typing import Any, Callable
|
|
8
7
|
|
|
9
|
-
from aiohttp import ClientSession,
|
|
10
|
-
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
13
|
-
from .constants import (
|
|
8
|
+
from aiohttp import ClientSession, TCPConnector, web
|
|
9
|
+
from rxon import Transport, create_transport
|
|
10
|
+
from rxon.blob import calculate_config_hash
|
|
11
|
+
from rxon.constants import (
|
|
14
12
|
COMMAND_CANCEL_TASK,
|
|
13
|
+
ERROR_CODE_INTEGRITY_MISMATCH,
|
|
15
14
|
ERROR_CODE_INVALID_INPUT,
|
|
16
15
|
ERROR_CODE_PERMANENT,
|
|
17
16
|
ERROR_CODE_TRANSIENT,
|
|
17
|
+
MSG_TYPE_PROGRESS,
|
|
18
18
|
TASK_STATUS_CANCELLED,
|
|
19
19
|
TASK_STATUS_FAILURE,
|
|
20
|
+
TASK_STATUS_SUCCESS,
|
|
21
|
+
)
|
|
22
|
+
from rxon.exceptions import RxonError
|
|
23
|
+
from rxon.models import (
|
|
24
|
+
FileMetadata,
|
|
25
|
+
GPUInfo,
|
|
26
|
+
Heartbeat,
|
|
27
|
+
InstalledModel,
|
|
28
|
+
ProgressUpdatePayload,
|
|
29
|
+
Resources,
|
|
30
|
+
TaskError,
|
|
31
|
+
TaskPayload,
|
|
32
|
+
TaskResult,
|
|
33
|
+
WorkerCapabilities,
|
|
34
|
+
WorkerRegistration,
|
|
20
35
|
)
|
|
36
|
+
from rxon.security import create_client_ssl_context
|
|
37
|
+
from rxon.utils import to_dict
|
|
38
|
+
from rxon.validators import validate_identifier
|
|
39
|
+
|
|
40
|
+
from .config import WorkerConfig
|
|
21
41
|
from .s3 import S3Manager
|
|
22
42
|
from .task_files import TaskFiles
|
|
23
|
-
from .types import ParamValidationError
|
|
43
|
+
from .types import CapacityChecker, Middleware, ParamValidationError
|
|
24
44
|
|
|
25
45
|
try:
|
|
26
46
|
from pydantic import BaseModel, ValidationError
|
|
@@ -37,7 +57,7 @@ class Worker:
|
|
|
37
57
|
"""The main class for creating and running a worker.
|
|
38
58
|
Implements a hybrid interaction model with the Orchestrator:
|
|
39
59
|
- PULL model for fetching tasks.
|
|
40
|
-
-
|
|
60
|
+
- Transport layer for real-time commands (cancellation) and sending progress.
|
|
41
61
|
"""
|
|
42
62
|
|
|
43
63
|
def __init__(
|
|
@@ -48,6 +68,8 @@ class Worker:
|
|
|
48
68
|
http_session: ClientSession | None = None,
|
|
49
69
|
skill_dependencies: dict[str, list[str]] | None = None,
|
|
50
70
|
config: WorkerConfig | None = None,
|
|
71
|
+
capacity_checker: CapacityChecker | None = None,
|
|
72
|
+
clients: list[tuple[dict[str, Any], Transport]] | None = None,
|
|
51
73
|
):
|
|
52
74
|
self._config = config or WorkerConfig()
|
|
53
75
|
self._s3_manager = S3Manager(self._config)
|
|
@@ -58,6 +80,8 @@ class Worker:
|
|
|
58
80
|
self._task_type_limits = task_type_limits or {}
|
|
59
81
|
self._task_handlers: dict[str, dict[str, Any]] = {}
|
|
60
82
|
self._skill_dependencies = skill_dependencies or {}
|
|
83
|
+
self._middlewares: list[Middleware] = []
|
|
84
|
+
self._capacity_checker = capacity_checker
|
|
61
85
|
|
|
62
86
|
# Worker state
|
|
63
87
|
self._current_load = 0
|
|
@@ -66,10 +90,10 @@ class Worker:
|
|
|
66
90
|
self._active_tasks: dict[str, Task] = {}
|
|
67
91
|
self._http_session = http_session
|
|
68
92
|
self._session_is_managed_externally = http_session is not None
|
|
69
|
-
self._ws_connection: ClientWebSocketResponse | None = None
|
|
70
93
|
self._shutdown_event = Event()
|
|
71
94
|
self._registered_event = Event()
|
|
72
95
|
self._debounce_task: Task | None = None
|
|
96
|
+
self._ssl_context = None
|
|
73
97
|
|
|
74
98
|
# --- Weighted Round-Robin State ---
|
|
75
99
|
self._total_orchestrator_weight = 0
|
|
@@ -77,23 +101,31 @@ class Worker:
|
|
|
77
101
|
for o in self._config.ORCHESTRATORS:
|
|
78
102
|
o["current_weight"] = 0
|
|
79
103
|
self._total_orchestrator_weight += o.get("weight", 1)
|
|
80
|
-
|
|
81
|
-
self._clients
|
|
82
|
-
if self._http_session:
|
|
104
|
+
self._clients = clients or []
|
|
105
|
+
if not self._clients and self._http_session:
|
|
83
106
|
self._init_clients()
|
|
84
107
|
|
|
108
|
+
def add_middleware(self, middleware: Middleware) -> None:
|
|
109
|
+
"""Adds a middleware to the execution chain."""
|
|
110
|
+
self._middlewares.append(middleware)
|
|
111
|
+
|
|
85
112
|
def _init_clients(self):
|
|
86
|
-
"""Initializes
|
|
87
|
-
if
|
|
88
|
-
|
|
113
|
+
"""Initializes Transport instances for each configured orchestrator."""
|
|
114
|
+
# Even if we don't have an external session, we might create transports
|
|
115
|
+
# that will create their own sessions. But if we want to share one, we need it here.
|
|
116
|
+
session_to_use = self._http_session if self._session_is_managed_externally else None
|
|
117
|
+
|
|
89
118
|
self._clients = [
|
|
90
119
|
(
|
|
91
120
|
o,
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
base_url=o["url"],
|
|
121
|
+
create_transport(
|
|
122
|
+
url=o["url"],
|
|
95
123
|
worker_id=self._config.WORKER_ID,
|
|
96
124
|
token=o.get("token", self._config.WORKER_TOKEN),
|
|
125
|
+
ssl_context=self._ssl_context,
|
|
126
|
+
session=session_to_use,
|
|
127
|
+
result_retries=self._config.RESULT_MAX_RETRIES,
|
|
128
|
+
result_retry_delay=self._config.RESULT_RETRY_INITIAL_DELAY,
|
|
97
129
|
),
|
|
98
130
|
)
|
|
99
131
|
for o in self._config.ORCHESTRATORS
|
|
@@ -114,6 +146,9 @@ class Worker:
|
|
|
114
146
|
|
|
115
147
|
def task(self, name: str, task_type: str | None = None) -> Callable:
|
|
116
148
|
"""Decorator to register a function as a task handler."""
|
|
149
|
+
validate_identifier(name, "task name")
|
|
150
|
+
if task_type:
|
|
151
|
+
validate_identifier(task_type, "task type")
|
|
117
152
|
|
|
118
153
|
def decorator(func: Callable) -> Callable:
|
|
119
154
|
logger.info(f"Registering task: '{name}' (type: {task_type or 'N/A'})")
|
|
@@ -164,10 +199,13 @@ class Worker:
|
|
|
164
199
|
if is_available:
|
|
165
200
|
supported_tasks.append(name)
|
|
166
201
|
|
|
202
|
+
if self._capacity_checker:
|
|
203
|
+
supported_tasks = [task for task in supported_tasks if self._capacity_checker(task)]
|
|
204
|
+
|
|
167
205
|
status = "idle" if supported_tasks else "busy"
|
|
168
206
|
return {"status": status, "supported_tasks": supported_tasks}
|
|
169
207
|
|
|
170
|
-
def _get_next_client(self) ->
|
|
208
|
+
def _get_next_client(self) -> Transport | None:
|
|
171
209
|
"""
|
|
172
210
|
Selects the next orchestrator client using a smooth weighted round-robin algorithm.
|
|
173
211
|
"""
|
|
@@ -206,27 +244,29 @@ class Worker:
|
|
|
206
244
|
# Schedule the new debounced call.
|
|
207
245
|
self._debounce_task = create_task(self._debounced_heartbeat_sender())
|
|
208
246
|
|
|
209
|
-
async def _poll_for_tasks(self, client:
|
|
247
|
+
async def _poll_for_tasks(self, client: Transport):
|
|
210
248
|
"""Polls a specific Orchestrator for new tasks."""
|
|
211
|
-
|
|
212
|
-
if
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
self._current_load += 1
|
|
216
|
-
if (task_handler_info := self._task_handlers.get(task_data["type"])) and (
|
|
217
|
-
task_type_for_limit := task_handler_info.get("type")
|
|
218
|
-
):
|
|
219
|
-
self._current_load_by_type[task_type_for_limit] += 1
|
|
220
|
-
self._schedule_heartbeat_debounce()
|
|
249
|
+
current_state = self._get_current_state()
|
|
250
|
+
if current_state["status"] == "busy":
|
|
251
|
+
return
|
|
221
252
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
253
|
+
try:
|
|
254
|
+
task_data = await client.poll_task(timeout=self._config.TASK_POLL_TIMEOUT)
|
|
255
|
+
if task_data:
|
|
256
|
+
task_data_dict = to_dict(task_data)
|
|
257
|
+
task_data_dict["client"] = client
|
|
258
|
+
|
|
259
|
+
self._current_load += 1
|
|
260
|
+
if (task_handler_info := self._task_handlers.get(task_data.type)) and (
|
|
261
|
+
task_type_for_limit := task_handler_info.get("type")
|
|
262
|
+
):
|
|
263
|
+
self._current_load_by_type[task_type_for_limit] += 1
|
|
264
|
+
self._schedule_heartbeat_debounce()
|
|
265
|
+
|
|
266
|
+
task = create_task(self._process_task(task_data_dict))
|
|
267
|
+
self._active_tasks[task_data.task_id] = task
|
|
268
|
+
except RxonError as e:
|
|
269
|
+
logger.error(f"Error polling tasks: {e}")
|
|
230
270
|
|
|
231
271
|
async def _start_polling(self):
|
|
232
272
|
"""The main loop for polling tasks."""
|
|
@@ -253,7 +293,6 @@ class Worker:
|
|
|
253
293
|
def _prepare_task_params(handler: Callable, params: dict[str, Any]) -> Any:
|
|
254
294
|
"""
|
|
255
295
|
Inspects the handler's signature to validate and instantiate params.
|
|
256
|
-
Supports dict, dataclasses, and optional pydantic models.
|
|
257
296
|
"""
|
|
258
297
|
sig = signature(handler)
|
|
259
298
|
params_annotation = sig.parameters.get("params").annotation
|
|
@@ -271,11 +310,11 @@ class Worker:
|
|
|
271
310
|
# Dataclass Instantiation
|
|
272
311
|
if isinstance(params_annotation, type) and is_dataclass(params_annotation):
|
|
273
312
|
try:
|
|
274
|
-
# Filter unknown fields
|
|
313
|
+
# Filter unknown fields
|
|
275
314
|
known_fields = {f.name for f in params_annotation.__dataclass_fields__.values()}
|
|
276
315
|
filtered_params = {k: v for k, v in params.items() if k in known_fields}
|
|
277
316
|
|
|
278
|
-
#
|
|
317
|
+
# Check required fields
|
|
279
318
|
required_fields = [
|
|
280
319
|
f.name
|
|
281
320
|
for f in params_annotation.__dataclass_fields__.values()
|
|
@@ -287,7 +326,6 @@ class Worker:
|
|
|
287
326
|
|
|
288
327
|
return params_annotation(**filtered_params)
|
|
289
328
|
except (TypeError, ValueError) as e:
|
|
290
|
-
# TypeError for missing/extra args, ValueError from __post_init__
|
|
291
329
|
raise ParamValidationError(str(e)) from e
|
|
292
330
|
|
|
293
331
|
return params
|
|
@@ -296,8 +334,7 @@ class Worker:
|
|
|
296
334
|
"""Injects dependencies based on type hints."""
|
|
297
335
|
deps = {}
|
|
298
336
|
task_dir = join(self._config.TASK_FILES_DIR, task_id)
|
|
299
|
-
|
|
300
|
-
task_files = TaskFiles(task_dir)
|
|
337
|
+
task_files = TaskFiles(task_dir, s3_manager=self._s3_manager)
|
|
301
338
|
|
|
302
339
|
sig = signature(handler)
|
|
303
340
|
for name, param in sig.parameters.items():
|
|
@@ -306,66 +343,150 @@ class Worker:
|
|
|
306
343
|
|
|
307
344
|
return deps
|
|
308
345
|
|
|
309
|
-
async def _process_task(self,
|
|
346
|
+
async def _process_task(self, task_data_raw: dict[str, Any]):
|
|
310
347
|
"""Executes the task logic."""
|
|
311
|
-
|
|
312
|
-
|
|
348
|
+
client: Transport = task_data_raw.pop("client")
|
|
349
|
+
|
|
350
|
+
# Parse incoming task data using protocol model
|
|
351
|
+
if "params_metadata" in task_data_raw and task_data_raw["params_metadata"]:
|
|
352
|
+
task_data_raw["params_metadata"] = {
|
|
353
|
+
k: self._from_dict(FileMetadata, v) for k, v in task_data_raw["params_metadata"].items()
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
task_payload = self._from_dict(TaskPayload, task_data_raw)
|
|
357
|
+
task_id, job_id, task_name = task_payload.task_id, task_payload.job_id, task_payload.type
|
|
358
|
+
params = task_payload.params
|
|
313
359
|
|
|
314
|
-
result: dict[str, Any] = {}
|
|
315
360
|
handler_data = self._task_handlers.get(task_name)
|
|
316
361
|
task_type_for_limit = handler_data.get("type") if handler_data else None
|
|
317
362
|
|
|
318
|
-
|
|
363
|
+
result_obj: TaskResult | None = None
|
|
364
|
+
|
|
365
|
+
# Create a progress sender wrapper attached to this specific client
|
|
366
|
+
async def send_progress_wrapper(task_id_arg, job_id_arg, progress, message=""):
|
|
367
|
+
payload = ProgressUpdatePayload(
|
|
368
|
+
event=MSG_TYPE_PROGRESS, task_id=task_id_arg, job_id=job_id_arg, progress=progress, message=message
|
|
369
|
+
)
|
|
370
|
+
await client.send_progress(payload)
|
|
319
371
|
|
|
320
372
|
try:
|
|
321
373
|
if not handler_data:
|
|
322
374
|
message = f"Unsupported task: {task_name}"
|
|
323
375
|
logger.warning(message)
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
376
|
+
error = TaskError(code=ERROR_CODE_PERMANENT, message=message)
|
|
377
|
+
result_obj = TaskResult(
|
|
378
|
+
job_id=job_id,
|
|
379
|
+
task_id=task_id,
|
|
380
|
+
worker_id=self._config.WORKER_ID,
|
|
381
|
+
status=TASK_STATUS_FAILURE,
|
|
382
|
+
error=error,
|
|
328
383
|
)
|
|
329
|
-
|
|
330
|
-
|
|
384
|
+
else:
|
|
385
|
+
# Download files
|
|
386
|
+
params = await self._s3_manager.process_params(params, task_id, metadata=task_payload.params_metadata)
|
|
387
|
+
validated_params = self._prepare_task_params(handler_data["func"], params)
|
|
388
|
+
deps = self._prepare_dependencies(handler_data["func"], task_id)
|
|
389
|
+
|
|
390
|
+
handler_kwargs = {
|
|
391
|
+
"params": validated_params,
|
|
392
|
+
"task_id": task_id,
|
|
393
|
+
"job_id": job_id,
|
|
394
|
+
"tracing_context": task_payload.tracing_context,
|
|
395
|
+
"priority": task_data_raw.get("priority", 0),
|
|
396
|
+
"send_progress": send_progress_wrapper,
|
|
397
|
+
"add_to_hot_cache": self.add_to_hot_cache,
|
|
398
|
+
"remove_from_hot_cache": self.remove_from_hot_cache,
|
|
399
|
+
**deps,
|
|
400
|
+
}
|
|
331
401
|
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
402
|
+
middleware_context = {
|
|
403
|
+
"task_id": task_id,
|
|
404
|
+
"job_id": job_id,
|
|
405
|
+
"task_name": task_name,
|
|
406
|
+
"params": validated_params,
|
|
407
|
+
"handler_kwargs": handler_kwargs,
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
async def _execution_logic():
|
|
411
|
+
handler = handler_data["func"]
|
|
412
|
+
final_kwargs = middleware_context["handler_kwargs"]
|
|
413
|
+
|
|
414
|
+
if iscoroutinefunction(handler):
|
|
415
|
+
return await handler(**final_kwargs)
|
|
416
|
+
else:
|
|
417
|
+
return await to_thread(handler, **final_kwargs)
|
|
418
|
+
|
|
419
|
+
handler_chain = _execution_logic
|
|
420
|
+
for middleware in reversed(self._middlewares):
|
|
421
|
+
|
|
422
|
+
def make_wrapper(mw: Middleware, next_handler: Callable) -> Callable:
|
|
423
|
+
async def wrapper():
|
|
424
|
+
return await mw(middleware_context, next_handler)
|
|
425
|
+
|
|
426
|
+
return wrapper
|
|
427
|
+
|
|
428
|
+
handler_chain = make_wrapper(middleware, handler_chain)
|
|
429
|
+
|
|
430
|
+
handler_result = await handler_chain()
|
|
431
|
+
|
|
432
|
+
updated_data, metadata_map = await self._s3_manager.process_result(handler_result.get("data", {}))
|
|
433
|
+
|
|
434
|
+
result_obj = TaskResult(
|
|
435
|
+
job_id=job_id,
|
|
436
|
+
task_id=task_id,
|
|
437
|
+
worker_id=self._config.WORKER_ID,
|
|
438
|
+
status=handler_result.get("status", TASK_STATUS_SUCCESS),
|
|
439
|
+
data=updated_data,
|
|
440
|
+
error=TaskError(**handler_result["error"]) if "error" in handler_result else None,
|
|
441
|
+
data_metadata=metadata_map if metadata_map else None,
|
|
442
|
+
)
|
|
335
443
|
|
|
336
|
-
result = await handler_data["func"](
|
|
337
|
-
validated_params,
|
|
338
|
-
task_id=task_id,
|
|
339
|
-
job_id=job_id,
|
|
340
|
-
priority=task_data.get("priority", 0),
|
|
341
|
-
send_progress=self.send_progress,
|
|
342
|
-
add_to_hot_cache=self.add_to_hot_cache,
|
|
343
|
-
remove_from_hot_cache=self.remove_from_hot_cache,
|
|
344
|
-
**deps,
|
|
345
|
-
)
|
|
346
|
-
result = await self._s3_manager.process_result(result)
|
|
347
444
|
except ParamValidationError as e:
|
|
348
445
|
logger.error(f"Task {task_id} failed validation: {e}")
|
|
349
|
-
|
|
446
|
+
error = TaskError(code=ERROR_CODE_INVALID_INPUT, message=str(e))
|
|
447
|
+
result_obj = TaskResult(
|
|
448
|
+
job_id=job_id,
|
|
449
|
+
task_id=task_id,
|
|
450
|
+
worker_id=self._config.WORKER_ID,
|
|
451
|
+
status=TASK_STATUS_FAILURE,
|
|
452
|
+
error=error,
|
|
453
|
+
)
|
|
350
454
|
except CancelledError:
|
|
351
455
|
logger.info(f"Task {task_id} was cancelled.")
|
|
352
|
-
|
|
353
|
-
|
|
456
|
+
result_obj = TaskResult(
|
|
457
|
+
job_id=job_id, task_id=task_id, worker_id=self._config.WORKER_ID, status=TASK_STATUS_CANCELLED
|
|
458
|
+
)
|
|
354
459
|
raise
|
|
460
|
+
except ValueError as e:
|
|
461
|
+
logger.error(f"Data integrity or validation error for task {task_id}: {e}")
|
|
462
|
+
error = TaskError(code=ERROR_CODE_INTEGRITY_MISMATCH, message=str(e))
|
|
463
|
+
result_obj = TaskResult(
|
|
464
|
+
job_id=job_id,
|
|
465
|
+
task_id=task_id,
|
|
466
|
+
worker_id=self._config.WORKER_ID,
|
|
467
|
+
status=TASK_STATUS_FAILURE,
|
|
468
|
+
error=error,
|
|
469
|
+
)
|
|
355
470
|
except Exception as e:
|
|
356
471
|
logger.exception(f"An unexpected error occurred while processing task {task_id}:")
|
|
357
|
-
|
|
472
|
+
error = TaskError(code=ERROR_CODE_TRANSIENT, message=str(e))
|
|
473
|
+
result_obj = TaskResult(
|
|
474
|
+
job_id=job_id,
|
|
475
|
+
task_id=task_id,
|
|
476
|
+
worker_id=self._config.WORKER_ID,
|
|
477
|
+
status=TASK_STATUS_FAILURE,
|
|
478
|
+
error=error,
|
|
479
|
+
)
|
|
358
480
|
finally:
|
|
359
|
-
# Cleanup task workspace
|
|
360
481
|
await self._s3_manager.cleanup(task_id)
|
|
361
482
|
|
|
362
|
-
if
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
self._active_tasks.pop(task_id, None)
|
|
483
|
+
if result_obj:
|
|
484
|
+
try:
|
|
485
|
+
await client.send_result(result_obj)
|
|
486
|
+
except RxonError as e:
|
|
487
|
+
logger.error(f"Failed to send task result: {e}")
|
|
368
488
|
|
|
489
|
+
self._active_tasks.pop(task_id, None)
|
|
369
490
|
self._current_load -= 1
|
|
370
491
|
if task_type_for_limit:
|
|
371
492
|
self._current_load_by_type[task_type_for_limit] -= 1
|
|
@@ -383,62 +504,127 @@ class Worker:
|
|
|
383
504
|
await self._send_heartbeats_to_all()
|
|
384
505
|
await sleep(self._config.HEARTBEAT_INTERVAL)
|
|
385
506
|
|
|
507
|
+
@staticmethod
|
|
508
|
+
def _from_dict(cls: type, data: dict[str, Any]) -> Any:
|
|
509
|
+
"""Safely instantiates a NamedTuple from a dict, ignoring unknown fields."""
|
|
510
|
+
if not data:
|
|
511
|
+
return None
|
|
512
|
+
fields = cls._fields
|
|
513
|
+
filtered_data = {k: v for k, v in data.items() if k in fields}
|
|
514
|
+
return cls(**filtered_data)
|
|
515
|
+
|
|
386
516
|
async def _register_with_all_orchestrators(self):
|
|
387
517
|
"""Registers the worker with all orchestrators."""
|
|
388
518
|
state = self._get_current_state()
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
519
|
+
|
|
520
|
+
gpu_info = None
|
|
521
|
+
if self._config.RESOURCES.get("gpu_info"):
|
|
522
|
+
gpu_info = GPUInfo(**self._config.RESOURCES["gpu_info"])
|
|
523
|
+
|
|
524
|
+
resources = Resources(
|
|
525
|
+
max_concurrent_tasks=self._config.MAX_CONCURRENT_TASKS,
|
|
526
|
+
cpu_cores=self._config.RESOURCES["cpu_cores"],
|
|
527
|
+
gpu_info=gpu_info,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
s3_hash = calculate_config_hash(
|
|
531
|
+
self._config.S3_ENDPOINT_URL,
|
|
532
|
+
self._config.S3_ACCESS_KEY,
|
|
533
|
+
self._config.S3_DEFAULT_BUCKET,
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
registration = WorkerRegistration(
|
|
537
|
+
worker_id=self._config.WORKER_ID,
|
|
538
|
+
worker_type=self._config.WORKER_TYPE,
|
|
539
|
+
supported_tasks=state["supported_tasks"],
|
|
540
|
+
resources=resources,
|
|
541
|
+
installed_software=self._config.INSTALLED_SOFTWARE,
|
|
542
|
+
installed_models=[InstalledModel(**m) for m in self._config.INSTALLED_MODELS],
|
|
543
|
+
capabilities=WorkerCapabilities(
|
|
544
|
+
hostname=self._config.HOSTNAME,
|
|
545
|
+
ip_address=self._config.IP_ADDRESS,
|
|
546
|
+
cost_per_skill=self._config.COST_PER_SKILL,
|
|
547
|
+
s3_config_hash=s3_hash,
|
|
548
|
+
),
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
await gather(*[self._safe_register(client, registration) for _, client in self._clients])
|
|
552
|
+
|
|
553
|
+
async def _safe_register(self, client: Transport, registration: WorkerRegistration):
|
|
554
|
+
try:
|
|
555
|
+
await client.register(registration)
|
|
556
|
+
except RxonError as e:
|
|
557
|
+
logger.error(f"Registration failed for {client}: {e}")
|
|
401
558
|
|
|
402
559
|
async def _send_heartbeats_to_all(self):
|
|
403
560
|
"""Sends heartbeat messages to all orchestrators."""
|
|
404
561
|
state = self._get_current_state()
|
|
405
|
-
payload = {
|
|
406
|
-
"load": self._current_load,
|
|
407
|
-
"status": state["status"],
|
|
408
|
-
"supported_tasks": state["supported_tasks"],
|
|
409
|
-
"hot_cache": list(self._hot_cache),
|
|
410
|
-
}
|
|
411
562
|
|
|
563
|
+
hot_skills = None
|
|
412
564
|
if self._skill_dependencies:
|
|
413
|
-
payload["skill_dependencies"] = self._skill_dependencies
|
|
414
565
|
hot_skills = [
|
|
415
566
|
skill for skill, models in self._skill_dependencies.items() if set(models).issubset(self._hot_cache)
|
|
416
567
|
]
|
|
417
|
-
if hot_skills:
|
|
418
|
-
payload["hot_skills"] = hot_skills
|
|
419
568
|
|
|
420
|
-
|
|
569
|
+
heartbeat = Heartbeat(
|
|
570
|
+
worker_id=self._config.WORKER_ID,
|
|
571
|
+
status=state["status"],
|
|
572
|
+
load=float(self._current_load),
|
|
573
|
+
current_tasks=list(self._active_tasks.keys()),
|
|
574
|
+
supported_tasks=state["supported_tasks"],
|
|
575
|
+
hot_cache=list(self._hot_cache),
|
|
576
|
+
skill_dependencies=self._skill_dependencies or None,
|
|
577
|
+
hot_skills=hot_skills or None,
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
await gather(*[self._safe_heartbeat(client, heartbeat) for _, client in self._clients])
|
|
581
|
+
|
|
582
|
+
async def _safe_heartbeat(self, client: Transport, heartbeat: Heartbeat):
|
|
583
|
+
try:
|
|
584
|
+
await client.send_heartbeat(heartbeat)
|
|
585
|
+
except RxonError as e:
|
|
586
|
+
logger.warning(f"Heartbeat failed for {client}: {e}")
|
|
421
587
|
|
|
422
588
|
async def main(self):
|
|
423
589
|
"""The main asynchronous function."""
|
|
424
590
|
self._config.validate()
|
|
425
|
-
self._validate_task_types()
|
|
591
|
+
self._validate_task_types()
|
|
426
592
|
if not self._http_session:
|
|
427
|
-
self.
|
|
428
|
-
|
|
593
|
+
if self._config.TLS_CA_PATH or (self._config.TLS_CERT_PATH and self._config.TLS_KEY_PATH):
|
|
594
|
+
logger.info("Initializing SSL context for mTLS.")
|
|
595
|
+
self._ssl_context = create_client_ssl_context(
|
|
596
|
+
ca_path=self._config.TLS_CA_PATH,
|
|
597
|
+
cert_path=self._config.TLS_CERT_PATH,
|
|
598
|
+
key_path=self._config.TLS_KEY_PATH,
|
|
599
|
+
)
|
|
600
|
+
connector = TCPConnector(ssl=self._ssl_context) if self._ssl_context else None
|
|
601
|
+
self._http_session = ClientSession(connector=connector)
|
|
602
|
+
if not self._clients:
|
|
603
|
+
self._init_clients()
|
|
604
|
+
|
|
605
|
+
# Connect transports
|
|
606
|
+
await gather(*[client.connect() for _, client in self._clients])
|
|
429
607
|
|
|
430
608
|
comm_task = create_task(self._manage_orchestrator_communications())
|
|
431
609
|
|
|
610
|
+
token_rotation_task = None
|
|
611
|
+
if self._ssl_context:
|
|
612
|
+
token_rotation_task = create_task(self._manage_token_rotation())
|
|
613
|
+
|
|
432
614
|
polling_task = create_task(self._start_polling())
|
|
433
615
|
await self._shutdown_event.wait()
|
|
434
616
|
|
|
435
617
|
for task in [comm_task, polling_task]:
|
|
436
618
|
task.cancel()
|
|
619
|
+
if token_rotation_task:
|
|
620
|
+
token_rotation_task.cancel()
|
|
621
|
+
|
|
437
622
|
if self._active_tasks:
|
|
438
623
|
await gather(*self._active_tasks.values(), return_exceptions=True)
|
|
439
624
|
|
|
440
|
-
|
|
441
|
-
|
|
625
|
+
# Close transports
|
|
626
|
+
await gather(*[client.close() for _, client in self._clients])
|
|
627
|
+
|
|
442
628
|
if self._http_session and not self._http_session.closed and not self._session_is_managed_externally:
|
|
443
629
|
await self._http_session.close()
|
|
444
630
|
|
|
@@ -449,6 +635,26 @@ class Worker:
|
|
|
449
635
|
except KeyboardInterrupt:
|
|
450
636
|
self._shutdown_event.set()
|
|
451
637
|
|
|
638
|
+
async def _manage_token_rotation(self):
|
|
639
|
+
"""Periodically refreshes auth tokens for all clients."""
|
|
640
|
+
await sleep(5)
|
|
641
|
+
|
|
642
|
+
while not self._shutdown_event.is_set():
|
|
643
|
+
min_expires_in = 3600
|
|
644
|
+
|
|
645
|
+
for _, client in self._clients:
|
|
646
|
+
try:
|
|
647
|
+
token_resp = await client.refresh_token()
|
|
648
|
+
if token_resp:
|
|
649
|
+
self._config.WORKER_TOKEN = token_resp.access_token
|
|
650
|
+
min_expires_in = min(min_expires_in, token_resp.expires_in)
|
|
651
|
+
except Exception as e:
|
|
652
|
+
logger.error(f"Error in token rotation loop: {e}")
|
|
653
|
+
|
|
654
|
+
refresh_delay = max(60, min_expires_in * 0.8)
|
|
655
|
+
logger.debug(f"Next token refresh scheduled in {refresh_delay:.1f}s")
|
|
656
|
+
await sleep(refresh_delay)
|
|
657
|
+
|
|
452
658
|
async def _run_health_check_server(self):
|
|
453
659
|
app = web.Application()
|
|
454
660
|
|
|
@@ -473,54 +679,26 @@ class Worker:
|
|
|
473
679
|
self._shutdown_event.set()
|
|
474
680
|
|
|
475
681
|
async def _start_websocket_manager(self):
|
|
476
|
-
"""Manages the
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
try:
|
|
481
|
-
ws = await client.connect_websocket()
|
|
482
|
-
if ws:
|
|
483
|
-
self._ws_connection = ws
|
|
484
|
-
await self._listen_for_commands()
|
|
485
|
-
finally:
|
|
486
|
-
self._ws_connection = None
|
|
487
|
-
await sleep(5) # Reconnection delay
|
|
488
|
-
if not self._clients:
|
|
489
|
-
await sleep(5)
|
|
682
|
+
"""Manages the command listeners."""
|
|
683
|
+
listeners = []
|
|
684
|
+
for _, client in self._clients:
|
|
685
|
+
listeners.append(create_task(self._listen_to_single_transport(client)))
|
|
490
686
|
|
|
491
|
-
|
|
492
|
-
"""Listens for and processes commands from the orchestrator via WebSocket."""
|
|
493
|
-
if not self._ws_connection:
|
|
494
|
-
return
|
|
687
|
+
await self._shutdown_event.wait()
|
|
495
688
|
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
if msg.type == WSMsgType.TEXT:
|
|
499
|
-
try:
|
|
500
|
-
command = msg.json()
|
|
501
|
-
if command.get("type") == COMMAND_CANCEL_TASK:
|
|
502
|
-
task_id = command.get("task_id")
|
|
503
|
-
if task_id in self._active_tasks:
|
|
504
|
-
self._active_tasks[task_id].cancel()
|
|
505
|
-
logger.info(f"Cancelled task {task_id} by orchestrator command.")
|
|
506
|
-
except JSONDecodeError:
|
|
507
|
-
logger.warning(f"Received invalid JSON over WebSocket: {msg.data}")
|
|
508
|
-
elif msg.type == WSMsgType.ERROR:
|
|
509
|
-
break
|
|
510
|
-
except Exception as e:
|
|
511
|
-
logger.error(f"Error in WebSocket listener: {e}")
|
|
689
|
+
for listener in listeners:
|
|
690
|
+
listener.cancel()
|
|
512
691
|
|
|
513
|
-
async def
|
|
514
|
-
|
|
515
|
-
if self._ws_connection and not self._ws_connection.closed:
|
|
692
|
+
async def _listen_to_single_transport(self, client: Transport):
|
|
693
|
+
while not self._shutdown_event.is_set():
|
|
516
694
|
try:
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
await self._ws_connection.send_json(payload)
|
|
695
|
+
async for command in client.listen_for_commands():
|
|
696
|
+
if command.command == COMMAND_CANCEL_TASK:
|
|
697
|
+
task_id = command.task_id
|
|
698
|
+
job_id = command.job_id
|
|
699
|
+
if task_id in self._active_tasks:
|
|
700
|
+
self._active_tasks[task_id].cancel()
|
|
701
|
+
logger.info(f"Cancelled task {task_id} (Job: {job_id or 'N/A'}) by orchestrator command.")
|
|
525
702
|
except Exception as e:
|
|
526
|
-
logger.
|
|
703
|
+
logger.error(f"Error in command listener: {e}")
|
|
704
|
+
await sleep(5)
|