avtomatika-worker 1.0b3__py3-none-any.whl → 1.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,26 +1,46 @@
1
- from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
1
+ from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep, to_thread
2
2
  from dataclasses import is_dataclass
3
- from inspect import Parameter, signature
4
- from json import JSONDecodeError
3
+ from inspect import Parameter, iscoroutinefunction, signature
5
4
  from logging import getLogger
6
5
  from os.path import join
7
6
  from typing import Any, Callable
8
7
 
9
- from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType, web
10
-
11
- from .client import OrchestratorClient
12
- from .config import WorkerConfig
13
- from .constants import (
8
+ from aiohttp import ClientSession, TCPConnector, web
9
+ from rxon import Transport, create_transport
10
+ from rxon.blob import calculate_config_hash
11
+ from rxon.constants import (
14
12
  COMMAND_CANCEL_TASK,
13
+ ERROR_CODE_INTEGRITY_MISMATCH,
15
14
  ERROR_CODE_INVALID_INPUT,
16
15
  ERROR_CODE_PERMANENT,
17
16
  ERROR_CODE_TRANSIENT,
17
+ MSG_TYPE_PROGRESS,
18
18
  TASK_STATUS_CANCELLED,
19
19
  TASK_STATUS_FAILURE,
20
+ TASK_STATUS_SUCCESS,
21
+ )
22
+ from rxon.exceptions import RxonError
23
+ from rxon.models import (
24
+ FileMetadata,
25
+ GPUInfo,
26
+ Heartbeat,
27
+ InstalledModel,
28
+ ProgressUpdatePayload,
29
+ Resources,
30
+ TaskError,
31
+ TaskPayload,
32
+ TaskResult,
33
+ WorkerCapabilities,
34
+ WorkerRegistration,
20
35
  )
36
+ from rxon.security import create_client_ssl_context
37
+ from rxon.utils import to_dict
38
+ from rxon.validators import validate_identifier
39
+
40
+ from .config import WorkerConfig
21
41
  from .s3 import S3Manager
22
42
  from .task_files import TaskFiles
23
- from .types import ParamValidationError
43
+ from .types import CapacityChecker, Middleware, ParamValidationError
24
44
 
25
45
  try:
26
46
  from pydantic import BaseModel, ValidationError
@@ -37,7 +57,7 @@ class Worker:
37
57
  """The main class for creating and running a worker.
38
58
  Implements a hybrid interaction model with the Orchestrator:
39
59
  - PULL model for fetching tasks.
40
- - WebSocket for real-time commands (cancellation) and sending progress.
60
+ - Transport layer for real-time commands (cancellation) and sending progress.
41
61
  """
42
62
 
43
63
  def __init__(
@@ -48,6 +68,8 @@ class Worker:
48
68
  http_session: ClientSession | None = None,
49
69
  skill_dependencies: dict[str, list[str]] | None = None,
50
70
  config: WorkerConfig | None = None,
71
+ capacity_checker: CapacityChecker | None = None,
72
+ clients: list[tuple[dict[str, Any], Transport]] | None = None,
51
73
  ):
52
74
  self._config = config or WorkerConfig()
53
75
  self._s3_manager = S3Manager(self._config)
@@ -58,6 +80,8 @@ class Worker:
58
80
  self._task_type_limits = task_type_limits or {}
59
81
  self._task_handlers: dict[str, dict[str, Any]] = {}
60
82
  self._skill_dependencies = skill_dependencies or {}
83
+ self._middlewares: list[Middleware] = []
84
+ self._capacity_checker = capacity_checker
61
85
 
62
86
  # Worker state
63
87
  self._current_load = 0
@@ -66,10 +90,10 @@ class Worker:
66
90
  self._active_tasks: dict[str, Task] = {}
67
91
  self._http_session = http_session
68
92
  self._session_is_managed_externally = http_session is not None
69
- self._ws_connection: ClientWebSocketResponse | None = None
70
93
  self._shutdown_event = Event()
71
94
  self._registered_event = Event()
72
95
  self._debounce_task: Task | None = None
96
+ self._ssl_context = None
73
97
 
74
98
  # --- Weighted Round-Robin State ---
75
99
  self._total_orchestrator_weight = 0
@@ -77,23 +101,31 @@ class Worker:
77
101
  for o in self._config.ORCHESTRATORS:
78
102
  o["current_weight"] = 0
79
103
  self._total_orchestrator_weight += o.get("weight", 1)
80
-
81
- self._clients: list[tuple[dict[str, Any], OrchestratorClient]] = []
82
- if self._http_session:
104
+ self._clients = clients or []
105
+ if not self._clients and self._http_session:
83
106
  self._init_clients()
84
107
 
108
+ def add_middleware(self, middleware: Middleware) -> None:
109
+ """Adds a middleware to the execution chain."""
110
+ self._middlewares.append(middleware)
111
+
85
112
  def _init_clients(self):
86
- """Initializes OrchestratorClient instances for each configured orchestrator."""
87
- if not self._http_session:
88
- return
113
+ """Initializes Transport instances for each configured orchestrator."""
114
+ # Even if we don't have an external session, we might create transports
115
+ # that will create their own sessions. But if we want to share one, we need it here.
116
+ session_to_use = self._http_session if self._session_is_managed_externally else None
117
+
89
118
  self._clients = [
90
119
  (
91
120
  o,
92
- OrchestratorClient(
93
- session=self._http_session,
94
- base_url=o["url"],
121
+ create_transport(
122
+ url=o["url"],
95
123
  worker_id=self._config.WORKER_ID,
96
124
  token=o.get("token", self._config.WORKER_TOKEN),
125
+ ssl_context=self._ssl_context,
126
+ session=session_to_use,
127
+ result_retries=self._config.RESULT_MAX_RETRIES,
128
+ result_retry_delay=self._config.RESULT_RETRY_INITIAL_DELAY,
97
129
  ),
98
130
  )
99
131
  for o in self._config.ORCHESTRATORS
@@ -114,6 +146,9 @@ class Worker:
114
146
 
115
147
  def task(self, name: str, task_type: str | None = None) -> Callable:
116
148
  """Decorator to register a function as a task handler."""
149
+ validate_identifier(name, "task name")
150
+ if task_type:
151
+ validate_identifier(task_type, "task type")
117
152
 
118
153
  def decorator(func: Callable) -> Callable:
119
154
  logger.info(f"Registering task: '{name}' (type: {task_type or 'N/A'})")
@@ -164,10 +199,13 @@ class Worker:
164
199
  if is_available:
165
200
  supported_tasks.append(name)
166
201
 
202
+ if self._capacity_checker:
203
+ supported_tasks = [task for task in supported_tasks if self._capacity_checker(task)]
204
+
167
205
  status = "idle" if supported_tasks else "busy"
168
206
  return {"status": status, "supported_tasks": supported_tasks}
169
207
 
170
- def _get_next_client(self) -> OrchestratorClient | None:
208
+ def _get_next_client(self) -> Transport | None:
171
209
  """
172
210
  Selects the next orchestrator client using a smooth weighted round-robin algorithm.
173
211
  """
@@ -206,27 +244,29 @@ class Worker:
206
244
  # Schedule the new debounced call.
207
245
  self._debounce_task = create_task(self._debounced_heartbeat_sender())
208
246
 
209
- async def _poll_for_tasks(self, client: OrchestratorClient):
247
+ async def _poll_for_tasks(self, client: Transport):
210
248
  """Polls a specific Orchestrator for new tasks."""
211
- task_data = await client.poll_task(timeout=self._config.TASK_POLL_TIMEOUT)
212
- if task_data:
213
- task_data["client"] = client
214
-
215
- self._current_load += 1
216
- if (task_handler_info := self._task_handlers.get(task_data["type"])) and (
217
- task_type_for_limit := task_handler_info.get("type")
218
- ):
219
- self._current_load_by_type[task_type_for_limit] += 1
220
- self._schedule_heartbeat_debounce()
249
+ current_state = self._get_current_state()
250
+ if current_state["status"] == "busy":
251
+ return
221
252
 
222
- task = create_task(self._process_task(task_data))
223
- self._active_tasks[task_data["task_id"]] = task
224
- else:
225
- # If no task but it was a 204 or error, the client already handled/logged it.
226
- # We might want a short sleep here if it was an error, but client.poll_task
227
- # doesn't distinguish between 204 and error currently.
228
- # However, the previous logic only slept on status != 204.
229
- pass
253
+ try:
254
+ task_data = await client.poll_task(timeout=self._config.TASK_POLL_TIMEOUT)
255
+ if task_data:
256
+ task_data_dict = to_dict(task_data)
257
+ task_data_dict["client"] = client
258
+
259
+ self._current_load += 1
260
+ if (task_handler_info := self._task_handlers.get(task_data.type)) and (
261
+ task_type_for_limit := task_handler_info.get("type")
262
+ ):
263
+ self._current_load_by_type[task_type_for_limit] += 1
264
+ self._schedule_heartbeat_debounce()
265
+
266
+ task = create_task(self._process_task(task_data_dict))
267
+ self._active_tasks[task_data.task_id] = task
268
+ except RxonError as e:
269
+ logger.error(f"Error polling tasks: {e}")
230
270
 
231
271
  async def _start_polling(self):
232
272
  """The main loop for polling tasks."""
@@ -253,7 +293,6 @@ class Worker:
253
293
  def _prepare_task_params(handler: Callable, params: dict[str, Any]) -> Any:
254
294
  """
255
295
  Inspects the handler's signature to validate and instantiate params.
256
- Supports dict, dataclasses, and optional pydantic models.
257
296
  """
258
297
  sig = signature(handler)
259
298
  params_annotation = sig.parameters.get("params").annotation
@@ -271,11 +310,11 @@ class Worker:
271
310
  # Dataclass Instantiation
272
311
  if isinstance(params_annotation, type) and is_dataclass(params_annotation):
273
312
  try:
274
- # Filter unknown fields to prevent TypeError on dataclass instantiation
313
+ # Filter unknown fields
275
314
  known_fields = {f.name for f in params_annotation.__dataclass_fields__.values()}
276
315
  filtered_params = {k: v for k, v in params.items() if k in known_fields}
277
316
 
278
- # Explicitly check for missing required fields
317
+ # Check required fields
279
318
  required_fields = [
280
319
  f.name
281
320
  for f in params_annotation.__dataclass_fields__.values()
@@ -287,7 +326,6 @@ class Worker:
287
326
 
288
327
  return params_annotation(**filtered_params)
289
328
  except (TypeError, ValueError) as e:
290
- # TypeError for missing/extra args, ValueError from __post_init__
291
329
  raise ParamValidationError(str(e)) from e
292
330
 
293
331
  return params
@@ -296,8 +334,7 @@ class Worker:
296
334
  """Injects dependencies based on type hints."""
297
335
  deps = {}
298
336
  task_dir = join(self._config.TASK_FILES_DIR, task_id)
299
- # Always create the object, but directory is lazy
300
- task_files = TaskFiles(task_dir)
337
+ task_files = TaskFiles(task_dir, s3_manager=self._s3_manager)
301
338
 
302
339
  sig = signature(handler)
303
340
  for name, param in sig.parameters.items():
@@ -306,66 +343,150 @@ class Worker:
306
343
 
307
344
  return deps
308
345
 
309
- async def _process_task(self, task_data: dict[str, Any]):
346
+ async def _process_task(self, task_data_raw: dict[str, Any]):
310
347
  """Executes the task logic."""
311
- task_id, job_id, task_name = task_data["task_id"], task_data["job_id"], task_data["type"]
312
- params, client = task_data.get("params", {}), task_data["client"]
348
+ client: Transport = task_data_raw.pop("client")
349
+
350
+ # Parse incoming task data using protocol model
351
+ if "params_metadata" in task_data_raw and task_data_raw["params_metadata"]:
352
+ task_data_raw["params_metadata"] = {
353
+ k: self._from_dict(FileMetadata, v) for k, v in task_data_raw["params_metadata"].items()
354
+ }
355
+
356
+ task_payload = self._from_dict(TaskPayload, task_data_raw)
357
+ task_id, job_id, task_name = task_payload.task_id, task_payload.job_id, task_payload.type
358
+ params = task_payload.params
313
359
 
314
- result: dict[str, Any] = {}
315
360
  handler_data = self._task_handlers.get(task_name)
316
361
  task_type_for_limit = handler_data.get("type") if handler_data else None
317
362
 
318
- result_sent = False # Flag to track if result has been sent
363
+ result_obj: TaskResult | None = None
364
+
365
+ # Create a progress sender wrapper attached to this specific client
366
+ async def send_progress_wrapper(task_id_arg, job_id_arg, progress, message=""):
367
+ payload = ProgressUpdatePayload(
368
+ event=MSG_TYPE_PROGRESS, task_id=task_id_arg, job_id=job_id_arg, progress=progress, message=message
369
+ )
370
+ await client.send_progress(payload)
319
371
 
320
372
  try:
321
373
  if not handler_data:
322
374
  message = f"Unsupported task: {task_name}"
323
375
  logger.warning(message)
324
- result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_PERMANENT, "message": message}}
325
- payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
326
- await client.send_result(
327
- payload, self._config.RESULT_MAX_RETRIES, self._config.RESULT_RETRY_INITIAL_DELAY
376
+ error = TaskError(code=ERROR_CODE_PERMANENT, message=message)
377
+ result_obj = TaskResult(
378
+ job_id=job_id,
379
+ task_id=task_id,
380
+ worker_id=self._config.WORKER_ID,
381
+ status=TASK_STATUS_FAILURE,
382
+ error=error,
328
383
  )
329
- result_sent = True # Mark result as sent
330
- return
384
+ else:
385
+ # Download files
386
+ params = await self._s3_manager.process_params(params, task_id, metadata=task_payload.params_metadata)
387
+ validated_params = self._prepare_task_params(handler_data["func"], params)
388
+ deps = self._prepare_dependencies(handler_data["func"], task_id)
389
+
390
+ handler_kwargs = {
391
+ "params": validated_params,
392
+ "task_id": task_id,
393
+ "job_id": job_id,
394
+ "tracing_context": task_payload.tracing_context,
395
+ "priority": task_data_raw.get("priority", 0),
396
+ "send_progress": send_progress_wrapper,
397
+ "add_to_hot_cache": self.add_to_hot_cache,
398
+ "remove_from_hot_cache": self.remove_from_hot_cache,
399
+ **deps,
400
+ }
331
401
 
332
- params = await self._s3_manager.process_params(params, task_id)
333
- validated_params = self._prepare_task_params(handler_data["func"], params)
334
- deps = self._prepare_dependencies(handler_data["func"], task_id)
402
+ middleware_context = {
403
+ "task_id": task_id,
404
+ "job_id": job_id,
405
+ "task_name": task_name,
406
+ "params": validated_params,
407
+ "handler_kwargs": handler_kwargs,
408
+ }
409
+
410
+ async def _execution_logic():
411
+ handler = handler_data["func"]
412
+ final_kwargs = middleware_context["handler_kwargs"]
413
+
414
+ if iscoroutinefunction(handler):
415
+ return await handler(**final_kwargs)
416
+ else:
417
+ return await to_thread(handler, **final_kwargs)
418
+
419
+ handler_chain = _execution_logic
420
+ for middleware in reversed(self._middlewares):
421
+
422
+ def make_wrapper(mw: Middleware, next_handler: Callable) -> Callable:
423
+ async def wrapper():
424
+ return await mw(middleware_context, next_handler)
425
+
426
+ return wrapper
427
+
428
+ handler_chain = make_wrapper(middleware, handler_chain)
429
+
430
+ handler_result = await handler_chain()
431
+
432
+ updated_data, metadata_map = await self._s3_manager.process_result(handler_result.get("data", {}))
433
+
434
+ result_obj = TaskResult(
435
+ job_id=job_id,
436
+ task_id=task_id,
437
+ worker_id=self._config.WORKER_ID,
438
+ status=handler_result.get("status", TASK_STATUS_SUCCESS),
439
+ data=updated_data,
440
+ error=TaskError(**handler_result["error"]) if "error" in handler_result else None,
441
+ data_metadata=metadata_map if metadata_map else None,
442
+ )
335
443
 
336
- result = await handler_data["func"](
337
- validated_params,
338
- task_id=task_id,
339
- job_id=job_id,
340
- priority=task_data.get("priority", 0),
341
- send_progress=self.send_progress,
342
- add_to_hot_cache=self.add_to_hot_cache,
343
- remove_from_hot_cache=self.remove_from_hot_cache,
344
- **deps,
345
- )
346
- result = await self._s3_manager.process_result(result)
347
444
  except ParamValidationError as e:
348
445
  logger.error(f"Task {task_id} failed validation: {e}")
349
- result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_INVALID_INPUT, "message": str(e)}}
446
+ error = TaskError(code=ERROR_CODE_INVALID_INPUT, message=str(e))
447
+ result_obj = TaskResult(
448
+ job_id=job_id,
449
+ task_id=task_id,
450
+ worker_id=self._config.WORKER_ID,
451
+ status=TASK_STATUS_FAILURE,
452
+ error=error,
453
+ )
350
454
  except CancelledError:
351
455
  logger.info(f"Task {task_id} was cancelled.")
352
- result = {"status": TASK_STATUS_CANCELLED}
353
- # We must re-raise the exception to be handled by the outer gather
456
+ result_obj = TaskResult(
457
+ job_id=job_id, task_id=task_id, worker_id=self._config.WORKER_ID, status=TASK_STATUS_CANCELLED
458
+ )
354
459
  raise
460
+ except ValueError as e:
461
+ logger.error(f"Data integrity or validation error for task {task_id}: {e}")
462
+ error = TaskError(code=ERROR_CODE_INTEGRITY_MISMATCH, message=str(e))
463
+ result_obj = TaskResult(
464
+ job_id=job_id,
465
+ task_id=task_id,
466
+ worker_id=self._config.WORKER_ID,
467
+ status=TASK_STATUS_FAILURE,
468
+ error=error,
469
+ )
355
470
  except Exception as e:
356
471
  logger.exception(f"An unexpected error occurred while processing task {task_id}:")
357
- result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_TRANSIENT, "message": str(e)}}
472
+ error = TaskError(code=ERROR_CODE_TRANSIENT, message=str(e))
473
+ result_obj = TaskResult(
474
+ job_id=job_id,
475
+ task_id=task_id,
476
+ worker_id=self._config.WORKER_ID,
477
+ status=TASK_STATUS_FAILURE,
478
+ error=error,
479
+ )
358
480
  finally:
