avtomatika-worker 1.0b2__py3-none-any.whl → 1.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,18 +1,46 @@
1
- from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
2
- from asyncio import TimeoutError as AsyncTimeoutError
1
+ from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep, to_thread
3
2
  from dataclasses import is_dataclass
4
- from inspect import Parameter, signature
5
- from json import JSONDecodeError
3
+ from inspect import Parameter, iscoroutinefunction, signature
6
4
  from logging import getLogger
7
5
  from os.path import join
8
6
  from typing import Any, Callable
9
7
 
10
- from aiohttp import ClientError, ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType, web
8
+ from aiohttp import ClientSession, TCPConnector, web
9
+ from rxon import Transport, create_transport
10
+ from rxon.blob import calculate_config_hash
11
+ from rxon.constants import (
12
+ COMMAND_CANCEL_TASK,
13
+ ERROR_CODE_INTEGRITY_MISMATCH,
14
+ ERROR_CODE_INVALID_INPUT,
15
+ ERROR_CODE_PERMANENT,
16
+ ERROR_CODE_TRANSIENT,
17
+ MSG_TYPE_PROGRESS,
18
+ TASK_STATUS_CANCELLED,
19
+ TASK_STATUS_FAILURE,
20
+ TASK_STATUS_SUCCESS,
21
+ )
22
+ from rxon.exceptions import RxonError
23
+ from rxon.models import (
24
+ FileMetadata,
25
+ GPUInfo,
26
+ Heartbeat,
27
+ InstalledModel,
28
+ ProgressUpdatePayload,
29
+ Resources,
30
+ TaskError,
31
+ TaskPayload,
32
+ TaskResult,
33
+ WorkerCapabilities,
34
+ WorkerRegistration,
35
+ )
36
+ from rxon.security import create_client_ssl_context
37
+ from rxon.utils import to_dict
38
+ from rxon.validators import validate_identifier
11
39
 
12
40
  from .config import WorkerConfig
13
41
  from .s3 import S3Manager
14
42
  from .task_files import TaskFiles
15
- from .types import INVALID_INPUT_ERROR, PERMANENT_ERROR, TRANSIENT_ERROR, ParamValidationError
43
+ from .types import CapacityChecker, Middleware, ParamValidationError
16
44
 
17
45
  try:
18
46
  from pydantic import BaseModel, ValidationError
@@ -29,7 +57,7 @@ class Worker:
29
57
  """The main class for creating and running a worker.
30
58
  Implements a hybrid interaction model with the Orchestrator:
31
59
  - PULL model for fetching tasks.
32
- - WebSocket for real-time commands (cancellation) and sending progress.
60
+ - Transport layer for real-time commands (cancellation) and sending progress.
33
61
  """
34
62
 
35
63
  def __init__(
@@ -40,16 +68,20 @@ class Worker:
40
68
  http_session: ClientSession | None = None,
41
69
  skill_dependencies: dict[str, list[str]] | None = None,
42
70
  config: WorkerConfig | None = None,
71
+ capacity_checker: CapacityChecker | None = None,
72
+ clients: list[tuple[dict[str, Any], Transport]] | None = None,
43
73
  ):
44
74
  self._config = config or WorkerConfig()
45
75
  self._s3_manager = S3Manager(self._config)
46
76
  self._config.WORKER_TYPE = worker_type # Allow overriding worker_type
47
77
  if max_concurrent_tasks is not None:
48
- self._config.max_concurrent_tasks = max_concurrent_tasks
78
+ self._config.MAX_CONCURRENT_TASKS = max_concurrent_tasks
49
79
 
50
80
  self._task_type_limits = task_type_limits or {}
51
81
  self._task_handlers: dict[str, dict[str, Any]] = {}
52
82
  self._skill_dependencies = skill_dependencies or {}
83
+ self._middlewares: list[Middleware] = []
84
+ self._capacity_checker = capacity_checker
53
85
 
54
86
  # Worker state
55
87
  self._current_load = 0
@@ -58,12 +90,10 @@ class Worker:
58
90
  self._active_tasks: dict[str, Task] = {}
59
91
  self._http_session = http_session
60
92
  self._session_is_managed_externally = http_session is not None
61
- self._ws_connection: ClientWebSocketResponse | None = None
62
- # Removed: self._headers = {"X-Worker-Token": self._config.WORKER_TOKEN}
63
93
  self._shutdown_event = Event()
64
94
  self._registered_event = Event()
65
- self._round_robin_index = 0
66
95
  self._debounce_task: Task | None = None
96
+ self._ssl_context = None
67
97
 
