avtomatika-worker 1.0a2__py3-none-any.whl → 1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,56 +12,63 @@ class WorkerConfig:
12
12
 
13
13
  def __init__(self):
14
14
  # --- Basic worker information ---
15
- self.worker_id: str = getenv("WORKER_ID", f"worker-{uuid4()}")
16
- self.worker_type: str = getenv("WORKER_TYPE", "generic-cpu-worker")
17
- self.worker_port: int = int(getenv("WORKER_PORT", "8083"))
18
- self.hostname: str = gethostname()
15
+ self.WORKER_ID: str = getenv("WORKER_ID", f"worker-{uuid4()}")
16
+ self.WORKER_TYPE: str = getenv("WORKER_TYPE", "generic-cpu-worker")
17
+ self.WORKER_PORT: int = int(getenv("WORKER_PORT", "8083"))
18
+ self.HOSTNAME: str = gethostname()
19
19
  try:
20
- self.ip_address: str = gethostbyname(self.hostname)
20
+ self.IP_ADDRESS: str = gethostbyname(self.HOSTNAME)
21
21
  except gaierror:
22
- self.ip_address: str = "127.0.0.1"
22
+ self.IP_ADDRESS: str = "127.0.0.1"
23
23
 
24
24
  # --- Orchestrator settings ---
25
- self.orchestrators: list[dict[str, Any]] = self._get_orchestrators_config()
25
+ self.ORCHESTRATORS: list[dict[str, Any]] = self._get_orchestrators_config()
26
26
 
27
27
  # --- Security ---
