avtomatika-worker 1.0b2__py3-none-any.whl → 1.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- avtomatika_worker/__init__.py +1 -1
- avtomatika_worker/config.py +20 -1
- avtomatika_worker/py.typed +0 -0
- avtomatika_worker/s3.py +171 -74
- avtomatika_worker/task_files.py +60 -2
- avtomatika_worker/types.py +23 -5
- avtomatika_worker/worker.py +372 -209
- {avtomatika_worker-1.0b2.dist-info → avtomatika_worker-1.0b4.dist-info}/METADATA +80 -11
- avtomatika_worker-1.0b4.dist-info/RECORD +12 -0
- {avtomatika_worker-1.0b2.dist-info → avtomatika_worker-1.0b4.dist-info}/WHEEL +1 -1
- {avtomatika_worker-1.0b2.dist-info → avtomatika_worker-1.0b4.dist-info}/licenses/LICENSE +1 -1
- avtomatika_worker-1.0b2.dist-info/RECORD +0 -11
- {avtomatika_worker-1.0b2.dist-info → avtomatika_worker-1.0b4.dist-info}/top_level.txt +0 -0
avtomatika_worker/worker.py
CHANGED
|
@@ -1,18 +1,46 @@
|
|
|
1
|
-
from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
|
|
2
|
-
from asyncio import TimeoutError as AsyncTimeoutError
|
|
1
|
+
from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep, to_thread
|
|
3
2
|
from dataclasses import is_dataclass
|
|
4
|
-
from inspect import Parameter, signature
|
|
5
|
-
from json import JSONDecodeError
|
|
3
|
+
from inspect import Parameter, iscoroutinefunction, signature
|
|
6
4
|
from logging import getLogger
|
|
7
5
|
from os.path import join
|
|
8
6
|
from typing import Any, Callable
|
|
9
7
|
|
|
10
|
-
from aiohttp import
|
|
8
|
+
from aiohttp import ClientSession, TCPConnector, web
|
|
9
|
+
from rxon import Transport, create_transport
|
|
10
|
+
from rxon.blob import calculate_config_hash
|
|
11
|
+
from rxon.constants import (
|
|
12
|
+
COMMAND_CANCEL_TASK,
|
|
13
|
+
ERROR_CODE_INTEGRITY_MISMATCH,
|
|
14
|
+
ERROR_CODE_INVALID_INPUT,
|
|
15
|
+
ERROR_CODE_PERMANENT,
|
|
16
|
+
ERROR_CODE_TRANSIENT,
|
|
17
|
+
MSG_TYPE_PROGRESS,
|
|
18
|
+
TASK_STATUS_CANCELLED,
|
|
19
|
+
TASK_STATUS_FAILURE,
|
|
20
|
+
TASK_STATUS_SUCCESS,
|
|
21
|
+
)
|
|
22
|
+
from rxon.exceptions import RxonError
|
|
23
|
+
from rxon.models import (
|
|
24
|
+
FileMetadata,
|
|
25
|
+
GPUInfo,
|
|
26
|
+
Heartbeat,
|
|
27
|
+
InstalledModel,
|
|
28
|
+
ProgressUpdatePayload,
|
|
29
|
+
Resources,
|
|
30
|
+
TaskError,
|
|
31
|
+
TaskPayload,
|
|
32
|
+
TaskResult,
|
|
33
|
+
WorkerCapabilities,
|
|
34
|
+
WorkerRegistration,
|
|
35
|
+
)
|
|
36
|
+
from rxon.security import create_client_ssl_context
|
|
37
|
+
from rxon.utils import to_dict
|
|
38
|
+
from rxon.validators import validate_identifier
|
|
11
39
|
|
|
12
40
|
from .config import WorkerConfig
|
|
13
41
|
from .s3 import S3Manager
|
|
14
42
|
from .task_files import TaskFiles
|
|
15
|
-
from .types import
|
|
43
|
+
from .types import CapacityChecker, Middleware, ParamValidationError
|
|
16
44
|
|
|
17
45
|
try:
|
|
18
46
|
from pydantic import BaseModel, ValidationError
|
|
@@ -29,7 +57,7 @@ class Worker:
|
|
|
29
57
|
"""The main class for creating and running a worker.
|
|
30
58
|
Implements a hybrid interaction model with the Orchestrator:
|
|
31
59
|
- PULL model for fetching tasks.
|
|
32
|
-
-
|
|
60
|
+
- Transport layer for real-time commands (cancellation) and sending progress.
|
|
33
61
|
"""
|
|
34
62
|
|
|
35
63
|
def __init__(
|
|
@@ -40,16 +68,20 @@ class Worker:
|
|
|
40
68
|
http_session: ClientSession | None = None,
|
|
41
69
|
skill_dependencies: dict[str, list[str]] | None = None,
|
|
42
70
|
config: WorkerConfig | None = None,
|
|
71
|
+
capacity_checker: CapacityChecker | None = None,
|
|
72
|
+
clients: list[tuple[dict[str, Any], Transport]] | None = None,
|
|
43
73
|
):
|
|
44
74
|
self._config = config or WorkerConfig()
|
|
45
75
|
self._s3_manager = S3Manager(self._config)
|
|
46
76
|
self._config.WORKER_TYPE = worker_type # Allow overriding worker_type
|
|
47
77
|
if max_concurrent_tasks is not None:
|
|
48
|
-
self._config.
|
|
78
|
+
self._config.MAX_CONCURRENT_TASKS = max_concurrent_tasks
|
|
49
79
|
|
|
50
80
|
self._task_type_limits = task_type_limits or {}
|
|
51
81
|
self._task_handlers: dict[str, dict[str, Any]] = {}
|
|
52
82
|
self._skill_dependencies = skill_dependencies or {}
|
|
83
|
+
self._middlewares: list[Middleware] = []
|
|
84
|
+
self._capacity_checker = capacity_checker
|
|
53
85
|
|
|
54
86
|
# Worker state
|
|
55
87
|
self._current_load = 0
|
|
@@ -58,12 +90,10 @@ class Worker:
|
|
|
58
90
|
self._active_tasks: dict[str, Task] = {}
|
|
59
91
|
self._http_session = http_session
|
|
60
92
|
self._session_is_managed_externally = http_session is not None
|
|
61
|
-
self._ws_connection: ClientWebSocketResponse | None = None
|
|
62
|
-
# Removed: self._headers = {"X-Worker-Token": self._config.WORKER_TOKEN}
|
|
63
93
|
self._shutdown_event = Event()
|
|
64
94
|
self._registered_event = Event()
|
|
65
|
-
self._round_robin_index = 0
|
|
66
95
|
self._debounce_task: Task | None = None
|
|
96
|
+
self._ssl_context = None
|
|
67
97
|
|
|
68
98
|
# --- Weighted Round-Robin State ---
|
|
69
99
|
self._total_orchestrator_weight = 0
|
|
@@ -71,8 +101,37 @@ class Worker:
|
|
|
71
101
|
for o in self._config.ORCHESTRATORS:
|
|
72
102
|
o["current_weight"] = 0
|
|
73
103
|
self._total_orchestrator_weight += o.get("weight", 1)
|
|
104
|
+
self._clients = clients or []
|
|
105
|
+
if not self._clients and self._http_session:
|
|
106
|
+
self._init_clients()
|
|
107
|
+
|
|
108
|
+
def add_middleware(self, middleware: Middleware) -> None:
|
|
109
|
+
"""Adds a middleware to the execution chain."""
|
|
110
|
+
self._middlewares.append(middleware)
|
|
111
|
+
|
|
112
|
+
def _init_clients(self):
|
|
113
|
+
"""Initializes Transport instances for each configured orchestrator."""
|
|
114
|
+
# Even if we don't have an external session, we might create transports
|
|
115
|
+
# that will create their own sessions. But if we want to share one, we need it here.
|
|
116
|
+
session_to_use = self._http_session if self._session_is_managed_externally else None
|
|
117
|
+
|
|
118
|
+
self._clients = [
|
|
119
|
+
(
|
|
120
|
+
o,
|
|
121
|
+
create_transport(
|
|
122
|
+
url=o["url"],
|
|
123
|
+
worker_id=self._config.WORKER_ID,
|
|
124
|
+
token=o.get("token", self._config.WORKER_TOKEN),
|
|
125
|
+
ssl_context=self._ssl_context,
|
|
126
|
+
session=session_to_use,
|
|
127
|
+
result_retries=self._config.RESULT_MAX_RETRIES,
|
|
128
|
+
result_retry_delay=self._config.RESULT_RETRY_INITIAL_DELAY,
|
|
129
|
+
),
|
|
130
|
+
)
|
|
131
|
+
for o in self._config.ORCHESTRATORS
|
|
132
|
+
]
|
|
74
133
|
|
|
75
|
-
def
|
|
134
|
+
def _validate_task_types(self):
|
|
76
135
|
"""Checks for unused task type limits and warns the user."""
|
|
77
136
|
registered_task_types = {
|
|
78
137
|
handler_data["type"] for handler_data in self._task_handlers.values() if handler_data["type"]
|
|
@@ -87,6 +146,9 @@ class Worker:
|
|
|
87
146
|
|
|
88
147
|
def task(self, name: str, task_type: str | None = None) -> Callable:
|
|
89
148
|
"""Decorator to register a function as a task handler."""
|
|
149
|
+
validate_identifier(name, "task name")
|
|
150
|
+
if task_type:
|
|
151
|
+
validate_identifier(task_type, "task type")
|
|
90
152
|
|
|
91
153
|
def decorator(func: Callable) -> Callable:
|
|
92
154
|
logger.info(f"Registering task: '{name}' (type: {task_type or 'N/A'})")
|
|
@@ -137,35 +199,37 @@ class Worker:
|
|
|
137
199
|
if is_available:
|
|
138
200
|
supported_tasks.append(name)
|
|
139
201
|
|
|
202
|
+
if self._capacity_checker:
|
|
203
|
+
supported_tasks = [task for task in supported_tasks if self._capacity_checker(task)]
|
|
204
|
+
|
|
140
205
|
status = "idle" if supported_tasks else "busy"
|
|
141
206
|
return {"status": status, "supported_tasks": supported_tasks}
|
|
142
207
|
|
|
143
|
-
def
|
|
144
|
-
"""Builds authentication headers for a specific orchestrator."""
|
|
145
|
-
token = orchestrator.get("token", self._config.WORKER_TOKEN)
|
|
146
|
-
return {"X-Worker-Token": token}
|
|
147
|
-
|
|
148
|
-
def _get_next_orchestrator(self) -> dict[str, Any] | None:
|
|
208
|
+
def _get_next_client(self) -> Transport | None:
|
|
149
209
|
"""
|
|
150
|
-
Selects the next orchestrator using a smooth weighted round-robin algorithm.
|
|
210
|
+
Selects the next orchestrator client using a smooth weighted round-robin algorithm.
|
|
151
211
|
"""
|
|
152
|
-
if not self.
|
|
212
|
+
if not self._clients:
|
|
153
213
|
return None
|
|
154
214
|
|
|
155
215
|
# The orchestrator with the highest current_weight is selected.
|
|
156
|
-
|
|
216
|
+
selected_client = None
|
|
157
217
|
highest_weight = -1
|
|
158
218
|
|
|
159
|
-
for o in self.
|
|
219
|
+
for o, client in self._clients:
|
|
160
220
|
o["current_weight"] += o["weight"]
|
|
161
221
|
if o["current_weight"] > highest_weight:
|
|
162
222
|
highest_weight = o["current_weight"]
|
|
163
|
-
|
|
223
|
+
selected_client = client
|
|
164
224
|
|
|
165
|
-
if
|
|
166
|
-
|
|
225
|
+
if selected_client:
|
|
226
|
+
# Find the config for the selected client to decrement its weight
|
|
227
|
+
for o, client in self._clients:
|
|
228
|
+
if client == selected_client:
|
|
229
|
+
o["current_weight"] -= self._total_orchestrator_weight
|
|
230
|
+
break
|
|
167
231
|
|
|
168
|
-
return
|
|
232
|
+
return selected_client
|
|
169
233
|
|
|
170
234
|
async def _debounced_heartbeat_sender(self):
|
|
171
235
|
"""Waits for the debounce delay then sends a heartbeat."""
|
|
@@ -180,33 +244,29 @@ class Worker:
|
|
|
180
244
|
# Schedule the new debounced call.
|
|
181
245
|
self._debounce_task = create_task(self._debounced_heartbeat_sender())
|
|
182
246
|
|
|
183
|
-
async def _poll_for_tasks(self,
|
|
247
|
+
async def _poll_for_tasks(self, client: Transport):
|
|
184
248
|
"""Polls a specific Orchestrator for new tasks."""
|
|
185
|
-
|
|
249
|
+
current_state = self._get_current_state()
|
|
250
|
+
if current_state["status"] == "busy":
|
|
251
|
+
return
|
|
252
|
+
|
|
186
253
|
try:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
self.
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
task = create_task(self._process_task(task_data))
|
|
204
|
-
self._active_tasks[task_data["task_id"]] = task
|
|
205
|
-
elif resp.status != 204:
|
|
206
|
-
await sleep(self._config.TASK_POLL_ERROR_DELAY)
|
|
207
|
-
except (AsyncTimeoutError, ClientError) as e:
|
|
208
|
-
logger.error(f"Error polling for tasks: {e}")
|
|
209
|
-
await sleep(self._config.TASK_POLL_ERROR_DELAY)
|
|
254
|
+
task_data = await client.poll_task(timeout=self._config.TASK_POLL_TIMEOUT)
|
|
255
|
+
if task_data:
|
|
256
|
+
task_data_dict = to_dict(task_data)
|
|
257
|
+
task_data_dict["client"] = client
|
|
258
|
+
|
|
259
|
+
self._current_load += 1
|
|
260
|
+
if (task_handler_info := self._task_handlers.get(task_data.type)) and (
|
|
261
|
+
task_type_for_limit := task_handler_info.get("type")
|
|
262
|
+
):
|
|
263
|
+
self._current_load_by_type[task_type_for_limit] += 1
|
|
264
|
+
self._schedule_heartbeat_debounce()
|
|
265
|
+
|
|
266
|
+
task = create_task(self._process_task(task_data_dict))
|
|
267
|
+
self._active_tasks[task_data.task_id] = task
|
|
268
|
+
except RxonError as e:
|
|
269
|
+
logger.error(f"Error polling tasks: {e}")
|
|
210
270
|
|
|
211
271
|
async def _start_polling(self):
|
|
212
272
|
"""The main loop for polling tasks."""
|
|
@@ -218,13 +278,13 @@ class Worker:
|
|
|
218
278
|
continue
|
|
219
279
|
|
|
220
280
|
if self._config.MULTI_ORCHESTRATOR_MODE == "ROUND_ROBIN":
|
|
221
|
-
if
|
|
222
|
-
await self._poll_for_tasks(
|
|
281
|
+
if client := self._get_next_client():
|
|
282
|
+
await self._poll_for_tasks(client)
|
|
223
283
|
else:
|
|
224
|
-
for
|
|
284
|
+
for _, client in self._clients:
|
|
225
285
|
if self._get_current_state()["status"] == "busy":
|
|
226
286
|
break
|
|
227
|
-
await self._poll_for_tasks(
|
|
287
|
+
await self._poll_for_tasks(client)
|
|
228
288
|
|
|
229
289
|
if self._current_load == 0:
|
|
230
290
|
await sleep(self._config.IDLE_POLL_DELAY)
|
|
@@ -233,7 +293,6 @@ class Worker:
|
|
|
233
293
|
def _prepare_task_params(handler: Callable, params: dict[str, Any]) -> Any:
|
|
234
294
|
"""
|
|
235
295
|
Inspects the handler's signature to validate and instantiate params.
|
|
236
|
-
Supports dict, dataclasses, and optional pydantic models.
|
|
237
296
|
"""
|
|
238
297
|
sig = signature(handler)
|
|
239
298
|
params_annotation = sig.parameters.get("params").annotation
|
|
@@ -251,11 +310,11 @@ class Worker:
|
|
|
251
310
|
# Dataclass Instantiation
|
|
252
311
|
if isinstance(params_annotation, type) and is_dataclass(params_annotation):
|
|
253
312
|
try:
|
|
254
|
-
# Filter unknown fields
|
|
313
|
+
# Filter unknown fields
|
|
255
314
|
known_fields = {f.name for f in params_annotation.__dataclass_fields__.values()}
|
|
256
315
|
filtered_params = {k: v for k, v in params.items() if k in known_fields}
|
|
257
316
|
|
|
258
|
-
#
|
|
317
|
+
# Check required fields
|
|
259
318
|
required_fields = [
|
|
260
319
|
f.name
|
|
261
320
|
for f in params_annotation.__dataclass_fields__.values()
|
|
@@ -267,7 +326,6 @@ class Worker:
|
|
|
267
326
|
|
|
268
327
|
return params_annotation(**filtered_params)
|
|
269
328
|
except (TypeError, ValueError) as e:
|
|
270
|
-
# TypeError for missing/extra args, ValueError from __post_init__
|
|
271
329
|
raise ParamValidationError(str(e)) from e
|
|
272
330
|
|
|
273
331
|
return params
|
|
@@ -276,8 +334,7 @@ class Worker:
|
|
|
276
334
|
"""Injects dependencies based on type hints."""
|
|
277
335
|
deps = {}
|
|
278
336
|
task_dir = join(self._config.TASK_FILES_DIR, task_id)
|
|
279
|
-
|
|
280
|
-
task_files = TaskFiles(task_dir)
|
|
337
|
+
task_files = TaskFiles(task_dir, s3_manager=self._s3_manager)
|
|
281
338
|
|
|
282
339
|
sig = signature(handler)
|
|
283
340
|
for name, param in sig.parameters.items():
|
|
@@ -286,82 +343,155 @@ class Worker:
|
|
|
286
343
|
|
|
287
344
|
return deps
|
|
288
345
|
|
|
289
|
-
async def _process_task(self,
|
|
346
|
+
async def _process_task(self, task_data_raw: dict[str, Any]):
|
|
290
347
|
"""Executes the task logic."""
|
|
291
|
-
|
|
292
|
-
|
|
348
|
+
client: Transport = task_data_raw.pop("client")
|
|
349
|
+
|
|
350
|
+
# Parse incoming task data using protocol model
|
|
351
|
+
if "params_metadata" in task_data_raw and task_data_raw["params_metadata"]:
|
|
352
|
+
task_data_raw["params_metadata"] = {
|
|
353
|
+
k: self._from_dict(FileMetadata, v) for k, v in task_data_raw["params_metadata"].items()
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
task_payload = self._from_dict(TaskPayload, task_data_raw)
|
|
357
|
+
task_id, job_id, task_name = task_payload.task_id, task_payload.job_id, task_payload.type
|
|
358
|
+
params = task_payload.params
|
|
293
359
|
|
|
294
|
-
result: dict[str, Any] = {}
|
|
295
360
|
handler_data = self._task_handlers.get(task_name)
|
|
296
361
|
task_type_for_limit = handler_data.get("type") if handler_data else None
|
|
297
362
|
|
|
298
|
-
|
|
363
|
+
result_obj: TaskResult | None = None
|
|
364
|
+
|
|
365
|
+
# Create a progress sender wrapper attached to this specific client
|
|
366
|
+
async def send_progress_wrapper(task_id_arg, job_id_arg, progress, message=""):
|
|
367
|
+
payload = ProgressUpdatePayload(
|
|
368
|
+
event=MSG_TYPE_PROGRESS, task_id=task_id_arg, job_id=job_id_arg, progress=progress, message=message
|
|
369
|
+
)
|
|
370
|
+
await client.send_progress(payload)
|
|
299
371
|
|
|
300
372
|
try:
|
|
301
373
|
if not handler_data:
|
|
302
374
|
message = f"Unsupported task: {task_name}"
|
|
303
375
|
logger.warning(message)
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
validated_params,
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
376
|
+
error = TaskError(code=ERROR_CODE_PERMANENT, message=message)
|
|
377
|
+
result_obj = TaskResult(
|
|
378
|
+
job_id=job_id,
|
|
379
|
+
task_id=task_id,
|
|
380
|
+
worker_id=self._config.WORKER_ID,
|
|
381
|
+
status=TASK_STATUS_FAILURE,
|
|
382
|
+
error=error,
|
|
383
|
+
)
|
|
384
|
+
else:
|
|
385
|
+
# Download files
|
|
386
|
+
params = await self._s3_manager.process_params(params, task_id, metadata=task_payload.params_metadata)
|
|
387
|
+
validated_params = self._prepare_task_params(handler_data["func"], params)
|
|
388
|
+
deps = self._prepare_dependencies(handler_data["func"], task_id)
|
|
389
|
+
|
|
390
|
+
handler_kwargs = {
|
|
391
|
+
"params": validated_params,
|
|
392
|
+
"task_id": task_id,
|
|
393
|
+
"job_id": job_id,
|
|
394
|
+
"tracing_context": task_payload.tracing_context,
|
|
395
|
+
"priority": task_data_raw.get("priority", 0),
|
|
396
|
+
"send_progress": send_progress_wrapper,
|
|
397
|
+
"add_to_hot_cache": self.add_to_hot_cache,
|
|
398
|
+
"remove_from_hot_cache": self.remove_from_hot_cache,
|
|
399
|
+
**deps,
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
middleware_context = {
|
|
403
|
+
"task_id": task_id,
|
|
404
|
+
"job_id": job_id,
|
|
405
|
+
"task_name": task_name,
|
|
406
|
+
"params": validated_params,
|
|
407
|
+
"handler_kwargs": handler_kwargs,
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
async def _execution_logic():
|
|
411
|
+
handler = handler_data["func"]
|
|
412
|
+
final_kwargs = middleware_context["handler_kwargs"]
|
|
413
|
+
|
|
414
|
+
if iscoroutinefunction(handler):
|
|
415
|
+
return await handler(**final_kwargs)
|
|
416
|
+
else:
|
|
417
|
+
return await to_thread(handler, **final_kwargs)
|
|
418
|
+
|
|
419
|
+
handler_chain = _execution_logic
|
|
420
|
+
for middleware in reversed(self._middlewares):
|
|
421
|
+
|
|
422
|
+
def make_wrapper(mw: Middleware, next_handler: Callable) -> Callable:
|
|
423
|
+
async def wrapper():
|
|
424
|
+
return await mw(middleware_context, next_handler)
|
|
425
|
+
|
|
426
|
+
return wrapper
|
|
427
|
+
|
|
428
|
+
handler_chain = make_wrapper(middleware, handler_chain)
|
|
429
|
+
|
|
430
|
+
handler_result = await handler_chain()
|
|
431
|
+
|
|
432
|
+
updated_data, metadata_map = await self._s3_manager.process_result(handler_result.get("data", {}))
|
|
433
|
+
|
|
434
|
+
result_obj = TaskResult(
|
|
435
|
+
job_id=job_id,
|
|
436
|
+
task_id=task_id,
|
|
437
|
+
worker_id=self._config.WORKER_ID,
|
|
438
|
+
status=handler_result.get("status", TASK_STATUS_SUCCESS),
|
|
439
|
+
data=updated_data,
|
|
440
|
+
error=TaskError(**handler_result["error"]) if "error" in handler_result else None,
|
|
441
|
+
data_metadata=metadata_map if metadata_map else None,
|
|
442
|
+
)
|
|
443
|
+
|
|
325
444
|
except ParamValidationError as e:
|
|
326
445
|
logger.error(f"Task {task_id} failed validation: {e}")
|
|
327
|
-
|
|
446
|
+
error = TaskError(code=ERROR_CODE_INVALID_INPUT, message=str(e))
|
|
447
|
+
result_obj = TaskResult(
|
|
448
|
+
job_id=job_id,
|
|
449
|
+
task_id=task_id,
|
|
450
|
+
worker_id=self._config.WORKER_ID,
|
|
451
|
+
status=TASK_STATUS_FAILURE,
|
|
452
|
+
error=error,
|
|
453
|
+
)
|
|
328
454
|
except CancelledError:
|
|
329
455
|
logger.info(f"Task {task_id} was cancelled.")
|
|
330
|
-
|
|
331
|
-
|
|
456
|
+
result_obj = TaskResult(
|
|
457
|
+
job_id=job_id, task_id=task_id, worker_id=self._config.WORKER_ID, status=TASK_STATUS_CANCELLED
|
|
458
|
+
)
|
|
332
459
|
raise
|
|
460
|
+
except ValueError as e:
|
|
461
|
+
logger.error(f"Data integrity or validation error for task {task_id}: {e}")
|
|
462
|
+
error = TaskError(code=ERROR_CODE_INTEGRITY_MISMATCH, message=str(e))
|
|
463
|
+
result_obj = TaskResult(
|
|
464
|
+
job_id=job_id,
|
|
465
|
+
task_id=task_id,
|
|
466
|
+
worker_id=self._config.WORKER_ID,
|
|
467
|
+
status=TASK_STATUS_FAILURE,
|
|
468
|
+
error=error,
|
|
469
|
+
)
|
|
333
470
|
except Exception as e:
|
|
334
471
|
logger.exception(f"An unexpected error occurred while processing task {task_id}:")
|
|
335
|
-
|
|
472
|
+
error = TaskError(code=ERROR_CODE_TRANSIENT, message=str(e))
|
|
473
|
+
result_obj = TaskResult(
|
|
474
|
+
job_id=job_id,
|
|
475
|
+
task_id=task_id,
|
|
476
|
+
worker_id=self._config.WORKER_ID,
|
|
477
|
+
status=TASK_STATUS_FAILURE,
|
|
478
|
+
error=error,
|
|
479
|
+
)
|
|
336
480
|
finally:
|
|
337
|
-
# Cleanup task workspace
|
|
338
481
|
await self._s3_manager.cleanup(task_id)
|
|
339
482
|
|
|
340
|
-
if
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
483
|
+
if result_obj:
|
|
484
|
+
try:
|
|
485
|
+
await client.send_result(result_obj)
|
|
486
|
+
except RxonError as e:
|
|
487
|
+
logger.error(f"Failed to send task result: {e}")
|
|
344
488
|
|
|
489
|
+
self._active_tasks.pop(task_id, None)
|
|
345
490
|
self._current_load -= 1
|
|
346
491
|
if task_type_for_limit:
|
|
347
492
|
self._current_load_by_type[task_type_for_limit] -= 1
|
|
348
493
|
self._schedule_heartbeat_debounce()
|
|
349
494
|
|
|
350
|
-
async def _send_result(self, payload: dict[str, Any], orchestrator: dict[str, Any]):
|
|
351
|
-
"""Sends the result to a specific orchestrator."""
|
|
352
|
-
url = f"{orchestrator['url']}/_worker/tasks/result"
|
|
353
|
-
delay = self._config.RESULT_RETRY_INITIAL_DELAY
|
|
354
|
-
headers = self._get_headers(orchestrator)
|
|
355
|
-
for i in range(self._config.RESULT_MAX_RETRIES):
|
|
356
|
-
try:
|
|
357
|
-
if self._http_session and not self._http_session.closed:
|
|
358
|
-
async with self._http_session.post(url, json=payload, headers=headers) as resp:
|
|
359
|
-
if resp.status == 200:
|
|
360
|
-
return
|
|
361
|
-
except ClientError as e:
|
|
362
|
-
logger.error(f"Error sending result: {e}")
|
|
363
|
-
await sleep(delay * (2**i))
|
|
364
|
-
|
|
365
495
|
async def _manage_orchestrator_communications(self):
|
|
366
496
|
"""Registers the worker and sends heartbeats."""
|
|
367
497
|
await self._register_with_all_orchestrators()
|
|
@@ -374,81 +504,127 @@ class Worker:
|
|
|
374
504
|
await self._send_heartbeats_to_all()
|
|
375
505
|
await sleep(self._config.HEARTBEAT_INTERVAL)
|
|
376
506
|
|
|
507
|
+
@staticmethod
|
|
508
|
+
def _from_dict(cls: type, data: dict[str, Any]) -> Any:
|
|
509
|
+
"""Safely instantiates a NamedTuple from a dict, ignoring unknown fields."""
|
|
510
|
+
if not data:
|
|
511
|
+
return None
|
|
512
|
+
fields = cls._fields
|
|
513
|
+
filtered_data = {k: v for k, v in data.items() if k in fields}
|
|
514
|
+
return cls(**filtered_data)
|
|
515
|
+
|
|
377
516
|
async def _register_with_all_orchestrators(self):
|
|
378
517
|
"""Registers the worker with all orchestrators."""
|
|
379
518
|
state = self._get_current_state()
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
519
|
+
|
|
520
|
+
gpu_info = None
|
|
521
|
+
if self._config.RESOURCES.get("gpu_info"):
|
|
522
|
+
gpu_info = GPUInfo(**self._config.RESOURCES["gpu_info"])
|
|
523
|
+
|
|
524
|
+
resources = Resources(
|
|
525
|
+
max_concurrent_tasks=self._config.MAX_CONCURRENT_TASKS,
|
|
526
|
+
cpu_cores=self._config.RESOURCES["cpu_cores"],
|
|
527
|
+
gpu_info=gpu_info,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
s3_hash = calculate_config_hash(
|
|
531
|
+
self._config.S3_ENDPOINT_URL,
|
|
532
|
+
self._config.S3_ACCESS_KEY,
|
|
533
|
+
self._config.S3_DEFAULT_BUCKET,
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
registration = WorkerRegistration(
|
|
537
|
+
worker_id=self._config.WORKER_ID,
|
|
538
|
+
worker_type=self._config.WORKER_TYPE,
|
|
539
|
+
supported_tasks=state["supported_tasks"],
|
|
540
|
+
resources=resources,
|
|
541
|
+
installed_software=self._config.INSTALLED_SOFTWARE,
|
|
542
|
+
installed_models=[InstalledModel(**m) for m in self._config.INSTALLED_MODELS],
|
|
543
|
+
capabilities=WorkerCapabilities(
|
|
544
|
+
hostname=self._config.HOSTNAME,
|
|
545
|
+
ip_address=self._config.IP_ADDRESS,
|
|
546
|
+
cost_per_skill=self._config.COST_PER_SKILL,
|
|
547
|
+
s3_config_hash=s3_hash,
|
|
548
|
+
),
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
await gather(*[self._safe_register(client, registration) for _, client in self._clients])
|
|
552
|
+
|
|
553
|
+
async def _safe_register(self, client: Transport, registration: WorkerRegistration):
|
|
554
|
+
try:
|
|
555
|
+
await client.register(registration)
|
|
556
|
+
except RxonError as e:
|
|
557
|
+
logger.error(f"Registration failed for {client}: {e}")
|
|
402
558
|
|
|
403
559
|
async def _send_heartbeats_to_all(self):
|
|
404
560
|
"""Sends heartbeat messages to all orchestrators."""
|
|
405
561
|
state = self._get_current_state()
|
|
406
|
-
payload = {
|
|
407
|
-
"load": self._current_load,
|
|
408
|
-
"status": state["status"],
|
|
409
|
-
"supported_tasks": state["supported_tasks"],
|
|
410
|
-
"hot_cache": list(self._hot_cache),
|
|
411
|
-
}
|
|
412
562
|
|
|
563
|
+
hot_skills = None
|
|
413
564
|
if self._skill_dependencies:
|
|
414
|
-
payload["skill_dependencies"] = self._skill_dependencies
|
|
415
565
|
hot_skills = [
|
|
416
566
|
skill for skill, models in self._skill_dependencies.items() if set(models).issubset(self._hot_cache)
|
|
417
567
|
]
|
|
418
|
-
if hot_skills:
|
|
419
|
-
payload["hot_skills"] = hot_skills
|
|
420
568
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
569
|
+
heartbeat = Heartbeat(
|
|
570
|
+
worker_id=self._config.WORKER_ID,
|
|
571
|
+
status=state["status"],
|
|
572
|
+
load=float(self._current_load),
|
|
573
|
+
current_tasks=list(self._active_tasks.keys()),
|
|
574
|
+
supported_tasks=state["supported_tasks"],
|
|
575
|
+
hot_cache=list(self._hot_cache),
|
|
576
|
+
skill_dependencies=self._skill_dependencies or None,
|
|
577
|
+
hot_skills=hot_skills or None,
|
|
578
|
+
)
|
|
431
579
|
|
|
432
|
-
await gather(*[
|
|
580
|
+
await gather(*[self._safe_heartbeat(client, heartbeat) for _, client in self._clients])
|
|
581
|
+
|
|
582
|
+
async def _safe_heartbeat(self, client: Transport, heartbeat: Heartbeat):
|
|
583
|
+
try:
|
|
584
|
+
await client.send_heartbeat(heartbeat)
|
|
585
|
+
except RxonError as e:
|
|
586
|
+
logger.warning(f"Heartbeat failed for {client}: {e}")
|
|
433
587
|
|
|
434
588
|
async def main(self):
|
|
435
589
|
"""The main asynchronous function."""
|
|
436
|
-
self.
|
|
590
|
+
self._config.validate()
|
|
591
|
+
self._validate_task_types()
|
|
437
592
|
if not self._http_session:
|
|
438
|
-
self.
|
|
593
|
+
if self._config.TLS_CA_PATH or (self._config.TLS_CERT_PATH and self._config.TLS_KEY_PATH):
|
|
594
|
+
logger.info("Initializing SSL context for mTLS.")
|
|
595
|
+
self._ssl_context = create_client_ssl_context(
|
|
596
|
+
ca_path=self._config.TLS_CA_PATH,
|
|
597
|
+
cert_path=self._config.TLS_CERT_PATH,
|
|
598
|
+
key_path=self._config.TLS_KEY_PATH,
|
|
599
|
+
)
|
|
600
|
+
connector = TCPConnector(ssl=self._ssl_context) if self._ssl_context else None
|
|
601
|
+
self._http_session = ClientSession(connector=connector)
|
|
602
|
+
if not self._clients:
|
|
603
|
+
self._init_clients()
|
|
604
|
+
|
|
605
|
+
# Connect transports
|
|
606
|
+
await gather(*[client.connect() for _, client in self._clients])
|
|
439
607
|
|
|
440
608
|
comm_task = create_task(self._manage_orchestrator_communications())
|
|
441
609
|
|
|
610
|
+
token_rotation_task = None
|
|
611
|
+
if self._ssl_context:
|
|
612
|
+
token_rotation_task = create_task(self._manage_token_rotation())
|
|
613
|
+
|
|
442
614
|
polling_task = create_task(self._start_polling())
|
|
443
615
|
await self._shutdown_event.wait()
|
|
444
616
|
|
|
445
617
|
for task in [comm_task, polling_task]:
|
|
446
618
|
task.cancel()
|
|
619
|
+
if token_rotation_task:
|
|
620
|
+
token_rotation_task.cancel()
|
|
621
|
+
|
|
447
622
|
if self._active_tasks:
|
|
448
623
|
await gather(*self._active_tasks.values(), return_exceptions=True)
|
|
449
624
|
|
|
450
|
-
|
|
451
|
-
|
|
625
|
+
# Close transports
|
|
626
|
+
await gather(*[client.close() for _, client in self._clients])
|
|
627
|
+
|
|
452
628
|
if self._http_session and not self._http_session.closed and not self._session_is_managed_externally:
|
|
453
629
|
await self._http_session.close()
|
|
454
630
|
|
|
@@ -459,6 +635,26 @@ class Worker:
|
|
|
459
635
|
except KeyboardInterrupt:
|
|
460
636
|
self._shutdown_event.set()
|
|
461
637
|
|
|
638
|
+
async def _manage_token_rotation(self):
|
|
639
|
+
"""Periodically refreshes auth tokens for all clients."""
|
|
640
|
+
await sleep(5)
|
|
641
|
+
|
|
642
|
+
while not self._shutdown_event.is_set():
|
|
643
|
+
min_expires_in = 3600
|
|
644
|
+
|
|
645
|
+
for _, client in self._clients:
|
|
646
|
+
try:
|
|
647
|
+
token_resp = await client.refresh_token()
|
|
648
|
+
if token_resp:
|
|
649
|
+
self._config.WORKER_TOKEN = token_resp.access_token
|
|
650
|
+
min_expires_in = min(min_expires_in, token_resp.expires_in)
|
|
651
|
+
except Exception as e:
|
|
652
|
+
logger.error(f"Error in token rotation loop: {e}")
|
|
653
|
+
|
|
654
|
+
refresh_delay = max(60, min_expires_in * 0.8)
|
|
655
|
+
logger.debug(f"Next token refresh scheduled in {refresh_delay:.1f}s")
|
|
656
|
+
await sleep(refresh_delay)
|
|
657
|
+
|
|
462
658
|
async def _run_health_check_server(self):
|
|
463
659
|
app = web.Application()
|
|
464
660
|
|
|
@@ -482,60 +678,27 @@ class Worker:
|
|
|
482
678
|
except KeyboardInterrupt:
|
|
483
679
|
self._shutdown_event.set()
|
|
484
680
|
|
|
485
|
-
# WebSocket methods omitted for brevity as they are not relevant to the changes
|
|
486
681
|
async def _start_websocket_manager(self):
|
|
487
|
-
"""Manages the
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
try:
|
|
492
|
-
if self._http_session:
|
|
493
|
-
async with self._http_session.ws_connect(ws_url, headers=self._get_headers(orchestrator)) as ws:
|
|
494
|
-
self._ws_connection = ws
|
|
495
|
-
logger.info(f"WebSocket connection established to {ws_url}")
|
|
496
|
-
await self._listen_for_commands()
|
|
497
|
-
except (ClientError, AsyncTimeoutError) as e:
|
|
498
|
-
logger.warning(f"WebSocket connection to {ws_url} failed: {e}")
|
|
499
|
-
finally:
|
|
500
|
-
self._ws_connection = None
|
|
501
|
-
logger.info(f"WebSocket connection to {ws_url} closed.")
|
|
502
|
-
await sleep(5) # Reconnection delay
|
|
503
|
-
if not self._config.ORCHESTRATORS:
|
|
504
|
-
await sleep(5)
|
|
505
|
-
|
|
506
|
-
async def _listen_for_commands(self):
|
|
507
|
-
"""Listens for and processes commands from the orchestrator via WebSocket."""
|
|
508
|
-
if not self._ws_connection:
|
|
509
|
-
return
|
|
682
|
+
"""Manages the command listeners."""
|
|
683
|
+
listeners = []
|
|
684
|
+
for _, client in self._clients:
|
|
685
|
+
listeners.append(create_task(self._listen_to_single_transport(client)))
|
|
510
686
|
|
|
511
|
-
|
|
512
|
-
async for msg in self._ws_connection:
|
|
513
|
-
if msg.type == WSMsgType.TEXT:
|
|
514
|
-
try:
|
|
515
|
-
command = msg.json()
|
|
516
|
-
if command.get("type") == "cancel_task":
|
|
517
|
-
task_id = command.get("task_id")
|
|
518
|
-
if task_id in self._active_tasks:
|
|
519
|
-
self._active_tasks[task_id].cancel()
|
|
520
|
-
logger.info(f"Cancelled task {task_id} by orchestrator command.")
|
|
521
|
-
except JSONDecodeError:
|
|
522
|
-
logger.warning(f"Received invalid JSON over WebSocket: {msg.data}")
|
|
523
|
-
elif msg.type == WSMsgType.ERROR:
|
|
524
|
-
break
|
|
525
|
-
except Exception as e:
|
|
526
|
-
logger.error(f"Error in WebSocket listener: {e}")
|
|
687
|
+
await self._shutdown_event.wait()
|
|
527
688
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
689
|
+
for listener in listeners:
|
|
690
|
+
listener.cancel()
|
|
691
|
+
|
|
692
|
+
async def _listen_to_single_transport(self, client: Transport):
|
|
693
|
+
while not self._shutdown_event.is_set():
|
|
531
694
|
try:
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
await self._ws_connection.send_json(payload)
|
|
695
|
+
async for command in client.listen_for_commands():
|
|
696
|
+
if command.command == COMMAND_CANCEL_TASK:
|
|
697
|
+
task_id = command.task_id
|
|
698
|
+
job_id = command.job_id
|
|
699
|
+
if task_id in self._active_tasks:
|
|
700
|
+
self._active_tasks[task_id].cancel()
|
|
701
|
+
logger.info(f"Cancelled task {task_id} (Job: {job_id or 'N/A'}) by orchestrator command.")
|
|
540
702
|
except Exception as e:
|
|
541
|
-
logger.
|
|
703
|
+
logger.error(f"Error in command listener: {e}")
|
|
704
|
+
await sleep(5)
|