avtomatika-worker 1.0a2__py3-none-any.whl → 1.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,25 @@
1
1
  from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
2
2
  from asyncio import TimeoutError as AsyncTimeoutError
3
+ from dataclasses import is_dataclass
4
+ from inspect import Parameter, signature
3
5
  from json import JSONDecodeError
4
6
  from logging import getLogger
7
+ from os.path import join
5
8
  from typing import Any, Callable
6
9
 
7
10
  from aiohttp import ClientError, ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType, web
8
11
 
9
12
  from .config import WorkerConfig
13
+ from .s3 import S3Manager
14
+ from .task_files import TaskFiles
15
+ from .types import INVALID_INPUT_ERROR, PERMANENT_ERROR, TRANSIENT_ERROR, ParamValidationError
16
+
17
+ try:
18
+ from pydantic import BaseModel, ValidationError
19
+
20
+ _PYDANTIC_INSTALLED = True
21
+ except ImportError:
22
+ _PYDANTIC_INSTALLED = False
10
23
 
11
24
  # Logging setup
12
25
  logger = getLogger(__name__)
@@ -26,9 +39,11 @@ class Worker:
26
39
  task_type_limits: dict[str, int] | None = None,
27
40
  http_session: ClientSession | None = None,
28
41
  skill_dependencies: dict[str, list[str]] | None = None,
42
+ config: WorkerConfig | None = None,
29
43
  ):
30
- self._config = WorkerConfig()
31
- self._config.worker_type = worker_type # Allow overriding worker_type
44
+ self._config = config or WorkerConfig()
45
+ self._s3_manager = S3Manager(self._config)
46
+ self._config.WORKER_TYPE = worker_type # Allow overriding worker_type
32
47
  if max_concurrent_tasks is not None:
33
48
  self._config.max_concurrent_tasks = max_concurrent_tasks
34
49
 
@@ -44,12 +59,19 @@ class Worker:
44
59
  self._http_session = http_session
45
60
  self._session_is_managed_externally = http_session is not None
46
61
  self._ws_connection: ClientWebSocketResponse | None = None
47
- self._headers = {"X-Worker-Token": self._config.worker_token}
62
+ # Removed: self._headers = {"X-Worker-Token": self._config.WORKER_TOKEN}
48
63
  self._shutdown_event = Event()
49
64
  self._registered_event = Event()
50
65
  self._round_robin_index = 0
51
66
  self._debounce_task: Task | None = None
52
67
 
68
+ # --- Weighted Round-Robin State ---
69
+ self._total_orchestrator_weight = 0
70
+ if self._config.ORCHESTRATORS:
71
+ for o in self._config.ORCHESTRATORS:
72
+ o["current_weight"] = 0
73
+ self._total_orchestrator_weight += o.get("weight", 1)
74
+
53
75
  def _validate_config(self):
54
76
  """Checks for unused task type limits and warns the user."""