68
98
  # --- Weighted Round-Robin State ---
69
99
  self._total_orchestrator_weight = 0
@@ -71,8 +101,37 @@ class Worker:
71
101
  for o in self._config.ORCHESTRATORS:
72
102
  o["current_weight"] = 0
73
103
  self._total_orchestrator_weight += o.get("weight", 1)
104
+ self._clients = clients or []
105
+ if not self._clients and self._http_session:
106
+ self._init_clients()
107
+
108
+ def add_middleware(self, middleware: Middleware) -> None:
109
+ """Adds a middleware to the execution chain."""
110
+ self._middlewares.append(middleware)
111
+
112
+ def _init_clients(self):
113
+ """Initializes Transport instances for each configured orchestrator."""
114
+ # Even if we don't have an external session, we might create transports
115
+ # that will create their own sessions. But if we want to share one, we need it here.
116
+ session_to_use = self._http_session if self._session_is_managed_externally else None
117
+
118
+ self._clients = [
119
+ (
120
+ o,
121
+ create_transport(
122
+ url=o["url"],
123
+ worker_id=self._config.WORKER_ID,
124
+ token=o.get("token", self._config.WORKER_TOKEN),
125
+ ssl_context=self._ssl_context,
126
+ session=session_to_use,
127
+ result_retries=self._config.RESULT_MAX_RETRIES,
128
+ result_retry_delay=self._config.RESULT_RETRY_INITIAL_DELAY,
129
+ ),
130
+ )
131
+ for o in self._config.ORCHESTRATORS
132
+ ]
74
133
 
75
- def _validate_config(self):
134
+ def _validate_task_types(self):
76
135
  """Checks for unused task type limits and warns the user."""
77
136
  registered_task_types = {
78
137
  handler_data["type"] for handler_data in self._task_handlers.values() if handler_data["type"]
@@ -87,6 +146,9 @@ class Worker:
87
146
 
88
147
  def task(self, name: str, task_type: str | None = None) -> Callable:
89
148
  """Decorator to register a function as a task handler."""
149
+ validate_identifier(name, "task name")
150
+ if task_type:
151
+ validate_identifier(task_type, "task type")
90
152
 
91
153
  def decorator(func: Callable) -> Callable:
92
154
  logger.info(f"Registering task: '{name}' (type: {task_type or 'N/A'})")
@@ -137,35 +199,37 @@ class Worker:
137
199
  if is_available:
138
200
  supported_tasks.append(name)
139
201
 
202
+ if self._capacity_checker:
203
+ supported_tasks = [task for task in supported_tasks if self._capacity_checker(task)]
204
+
140
205
  status = "idle" if supported_tasks else "busy"
141
206
  return {"status": status, "supported_tasks": supported_tasks}
142
207
 
143
- def _get_headers(self, orchestrator: dict[str, Any]) -> dict[str, str]:
144
- """Builds authentication headers for a specific orchestrator."""
145
- token = orchestrator.get("token", self._config.WORKER_TOKEN)
146
- return {"X-Worker-Token": token}
147
-
148
- def _get_next_orchestrator(self) -> dict[str, Any] | None:
208
+ def _get_next_client(self) -> Transport | None:
149
209
  """
150
- Selects the next orchestrator using a smooth weighted round-robin algorithm.
210
+ Selects the next orchestrator client using a smooth weighted round-robin algorithm.
151
211
  """
152
- if not self._config.ORCHESTRATORS:
212
+ if not self._clients:
153
213
  return None
154
214
 
155
215
  # The orchestrator with the highest current_weight is selected.
156
- selected_orchestrator = None
216
+ selected_client = None
157
217
  highest_weight = -1
158
218
 
159
- for o in self._config.ORCHESTRATORS:
219
+ for o, client in self._clients:
160
220
  o["current_weight"] += o["weight"]
161
221
  if o["current_weight"] > highest_weight:
162
222
  highest_weight = o["current_weight"]
163
- selected_orchestrator = o
223
+ selected_client = client
164
224
 
165
- if selected_orchestrator:
166
- selected_orchestrator["current_weight"] -= self._total_orchestrator_weight
225
+ if selected_client:
226
+ # Find the config for the selected client to decrement its weight
227
+ for o, client in self._clients:
228
+ if client == selected_client:
229
+ o["current_weight"] -= self._total_orchestrator_weight
230
+ break
167
231
 
168
- return selected_orchestrator
232
+ return selected_client
169
233
 
170
234
  async def _debounced_heartbeat_sender(self):
171
235
  """Waits for the debounce delay then sends a heartbeat."""
@@ -180,33 +244,29 @@ class Worker:
180
244
  # Schedule the new debounced call.
181
245
  self._debounce_task = create_task(self._debounced_heartbeat_sender())
182
246
 
183
- async def _poll_for_tasks(self, orchestrator: dict[str, Any]):
247
+ async def _poll_for_tasks(self, client: Transport):
184
248
  """Polls a specific Orchestrator for new tasks."""
185
- url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}/tasks/next"
249
+ current_state = self._get_current_state()
250
+ if current_state["status"] == "busy":
251
+ return
252
+
186
253
  try:
187
- if not self._http_session:
188
- return
189
- timeout = ClientTimeout(total=self._config.TASK_POLL_TIMEOUT + 5)
190
- headers = self._get_headers(orchestrator)
191
- async with self._http_session.get(url, headers=headers, timeout=timeout) as resp:
192
- if resp.status == 200:
193
- task_data = await resp.json()
194
- task_data["orchestrator"] = orchestrator
195
-
196
- self._current_load += 1
197
- if (task_handler_info := self._task_handlers.get(task_data["type"])) and (
198
- task_type_for_limit := task_handler_info.get("type")
199
- ):
200
- self._current_load_by_type[task_type_for_limit] += 1
201
- self._schedule_heartbeat_debounce()
202
-
203
- task = create_task(self._process_task(task_data))
204
- self._active_tasks[task_data["task_id"]] = task
205
- elif resp.status != 204:
206
- await sleep(self._config.TASK_POLL_ERROR_DELAY)
207
- except (AsyncTimeoutError, ClientError) as e:
208
- logger.error(f"Error polling for tasks: {e}")
209
- await sleep(self._config.TASK_POLL_ERROR_DELAY)
254
+ task_data = await client.poll_task(timeout=self._config.TASK_POLL_TIMEOUT)
255
+ if task_data:
256
+ task_data_dict = to_dict(task_data)
257
+ task_data_dict["client"] = client
258
+
259
+ self._current_load += 1
260
+ if (task_handler_info := self._task_handlers.get(task_data.type)) and (
261
+ task_type_for_limit := task_handler_info.get("type")
262
+ ):
263
+ self._current_load_by_type[task_type_for_limit] += 1
264
+ self._schedule_heartbeat_debounce()
265
+
266
+ task = create_task(self._process_task(task_data_dict))
267
+ self._active_tasks[task_data.task_id] = task
268
+ except RxonError as e:
269
+ logger.error(f"Error polling tasks: {e}")
210
270
 
211
271
  async def _start_polling(self):
212
272
  """The main loop for polling tasks."""
@@ -218,13 +278,13 @@ class Worker:
218
278
  continue
219
279
 
220
280
  if self._config.MULTI_ORCHESTRATOR_MODE == "ROUND_ROBIN":
221
- if orchestrator := self._get_next_orchestrator():
222
- await self._poll_for_tasks(orchestrator)
281
+ if client := self._get_next_client():
282
+ await self._poll_for_tasks(client)
223
283
  else:
224
- for orchestrator in self._config.ORCHESTRATORS:
284
+ for _, client in self._clients:
225
285
  if self._get_current_state()["status"] == "busy":
226
286
  break
227
- await self._poll_for_tasks(orchestrator)
287
+ await self._poll_for_tasks(client)
228
288
 
229
289
  if self._current_load == 0:
230
290
  await sleep(self._config.IDLE_POLL_DELAY)
@@ -233,7 +293,6 @@ class Worker:
233
293
  def _prepare_task_params(handler: Callable, params: dict[str, Any]) -> Any:
234
294
  """
235
295
  Inspects the handler's signature to validate and instantiate params.
236
- Supports dict, dataclasses, and optional pydantic models.
237
296
  """
238
297
  sig = signature(handler)
239
298
  params_annotation = sig.parameters.get("params").annotation
@@ -251,11 +310,11 @@ class Worker:
251
310
  # Dataclass Instantiation
252
311
  if isinstance(params_annotation, type) and is_dataclass(params_annotation):
253
312
  try:
254
- # Filter unknown fields to prevent TypeError on dataclass instantiation
313
+ # Filter unknown fields
255
314
  known_fields = {f.name for f in params_annotation.__dataclass_fields__.values()}
256
315
  filtered_params = {k: v for k, v in params.items() if k in known_fields}
257
316
 