359
- # Cleanup task workspace
360
481
  await self._s3_manager.cleanup(task_id)
361
482
 
362
- if not result_sent: # Only send if not already sent
363
- payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
364
- await client.send_result(
365
- payload, self._config.RESULT_MAX_RETRIES, self._config.RESULT_RETRY_INITIAL_DELAY
366
- )
367
- self._active_tasks.pop(task_id, None)
483
+ if result_obj:
484
+ try:
485
+ await client.send_result(result_obj)
486
+ except RxonError as e:
487
+ logger.error(f"Failed to send task result: {e}")
368
488
 
489
+ self._active_tasks.pop(task_id, None)
369
490
  self._current_load -= 1
370
491
  if task_type_for_limit:
371
492
  self._current_load_by_type[task_type_for_limit] -= 1
@@ -383,62 +504,127 @@ class Worker:
383
504
  await self._send_heartbeats_to_all()
384
505
  await sleep(self._config.HEARTBEAT_INTERVAL)
385
506
 
507
+ @staticmethod
508
+ def _from_dict(cls: type, data: dict[str, Any]) -> Any:
509
+ """Safely instantiates a NamedTuple from a dict, ignoring unknown fields."""
510
+ if not data:
511
+ return None
512
+ fields = cls._fields
513
+ filtered_data = {k: v for k, v in data.items() if k in fields}
514
+ return cls(**filtered_data)
515
+
386
516
  async def _register_with_all_orchestrators(self):