55
77
  registered_task_types = {
@@ -98,7 +120,7 @@ class Worker:
98
120
  """
99
121
  Calculates the current worker state including status and available tasks.
100
122
  """
101
- if self._current_load >= self._config.max_concurrent_tasks:
123
+ if self._current_load >= self._config.MAX_CONCURRENT_TASKS:
102
124
  return {"status": "busy", "supported_tasks": []}
103
125
 
104
126
  supported_tasks = []
@@ -118,9 +140,36 @@ class Worker:
118
140
  status = "idle" if supported_tasks else "busy"
119
141
  return {"status": status, "supported_tasks": supported_tasks}
120
142
 
143
+ def _get_headers(self, orchestrator: dict[str, Any]) -> dict[str, str]:
144
+ """Builds authentication headers for a specific orchestrator."""
145
+ token = orchestrator.get("token", self._config.WORKER_TOKEN)
146
+ return {"X-Worker-Token": token}
147
+
148
+ def _get_next_orchestrator(self) -> dict[str, Any] | None:
149
+ """
150
+ Selects the next orchestrator using a smooth weighted round-robin algorithm.
151
+ """
152
+ if not self._config.ORCHESTRATORS:
153
+ return None
154
+
155
+ # The orchestrator with the highest current_weight is selected.
156
+ selected_orchestrator = None
157
+ highest_weight = -1
158
+
159
+ for o in self._config.ORCHESTRATORS:
160
+ o["current_weight"] += o["weight"]
161
+ if o["current_weight"] > highest_weight:
162
+ highest_weight = o["current_weight"]
163
+ selected_orchestrator = o
164
+
165
+ if selected_orchestrator:
166
+ selected_orchestrator["current_weight"] -= self._total_orchestrator_weight
167
+
168
+ return selected_orchestrator
169
+
121
170
  async def _debounced_heartbeat_sender(self):
122
171
  """Waits for the debounce delay then sends a heartbeat."""
123
- await sleep(self._config.heartbeat_debounce_delay)
172
+ await sleep(self._config.HEARTBEAT_DEBOUNCE_DELAY)
124
173
  await self._send_heartbeats_to_all()
125
174
 
126
175
  def _schedule_heartbeat_debounce(self):
@@ -131,86 +180,166 @@ class Worker:
131
180
  # Schedule the new debounced call.
132
181
  self._debounce_task = create_task(self._debounced_heartbeat_sender())
133
182
 
134
- async def _poll_for_tasks(self, orchestrator_url: str):
183
+ async def _poll_for_tasks(self, orchestrator: dict[str, Any]):
135
184
  """Polls a specific Orchestrator for new tasks."""
136
- url = f"{orchestrator_url}/_worker/workers/{self._config.worker_id}/tasks/next"
185
+ url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}/tasks/next"
137
186
  try:
138
187
  if not self._http_session:
139
188
  return
140
- timeout = ClientTimeout(total=self._config.task_poll_timeout + 5)
141
- async with self._http_session.get(url, headers=self._headers, timeout=timeout) as resp:
189
+ timeout = ClientTimeout(total=self._config.TASK_POLL_TIMEOUT + 5)
190
+ headers = self._get_headers(orchestrator)
191
+ async with self._http_session.get(url, headers=headers, timeout=timeout) as resp:
142
192
  if resp.status == 200:
143
193
  task_data = await resp.json()
144
- task_data["orchestrator_url"] = orchestrator_url
194
+ task_data["orchestrator"] = orchestrator
145
195
 
146
196
  self._current_load += 1
147
- task_handler_info = self._task_handlers.get(task_data["type"])
148
- if task_handler_info:
149
- task_type_for_limit = task_handler_info.get("type")
150
- if task_type_for_limit:
151
- self._current_load_by_type[task_type_for_limit] += 1
197
+ if (task_handler_info := self._task_handlers.get(task_data["type"])) and (
198
+ task_type_for_limit := task_handler_info.get("type")
199
+ ):
200
+ self._current_load_by_type[task_type_for_limit] += 1
152
201
  self._schedule_heartbeat_debounce()
153
202
 
154
203
  task = create_task(self._process_task(task_data))
155
204
  self._active_tasks[task_data["task_id"]] = task
156
205
  elif resp.status != 204:
157
- await sleep(self._config.task_poll_error_delay)
206
+ await sleep(self._config.TASK_POLL_ERROR_DELAY)
158
207
  except (AsyncTimeoutError, ClientError) as e:
159
208
  logger.error(f"Error polling for tasks: {e}")
160
- await sleep(self._config.task_poll_error_delay)
209
+ await sleep(self._config.TASK_POLL_ERROR_DELAY)
161
210
 
162
211
  async def _start_polling(self):
163
- print("Waiting for registration")
164
212
  """The main loop for polling tasks."""
165
213
  await self._registered_event.wait()
166
- print("Polling started")
214
+
167
215
  while not self._shutdown_event.is_set():
168
216
  if self._get_current_state()["status"] == "busy":
