avtomatika-worker 1.0a2__py3-none-any.whl → 1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- avtomatika_worker/config.py +34 -23
- avtomatika_worker/s3.py +75 -0
- avtomatika_worker/types.py +4 -0
- avtomatika_worker/worker.py +182 -72
- avtomatika_worker-1.0b1.dist-info/METADATA +537 -0
- avtomatika_worker-1.0b1.dist-info/RECORD +10 -0
- avtomatika_worker-1.0a2.dist-info/METADATA +0 -307
- avtomatika_worker-1.0a2.dist-info/RECORD +0 -9
- {avtomatika_worker-1.0a2.dist-info → avtomatika_worker-1.0b1.dist-info}/WHEEL +0 -0
- {avtomatika_worker-1.0a2.dist-info → avtomatika_worker-1.0b1.dist-info}/licenses/LICENSE +0 -0
- {avtomatika_worker-1.0a2.dist-info → avtomatika_worker-1.0b1.dist-info}/top_level.txt +0 -0
avtomatika_worker/config.py
CHANGED
|
@@ -12,56 +12,63 @@ class WorkerConfig:
|
|
|
12
12
|
|
|
13
13
|
def __init__(self):
|
|
14
14
|
# --- Basic worker information ---
|
|
15
|
-
self.
|
|
16
|
-
self.
|
|
17
|
-
self.
|
|
18
|
-
self.
|
|
15
|
+
self.WORKER_ID: str = getenv("WORKER_ID", f"worker-{uuid4()}")
|
|
16
|
+
self.WORKER_TYPE: str = getenv("WORKER_TYPE", "generic-cpu-worker")
|
|
17
|
+
self.WORKER_PORT: int = int(getenv("WORKER_PORT", "8083"))
|
|
18
|
+
self.HOSTNAME: str = gethostname()
|
|
19
19
|
try:
|
|
20
|
-
self.
|
|
20
|
+
self.IP_ADDRESS: str = gethostbyname(self.HOSTNAME)
|
|
21
21
|
except gaierror:
|
|
22
|
-
self.
|
|
22
|
+
self.IP_ADDRESS: str = "127.0.0.1"
|
|
23
23
|
|
|
24
24
|
# --- Orchestrator settings ---
|
|
25
|
-
self.
|
|
25
|
+
self.ORCHESTRATORS: list[dict[str, Any]] = self._get_orchestrators_config()
|
|
26
26
|
|
|
27
27
|
# --- Security ---
|
|
28
|
-
self.
|
|
28
|
+
self.WORKER_TOKEN: str = getenv(
|
|
29
29
|
"WORKER_INDIVIDUAL_TOKEN",
|
|
30
30
|
getenv("WORKER_TOKEN", "your-secret-worker-token"),
|
|
31
31
|
)
|
|
32
32
|
|
|
33
33
|
# --- Resources and performance ---
|
|
34
|
-
self.
|
|
35
|
-
self.
|
|
36
|
-
self.
|
|
34
|
+
self.COST_PER_SKILL: dict[str, float] = self._load_json_from_env("COST_PER_SKILL", default={})
|
|
35
|
+
self.MAX_CONCURRENT_TASKS: int = int(getenv("MAX_CONCURRENT_TASKS", "10"))
|
|
36
|
+
self.RESOURCES: dict[str, Any] = {
|
|
37
37
|
"cpu_cores": int(getenv("CPU_CORES", "4")),
|
|
38
38
|
"gpu_info": self._get_gpu_info(),
|
|
39
39
|
}
|
|
40
40
|
|
|
41
41
|
# --- Installed software and models (read as JSON strings) ---
|
|
42
|
-
self.
|
|
42
|
+
self.INSTALLED_SOFTWARE: dict[str, str] = self._load_json_from_env(
|
|
43
43
|
"INSTALLED_SOFTWARE",
|
|
44
44
|
default={"python": "3.9"},
|
|
45
45
|
)
|
|
46
|
-
self.
|
|
46
|
+
self.INSTALLED_MODELS: list[dict[str, str]] = self._load_json_from_env(
|
|
47
47
|
"INSTALLED_MODELS",
|
|
48
48
|
default=[],
|
|
49
49
|
)
|
|
50
50
|
|
|
51
|
+
# --- S3 Settings for payload offloading ---
|
|
52
|
+
self.WORKER_PAYLOAD_DIR: str = getenv("WORKER_PAYLOAD_DIR", "/tmp/payloads")
|
|
53
|
+
self.S3_ENDPOINT_URL: str | None = getenv("S3_ENDPOINT_URL")
|
|
54
|
+
self.S3_ACCESS_KEY: str | None = getenv("S3_ACCESS_KEY")
|
|
55
|
+
self.S3_SECRET_KEY: str | None = getenv("S3_SECRET_KEY")
|
|
56
|
+
self.S3_DEFAULT_BUCKET: str = getenv("S3_DEFAULT_BUCKET", "avtomatika-payloads")
|
|
57
|
+
|
|
51
58
|
# --- Tuning parameters ---
|
|
52
|
-
self.
|
|
53
|
-
self.
|
|
54
|
-
self.
|
|
59
|
+
self.HEARTBEAT_INTERVAL: float = float(getenv("HEARTBEAT_INTERVAL", "15"))
|
|
60
|
+
self.RESULT_MAX_RETRIES: int = int(getenv("RESULT_MAX_RETRIES", "5"))
|
|
61
|
+
self.RESULT_RETRY_INITIAL_DELAY: float = float(
|
|
55
62
|
getenv("RESULT_RETRY_INITIAL_DELAY", "1.0"),
|
|
56
63
|
)
|
|
57
|
-
self.
|
|
58
|
-
self.
|
|
59
|
-
self.
|
|
64
|
+
self.HEARTBEAT_DEBOUNCE_DELAY: float = float(getenv("WORKER_HEARTBEAT_DEBOUNCE_DELAY", 0.1))
|
|
65
|
+
self.TASK_POLL_TIMEOUT: float = float(getenv("TASK_POLL_TIMEOUT", "30"))
|
|
66
|
+
self.TASK_POLL_ERROR_DELAY: float = float(
|
|
60
67
|
getenv("TASK_POLL_ERROR_DELAY", "5.0"),
|
|
61
68
|
)
|
|
62
|
-
self.
|
|
63
|
-
self.
|
|
64
|
-
self.
|
|
69
|
+
self.IDLE_POLL_DELAY: float = float(getenv("IDLE_POLL_DELAY", "0.01"))
|
|
70
|
+
self.ENABLE_WEBSOCKETS: bool = getenv("WORKER_ENABLE_WEBSOCKETS", "false").lower() == "true"
|
|
71
|
+
self.MULTI_ORCHESTRATOR_MODE: str = getenv("MULTI_ORCHESTRATOR_MODE", "FAILOVER")
|
|
65
72
|
|
|
66
73
|
def _get_orchestrators_config(self) -> list[dict[str, Any]]:
|
|
67
74
|
"""
|
|
@@ -72,16 +79,20 @@ class WorkerConfig:
|
|
|
72
79
|
if orchestrators_json:
|
|
73
80
|
try:
|
|
74
81
|
orchestrators = loads(orchestrators_json)
|
|
82
|
+
if getenv("ORCHESTRATOR_URL"):
|
|
83
|
+
print("Info: Both ORCHESTRATORS_CONFIG and ORCHESTRATOR_URL are set. Using ORCHESTRATORS_CONFIG.")
|
|
75
84
|
for o in orchestrators:
|
|
76
85
|
if "priority" not in o:
|
|
77
86
|
o["priority"] = 10
|
|
87
|
+
if "weight" not in o:
|
|
88
|
+
o["weight"] = 1
|
|
78
89
|
orchestrators.sort(key=lambda x: (x.get("priority", 10), x.get("url")))
|
|
79
90
|
return orchestrators
|
|
80
91
|
except JSONDecodeError:
|
|
81
92
|
print("Warning: Could not decode JSON from ORCHESTRATORS_CONFIG. Falling back to default.")
|
|
82
93
|
|
|
83
94
|
orchestrator_url = getenv("ORCHESTRATOR_URL", "http://localhost:8080")
|
|
84
|
-
return [{"url": orchestrator_url, "priority": 1}]
|
|
95
|
+
return [{"url": orchestrator_url, "priority": 1, "weight": 1}]
|
|
85
96
|
|
|
86
97
|
def _get_gpu_info(self) -> dict[str, Any] | None:
|
|
87
98
|
"""Collects GPU information from environment variables.
|
avtomatika_worker/s3.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any
|
|
4
|
+
from urllib.parse import urlparse
|
|
5
|
+
|
|
6
|
+
import boto3
|
|
7
|
+
from botocore.client import Config
|
|
8
|
+
|
|
9
|
+
from .config import WorkerConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class S3Manager:
|
|
13
|
+
"""Handles S3 payload offloading."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, config: WorkerConfig):
|
|
16
|
+
self._config = config
|
|
17
|
+
self._s3 = boto3.client(
|
|
18
|
+
"s3",
|
|
19
|
+
endpoint_url=self._config.S3_ENDPOINT_URL,
|
|
20
|
+
aws_access_key_id=self._config.S3_ACCESS_KEY,
|
|
21
|
+
aws_secret_access_key=self._config.S3_SECRET_KEY,
|
|
22
|
+
config=Config(signature_version="s3v4"),
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
async def _process_s3_uri(self, uri: str) -> str:
|
|
26
|
+
"""Downloads a file from S3 and returns the local path."""
|
|
27
|
+
parsed_url = urlparse(uri)
|
|
28
|
+
bucket_name = parsed_url.netloc
|
|
29
|
+
object_key = parsed_url.path.lstrip("/")
|
|
30
|
+
local_dir = self._config.WORKER_PAYLOAD_DIR
|
|
31
|
+
os.makedirs(local_dir, exist_ok=True)
|
|
32
|
+
local_path = os.path.join(local_dir, os.path.basename(object_key))
|
|
33
|
+
|
|
34
|
+
await asyncio.to_thread(self._s3.download_file, bucket_name, object_key, local_path)
|
|
35
|
+
return local_path
|
|
36
|
+
|
|
37
|
+
async def _upload_to_s3(self, local_path: str) -> str:
|
|
38
|
+
"""Uploads a file to S3 and returns the S3 URI."""
|
|
39
|
+
bucket_name = self._config.S3_DEFAULT_BUCKET
|
|
40
|
+
object_key = os.path.basename(local_path)
|
|
41
|
+
|
|
42
|
+
await asyncio.to_thread(self._s3.upload_file, local_path, bucket_name, object_key)
|
|
43
|
+
return f"s3://{bucket_name}/{object_key}"
|
|
44
|
+
|
|
45
|
+
async def process_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
|
46
|
+
"""Recursively searches for S3 URIs in params and downloads the files."""
|
|
47
|
+
if not self._config.S3_ENDPOINT_URL:
|
|
48
|
+
return params
|
|
49
|
+
|
|
50
|
+
async def _process(item: Any) -> Any:
|
|
51
|
+
if isinstance(item, str) and item.startswith("s3://"):
|
|
52
|
+
return await self._process_s3_uri(item)
|
|
53
|
+
if isinstance(item, dict):
|
|
54
|
+
return {k: await _process(v) for k, v in item.items()}
|
|
55
|
+
if isinstance(item, list):
|
|
56
|
+
return [await _process(i) for i in item]
|
|
57
|
+
return item
|
|
58
|
+
|
|
59
|
+
return await _process(params)
|
|
60
|
+
|
|
61
|
+
async def process_result(self, result: dict[str, Any]) -> dict[str, Any]:
|
|
62
|
+
"""Recursively searches for local file paths in the result and uploads them to S3."""
|
|
63
|
+
if not self._config.S3_ENDPOINT_URL:
|
|
64
|
+
return result
|
|
65
|
+
|
|
66
|
+
async def _process(item: Any) -> Any:
|
|
67
|
+
if isinstance(item, str) and os.path.exists(item) and item.startswith(self._config.WORKER_PAYLOAD_DIR):
|
|
68
|
+
return await self._upload_to_s3(item)
|
|
69
|
+
if isinstance(item, dict):
|
|
70
|
+
return {k: await _process(v) for k, v in item.items()}
|
|
71
|
+
if isinstance(item, list):
|
|
72
|
+
return [await _process(i) for i in item]
|
|
73
|
+
return item
|
|
74
|
+
|
|
75
|
+
return await _process(result)
|
avtomatika_worker/types.py
CHANGED
avtomatika_worker/worker.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
|
|
2
2
|
from asyncio import TimeoutError as AsyncTimeoutError
|
|
3
|
+
from dataclasses import is_dataclass
|
|
4
|
+
from inspect import Parameter, signature
|
|
3
5
|
from json import JSONDecodeError
|
|
4
6
|
from logging import getLogger
|
|
5
7
|
from typing import Any, Callable
|
|
@@ -7,6 +9,15 @@ from typing import Any, Callable
|
|
|
7
9
|
from aiohttp import ClientError, ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType, web
|
|
8
10
|
|
|
9
11
|
from .config import WorkerConfig
|
|
12
|
+
from .s3 import S3Manager
|
|
13
|
+
from .types import INVALID_INPUT_ERROR, PERMANENT_ERROR, TRANSIENT_ERROR, ParamValidationError
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from pydantic import BaseModel, ValidationError
|
|
17
|
+
|
|
18
|
+
_PYDANTIC_INSTALLED = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
_PYDANTIC_INSTALLED = False
|
|
10
21
|
|
|
11
22
|
# Logging setup
|
|
12
23
|
logger = getLogger(__name__)
|
|
@@ -26,9 +37,11 @@ class Worker:
|
|
|
26
37
|
task_type_limits: dict[str, int] | None = None,
|
|
27
38
|
http_session: ClientSession | None = None,
|
|
28
39
|
skill_dependencies: dict[str, list[str]] | None = None,
|
|
40
|
+
config: WorkerConfig | None = None,
|
|
29
41
|
):
|
|
30
|
-
self._config = WorkerConfig()
|
|
31
|
-
self.
|
|
42
|
+
self._config = config or WorkerConfig()
|
|
43
|
+
self._s3_manager = S3Manager(self._config)
|
|
44
|
+
self._config.WORKER_TYPE = worker_type # Allow overriding worker_type
|
|
32
45
|
if max_concurrent_tasks is not None:
|
|
33
46
|
self._config.max_concurrent_tasks = max_concurrent_tasks
|
|
34
47
|
|
|
@@ -44,12 +57,19 @@ class Worker:
|
|
|
44
57
|
self._http_session = http_session
|
|
45
58
|
self._session_is_managed_externally = http_session is not None
|
|
46
59
|
self._ws_connection: ClientWebSocketResponse | None = None
|
|
47
|
-
self._headers = {"X-Worker-Token": self._config.
|
|
60
|
+
# Removed: self._headers = {"X-Worker-Token": self._config.WORKER_TOKEN}
|
|
48
61
|
self._shutdown_event = Event()
|
|
49
62
|
self._registered_event = Event()
|
|
50
63
|
self._round_robin_index = 0
|
|
51
64
|
self._debounce_task: Task | None = None
|
|
52
65
|
|
|
66
|
+
# --- Weighted Round-Robin State ---
|
|
67
|
+
self._total_orchestrator_weight = 0
|
|
68
|
+
if self._config.ORCHESTRATORS:
|
|
69
|
+
for o in self._config.ORCHESTRATORS:
|
|
70
|
+
o["current_weight"] = 0
|
|
71
|
+
self._total_orchestrator_weight += o.get("weight", 1)
|
|
72
|
+
|
|
53
73
|
def _validate_config(self):
|
|
54
74
|
"""Checks for unused task type limits and warns the user."""
|
|
55
75
|
registered_task_types = {
|
|
@@ -98,7 +118,7 @@ class Worker:
|
|
|
98
118
|
"""
|
|
99
119
|
Calculates the current worker state including status and available tasks.
|
|
100
120
|
"""
|
|
101
|
-
if self._current_load >= self._config.
|
|
121
|
+
if self._current_load >= self._config.MAX_CONCURRENT_TASKS:
|
|
102
122
|
return {"status": "busy", "supported_tasks": []}
|
|
103
123
|
|
|
104
124
|
supported_tasks = []
|
|
@@ -118,9 +138,36 @@ class Worker:
|
|
|
118
138
|
status = "idle" if supported_tasks else "busy"
|
|
119
139
|
return {"status": status, "supported_tasks": supported_tasks}
|
|
120
140
|
|
|
141
|
+
def _get_headers(self, orchestrator: dict[str, Any]) -> dict[str, str]:
|
|
142
|
+
"""Builds authentication headers for a specific orchestrator."""
|
|
143
|
+
token = orchestrator.get("token", self._config.WORKER_TOKEN)
|
|
144
|
+
return {"X-Worker-Token": token}
|
|
145
|
+
|
|
146
|
+
def _get_next_orchestrator(self) -> dict[str, Any] | None:
|
|
147
|
+
"""
|
|
148
|
+
Selects the next orchestrator using a smooth weighted round-robin algorithm.
|
|
149
|
+
"""
|
|
150
|
+
if not self._config.ORCHESTRATORS:
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
# The orchestrator with the highest current_weight is selected.
|
|
154
|
+
selected_orchestrator = None
|
|
155
|
+
highest_weight = -1
|
|
156
|
+
|
|
157
|
+
for o in self._config.ORCHESTRATORS:
|
|
158
|
+
o["current_weight"] += o["weight"]
|
|
159
|
+
if o["current_weight"] > highest_weight:
|
|
160
|
+
highest_weight = o["current_weight"]
|
|
161
|
+
selected_orchestrator = o
|
|
162
|
+
|
|
163
|
+
if selected_orchestrator:
|
|
164
|
+
selected_orchestrator["current_weight"] -= self._total_orchestrator_weight
|
|
165
|
+
|
|
166
|
+
return selected_orchestrator
|
|
167
|
+
|
|
121
168
|
async def _debounced_heartbeat_sender(self):
|
|
122
169
|
"""Waits for the debounce delay then sends a heartbeat."""
|
|
123
|
-
await sleep(self._config.
|
|
170
|
+
await sleep(self._config.HEARTBEAT_DEBOUNCE_DELAY)
|
|
124
171
|
await self._send_heartbeats_to_all()
|
|
125
172
|
|
|
126
173
|
def _schedule_heartbeat_debounce(self):
|
|
@@ -131,17 +178,18 @@ class Worker:
|
|
|
131
178
|
# Schedule the new debounced call.
|
|
132
179
|
self._debounce_task = create_task(self._debounced_heartbeat_sender())
|
|
133
180
|
|
|
134
|
-
async def _poll_for_tasks(self,
|
|
181
|
+
async def _poll_for_tasks(self, orchestrator: dict[str, Any]):
|
|
135
182
|
"""Polls a specific Orchestrator for new tasks."""
|
|
136
|
-
url = f"{
|
|
183
|
+
url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}/tasks/next"
|
|
137
184
|
try:
|
|
138
185
|
if not self._http_session:
|
|
139
186
|
return
|
|
140
|
-
timeout = ClientTimeout(total=self._config.
|
|
141
|
-
|
|
187
|
+
timeout = ClientTimeout(total=self._config.TASK_POLL_TIMEOUT + 5)
|
|
188
|
+
headers = self._get_headers(orchestrator)
|
|
189
|
+
async with self._http_session.get(url, headers=headers, timeout=timeout) as resp:
|
|
142
190
|
if resp.status == 200:
|
|
143
191
|
task_data = await resp.json()
|
|
144
|
-
task_data["
|
|
192
|
+
task_data["orchestrator"] = orchestrator
|
|
145
193
|
|
|
146
194
|
self._current_load += 1
|
|
147
195
|
task_handler_info = self._task_handlers.get(task_data["type"])
|
|
@@ -154,63 +202,125 @@ class Worker:
|
|
|
154
202
|
task = create_task(self._process_task(task_data))
|
|
155
203
|
self._active_tasks[task_data["task_id"]] = task
|
|
156
204
|
elif resp.status != 204:
|
|
157
|
-
await sleep(self._config.
|
|
205
|
+
await sleep(self._config.TASK_POLL_ERROR_DELAY)
|
|
158
206
|
except (AsyncTimeoutError, ClientError) as e:
|
|
159
207
|
logger.error(f"Error polling for tasks: {e}")
|
|
160
|
-
await sleep(self._config.
|
|
208
|
+
await sleep(self._config.TASK_POLL_ERROR_DELAY)
|
|
161
209
|
|
|
162
210
|
async def _start_polling(self):
|
|
163
|
-
print("Waiting for registration")
|
|
164
211
|
"""The main loop for polling tasks."""
|
|
165
212
|
await self._registered_event.wait()
|
|
166
|
-
|
|
213
|
+
|
|
167
214
|
while not self._shutdown_event.is_set():
|
|
168
215
|
if self._get_current_state()["status"] == "busy":
|
|
169
|
-
await sleep(self._config.
|
|
216
|
+
await sleep(self._config.IDLE_POLL_DELAY)
|
|
170
217
|
continue
|
|
171
218
|
|
|
172
|
-
if self._config.
|
|
173
|
-
orchestrator = self.
|
|
174
|
-
|
|
175
|
-
|
|
219
|
+
if self._config.MULTI_ORCHESTRATOR_MODE == "ROUND_ROBIN":
|
|
220
|
+
orchestrator = self._get_next_orchestrator()
|
|
221
|
+
if orchestrator:
|
|
222
|
+
await self._poll_for_tasks(orchestrator)
|
|
176
223
|
else:
|
|
177
|
-
for orchestrator in self._config.
|
|
224
|
+
for orchestrator in self._config.ORCHESTRATORS:
|
|
178
225
|
if self._get_current_state()["status"] == "busy":
|
|
179
226
|
break
|
|
180
|
-
await self._poll_for_tasks(orchestrator
|
|
227
|
+
await self._poll_for_tasks(orchestrator)
|
|
181
228
|
|
|
182
229
|
if self._current_load == 0:
|
|
183
|
-
await sleep(self._config.
|
|
230
|
+
await sleep(self._config.IDLE_POLL_DELAY)
|
|
231
|
+
|
|
232
|
+
def _prepare_task_params(self, handler: Callable, params: dict[str, Any]) -> Any:
|
|
233
|
+
"""
|
|
234
|
+
Inspects the handler's signature to validate and instantiate params.
|
|
235
|
+
Supports dict, dataclasses, and optional pydantic models.
|
|
236
|
+
"""
|
|
237
|
+
sig = signature(handler)
|
|
238
|
+
params_annotation = sig.parameters.get("params").annotation
|
|
239
|
+
|
|
240
|
+
if params_annotation is sig.empty or params_annotation is dict:
|
|
241
|
+
return params
|
|
242
|
+
|
|
243
|
+
# Pydantic Model Validation
|
|
244
|
+
if _PYDANTIC_INSTALLED and isinstance(params_annotation, type) and issubclass(params_annotation, BaseModel):
|
|
245
|
+
try:
|
|
246
|
+
return params_annotation.model_validate(params)
|
|
247
|
+
except ValidationError as e:
|
|
248
|
+
raise ParamValidationError(str(e)) from e
|
|
249
|
+
|
|
250
|
+
# Dataclass Instantiation
|
|
251
|
+
if isinstance(params_annotation, type) and is_dataclass(params_annotation):
|
|
252
|
+
try:
|
|
253
|
+
# Filter unknown fields to prevent TypeError on dataclass instantiation
|
|
254
|
+
known_fields = {f.name for f in params_annotation.__dataclass_fields__.values()}
|
|
255
|
+
filtered_params = {k: v for k, v in params.items() if k in known_fields}
|
|
256
|
+
|
|
257
|
+
# Explicitly check for missing required fields
|
|
258
|
+
required_fields = [
|
|
259
|
+
f.name
|
|
260
|
+
for f in params_annotation.__dataclass_fields__.values()
|
|
261
|
+
if f.default is Parameter.empty and f.default_factory is Parameter.empty
|
|
262
|
+
]
|
|
263
|
+
|
|
264
|
+
missing_fields = [f for f in required_fields if f not in filtered_params]
|
|
265
|
+
if missing_fields:
|
|
266
|
+
raise ParamValidationError(f"Missing required fields for dataclass: {', '.join(missing_fields)}")
|
|
267
|
+
|
|
268
|
+
return params_annotation(**filtered_params)
|
|
269
|
+
except (TypeError, ValueError) as e:
|
|
270
|
+
# TypeError for missing/extra args, ValueError from __post_init__
|
|
271
|
+
raise ParamValidationError(str(e)) from e
|
|
272
|
+
|
|
273
|
+
return params
|
|
184
274
|
|
|
185
275
|
async def _process_task(self, task_data: dict[str, Any]):
|
|
186
276
|
"""Executes the task logic."""
|
|
187
277
|
task_id, job_id, task_name = task_data["task_id"], task_data["job_id"], task_data["type"]
|
|
188
|
-
params,
|
|
278
|
+
params, orchestrator = task_data.get("params", {}), task_data["orchestrator"]
|
|
189
279
|
|
|
190
280
|
result: dict[str, Any] = {}
|
|
191
281
|
handler_data = self._task_handlers.get(task_name)
|
|
192
282
|
task_type_for_limit = handler_data.get("type") if handler_data else None
|
|
193
283
|
|
|
284
|
+
result_sent = False # Flag to track if result has been sent
|
|
285
|
+
|
|
194
286
|
try:
|
|
195
|
-
if handler_data:
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
287
|
+
if not handler_data:
|
|
288
|
+
message = f"Unsupported task: {task_name}"
|
|
289
|
+
logger.warning(message)
|
|
290
|
+
result = {"status": "failure", "error": {"code": PERMANENT_ERROR, "message": message}}
|
|
291
|
+
payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
|
|
292
|
+
await self._send_result(payload, orchestrator)
|
|
293
|
+
result_sent = True # Mark result as sent
|
|
294
|
+
return
|
|
295
|
+
|
|
296
|
+
params = await self._s3_manager.process_params(params)
|
|
297
|
+
validated_params = self._prepare_task_params(handler_data["func"], params)
|
|
298
|
+
|
|
299
|
+
result = await handler_data["func"](
|
|
300
|
+
validated_params,
|
|
301
|
+
task_id=task_id,
|
|
302
|
+
job_id=job_id,
|
|
303
|
+
priority=task_data.get("priority", 0),
|
|
304
|
+
send_progress=self.send_progress,
|
|
305
|
+
add_to_hot_cache=self.add_to_hot_cache,
|
|
306
|
+
remove_from_hot_cache=self.remove_from_hot_cache,
|
|
307
|
+
)
|
|
308
|
+
result = await self._s3_manager.process_result(result)
|
|
309
|
+
except ParamValidationError as e:
|
|
310
|
+
logger.error(f"Task {task_id} failed validation: {e}")
|
|
311
|
+
result = {"status": "failure", "error": {"code": INVALID_INPUT_ERROR, "message": str(e)}}
|
|
207
312
|
except CancelledError:
|
|
313
|
+
logger.info(f"Task {task_id} was cancelled.")
|
|
208
314
|
result = {"status": "cancelled"}
|
|
315
|
+
# We must re-raise the exception to be handled by the outer gather
|
|
316
|
+
raise
|
|
209
317
|
except Exception as e:
|
|
210
|
-
|
|
318
|
+
logger.exception(f"An unexpected error occurred while processing task {task_id}:")
|
|
319
|
+
result = {"status": "failure", "error": {"code": TRANSIENT_ERROR, "message": str(e)}}
|
|
211
320
|
finally:
|
|
212
|
-
|
|
213
|
-
|
|
321
|
+
if not result_sent: # Only send if not already sent
|
|
322
|
+
payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
|
|
323
|
+
await self._send_result(payload, orchestrator)
|
|
214
324
|
self._active_tasks.pop(task_id, None)
|
|
215
325
|
|
|
216
326
|
self._current_load -= 1
|
|
@@ -218,14 +328,15 @@ class Worker:
|
|
|
218
328
|
self._current_load_by_type[task_type_for_limit] -= 1
|
|
219
329
|
self._schedule_heartbeat_debounce()
|
|
220
330
|
|
|
221
|
-
async def _send_result(self, payload: dict[str, Any],
|
|
331
|
+
async def _send_result(self, payload: dict[str, Any], orchestrator: dict[str, Any]):
|
|
222
332
|
"""Sends the result to a specific orchestrator."""
|
|
223
|
-
url = f"{
|
|
224
|
-
delay = self._config.
|
|
225
|
-
|
|
333
|
+
url = f"{orchestrator['url']}/_worker/tasks/result"
|
|
334
|
+
delay = self._config.RESULT_RETRY_INITIAL_DELAY
|
|
335
|
+
headers = self._get_headers(orchestrator)
|
|
336
|
+
for i in range(self._config.RESULT_MAX_RETRIES):
|
|
226
337
|
try:
|
|
227
338
|
if self._http_session and not self._http_session.closed:
|
|
228
|
-
async with self._http_session.post(url, json=payload, headers=
|
|
339
|
+
async with self._http_session.post(url, json=payload, headers=headers) as resp:
|
|
229
340
|
if resp.status == 200:
|
|
230
341
|
return
|
|
231
342
|
except ClientError as e:
|
|
@@ -233,43 +344,44 @@ class Worker:
|
|
|
233
344
|
await sleep(delay * (2**i))
|
|
234
345
|
|
|
235
346
|
async def _manage_orchestrator_communications(self):
|
|
236
|
-
print("Registering worker")
|
|
237
347
|
"""Registers the worker and sends heartbeats."""
|
|
238
348
|
await self._register_with_all_orchestrators()
|
|
239
|
-
|
|
349
|
+
|
|
240
350
|
self._registered_event.set()
|
|
241
|
-
if self._config.
|
|
351
|
+
if self._config.ENABLE_WEBSOCKETS:
|
|
242
352
|
create_task(self._start_websocket_manager())
|
|
243
353
|
|
|
244
354
|
while not self._shutdown_event.is_set():
|
|
245
355
|
await self._send_heartbeats_to_all()
|
|
246
|
-
await sleep(self._config.
|
|
356
|
+
await sleep(self._config.HEARTBEAT_INTERVAL)
|
|
247
357
|
|
|
248
358
|
async def _register_with_all_orchestrators(self):
|
|
249
359
|
"""Registers the worker with all orchestrators."""
|
|
250
360
|
state = self._get_current_state()
|
|
251
361
|
payload = {
|
|
252
|
-
"worker_id": self._config.
|
|
253
|
-
"worker_type": self._config.
|
|
362
|
+
"worker_id": self._config.WORKER_ID,
|
|
363
|
+
"worker_type": self._config.WORKER_TYPE,
|
|
254
364
|
"supported_tasks": state["supported_tasks"],
|
|
255
|
-
"max_concurrent_tasks": self._config.
|
|
256
|
-
"
|
|
257
|
-
"
|
|
258
|
-
"
|
|
259
|
-
"
|
|
365
|
+
"max_concurrent_tasks": self._config.MAX_CONCURRENT_TASKS,
|
|
366
|
+
"cost_per_skill": self._config.COST_PER_SKILL,
|
|
367
|
+
"installed_models": self._config.INSTALLED_MODELS,
|
|
368
|
+
"hostname": self._config.HOSTNAME,
|
|
369
|
+
"ip_address": self._config.IP_ADDRESS,
|
|
370
|
+
"resources": self._config.RESOURCES,
|
|
260
371
|
}
|
|
261
|
-
for orchestrator in self._config.
|
|
372
|
+
for orchestrator in self._config.ORCHESTRATORS:
|
|
262
373
|
url = f"{orchestrator['url']}/_worker/workers/register"
|
|
263
374
|
try:
|
|
264
375
|
if self._http_session:
|
|
265
|
-
async with self._http_session.post(
|
|
376
|
+
async with self._http_session.post(
|
|
377
|
+
url, json=payload, headers=self._get_headers(orchestrator)
|
|
378
|
+
) as resp:
|
|
266
379
|
if resp.status >= 400:
|
|
267
380
|
logger.error(f"Error registering with {orchestrator['url']}: {resp.status}")
|
|
268
381
|
except ClientError as e:
|
|
269
382
|
logger.error(f"Error registering with orchestrator {orchestrator['url']}: {e}")
|
|
270
383
|
|
|
271
384
|
async def _send_heartbeats_to_all(self):
|
|
272
|
-
print("Sending heartbeats")
|
|
273
385
|
"""Sends heartbeat messages to all orchestrators."""
|
|
274
386
|
state = self._get_current_state()
|
|
275
387
|
payload = {
|
|
@@ -287,27 +399,27 @@ class Worker:
|
|
|
287
399
|
if hot_skills:
|
|
288
400
|
payload["hot_skills"] = hot_skills
|
|
289
401
|
|
|
290
|
-
async def _send_single(
|
|
291
|
-
url = f"{
|
|
402
|
+
async def _send_single(orchestrator: dict[str, Any]):
|
|
403
|
+
url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}"
|
|
404
|
+
headers = self._get_headers(orchestrator)
|
|
292
405
|
try:
|
|
293
406
|
if self._http_session and not self._http_session.closed:
|
|
294
|
-
async with self._http_session.patch(url, json=payload, headers=
|
|
407
|
+
async with self._http_session.patch(url, json=payload, headers=headers) as resp:
|
|
295
408
|
if resp.status >= 400:
|
|
296
|
-
logger.warning(f"Heartbeat to {
|
|
409
|
+
logger.warning(f"Heartbeat to {orchestrator['url']} failed with status: {resp.status}")
|
|
297
410
|
except ClientError as e:
|
|
298
|
-
logger.error(f"Error sending heartbeat to orchestrator {
|
|
411
|
+
logger.error(f"Error sending heartbeat to orchestrator {orchestrator['url']}: {e}")
|
|
299
412
|
|
|
300
|
-
await gather(*[_send_single(o
|
|
413
|
+
await gather(*[_send_single(o) for o in self._config.ORCHESTRATORS])
|
|
301
414
|
|
|
302
415
|
async def main(self):
|
|
303
|
-
print("Main started")
|
|
304
416
|
"""The main asynchronous function."""
|
|
305
417
|
self._validate_config() # Validate config now that all tasks are registered
|
|
306
418
|
if not self._http_session:
|
|
307
419
|
self._http_session = ClientSession()
|
|
308
|
-
|
|
420
|
+
|
|
309
421
|
comm_task = create_task(self._manage_orchestrator_communications())
|
|
310
|
-
|
|
422
|
+
|
|
311
423
|
polling_task = create_task(self._start_polling())
|
|
312
424
|
await self._shutdown_event.wait()
|
|
313
425
|
|
|
@@ -327,14 +439,13 @@ class Worker:
|
|
|
327
439
|
run(self.main())
|
|
328
440
|
except KeyboardInterrupt:
|
|
329
441
|
self._shutdown_event.set()
|
|
330
|
-
run(sleep(1.5))
|
|
331
442
|
|
|
332
443
|
async def _run_health_check_server(self):
|
|
333
444
|
app = web.Application()
|
|
334
445
|
app.router.add_get("/health", lambda r: web.Response(text="OK"))
|
|
335
446
|
runner = web.AppRunner(app)
|
|
336
447
|
await runner.setup()
|
|
337
|
-
site = web.TCPSite(runner, "0.0.0.0", self._config.
|
|
448
|
+
site = web.TCPSite(runner, "0.0.0.0", self._config.WORKER_PORT)
|
|
338
449
|
await site.start()
|
|
339
450
|
await self._shutdown_event.wait()
|
|
340
451
|
await runner.cleanup()
|
|
@@ -347,17 +458,16 @@ class Worker:
|
|
|
347
458
|
run(_main_wrapper())
|
|
348
459
|
except KeyboardInterrupt:
|
|
349
460
|
self._shutdown_event.set()
|
|
350
|
-
run(sleep(1.5))
|
|
351
461
|
|
|
352
462
|
# WebSocket methods omitted for brevity as they are not relevant to the changes
|
|
353
463
|
async def _start_websocket_manager(self):
|
|
354
464
|
"""Manages the WebSocket connection to the orchestrator."""
|
|
355
465
|
while not self._shutdown_event.is_set():
|
|
356
|
-
for orchestrator in self._config.
|
|
466
|
+
for orchestrator in self._config.ORCHESTRATORS:
|
|
357
467
|
ws_url = orchestrator["url"].replace("http", "ws", 1) + "/_worker/ws"
|
|
358
468
|
try:
|
|
359
469
|
if self._http_session:
|
|
360
|
-
async with self._http_session.ws_connect(ws_url, headers=self.
|
|
470
|
+
async with self._http_session.ws_connect(ws_url, headers=self._get_headers(orchestrator)) as ws:
|
|
361
471
|
self._ws_connection = ws
|
|
362
472
|
logger.info(f"WebSocket connection established to {ws_url}")
|
|
363
473
|
await self._listen_for_commands()
|
|
@@ -367,7 +477,7 @@ class Worker:
|
|
|
367
477
|
self._ws_connection = None
|
|
368
478
|
logger.info(f"WebSocket connection to {ws_url} closed.")
|
|
369
479
|
await sleep(5) # Reconnection delay
|
|
370
|
-
if not self._config.
|
|
480
|
+
if not self._config.ORCHESTRATORS:
|
|
371
481
|
await sleep(5)
|
|
372
482
|
|
|
373
483
|
async def _listen_for_commands(self):
|