387
517
  """Registers the worker with all orchestrators."""
388
518
  state = self._get_current_state()
389
- payload = {
390
- "worker_id": self._config.WORKER_ID,
391
- "worker_type": self._config.WORKER_TYPE,
392
- "supported_tasks": state["supported_tasks"],
393
- "max_concurrent_tasks": self._config.MAX_CONCURRENT_TASKS,
394
- "cost_per_skill": self._config.COST_PER_SKILL,
395
- "installed_models": self._config.INSTALLED_MODELS,
396
- "hostname": self._config.HOSTNAME,
397
- "ip_address": self._config.IP_ADDRESS,
398
- "resources": self._config.RESOURCES,
399
- }
400
- await gather(*[client.register(payload) for _, client in self._clients])
519
+
520
+ gpu_info = None
521
+ if self._config.RESOURCES.get("gpu_info"):
522
+ gpu_info = GPUInfo(**self._config.RESOURCES["gpu_info"])
523
+
524
+ resources = Resources(
525
+ max_concurrent_tasks=self._config.MAX_CONCURRENT_TASKS,
526
+ cpu_cores=self._config.RESOURCES["cpu_cores"],
527
+ gpu_info=gpu_info,
528
+ )
529
+
530
+ s3_hash = calculate_config_hash(
531
+ self._config.S3_ENDPOINT_URL,
532
+ self._config.S3_ACCESS_KEY,
533
+ self._config.S3_DEFAULT_BUCKET,
534
+ )
535
+
536
+ registration = WorkerRegistration(
537
+ worker_id=self._config.WORKER_ID,
538
+ worker_type=self._config.WORKER_TYPE,
539
+ supported_tasks=state["supported_tasks"],
540
+ resources=resources,
541
+ installed_software=self._config.INSTALLED_SOFTWARE,
542
+ installed_models=[InstalledModel(**m) for m in self._config.INSTALLED_MODELS],
543
+ capabilities=WorkerCapabilities(
544
+ hostname=self._config.HOSTNAME,
545
+ ip_address=self._config.IP_ADDRESS,
546
+ cost_per_skill=self._config.COST_PER_SKILL,
547
+ s3_config_hash=s3_hash,
548
+ ),
549
+ )
550
+
551
+ await gather(*[self._safe_register(client, registration) for _, client in self._clients])
552
+
553
+ async def _safe_register(self, client: Transport, registration: WorkerRegistration):
554
+ try:
555
+ await client.register(registration)
556
+ except RxonError as e:
557
+ logger.error(f"Registration failed for {client}: {e}")
401
558
 