169
- await sleep(self._config.idle_poll_delay)
217
+ await sleep(self._config.IDLE_POLL_DELAY)
170
218
  continue
171
219
 
172
- if self._config.multi_orchestrator_mode == "ROUND_ROBIN":
173
- orchestrator = self._config.orchestrators[self._round_robin_index]
174
- await self._poll_for_tasks(orchestrator["url"])
175
- self._round_robin_index = (self._round_robin_index + 1) % len(self._config.orchestrators)
220
+ if self._config.MULTI_ORCHESTRATOR_MODE == "ROUND_ROBIN":
221
+ if orchestrator := self._get_next_orchestrator():
222
+ await self._poll_for_tasks(orchestrator)
176
223
  else:
177
- for orchestrator in self._config.orchestrators:
224
+ for orchestrator in self._config.ORCHESTRATORS:
178
225
  if self._get_current_state()["status"] == "busy":
179
226
  break
180
- await self._poll_for_tasks(orchestrator["url"])
227
+ await self._poll_for_tasks(orchestrator)
181
228
 
182
229
  if self._current_load == 0:
183
- await sleep(self._config.idle_poll_delay)
230
+ await sleep(self._config.IDLE_POLL_DELAY)
231
+
232
+ @staticmethod
233
+ def _prepare_task_params(handler: Callable, params: dict[str, Any]) -> Any:
234
+ """
235
+ Inspects the handler's signature to validate and instantiate params.
236
+ Supports dict, dataclasses, and optional pydantic models.
237
+ """
238
+ sig = signature(handler)
239
+ params_annotation = sig.parameters.get("params").annotation
240
+
241
+ if params_annotation is sig.empty or params_annotation is dict:
242
+ return params
243
+
244
+ # Pydantic Model Validation
245
+ if _PYDANTIC_INSTALLED and isinstance(params_annotation, type) and issubclass(params_annotation, BaseModel):
246
+ try:
247
+ return params_annotation.model_validate(params)
248
+ except ValidationError as e:
249
+ raise ParamValidationError(str(e)) from e
250
+
251
+ # Dataclass Instantiation
252
+ if isinstance(params_annotation, type) and is_dataclass(params_annotation):
253
+ try:
254
+ # Filter unknown fields to prevent TypeError on dataclass instantiation
255
+ known_fields = {f.name for f in params_annotation.__dataclass_fields__.values()}
256
+ filtered_params = {k: v for k, v in params.items() if k in known_fields}
257
+
258
+ # Explicitly check for missing required fields
259
+ required_fields = [
260
+ f.name
261
+ for f in params_annotation.__dataclass_fields__.values()
262
+ if f.default is Parameter.empty and f.default_factory is Parameter.empty
263
+ ]
264
+
265
+ if missing_fields := [f for f in required_fields if f not in filtered_params]:
266
+ raise ParamValidationError(f"Missing required fields for dataclass: {', '.join(missing_fields)}")
267
+
268
+ return params_annotation(**filtered_params)
269
+ except (TypeError, ValueError) as e:
270
+ # TypeError for missing/extra args, ValueError from __post_init__
271
+ raise ParamValidationError(str(e)) from e
272
+
273
+ return params
274
+
275
+ def _prepare_dependencies(self, handler: Callable, task_id: str) -> dict[str, Any]:
276
+ """Injects dependencies based on type hints."""
277
+ deps = {}
278
+ task_dir = join(self._config.TASK_FILES_DIR, task_id)
279
+ # Always create the object, but directory is lazy
280
+ task_files = TaskFiles(task_dir)
281
+
282
+ sig = signature(handler)
283
+ for name, param in sig.parameters.items():
284
+ if param.annotation is TaskFiles:
285
+ deps[name] = task_files
286
+
287
+ return deps
184
288
 
185
289
  async def _process_task(self, task_data: dict[str, Any]):
186
290
  """Executes the task logic."""
187
291
  task_id, job_id, task_name = task_data["task_id"], task_data["job_id"], task_data["type"]