258
- # Explicitly check for missing required fields
317
+ # Check required fields
259
318
  required_fields = [
260
319
  f.name
261
320
  for f in params_annotation.__dataclass_fields__.values()
@@ -267,7 +326,6 @@ class Worker:
267
326
 
268
327
  return params_annotation(**filtered_params)
269
328
  except (TypeError, ValueError) as e:
270
- # TypeError for missing/extra args, ValueError from __post_init__
271
329
  raise ParamValidationError(str(e)) from e
272
330
 
273
331
  return params
@@ -276,8 +334,7 @@ class Worker:
276
334
  """Injects dependencies based on type hints."""
277
335
  deps = {}
278
336
  task_dir = join(self._config.TASK_FILES_DIR, task_id)
279
- # Always create the object, but directory is lazy
280
- task_files = TaskFiles(task_dir)
337
+ task_files = TaskFiles(task_dir, s3_manager=self._s3_manager)
281
338
 
282
339
  sig = signature(handler)
283
340
  for name, param in sig.parameters.items():
@@ -286,82 +343,155 @@ class Worker:
286
343
 
287
344
  return deps
288
345
 
289
- async def _process_task(self, task_data: dict[str, Any]):
346
+ async def _process_task(self, task_data_raw: dict[str, Any]):
290
347
  """Executes the task logic."""
291
- task_id, job_id, task_name = task_data["task_id"], task_data["job_id"], task_data["type"]
292
- params, orchestrator = task_data.get("params", {}), task_data["orchestrator"]
348
+ client: Transport = task_data_raw.pop("client")
349
+
350
+ # Parse incoming task data using protocol model
351
+ if "params_metadata" in task_data_raw and task_data_raw["params_metadata"]:
352
+ task_data_raw["params_metadata"] = {
353
+ k: self._from_dict(FileMetadata, v) for k, v in task_data_raw["params_metadata"].items()
354
+ }
355
+
356
+ task_payload = self._from_dict(TaskPayload, task_data_raw)
357
+ task_id, job_id, task_name = task_payload.task_id, task_payload.job_id, task_payload.type
358
+ params = task_payload.params
293
359
 
294
- result: dict[str, Any] = {}
295
360
  handler_data = self._task_handlers.get(task_name)
296
361
  task_type_for_limit = handler_data.get("type") if handler_data else None
297
362
 
298
- result_sent = False # Flag to track if result has been sent
363
+ result_obj: TaskResult | None = None
364
+
365
+ # Create a progress sender wrapper attached to this specific client
366
+ async def send_progress_wrapper(task_id_arg, job_id_arg, progress, message=""):
367
+ payload = ProgressUpdatePayload(
368
+ event=MSG_TYPE_PROGRESS, task_id=task_id_arg, job_id=job_id_arg, progress=progress, message=message
369
+ )
370
+ await client.send_progress(payload)
299
371
 
300
372
  try:
301
373
  if not handler_data:
302
374
  message = f"Unsupported task: {task_name}"
303
375
  logger.warning(message)
304
- result = {"status": "failure", "error": {"code": PERMANENT_ERROR, "message": message}}
305
- payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
306
- await self._send_result(payload, orchestrator)
307
- result_sent = True # Mark result as sent
308
- return
309
-
310
- params = await self._s3_manager.process_params(params, task_id)
311
- validated_params = self._prepare_task_params(handler_data["func"], params)
312
- deps = self._prepare_dependencies(handler_data["func"], task_id)
313
-
314
- result = await handler_data["func"](
315
- validated_params,
316
- task_id=task_id,
317
- job_id=job_id,
318
- priority=task_data.get("priority", 0),
319
- send_progress=self.send_progress,
320
- add_to_hot_cache=self.add_to_hot_cache,
321
- remove_from_hot_cache=self.remove_from_hot_cache,
322
- **deps,
323
- )
324
- result = await self._s3_manager.process_result(result)
376
+ error = TaskError(code=ERROR_CODE_PERMANENT, message=message)
377
+ result_obj = TaskResult(
378
+ job_id=job_id,
379
+ task_id=task_id,
380
+ worker_id=self._config.WORKER_ID,
381
+ status=TASK_STATUS_FAILURE,
382
+ error=error,
383
+ )
384
+ else:
385
+ # Download files
386
+ params = await self._s3_manager.process_params(params, task_id, metadata=task_payload.params_metadata)
387
+ validated_params = self._prepare_task_params(handler_data["func"], params)
388
+ deps = self._prepare_dependencies(handler_data["func"], task_id)
389
+
390
+ handler_kwargs = {
391
+ "params": validated_params,
392
+ "task_id": task_id,
393
+ "job_id": job_id,
394
+ "tracing_context": task_payload.tracing_context,
395
+ "priority": task_data_raw.get("priority", 0),
396
+ "send_progress": send_progress_wrapper,
397
+ "add_to_hot_cache": self.add_to_hot_cache,
398
+ "remove_from_hot_cache": self.remove_from_hot_cache,
399
+ **deps,
400
+ }
401
+
402
+ middleware_context = {
403
+ "task_id": task_id,
404
+ "job_id": job_id,
405
+ "task_name": task_name,
406
+ "params": validated_params,
407
+ "handler_kwargs": handler_kwargs,
408
+ }
409
+
410
+ async def _execution_logic():
411
+ handler = handler_data["func"]
412
+ final_kwargs = middleware_context["handler_kwargs"]
413
+
414
+ if iscoroutinefunction(handler):
415
+ return await handler(**final_kwargs)
416
+ else:
417
+ return await to_thread(handler, **final_kwargs)
418
+
419
+ handler_chain = _execution_logic
420
+ for middleware in reversed(self._middlewares):
421
+
422
+ def make_wrapper(mw: Middleware, next_handler: Callable) -> Callable:
423
+ async def wrapper():
424
+ return await mw(middleware_context, next_handler)
425
+
426
+ return wrapper
427
+
428
+ handler_chain = make_wrapper(middleware, handler_chain)
429
+
430
+ handler_result = await handler_chain()
431
+
432
+ updated_data, metadata_map = await self._s3_manager.process_result(handler_result.get("data", {}))
433
+
434
+ result_obj = TaskResult(
435
+ job_id=job_id,
436
+ task_id=task_id,
437
+ worker_id=self._config.WORKER_ID,
438
+ status=handler_result.get("status", TASK_STATUS_SUCCESS),
439
+ data=updated_data,
440
+ error=TaskError(**handler_result["error"]) if "error" in handler_result else None,
441
+ data_metadata=metadata_map if metadata_map else None,
442
+ )
443
+
325
444
  except ParamValidationError as e:
326
445
  logger.error(f"Task {task_id} failed validation: {e}")
327
- result = {"status": "failure", "error": {"code": INVALID_INPUT_ERROR, "message": str(e)}}
446
+ error = TaskError(code=ERROR_CODE_INVALID_INPUT, message=str(e))
447
+ result_obj = TaskResult(
448
+ job_id=job_id,
449
+ task_id=task_id,
450
+ worker_id=self._config.WORKER_ID,
451
+ status=TASK_STATUS_FAILURE,
452
+ error=error,
453
+ )
328
454
  except CancelledError:
329
455
  logger.info(f"Task {task_id} was cancelled.")
330
- result = {"status": "cancelled"}
331
- # We must re-raise the exception to be handled by the outer gather
456
+ result_obj = TaskResult(
457
+ job_id=job_id, task_id=task_id, worker_id=self._config.WORKER_ID, status=TASK_STATUS_CANCELLED
458
+ )
332
459
  raise
460
+ except ValueError as e:
461
+ logger.error(f"Data integrity or validation error for task {task_id}: {e}")
462
+ error = TaskError(code=ERROR_CODE_INTEGRITY_MISMATCH, message=str(e))
463
+ result_obj = TaskResult(
464
+ job_id=job_id,
465
+ task_id=task_id,
466
+ worker_id=self._config.WORKER_ID,
467
+ status=TASK_STATUS_FAILURE,
468
+ error=error,
469
+ )
333
470
  except Exception as e:
334
471
  logger.exception(f"An unexpected error occurred while processing task {task_id}:")
335
- result = {"status": "failure", "error": {"code": TRANSIENT_ERROR, "message": str(e)}}
472
+ error = TaskError(code=ERROR_CODE_TRANSIENT, message=str(e))
473
+ result_obj = TaskResult(
474
+ job_id=job_id,
475
+ task_id=task_id,
476
+ worker_id=self._config.WORKER_ID,
477
+ status=TASK_STATUS_FAILURE,
478
+ error=error,
479
+ )
336
480
  finally:
337
- # Cleanup task workspace
338
481
  await self._s3_manager.cleanup(task_id)
339
482
 
340
- if not result_sent: # Only send if not already sent
341
- payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
342
- await self._send_result(payload, orchestrator)
343
- self._active_tasks.pop(task_id, None)
483
+ if result_obj:
484
+ try:
485
+ await client.send_result(result_obj)
486
+ except RxonError as e:
487
+ logger.error(f"Failed to send task result: {e}")
344
488
 
489
+ self._active_tasks.pop(task_id, None)
345
490
  self._current_load -= 1
346
491
  if task_type_for_limit:
347
492
  self._current_load_by_type[task_type_for_limit] -= 1
348
493
  self._schedule_heartbeat_debounce()
349
494
 
350
- async def _send_result(self, payload: dict[str, Any], orchestrator: dict[str, Any]):
351
- """Sends the result to a specific orchestrator."""
352
- url = f"{orchestrator['url']}/_worker/tasks/result"
353
- delay = self._config.RESULT_RETRY_INITIAL_DELAY
354
- headers = self._get_headers(orchestrator)
355
- for i in range(self._config.RESULT_MAX_RETRIES):
356
- try:
357
- if self._http_session and not self._http_session.closed:
358
- async with self._http_session.post(url, json=payload, headers=headers) as resp:
359
- if resp.status == 200:
360
- return
361
- except ClientError as e:
362
- logger.error(f"Error sending result: {e}")
363
- await sleep(delay * (2**i))
364
-
365
495
  async def _manage_orchestrator_communications(self):
366
496
  """Registers the worker and sends heartbeats."""
367
497
  await self._register_with_all_orchestrators()
@@ -374,81 +504,127 @@ class Worker:
374
504
  await self._send_heartbeats_to_all()
375
505
  await sleep(self._config.HEARTBEAT_INTERVAL)
376
506
 
507
+ @staticmethod
508
+ def _from_dict(cls: type, data: dict[str, Any]) -> Any:
509
+ """Safely instantiates a NamedTuple from a dict, ignoring unknown fields."""
510
+ if not data:
511
+ return None
512
+ fields = cls._fields
513
+ filtered_data = {k: v for k, v in data.items() if k in fields}
514
+ return cls(**filtered_data)
515
+
377
516
  async def _register_with_all_orchestrators(self):
378
517
  """Registers the worker with all orchestrators."""
379
518
  state = self._get_current_state()
380
- payload = {
381
- "worker_id": self._config.WORKER_ID,
382
- "worker_type": self._config.WORKER_TYPE,
383
- "supported_tasks": state["supported_tasks"],
384
- "max_concurrent_tasks": self._config.MAX_CONCURRENT_TASKS,
385
- "cost_per_skill": self._config.COST_PER_SKILL,
386
- "installed_models": self._config.INSTALLED_MODELS,
387
- "hostname": self._config.HOSTNAME,
388
- "ip_address": self._config.IP_ADDRESS,
389
- "resources": self._config.RESOURCES,
390
- }
391
- for orchestrator in self._config.ORCHESTRATORS:
392
- url = f"{orchestrator['url']}/_worker/workers/register"
393
- try:
394
- if self._http_session:
395
- async with self._http_session.post(
396
- url, json=payload, headers=self._get_headers(orchestrator)
397
- ) as resp:
398
- if resp.status >= 400:
399
- logger.error(f"Error registering with {orchestrator['url']}: {resp.status}")
400
- except ClientError as e:
401
- logger.error(f"Error registering with orchestrator {orchestrator['url']}: {e}")
519
+
520
+ gpu_info = None
521
+ if self._config.RESOURCES.get("gpu_info"):
522
+ gpu_info = GPUInfo(**self._config.RESOURCES["gpu_info"])
523
+
524
+ resources = Resources(
525
+ max_concurrent_tasks=self._config.MAX_CONCURRENT_TASKS,
526
+ cpu_cores=self._config.RESOURCES["cpu_cores"],
527
+ gpu_info=gpu_info,
528
+ )
529
+
530
+ s3_hash = calculate_config_hash(
531
+ self._config.S3_ENDPOINT_URL,
532
+ self._config.S3_ACCESS_KEY,
533
+ self._config.S3_DEFAULT_BUCKET,
534
+ )
535
+
536
+ registration = WorkerRegistration(
537
+ worker_id=self._config.WORKER_ID,
538
+ worker_type=self._config.WORKER_TYPE,
539
+ supported_tasks=state["supported_tasks"],
540
+ resources=resources,
541
+ installed_software=self._config.INSTALLED_SOFTWARE,
542
+ installed_models=[InstalledModel(**m) for m in self._config.INSTALLED_MODELS],
543
+ capabilities=WorkerCapabilities(
544
+ hostname=self._config.HOSTNAME,
545
+ ip_address=self._config.IP_ADDRESS,
546
+ cost_per_skill=self._config.COST_PER_SKILL,
547
+ s3_config_hash=s3_hash,
548
+ ),
549
+ )
550
+
551
+ await gather(*[self._safe_register(client, registration) for _, client in self._clients])
552
+
553
+ async def _safe_register(self, client: Transport, registration: WorkerRegistration):
554
+ try:
555
+ await client.register(registration)
556
+ except RxonError as e:
557
+ logger.error(f"Registration failed for {client}: {e}")
402
558
 
403
559
  async def _send_heartbeats_to_all(self):
404
560
  """Sends heartbeat messages to all orchestrators."""
405
561
  state = self._get_current_state()
406
- payload = {
407
- "load": self._current_load,
408
- "status": state["status"],
409
- "supported_tasks": state["supported_tasks"],
410
- "hot_cache": list(self._hot_cache),
411
- }
412
562
 
563
+ hot_skills = None
413
564
  if self._skill_dependencies:
414
- payload["skill_dependencies"] = self._skill_dependencies
415
565
  hot_skills = [
416
566
  skill for skill, models in self._skill_dependencies.items() if set(models).issubset(self._hot_cache)
417
567
  ]
418
- if hot_skills:
419
- payload["hot_skills"] = hot_skills
420
568
 
421
- async def _send_single(orchestrator: dict[str, Any]):
422
- url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}"
423
- headers = self._get_headers(orchestrator)
424
- try:
425
- if self._http_session and not self._http_session.closed:
426
- async with self._http_session.patch(url, json=payload, headers=headers) as resp:
427
- if resp.status >= 400:
428
- logger.warning(f"Heartbeat to {orchestrator['url']} failed with status: {resp.status}")
429
- except ClientError as e:
430
- logger.error(f"Error sending heartbeat to orchestrator {orchestrator['url']}: {e}")
569
+ heartbeat = Heartbeat(
570
+ worker_id=self._config.WORKER_ID,
571
+ status=state["status"],
572
+ load=float(self._current_load),
573
+ current_tasks=list(self._active_tasks.keys()),
574
+ supported_tasks=state["supported_tasks"],
575
+ hot_cache=list(self._hot_cache),
576
+ skill_dependencies=self._skill_dependencies or None,
577
+ hot_skills=hot_skills or None,
578
+ )
431
579
 