402
559
  async def _send_heartbeats_to_all(self):
403
560
  """Sends heartbeat messages to all orchestrators."""
404
561
  state = self._get_current_state()
405
- payload = {
406
- "load": self._current_load,
407
- "status": state["status"],
408
- "supported_tasks": state["supported_tasks"],
409
- "hot_cache": list(self._hot_cache),
410
- }
411
562
 
563
+ hot_skills = None
412
564
  if self._skill_dependencies:
413
- payload["skill_dependencies"] = self._skill_dependencies
414
565
  hot_skills = [
415
566
  skill for skill, models in self._skill_dependencies.items() if set(models).issubset(self._hot_cache)
416
567
  ]
417
- if hot_skills:
418
- payload["hot_skills"] = hot_skills
419
568
 
420
- await gather(*[client.send_heartbeat(payload) for _, client in self._clients])
569
+ heartbeat = Heartbeat(
570
+ worker_id=self._config.WORKER_ID,
571
+ status=state["status"],
572
+ load=float(self._current_load),
573
+ current_tasks=list(self._active_tasks.keys()),
574
+ supported_tasks=state["supported_tasks"],
575
+ hot_cache=list(self._hot_cache),
576
+ skill_dependencies=self._skill_dependencies or None,
577
+ hot_skills=hot_skills or None,
578
+ )
579
+
580
+ await gather(*[self._safe_heartbeat(client, heartbeat) for _, client in self._clients])
581
+
582
+ async def _safe_heartbeat(self, client: Transport, heartbeat: Heartbeat):
583
+ try:
584
+ await client.send_heartbeat(heartbeat)
585
+ except RxonError as e:
586
+ logger.warning(f"Heartbeat failed for {client}: {e}")
421
587
 