188
- params, orchestrator_url = task_data.get("params", {}), task_data["orchestrator_url"]
292
+ params, orchestrator = task_data.get("params", {}), task_data["orchestrator"]
189
293
 
190
294
  result: dict[str, Any] = {}
191
295
  handler_data = self._task_handlers.get(task_name)
192
296
  task_type_for_limit = handler_data.get("type") if handler_data else None
193
297
 
298
+ result_sent = False # Flag to track if result has been sent
299
+
194
300
  try:
195
- if handler_data:
196
- result = await handler_data["func"](
197
- params,
198
- task_id=task_id,
199
- job_id=job_id,
200
- priority=task_data.get("priority", 0),
201
- send_progress=self.send_progress,
202
- add_to_hot_cache=self.add_to_hot_cache,
203
- remove_from_hot_cache=self.remove_from_hot_cache,
204
- )
205
- else:
206
- result = {"status": "failure", "error_message": f"Unsupported task: {task_name}"}
301
+ if not handler_data:
302
+ message = f"Unsupported task: {task_name}"
303
+ logger.warning(message)
304
+ result = {"status": "failure", "error": {"code": PERMANENT_ERROR, "message": message}}
305
+ payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
306
+ await self._send_result(payload, orchestrator)
307
+ result_sent = True # Mark result as sent
308
+ return
309
+
310
+ params = await self._s3_manager.process_params(params, task_id)
311
+ validated_params = self._prepare_task_params(handler_data["func"], params)
312
+ deps = self._prepare_dependencies(handler_data["func"], task_id)
313
+
314
+ result = await handler_data["func"](
315
+ validated_params,
316
+ task_id=task_id,
317
+ job_id=job_id,
318
+ priority=task_data.get("priority", 0),
319
+ send_progress=self.send_progress,
320
+ add_to_hot_cache=self.add_to_hot_cache,
321
+ remove_from_hot_cache=self.remove_from_hot_cache,
322
+ **deps,
323
+ )
324
+ result = await self._s3_manager.process_result(result)
325
+ except ParamValidationError as e:
326
+ logger.error(f"Task {task_id} failed validation: {e}")
327
+ result = {"status": "failure", "error": {"code": INVALID_INPUT_ERROR, "message": str(e)}}
207
328
  except CancelledError:
329
+ logger.info(f"Task {task_id} was cancelled.")
208
330
  result = {"status": "cancelled"}
331
+ # We must re-raise the exception to be handled by the outer gather
332
+ raise
209
333
  except Exception as e:
210
- result = {"status": "failure", "error": {"code": "TRANSIENT_ERROR", "message": str(e)}}
334
+ logger.exception(f"An unexpected error occurred while processing task {task_id}:")
335
+ result = {"status": "failure", "error": {"code": TRANSIENT_ERROR, "message": str(e)}}
211
336
  finally:
212
- payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.worker_id, "result": result}
213
- await self._send_result(payload, orchestrator_url)
337
+ # Cleanup task workspace
338
+ await self._s3_manager.cleanup(task_id)
339
+
340
+ if not result_sent: # Only send if not already sent
341
+ payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
342
+ await self._send_result(payload, orchestrator)
214
343
  self._active_tasks.pop(task_id, None)
215
344
 
216
345
  self._current_load -= 1
@@ -218,14 +347,15 @@ class Worker:
218
347
  self._current_load_by_type[task_type_for_limit] -= 1
219
348
  self._schedule_heartbeat_debounce()
220
349
 
221
- async def _send_result(self, payload: dict[str, Any], orchestrator_url: str):
350
+ async def _send_result(self, payload: dict[str, Any], orchestrator: dict[str, Any]):
222
351
  """Sends the result to a specific orchestrator."""