432
- await gather(*[_send_single(o) for o in self._config.ORCHESTRATORS])
580
+ await gather(*[self._safe_heartbeat(client, heartbeat) for _, client in self._clients])
581
+
582
+ async def _safe_heartbeat(self, client: Transport, heartbeat: Heartbeat):
583
+ try:
584
+ await client.send_heartbeat(heartbeat)
585
+ except RxonError as e:
586
+ logger.warning(f"Heartbeat failed for {client}: {e}")
433
587
 
434
588
  async def main(self):
435
589
  """The main asynchronous function."""
436
- self._validate_config() # Validate config now that all tasks are registered
590
+ self._config.validate()
591
+ self._validate_task_types()
437
592
  if not self._http_session:
438
- self._http_session = ClientSession()
593
+ if self._config.TLS_CA_PATH or (self._config.TLS_CERT_PATH and self._config.TLS_KEY_PATH):
594
+ logger.info("Initializing SSL context for mTLS.")
595
+ self._ssl_context = create_client_ssl_context(
596
+ ca_path=self._config.TLS_CA_PATH,
597
+ cert_path=self._config.TLS_CERT_PATH,
598
+ key_path=self._config.TLS_KEY_PATH,
599
+ )
600
+ connector = TCPConnector(ssl=self._ssl_context) if self._ssl_context else None
601
+ self._http_session = ClientSession(connector=connector)
602
+ if not self._clients:
603
+ self._init_clients()
604
+
605
+ # Connect transports
606
+ await gather(*[client.connect() for _, client in self._clients])
439
607
 