422
588
  async def main(self):
423
589
  """The main asynchronous function."""
424
590
  self._config.validate()
425
- self._validate_task_types() # Validate config now that all tasks are registered
591
+ self._validate_task_types()
426
592
  if not self._http_session:
427
- self._http_session = ClientSession()
428
- self._init_clients()
593
+ if self._config.TLS_CA_PATH or (self._config.TLS_CERT_PATH and self._config.TLS_KEY_PATH):
594
+ logger.info("Initializing SSL context for mTLS.")
595
+ self._ssl_context = create_client_ssl_context(
596
+ ca_path=self._config.TLS_CA_PATH,
597
+ cert_path=self._config.TLS_CERT_PATH,
598
+ key_path=self._config.TLS_KEY_PATH,
599
+ )
600
+ connector = TCPConnector(ssl=self._ssl_context) if self._ssl_context else None
601
+ self._http_session = ClientSession(connector=connector)
602
+ if not self._clients:
603
+ self._init_clients()
604
+
605
+ # Connect transports
606
+ await gather(*[client.connect() for _, client in self._clients])
429
607
 
430
608
  comm_task = create_task(self._manage_orchestrator_communications())
431
609
 
610
+ token_rotation_task = None
611
+ if self._ssl_context:
612
+ token_rotation_task = create_task(self._manage_token_rotation())
613
+
432
614
  polling_task = create_task(self._start_polling())