223
- url = f"{orchestrator_url}/_worker/tasks/result"
224
- delay = self._config.result_retry_initial_delay
225
- for i in range(self._config.result_max_retries):
352
+ url = f"{orchestrator['url']}/_worker/tasks/result"
353
+ delay = self._config.RESULT_RETRY_INITIAL_DELAY
354
+ headers = self._get_headers(orchestrator)
355
+ for i in range(self._config.RESULT_MAX_RETRIES):
226
356
  try:
227
357
  if self._http_session and not self._http_session.closed:
228
- async with self._http_session.post(url, json=payload, headers=self._headers) as resp:
358
+ async with self._http_session.post(url, json=payload, headers=headers) as resp:
229
359
  if resp.status == 200:
230
360
  return
231
361
  except ClientError as e:
@@ -233,43 +363,44 @@ class Worker:
233
363
  await sleep(delay * (2**i))
234
364
 
235
365
  async def _manage_orchestrator_communications(self):
236
- print("Registering worker")
237
366
  """Registers the worker and sends heartbeats."""
238
367
  await self._register_with_all_orchestrators()
239
- print("Worker registered")
368
+
240
369
  self._registered_event.set()
241
- if self._config.enable_websockets:
370
+ if self._config.ENABLE_WEBSOCKETS:
242
371
  create_task(self._start_websocket_manager())
243
372
 
244
373
  while not self._shutdown_event.is_set():
245
374
  await self._send_heartbeats_to_all()
246
- await sleep(self._config.heartbeat_interval)
375
+ await sleep(self._config.HEARTBEAT_INTERVAL)
247
376
 
248
377
  async def _register_with_all_orchestrators(self):
249
378
  """Registers the worker with all orchestrators."""
250
379
  state = self._get_current_state()
251
380
  payload = {
252
- "worker_id": self._config.worker_id,
253
- "worker_type": self._config.worker_type,
381
+ "worker_id": self._config.WORKER_ID,
382
+ "worker_type": self._config.WORKER_TYPE,
254
383
  "supported_tasks": state["supported_tasks"],
255
- "max_concurrent_tasks": self._config.max_concurrent_tasks,
256
- "installed_models": self._config.installed_models,
257
- "hostname": self._config.hostname,
258
- "ip_address": self._config.ip_address,
259
- "resources": self._config.resources,
384
+ "max_concurrent_tasks": self._config.MAX_CONCURRENT_TASKS,
385
+ "cost_per_skill": self._config.COST_PER_SKILL,
386
+ "installed_models": self._config.INSTALLED_MODELS,
387
+ "hostname": self._config.HOSTNAME,
388
+ "ip_address": self._config.IP_ADDRESS,
389
+ "resources": self._config.RESOURCES,
260
390
  }
261
- for orchestrator in self._config.orchestrators:
391
+ for orchestrator in self._config.ORCHESTRATORS:
262
392
  url = f"{orchestrator['url']}/_worker/workers/register"
263
393
  try:
264
394
  if self._http_session:
265
- async with self._http_session.post(url, json=payload, headers=self._headers) as resp:
395
+ async with self._http_session.post(
396
+ url, json=payload, headers=self._get_headers(orchestrator)
397
+ ) as resp:
266
398
  if resp.status >= 400:
267
399
  logger.error(f"Error registering with {orchestrator['url']}: {resp.status}")
268
400
  except ClientError as e:
269
401
  logger.error(f"Error registering with orchestrator {orchestrator['url']}: {e}")
270
402
 
271
403
  async def _send_heartbeats_to_all(self):
272
- print("Sending heartbeats")
273
404
  """Sends heartbeat messages to all orchestrators."""
274
405
  state = self._get_current_state()
275
406
  payload = {
@@ -287,27 +418,27 @@ class Worker:
287
418
  if hot_skills:
288
419
  payload["hot_skills"] = hot_skills
289
420
 
290
- async def _send_single(orchestrator_url: str):
291
- url = f"{orchestrator_url}/_worker/workers/{self._config.worker_id}"
421
+ async def _send_single(orchestrator: dict[str, Any]):
422
+ url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}"
423
+ headers = self._get_headers(orchestrator)
292
424
  try:
293
425
  if self._http_session and not self._http_session.closed:
294
- async with self._http_session.patch(url, json=payload, headers=self._headers) as resp:
426
+ async with self._http_session.patch(url, json=payload, headers=headers) as resp:
295
427
  if resp.status >= 400:
296
- logger.warning(f"Heartbeat to {orchestrator_url} failed with status: {resp.status}")
428
+ logger.warning(f"Heartbeat to {orchestrator['url']} failed with status: {resp.status}")
297
429
  except ClientError as e:
298
- logger.error(f"Error sending heartbeat to orchestrator {orchestrator_url}: {e}")
430
+ logger.error(f"Error sending heartbeat to orchestrator {orchestrator['url']}: {e}")
299
431
 
300
- await gather(*[_send_single(o["url"]) for o in self._config.orchestrators])
432
+ await gather(*[_send_single(o) for o in self._config.ORCHESTRATORS])
301
433
 
302
434
  async def main(self):
303
- print("Main started")
304
435
  """The main asynchronous function."""
305
436
  self._validate_config() # Validate config now that all tasks are registered
306
437
  if not self._http_session:
307
438
  self._http_session = ClientSession()
308
- print("Starting comm task")
439
+
309
440
  comm_task = create_task(self._manage_orchestrator_communications())
310
- print("Starting polling task")
441
+
311
442
  polling_task = create_task(self._start_polling())
312
443
  await self._shutdown_event.wait()
313
444
 
@@ -327,14 +458,17 @@ class Worker:
327
458
  run(self.main())
328
459
  except KeyboardInterrupt:
329
460
  self._shutdown_event.set()
330
- run(sleep(1.5))
331
461
 
332
462
  async def _run_health_check_server(self):
333
463
  app = web.Application()
334
- app.router.add_get("/health", lambda r: web.Response(text="OK"))
464
+
465
+ async def health_handler(_):
466
+ return web.Response(text="OK")
467
+
468
+ app.router.add_get("/health", health_handler)
335
469
  runner = web.AppRunner(app)
336
470
  await runner.setup()
337
- site = web.TCPSite(runner, "0.0.0.0", self._config.worker_port)
471
+ site = web.TCPSite(runner, "0.0.0.0", self._config.WORKER_PORT)
338
472
  await site.start()
339
473
  await self._shutdown_event.wait()
340
474
  await runner.cleanup()
@@ -347,17 +481,16 @@ class Worker:
347
481
  run(_main_wrapper())
348
482
  except KeyboardInterrupt:
349
483
  self._shutdown_event.set()
350
- run(sleep(1.5))
351
484
 
352
485
  # WebSocket methods omitted for brevity as they are not relevant to the changes
353
486
  async def _start_websocket_manager(self):
354
487
  """Manages the WebSocket connection to the orchestrator."""
355
488
  while not self._shutdown_event.is_set():
356
- for orchestrator in self._config.orchestrators:
489
+ for orchestrator in self._config.ORCHESTRATORS:
357
490
  ws_url = orchestrator["url"].replace("http", "ws", 1) + "/_worker/ws"
358
491
  try:
359
492
  if self._http_session:
360
- async with self._http_session.ws_connect(ws_url, headers=self._headers) as ws:
493
+ async with self._http_session.ws_connect(ws_url, headers=self._get_headers(orchestrator)) as ws:
361
494
  self._ws_connection = ws
362
495
  logger.info(f"WebSocket connection established to {ws_url}")
363
496
  await self._listen_for_commands()
@@ -367,7 +500,7 @@ class Worker:
367
500
  self._ws_connection = None
368
501
  logger.info(f"WebSocket connection to {ws_url} closed.")
369
502
  await sleep(5) # Reconnection delay
370
- if not self._config.orchestrators:
503
+ if not self._config.ORCHESTRATORS:
371
504
  await sleep(5)
372
505
 
373
506
  async def _listen_for_commands(self):