440
608
  comm_task = create_task(self._manage_orchestrator_communications())
441
609
 
610
+ token_rotation_task = None
611
+ if self._ssl_context:
612
+ token_rotation_task = create_task(self._manage_token_rotation())
613
+
442
614
  polling_task = create_task(self._start_polling())
443
615
  await self._shutdown_event.wait()
444
616
 
445
617
  for task in [comm_task, polling_task]:
446
618
  task.cancel()
619
+ if token_rotation_task:
620
+ token_rotation_task.cancel()
621
+
447
622
  if self._active_tasks:
448
623
  await gather(*self._active_tasks.values(), return_exceptions=True)
449
624
 
450
- if self._ws_connection and not self._ws_connection.closed:
451
- await self._ws_connection.close()
625
+ # Close transports
626
+ await gather(*[client.close() for _, client in self._clients])
627
+
452
628
  if self._http_session and not self._http_session.closed and not self._session_is_managed_externally:
453
629
  await self._http_session.close()
454
630
 
@@ -459,6 +635,26 @@ class Worker:
459
635
  except KeyboardInterrupt:
460
636
  self._shutdown_event.set()
461
637
 
638
+ async def _manage_token_rotation(self):
639
+ """Periodically refreshes auth tokens for all clients."""
640
+ await sleep(5)
641
+
642
+ while not self._shutdown_event.is_set():
643
+ min_expires_in = 3600
644
+
645
+ for _, client in self._clients:
646
+ try:
647
+ token_resp = await client.refresh_token()
648
+ if token_resp:
649
+ self._config.WORKER_TOKEN = token_resp.access_token
650
+ min_expires_in = min(min_expires_in, token_resp.expires_in)
651
+ except Exception as e:
652
+ logger.error(f"Error in token rotation loop: {e}")
653
+
654
+ refresh_delay = max(60, min_expires_in * 0.8)
655
+ logger.debug(f"Next token refresh scheduled in {refresh_delay:.1f}s")
656
+ await sleep(refresh_delay)
657
+
462
658
  async def _run_health_check_server(self):