433
615
  await self._shutdown_event.wait()
434
616
 
435
617
  for task in [comm_task, polling_task]:
436
618
  task.cancel()
619
+ if token_rotation_task:
620
+ token_rotation_task.cancel()
621
+
437
622
  if self._active_tasks:
438
623
  await gather(*self._active_tasks.values(), return_exceptions=True)
439
624
 
440
- if self._ws_connection and not self._ws_connection.closed:
441
- await self._ws_connection.close()
625
+ # Close transports
626
+ await gather(*[client.close() for _, client in self._clients])
627
+
442
628
  if self._http_session and not self._http_session.closed and not self._session_is_managed_externally:
443
629
  await self._http_session.close()
444
630
 
@@ -449,6 +635,26 @@ class Worker:
449
635
  except KeyboardInterrupt:
450
636
  self._shutdown_event.set()
451
637
 
638
+ async def _manage_token_rotation(self):
639
+ """Periodically refreshes auth tokens for all clients."""
640
+ await sleep(5)
641
+
642
+ while not self._shutdown_event.is_set():
643
+ min_expires_in = 3600
644
+
645
+ for _, client in self._clients:
646
+ try:
647
+ token_resp = await client.refresh_token()
648
+ if token_resp:
649
+ self._config.WORKER_TOKEN = token_resp.access_token
650
+ min_expires_in = min(min_expires_in, token_resp.expires_in)
651
+ except Exception as e:
652
+ logger.error(f"Error in token rotation loop: {e}")
653
+
654
+ refresh_delay = max(60, min_expires_in * 0.8)
655
+ logger.debug(f"Next token refresh scheduled in {refresh_delay:.1f}s")
656
+ await sleep(refresh_delay)
657
+
452
658
  async def _run_health_check_server(self):