28
- self.worker_token: str = getenv(
28
+ self.WORKER_TOKEN: str = getenv(
29
29
  "WORKER_INDIVIDUAL_TOKEN",
30
30
  getenv("WORKER_TOKEN", "your-secret-worker-token"),
31
31
  )
32
32
 
33
33
  # --- Resources and performance ---
34
- self.cost_per_second: float = float(getenv("WORKER_COST_PER_SECOND", "0.01"))
35
- self.max_concurrent_tasks: int = int(getenv("MAX_CONCURRENT_TASKS", "10"))
36
- self.resources: dict[str, Any] = {
34
+ self.COST_PER_SKILL: dict[str, float] = self._load_json_from_env("COST_PER_SKILL", default={})
35
+ self.MAX_CONCURRENT_TASKS: int = int(getenv("MAX_CONCURRENT_TASKS", "10"))
36
+ self.RESOURCES: dict[str, Any] = {
37
37
  "cpu_cores": int(getenv("CPU_CORES", "4")),
38
38
  "gpu_info": self._get_gpu_info(),
39
39
  }
40
40
 
41
41
  # --- Installed software and models (read as JSON strings) ---
42
- self.installed_software: dict[str, str] = self._load_json_from_env(
42
+ self.INSTALLED_SOFTWARE: dict[str, str] = self._load_json_from_env(
43
43
  "INSTALLED_SOFTWARE",
44
44
  default={"python": "3.9"},
45
45
  )
46
- self.installed_models: list[dict[str, str]] = self._load_json_from_env(
46
+ self.INSTALLED_MODELS: list[dict[str, str]] = self._load_json_from_env(
47
47
  "INSTALLED_MODELS",
48
48
  default=[],
49
49
  )
50
50
 
51
+ # --- S3 Settings for payload offloading ---
52
+ self.WORKER_PAYLOAD_DIR: str = getenv("WORKER_PAYLOAD_DIR", "/tmp/payloads")
53
+ self.S3_ENDPOINT_URL: str | None = getenv("S3_ENDPOINT_URL")
54
+ self.S3_ACCESS_KEY: str | None = getenv("S3_ACCESS_KEY")
55
+ self.S3_SECRET_KEY: str | None = getenv("S3_SECRET_KEY")
56
+ self.S3_DEFAULT_BUCKET: str = getenv("S3_DEFAULT_BUCKET", "avtomatika-payloads")
57
+
51
58
  # --- Tuning parameters ---
52
- self.heartbeat_interval: float = float(getenv("HEARTBEAT_INTERVAL", "15"))
53
- self.result_max_retries: int = int(getenv("RESULT_MAX_RETRIES", "5"))
54
- self.result_retry_initial_delay: float = float(
59
+ self.HEARTBEAT_INTERVAL: float = float(getenv("HEARTBEAT_INTERVAL", "15"))
60
+ self.RESULT_MAX_RETRIES: int = int(getenv("RESULT_MAX_RETRIES", "5"))
61
+ self.RESULT_RETRY_INITIAL_DELAY: float = float(
55
62
  getenv("RESULT_RETRY_INITIAL_DELAY", "1.0"),
56
63
  )
57
- self.heartbeat_debounce_delay: float = float(getenv("WORKER_HEARTBEAT_DEBOUNCE_DELAY", 0.1))
58
- self.task_poll_timeout: float = float(getenv("TASK_POLL_TIMEOUT", "30"))
59
- self.task_poll_error_delay: float = float(
64
+ self.HEARTBEAT_DEBOUNCE_DELAY: float = float(getenv("WORKER_HEARTBEAT_DEBOUNCE_DELAY", 0.1))
65
+ self.TASK_POLL_TIMEOUT: float = float(getenv("TASK_POLL_TIMEOUT", "30"))
66
+ self.TASK_POLL_ERROR_DELAY: float = float(
60
67
  getenv("TASK_POLL_ERROR_DELAY", "5.0"),
61
68
  )
62
- self.idle_poll_delay: float = float(getenv("IDLE_POLL_DELAY", "0.01"))
63
- self.enable_websockets: bool = getenv("WORKER_ENABLE_WEBSOCKETS", "false").lower() == "true"
64
- self.multi_orchestrator_mode: str = getenv("MULTI_ORCHESTRATOR_MODE", "FAILOVER")
69
+ self.IDLE_POLL_DELAY: float = float(getenv("IDLE_POLL_DELAY", "0.01"))
70
+ self.ENABLE_WEBSOCKETS: bool = getenv("WORKER_ENABLE_WEBSOCKETS", "false").lower() == "true"
71
+ self.MULTI_ORCHESTRATOR_MODE: str = getenv("MULTI_ORCHESTRATOR_MODE", "FAILOVER")
65
72
 
66
73
  def _get_orchestrators_config(self) -> list[dict[str, Any]]:
67
74
  """
@@ -72,16 +79,20 @@ class WorkerConfig:
72
79
  if orchestrators_json:
73
80
  try:
74
81
  orchestrators = loads(orchestrators_json)
82
+ if getenv("ORCHESTRATOR_URL"):
83
+ print("Info: Both ORCHESTRATORS_CONFIG and ORCHESTRATOR_URL are set. Using ORCHESTRATORS_CONFIG.")
75
84
  for o in orchestrators:
76
85
  if "priority" not in o:
77
86
  o["priority"] = 10
87
+ if "weight" not in o:
88
+ o["weight"] = 1
78
89
  orchestrators.sort(key=lambda x: (x.get("priority", 10), x.get("url")))
79
90
  return orchestrators
80
91
  except JSONDecodeError:
81
92
  print("Warning: Could not decode JSON from ORCHESTRATORS_CONFIG. Falling back to default.")
82
93
 
83
94
  orchestrator_url = getenv("ORCHESTRATOR_URL", "http://localhost:8080")
84
- return [{"url": orchestrator_url, "priority": 1}]
95
+ return [{"url": orchestrator_url, "priority": 1, "weight": 1}]
85
96
 
86
97
  def _get_gpu_info(self) -> dict[str, Any] | None:
87
98
  """Collects GPU information from environment variables.
@@ -0,0 +1,75 @@
1
+ import asyncio
2
+ import os
3
+ from typing import Any
4
+ from urllib.parse import urlparse
5
+
6
+ import boto3
7
+ from botocore.client import Config
8
+
9
+ from .config import WorkerConfig
10
+
11
+
12
+ class S3Manager:
13
+ """Handles S3 payload offloading."""
14
+
15
+ def __init__(self, config: WorkerConfig):
16
+ self._config = config
17
+ self._s3 = boto3.client(
18
+ "s3",
19
+ endpoint_url=self._config.S3_ENDPOINT_URL,
20
+ aws_access_key_id=self._config.S3_ACCESS_KEY,
21
+ aws_secret_access_key=self._config.S3_SECRET_KEY,
22
+ config=Config(signature_version="s3v4"),
23
+ )
24
+
25
+ async def _process_s3_uri(self, uri: str) -> str:
26
+ """Downloads a file from S3 and returns the local path."""
27
+ parsed_url = urlparse(uri)
28
+ bucket_name = parsed_url.netloc
29
+ object_key = parsed_url.path.lstrip("/")
30
+ local_dir = self._config.WORKER_PAYLOAD_DIR
31
+ os.makedirs(local_dir, exist_ok=True)
32
+ local_path = os.path.join(local_dir, os.path.basename(object_key))
33
+
34
+ await asyncio.to_thread(self._s3.download_file, bucket_name, object_key, local_path)
35
+ return local_path
36
+
37
+ async def _upload_to_s3(self, local_path: str) -> str:
38
+ """Uploads a file to S3 and returns the S3 URI."""
39
+ bucket_name = self._config.S3_DEFAULT_BUCKET
40
+ object_key = os.path.basename(local_path)
41
+
42
+ await asyncio.to_thread(self._s3.upload_file, local_path, bucket_name, object_key)
43
+ return f"s3://{bucket_name}/{object_key}"
44
+
45
+ async def process_params(self, params: dict[str, Any]) -> dict[str, Any]:
46
+ """Recursively searches for S3 URIs in params and downloads the files."""
47
+ if not self._config.S3_ENDPOINT_URL:
48
+ return params
49
+
50
+ async def _process(item: Any) -> Any:
51
+ if isinstance(item, str) and item.startswith("s3://"):
52
+ return await self._process_s3_uri(item)
53
+ if isinstance(item, dict):
54
+ return {k: await _process(v) for k, v in item.items()}
55
+ if isinstance(item, list):
56
+ return [await _process(i) for i in item]
57
+ return item
58
+
59
+ return await _process(params)
60
+
61
+ async def process_result(self, result: dict[str, Any]) -> dict[str, Any]:
62
+ """Recursively searches for local file paths in the result and uploads them to S3."""
63
+ if not self._config.S3_ENDPOINT_URL:
64
+ return result
65
+
66
+ async def _process(item: Any) -> Any:
67
+ if isinstance(item, str) and os.path.exists(item) and item.startswith(self._config.WORKER_PAYLOAD_DIR):
68
+ return await self._upload_to_s3(item)
69
+ if isinstance(item, dict):
70
+ return {k: await _process(v) for k, v in item.items()}
71
+ if isinstance(item, list):
72
+ return [await _process(i) for i in item]
73
+ return item
74
+
75
+ return await _process(result)
@@ -2,3 +2,7 @@
2
2
  TRANSIENT_ERROR = "TRANSIENT_ERROR"
3
3
  PERMANENT_ERROR = "PERMANENT_ERROR"
4
4
  INVALID_INPUT_ERROR = "INVALID_INPUT_ERROR"
5
+
6
+
7
+ class ParamValidationError(Exception):
8
+ """Custom exception for parameter validation errors."""
@@ -1,5 +1,7 @@
1
1
  from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
2
2
  from asyncio import TimeoutError as AsyncTimeoutError
3
+ from dataclasses import is_dataclass
4
+ from inspect import Parameter, signature
3
5
  from json import JSONDecodeError
4
6
  from logging import getLogger
5
7
  from typing import Any, Callable
@@ -7,6 +9,15 @@ from typing import Any, Callable
7
9
  from aiohttp import ClientError, ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType, web
8
10
 
9
11
  from .config import WorkerConfig
12
+ from .s3 import S3Manager
13
+ from .types import INVALID_INPUT_ERROR, PERMANENT_ERROR, TRANSIENT_ERROR, ParamValidationError
14
+
15
+ try:
16
+ from pydantic import BaseModel, ValidationError
17
+
18
+ _PYDANTIC_INSTALLED = True
19
+ except ImportError:
20
+ _PYDANTIC_INSTALLED = False
10
21
 
11
22
  # Logging setup
12
23
  logger = getLogger(__name__)
@@ -26,9 +37,11 @@ class Worker:
26
37
  task_type_limits: dict[str, int] | None = None,
27
38
  http_session: ClientSession | None = None,
28
39
  skill_dependencies: dict[str, list[str]] | None = None,
40
+ config: WorkerConfig | None = None,
29
41
  ):
30
- self._config = WorkerConfig()
31
- self._config.worker_type = worker_type # Allow overriding worker_type
42
+ self._config = config or WorkerConfig()
43
+ self._s3_manager = S3Manager(self._config)
44
+ self._config.WORKER_TYPE = worker_type # Allow overriding worker_type
32
45
  if max_concurrent_tasks is not None:
33
46
  self._config.max_concurrent_tasks = max_concurrent_tasks
34
47
 
@@ -44,12 +57,19 @@ class Worker:
44
57
  self._http_session = http_session
45
58
  self._session_is_managed_externally = http_session is not None
46
59
  self._ws_connection: ClientWebSocketResponse | None = None
47
- self._headers = {"X-Worker-Token": self._config.worker_token}
60
+ # Removed: self._headers = {"X-Worker-Token": self._config.WORKER_TOKEN}
48
61
  self._shutdown_event = Event()
49
62
  self._registered_event = Event()
50
63
  self._round_robin_index = 0
51
64
  self._debounce_task: Task | None = None
52
65
 
66
+ # --- Weighted Round-Robin State ---
67
+ self._total_orchestrator_weight = 0
68
+ if self._config.ORCHESTRATORS:
69
+ for o in self._config.ORCHESTRATORS:
70
+ o["current_weight"] = 0
71
+ self._total_orchestrator_weight += o.get("weight", 1)
72
+
53
73
  def _validate_config(self):
54
74
  """Checks for unused task type limits and warns the user."""
55
75
  registered_task_types = {
@@ -98,7 +118,7 @@ class Worker:
98
118
  """
99
119
  Calculates the current worker state including status and available tasks.
100
120
  """
101
- if self._current_load >= self._config.max_concurrent_tasks:
121
+ if self._current_load >= self._config.MAX_CONCURRENT_TASKS:
102
122
  return {"status": "busy", "supported_tasks": []}
103
123
 
104
124
  supported_tasks = []
@@ -118,9 +138,36 @@ class Worker:
118
138
  status = "idle" if supported_tasks else "busy"
119
139
  return {"status": status, "supported_tasks": supported_tasks}
120
140
 
141
+ def _get_headers(self, orchestrator: dict[str, Any]) -> dict[str, str]:
142
+ """Builds authentication headers for a specific orchestrator."""
143
+ token = orchestrator.get("token", self._config.WORKER_TOKEN)
144
+ return {"X-Worker-Token": token}
145
+
146
+ def _get_next_orchestrator(self) -> dict[str, Any] | None:
147
+ """
148
+ Selects the next orchestrator using a smooth weighted round-robin algorithm.
149
+ """
150
+ if not self._config.ORCHESTRATORS:
151
+ return None
152
+
153
+ # The orchestrator with the highest current_weight is selected.
154
+ selected_orchestrator = None
155
+ highest_weight = -1
156
+
157
+ for o in self._config.ORCHESTRATORS:
158
+ o["current_weight"] += o["weight"]
159
+ if o["current_weight"] > highest_weight:
160
+ highest_weight = o["current_weight"]
161
+ selected_orchestrator = o
162
+
163
+ if selected_orchestrator:
164
+ selected_orchestrator["current_weight"] -= self._total_orchestrator_weight
165
+
166
+ return selected_orchestrator
167
+
121
168
  async def _debounced_heartbeat_sender(self):
122
169
  """Waits for the debounce delay then sends a heartbeat."""
123
- await sleep(self._config.heartbeat_debounce_delay)
170
+ await sleep(self._config.HEARTBEAT_DEBOUNCE_DELAY)
124
171
  await self._send_heartbeats_to_all()
125
172
 
126
173
  def _schedule_heartbeat_debounce(self):
@@ -131,17 +178,18 @@ class Worker:
131
178
  # Schedule the new debounced call.
132
179
  self._debounce_task = create_task(self._debounced_heartbeat_sender())
133
180
 
134
- async def _poll_for_tasks(self, orchestrator_url: str):
181
+ async def _poll_for_tasks(self, orchestrator: dict[str, Any]):
135
182
  """Polls a specific Orchestrator for new tasks."""
136
- url = f"{orchestrator_url}/_worker/workers/{self._config.worker_id}/tasks/next"
183
+ url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}/tasks/next"
137
184
  try:
138
185
  if not self._http_session:
139
186
  return
140
- timeout = ClientTimeout(total=self._config.task_poll_timeout + 5)
141
- async with self._http_session.get(url, headers=self._headers, timeout=timeout) as resp:
187
+ timeout = ClientTimeout(total=self._config.TASK_POLL_TIMEOUT + 5)
188
+ headers = self._get_headers(orchestrator)
189
+ async with self._http_session.get(url, headers=headers, timeout=timeout) as resp:
142
190
  if resp.status == 200:
143
191
  task_data = await resp.json()
144
- task_data["orchestrator_url"] = orchestrator_url
192
+ task_data["orchestrator"] = orchestrator
145
193
 
146
194
  self._current_load += 1
147
195
  task_handler_info = self._task_handlers.get(task_data["type"])
@@ -154,63 +202,125 @@ class Worker:
154
202
  task = create_task(self._process_task(task_data))
155
203
  self._active_tasks[task_data["task_id"]] = task
156
204
  elif resp.status != 204:
157
- await sleep(self._config.task_poll_error_delay)
205
+ await sleep(self._config.TASK_POLL_ERROR_DELAY)
158
206
  except (AsyncTimeoutError, ClientError) as e:
159
207
  logger.error(f"Error polling for tasks: {e}")
160
- await sleep(self._config.task_poll_error_delay)
208
+ await sleep(self._config.TASK_POLL_ERROR_DELAY)
161
209
 
162
210
  async def _start_polling(self):
163
- print("Waiting for registration")
164
211
  """The main loop for polling tasks."""
165
212
  await self._registered_event.wait()
166
- print("Polling started")
213
+
167
214
  while not self._shutdown_event.is_set():
168
215
  if self._get_current_state()["status"] == "busy":
169
- await sleep(self._config.idle_poll_delay)
216
+ await sleep(self._config.IDLE_POLL_DELAY)
170
217
  continue
171
218
 
172
- if self._config.multi_orchestrator_mode == "ROUND_ROBIN":
173
- orchestrator = self._config.orchestrators[self._round_robin_index]
174
- await self._poll_for_tasks(orchestrator["url"])
175
- self._round_robin_index = (self._round_robin_index + 1) % len(self._config.orchestrators)
219
+ if self._config.MULTI_ORCHESTRATOR_MODE == "ROUND_ROBIN":
220
+ orchestrator = self._get_next_orchestrator()
221
+ if orchestrator:
222
+ await self._poll_for_tasks(orchestrator)
176
223
  else:
177
- for orchestrator in self._config.orchestrators:
224
+ for orchestrator in self._config.ORCHESTRATORS:
178
225
  if self._get_current_state()["status"] == "busy":
179
226
  break
180
- await self._poll_for_tasks(orchestrator["url"])
227
+ await self._poll_for_tasks(orchestrator)
181
228
 
182
229
  if self._current_load == 0:
183
- await sleep(self._config.idle_poll_delay)
230
+ await sleep(self._config.IDLE_POLL_DELAY)
231
+
232
+ def _prepare_task_params(self, handler: Callable, params: dict[str, Any]) -> Any:
233
+ """
234
+ Inspects the handler's signature to validate and instantiate params.
235
+ Supports dict, dataclasses, and optional pydantic models.
236
+ """
237
+ sig = signature(handler)
238
+ params_annotation = sig.parameters.get("params").annotation
239
+
240
+ if params_annotation is sig.empty or params_annotation is dict:
241
+ return params
242
+
243
+ # Pydantic Model Validation
244
+ if _PYDANTIC_INSTALLED and isinstance(params_annotation, type) and issubclass(params_annotation, BaseModel):
245
+ try:
246
+ return params_annotation.model_validate(params)
247
+ except ValidationError as e:
248
+ raise ParamValidationError(str(e)) from e
249
+
250
+ # Dataclass Instantiation
251
+ if isinstance(params_annotation, type) and is_dataclass(params_annotation):
252
+ try:
253
+ # Filter unknown fields to prevent TypeError on dataclass instantiation
254
+ known_fields = {f.name for f in params_annotation.__dataclass_fields__.values()}
255
+ filtered_params = {k: v for k, v in params.items() if k in known_fields}
256
+
257
+ # Explicitly check for missing required fields
258
+ required_fields = [
259
+ f.name
260
+ for f in params_annotation.__dataclass_fields__.values()
261
+ if f.default is Parameter.empty and f.default_factory is Parameter.empty
262
+ ]
263
+
264
+ missing_fields = [f for f in required_fields if f not in filtered_params]
265
+ if missing_fields:
266
+ raise ParamValidationError(f"Missing required fields for dataclass: {', '.join(missing_fields)}")
267
+
268
+ return params_annotation(**filtered_params)
269
+ except (TypeError, ValueError) as e:
270
+ # TypeError for missing/extra args, ValueError from __post_init__
271
+ raise ParamValidationError(str(e)) from e
272
+
273
+ return params
184
274
 
185
275
  async def _process_task(self, task_data: dict[str, Any]):
186
276
  """Executes the task logic."""
187
277
  task_id, job_id, task_name = task_data["task_id"], task_data["job_id"], task_data["type"]
188
- params, orchestrator_url = task_data.get("params", {}), task_data["orchestrator_url"]
278
+ params, orchestrator = task_data.get("params", {}), task_data["orchestrator"]
189
279
 
190
280
  result: dict[str, Any] = {}
191
281
  handler_data = self._task_handlers.get(task_name)
192
282
  task_type_for_limit = handler_data.get("type") if handler_data else None
193
283
 
284
+ result_sent = False # Flag to track if result has been sent
285
+
194
286
  try:
195
- if handler_data:
196
- result = await handler_data["func"](
197
- params,
198
- task_id=task_id,
199
- job_id=job_id,
200
- priority=task_data.get("priority", 0),
201
- send_progress=self.send_progress,
202
- add_to_hot_cache=self.add_to_hot_cache,
203
- remove_from_hot_cache=self.remove_from_hot_cache,
204
- )
205
- else:
206
- result = {"status": "failure", "error_message": f"Unsupported task: {task_name}"}
287
+ if not handler_data:
288
+ message = f"Unsupported task: {task_name}"
289
+ logger.warning(message)
290
+ result = {"status": "failure", "error": {"code": PERMANENT_ERROR, "message": message}}
291
+ payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
292
+ await self._send_result(payload, orchestrator)
293
+ result_sent = True # Mark result as sent
294
+ return
295
+
296
+ params = await self._s3_manager.process_params(params)
297
+ validated_params = self._prepare_task_params(handler_data["func"], params)
298
+
299
+ result = await handler_data["func"](
300
+ validated_params,
301
+ task_id=task_id,
302
+ job_id=job_id,
303
+ priority=task_data.get("priority", 0),
304
+ send_progress=self.send_progress,
305
+ add_to_hot_cache=self.add_to_hot_cache,
306
+ remove_from_hot_cache=self.remove_from_hot_cache,
307
+ )
308
+ result = await self._s3_manager.process_result(result)
309
+ except ParamValidationError as e:
310
+ logger.error(f"Task {task_id} failed validation: {e}")
311
+ result = {"status": "failure", "error": {"code": INVALID_INPUT_ERROR, "message": str(e)}}
207
312
  except CancelledError:
313
+ logger.info(f"Task {task_id} was cancelled.")
208
314
  result = {"status": "cancelled"}
315
+ # We must re-raise the exception to be handled by the outer gather
316
+ raise
209
317
  except Exception as e:
210
- result = {"status": "failure", "error": {"code": "TRANSIENT_ERROR", "message": str(e)}}
318
+ logger.exception(f"An unexpected error occurred while processing task {task_id}:")
319
+ result = {"status": "failure", "error": {"code": TRANSIENT_ERROR, "message": str(e)}}
211
320
  finally:
212
- payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.worker_id, "result": result}
213
- await self._send_result(payload, orchestrator_url)
321
+ if not result_sent: # Only send if not already sent
322
+ payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
323
+ await self._send_result(payload, orchestrator)
214
324
  self._active_tasks.pop(task_id, None)
215
325
 
216
326
  self._current_load -= 1
@@ -218,14 +328,15 @@ class Worker:
218
328
  self._current_load_by_type[task_type_for_limit] -= 1
219
329
  self._schedule_heartbeat_debounce()
220
330
 
221
- async def _send_result(self, payload: dict[str, Any], orchestrator_url: str):
331
+ async def _send_result(self, payload: dict[str, Any], orchestrator: dict[str, Any]):
222
332
  """Sends the result to a specific orchestrator."""
223
- url = f"{orchestrator_url}/_worker/tasks/result"
224
- delay = self._config.result_retry_initial_delay
225
- for i in range(self._config.result_max_retries):
333
+ url = f"{orchestrator['url']}/_worker/tasks/result"
334
+ delay = self._config.RESULT_RETRY_INITIAL_DELAY
335
+ headers = self._get_headers(orchestrator)
336
+ for i in range(self._config.RESULT_MAX_RETRIES):
226
337
  try:
227
338
  if self._http_session and not self._http_session.closed:
228
- async with self._http_session.post(url, json=payload, headers=self._headers) as resp:
339
+ async with self._http_session.post(url, json=payload, headers=headers) as resp:
229
340
  if resp.status == 200:
230
341
  return
231
342
  except ClientError as e:
@@ -233,43 +344,44 @@ class Worker:
233
344
  await sleep(delay * (2**i))
234
345
 
235
346
  async def _manage_orchestrator_communications(self):
236
- print("Registering worker")
237
347
  """Registers the worker and sends heartbeats."""
238
348
  await self._register_with_all_orchestrators()
239
- print("Worker registered")
349
+
240
350
  self._registered_event.set()
241
- if self._config.enable_websockets:
351
+ if self._config.ENABLE_WEBSOCKETS:
242
352
  create_task(self._start_websocket_manager())
243
353
 
244
354
  while not self._shutdown_event.is_set():
245
355
  await self._send_heartbeats_to_all()
246
- await sleep(self._config.heartbeat_interval)
356
+ await sleep(self._config.HEARTBEAT_INTERVAL)
247
357
 
248
358
  async def _register_with_all_orchestrators(self):
249
359
  """Registers the worker with all orchestrators."""
250
360
  state = self._get_current_state()
251
361
  payload = {
252
- "worker_id": self._config.worker_id,
253
- "worker_type": self._config.worker_type,
362
+ "worker_id": self._config.WORKER_ID,
363
+ "worker_type": self._config.WORKER_TYPE,
254
364
  "supported_tasks": state["supported_tasks"],
255
- "max_concurrent_tasks": self._config.max_concurrent_tasks,
256
- "installed_models": self._config.installed_models,
257
- "hostname": self._config.hostname,
258
- "ip_address": self._config.ip_address,
259
- "resources": self._config.resources,
365
+ "max_concurrent_tasks": self._config.MAX_CONCURRENT_TASKS,
366
+ "cost_per_skill": self._config.COST_PER_SKILL,
367
+ "installed_models": self._config.INSTALLED_MODELS,
368
+ "hostname": self._config.HOSTNAME,
369
+ "ip_address": self._config.IP_ADDRESS,
370
+ "resources": self._config.RESOURCES,
260
371
  }
261
- for orchestrator in self._config.orchestrators:
372
+ for orchestrator in self._config.ORCHESTRATORS:
262
373
  url = f"{orchestrator['url']}/_worker/workers/register"
263
374
  try:
264
375
  if self._http_session:
265
- async with self._http_session.post(url, json=payload, headers=self._headers) as resp:
376
+ async with self._http_session.post(
377
+ url, json=payload, headers=self._get_headers(orchestrator)
378
+ ) as resp:
266
379
  if resp.status >= 400:
267
380
  logger.error(f"Error registering with {orchestrator['url']}: {resp.status}")
268
381
  except ClientError as e:
269
382
  logger.error(f"Error registering with orchestrator {orchestrator['url']}: {e}")
270
383
 
271
384
  async def _send_heartbeats_to_all(self):
272
- print("Sending heartbeats")
273
385
  """Sends heartbeat messages to all orchestrators."""
274
386
  state = self._get_current_state()
275
387
  payload = {
@@ -287,27 +399,27 @@ class Worker:
287
399
  if hot_skills:
288
400
  payload["hot_skills"] = hot_skills
289
401
 
290
- async def _send_single(orchestrator_url: str):
291
- url = f"{orchestrator_url}/_worker/workers/{self._config.worker_id}"
402
+ async def _send_single(orchestrator: dict[str, Any]):
403
+ url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}"
404
+ headers = self._get_headers(orchestrator)
292
405
  try:
293
406
  if self._http_session and not self._http_session.closed:
294
- async with self._http_session.patch(url, json=payload, headers=self._headers) as resp:
407
+ async with self._http_session.patch(url, json=payload, headers=headers) as resp:
295
408
  if resp.status >= 400:
296
- logger.warning(f"Heartbeat to {orchestrator_url} failed with status: {resp.status}")
409
+ logger.warning(f"Heartbeat to {orchestrator['url']} failed with status: {resp.status}")
297
410
  except ClientError as e:
298
- logger.error(f"Error sending heartbeat to orchestrator {orchestrator_url}: {e}")
411
+ logger.error(f"Error sending heartbeat to orchestrator {orchestrator['url']}: {e}")
299
412
 
300
- await gather(*[_send_single(o["url"]) for o in self._config.orchestrators])
413
+ await gather(*[_send_single(o) for o in self._config.ORCHESTRATORS])
301
414
 
302
415
  async def main(self):
303
- print("Main started")
304
416
  """The main asynchronous function."""
305
417
  self._validate_config() # Validate config now that all tasks are registered
306
418
  if not self._http_session:
307
419
  self._http_session = ClientSession()
308
- print("Starting comm task")
420
+
309
421
  comm_task = create_task(self._manage_orchestrator_communications())
310
- print("Starting polling task")
422
+
311
423
  polling_task = create_task(self._start_polling())
312
424
  await self._shutdown_event.wait()
313
425
 
@@ -327,14 +439,13 @@ class Worker:
327
439
  run(self.main())
328
440
  except KeyboardInterrupt:
329
441
  self._shutdown_event.set()
330
- run(sleep(1.5))
331
442
 
332
443
  async def _run_health_check_server(self):
333
444
  app = web.Application()
334
445
  app.router.add_get("/health", lambda r: web.Response(text="OK"))
335
446
  runner = web.AppRunner(app)
336
447
  await runner.setup()
337
- site = web.TCPSite(runner, "0.0.0.0", self._config.worker_port)
448
+ site = web.TCPSite(runner, "0.0.0.0", self._config.WORKER_PORT)
338
449
  await site.start()
339
450
  await self._shutdown_event.wait()
340
451
  await runner.cleanup()
@@ -347,17 +458,16 @@ class Worker:
347
458
  run(_main_wrapper())
348
459
  except KeyboardInterrupt:
349
460
  self._shutdown_event.set()
350
- run(sleep(1.5))
351
461
 
352
462
  # WebSocket methods omitted for brevity as they are not relevant to the changes
353
463
  async def _start_websocket_manager(self):
354
464
  """Manages the WebSocket connection to the orchestrator."""
355
465
  while not self._shutdown_event.is_set():
356
- for orchestrator in self._config.orchestrators:
466
+ for orchestrator in self._config.ORCHESTRATORS:
357
467
  ws_url = orchestrator["url"].replace("http", "ws", 1) + "/_worker/ws"
358
468
  try:
359
469
  if self._http_session:
360
- async with self._http_session.ws_connect(ws_url, headers=self._headers) as ws:
470
+ async with self._http_session.ws_connect(ws_url, headers=self._get_headers(orchestrator)) as ws:
361
471
  self._ws_connection = ws
362
472
  logger.info(f"WebSocket connection established to {ws_url}")
363
473
  await self._listen_for_commands()
@@ -367,7 +477,7 @@ class Worker:
367
477
  self._ws_connection = None
368
478
  logger.info(f"WebSocket connection to {ws_url} closed.")
369
479
  await sleep(5) # Reconnection delay
370
- if not self._config.orchestrators:
480
+ if not self._config.ORCHESTRATORS:
371
481
  await sleep(5)
372
482
 
373
483
  async def _listen_for_commands(self):