463
659
  app = web.Application()
464
660
 
@@ -482,60 +678,27 @@ class Worker:
482
678
  except KeyboardInterrupt:
483
679
  self._shutdown_event.set()
484
680
 
485
- # WebSocket methods omitted for brevity as they are not relevant to the changes
486
681
  async def _start_websocket_manager(self):
487
- """Manages the WebSocket connection to the orchestrator."""
488
- while not self._shutdown_event.is_set():
489
- for orchestrator in self._config.ORCHESTRATORS:
490
- ws_url = orchestrator["url"].replace("http", "ws", 1) + "/_worker/ws"
491
- try:
492
- if self._http_session:
493
- async with self._http_session.ws_connect(ws_url, headers=self._get_headers(orchestrator)) as ws:
494
- self._ws_connection = ws
495
- logger.info(f"WebSocket connection established to {ws_url}")
496
- await self._listen_for_commands()
497
- except (ClientError, AsyncTimeoutError) as e:
498
- logger.warning(f"WebSocket connection to {ws_url} failed: {e}")
499
- finally:
500
- self._ws_connection = None
501
- logger.info(f"WebSocket connection to {ws_url} closed.")
502
- await sleep(5) # Reconnection delay
503
- if not self._config.ORCHESTRATORS:
504
- await sleep(5)
505
-
506
- async def _listen_for_commands(self):
507
- """Listens for and processes commands from the orchestrator via WebSocket."""
508
- if not self._ws_connection:
509
- return
682
+ """Manages the command listeners."""
683
+ listeners = []
684
+ for _, client in self._clients:
685
+ listeners.append(create_task(self._listen_to_single_transport(client)))
510
686
 