453
659
  app = web.Application()
454
660
 
@@ -473,54 +679,26 @@ class Worker:
473
679
  self._shutdown_event.set()
474
680
 
475
681
  async def _start_websocket_manager(self):
476
- """Manages the WebSocket connection to the orchestrator."""
477
- while not self._shutdown_event.is_set():
478
- # In multi-orchestrator mode, we currently only connect to the first one available
479
- for _, client in self._clients:
480
- try:
481
- ws = await client.connect_websocket()
482
- if ws:
483
- self._ws_connection = ws
484
- await self._listen_for_commands()
485
- finally:
486
- self._ws_connection = None
487
- await sleep(5) # Reconnection delay
488
- if not self._clients:
489
- await sleep(5)
682
+ """Manages the command listeners."""
683
+ listeners = []
684
+ for _, client in self._clients:
685
+ listeners.append(create_task(self._listen_to_single_transport(client)))
490
686
 
491
- async def _listen_for_commands(self):
492
- """Listens for and processes commands from the orchestrator via WebSocket."""
493
- if not self._ws_connection:
494
- return
687
+ await self._shutdown_event.wait()
495
688
 
496
- try:
497
- async for msg in self._ws_connection:
498
- if msg.type == WSMsgType.TEXT:
499
- try:
500
- command = msg.json()
501
- if command.get("type") == COMMAND_CANCEL_TASK:
502
- task_id = command.get("task_id")
503
- if task_id in self._active_tasks:
504
- self._active_tasks[task_id].cancel()
505
- logger.info(f"Cancelled task {task_id} by orchestrator command.")
506
- except JSONDecodeError:
507
- logger.warning(f"Received invalid JSON over WebSocket: {msg.data}")
508
- elif msg.type == WSMsgType.ERROR:
509
- break
510
- except Exception as e:
511
- logger.error(f"Error in WebSocket listener: {e}")
689
+ for listener in listeners:
690
+ listener.cancel()
512
691
 
513
- async def send_progress(self, task_id: str, job_id: str, progress: float, message: str = ""):
514
- """Sends a progress update to the orchestrator via WebSocket."""
515
- if self._ws_connection and not self._ws_connection.closed:
692
+ async def _listen_to_single_transport(self, client: Transport):
693
+ while not self._shutdown_event.is_set():
516
694
  try:
517
- payload = {
518
- "type": "progress_update",
519
- "task_id": task_id,
520
- "job_id": job_id,
521
- "progress": progress,
522
- "message": message,
523
- }
524
- await self._ws_connection.send_json(payload)
695
+ async for command in client.listen_for_commands():
696
+ if command.command == COMMAND_CANCEL_TASK:
697
+ task_id = command.task_id
698
+ job_id = command.job_id
699
+ if task_id in self._active_tasks:
700
+ self._active_tasks[task_id].cancel()
701
+ logger.info(f"Cancelled task {task_id} (Job: {job_id or 'N/A'}) by orchestrator command.")
525
702
  except Exception as e:
526
- logger.warning(f"Could not send progress update for task {task_id}: {e}")
703
+ logger.error(f"Error in command listener: {e}")
704
+ await sleep(5)