511
- try:
512
- async for msg in self._ws_connection:
513
- if msg.type == WSMsgType.TEXT:
514
- try:
515
- command = msg.json()
516
- if command.get("type") == "cancel_task":
517
- task_id = command.get("task_id")
518
- if task_id in self._active_tasks:
519
- self._active_tasks[task_id].cancel()
520
- logger.info(f"Cancelled task {task_id} by orchestrator command.")
521
- except JSONDecodeError:
522
- logger.warning(f"Received invalid JSON over WebSocket: {msg.data}")
523
- elif msg.type == WSMsgType.ERROR:
524
- break
525
- except Exception as e:
526
- logger.error(f"Error in WebSocket listener: {e}")
687
+ await self._shutdown_event.wait()
527
688
 
528
- async def send_progress(self, task_id: str, job_id: str, progress: float, message: str = ""):
529
- """Sends a progress update to the orchestrator via WebSocket."""
530
- if self._ws_connection and not self._ws_connection.closed:
689
+ for listener in listeners:
690
+ listener.cancel()
691
+
692
+ async def _listen_to_single_transport(self, client: Transport):
693
+ while not self._shutdown_event.is_set():
531
694
  try:
532
- payload = {
533
- "type": "progress_update",
534
- "task_id": task_id,
535
- "job_id": job_id,
536
- "progress": progress,
537
- "message": message,
538
- }
539
- await self._ws_connection.send_json(payload)
695
+ async for command in client.listen_for_commands():
696
+ if command.command == COMMAND_CANCEL_TASK:
697
+ task_id = command.task_id
698
+ job_id = command.job_id
699
+ if task_id in self._active_tasks:
700
+ self._active_tasks[task_id].cancel()
701
+ logger.info(f"Cancelled task {task_id} (Job: {job_id or 'N/A'}) by orchestrator command.")
540
702
  except Exception as e:
541
- logger.warning(f"Could not send progress update for task {task_id}: {e}")
703
+ logger.error(f"Error in command listener: {e}")
704
+ await sleep(5)