avtomatika-worker 1.0a2__py3-none-any.whl → 1.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- avtomatika_worker/__init__.py +2 -1
- avtomatika_worker/config.py +46 -36
- avtomatika_worker/s3.py +141 -0
- avtomatika_worker/task_files.py +97 -0
- avtomatika_worker/types.py +4 -0
- avtomatika_worker/worker.py +211 -78
- avtomatika_worker-1.0b2.dist-info/METADATA +610 -0
- avtomatika_worker-1.0b2.dist-info/RECORD +11 -0
- avtomatika_worker-1.0a2.dist-info/METADATA +0 -307
- avtomatika_worker-1.0a2.dist-info/RECORD +0 -9
- {avtomatika_worker-1.0a2.dist-info → avtomatika_worker-1.0b2.dist-info}/WHEEL +0 -0
- {avtomatika_worker-1.0a2.dist-info → avtomatika_worker-1.0b2.dist-info}/licenses/LICENSE +0 -0
- {avtomatika_worker-1.0a2.dist-info → avtomatika_worker-1.0b2.dist-info}/top_level.txt +0 -0
avtomatika_worker/worker.py
CHANGED
|
@@ -1,12 +1,25 @@
|
|
|
1
1
|
from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
|
|
2
2
|
from asyncio import TimeoutError as AsyncTimeoutError
|
|
3
|
+
from dataclasses import is_dataclass
|
|
4
|
+
from inspect import Parameter, signature
|
|
3
5
|
from json import JSONDecodeError
|
|
4
6
|
from logging import getLogger
|
|
7
|
+
from os.path import join
|
|
5
8
|
from typing import Any, Callable
|
|
6
9
|
|
|
7
10
|
from aiohttp import ClientError, ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType, web
|
|
8
11
|
|
|
9
12
|
from .config import WorkerConfig
|
|
13
|
+
from .s3 import S3Manager
|
|
14
|
+
from .task_files import TaskFiles
|
|
15
|
+
from .types import INVALID_INPUT_ERROR, PERMANENT_ERROR, TRANSIENT_ERROR, ParamValidationError
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from pydantic import BaseModel, ValidationError
|
|
19
|
+
|
|
20
|
+
_PYDANTIC_INSTALLED = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
_PYDANTIC_INSTALLED = False
|
|
10
23
|
|
|
11
24
|
# Logging setup
|
|
12
25
|
logger = getLogger(__name__)
|
|
@@ -26,9 +39,11 @@ class Worker:
|
|
|
26
39
|
task_type_limits: dict[str, int] | None = None,
|
|
27
40
|
http_session: ClientSession | None = None,
|
|
28
41
|
skill_dependencies: dict[str, list[str]] | None = None,
|
|
42
|
+
config: WorkerConfig | None = None,
|
|
29
43
|
):
|
|
30
|
-
self._config = WorkerConfig()
|
|
31
|
-
self.
|
|
44
|
+
self._config = config or WorkerConfig()
|
|
45
|
+
self._s3_manager = S3Manager(self._config)
|
|
46
|
+
self._config.WORKER_TYPE = worker_type # Allow overriding worker_type
|
|
32
47
|
if max_concurrent_tasks is not None:
|
|
33
48
|
self._config.max_concurrent_tasks = max_concurrent_tasks
|
|
34
49
|
|
|
@@ -44,12 +59,19 @@ class Worker:
|
|
|
44
59
|
self._http_session = http_session
|
|
45
60
|
self._session_is_managed_externally = http_session is not None
|
|
46
61
|
self._ws_connection: ClientWebSocketResponse | None = None
|
|
47
|
-
self._headers = {"X-Worker-Token": self._config.
|
|
62
|
+
# Removed: self._headers = {"X-Worker-Token": self._config.WORKER_TOKEN}
|
|
48
63
|
self._shutdown_event = Event()
|
|
49
64
|
self._registered_event = Event()
|
|
50
65
|
self._round_robin_index = 0
|
|
51
66
|
self._debounce_task: Task | None = None
|
|
52
67
|
|
|
68
|
+
# --- Weighted Round-Robin State ---
|
|
69
|
+
self._total_orchestrator_weight = 0
|
|
70
|
+
if self._config.ORCHESTRATORS:
|
|
71
|
+
for o in self._config.ORCHESTRATORS:
|
|
72
|
+
o["current_weight"] = 0
|
|
73
|
+
self._total_orchestrator_weight += o.get("weight", 1)
|
|
74
|
+
|
|
53
75
|
def _validate_config(self):
|
|
54
76
|
"""Checks for unused task type limits and warns the user."""
|
|
55
77
|
registered_task_types = {
|
|
@@ -98,7 +120,7 @@ class Worker:
|
|
|
98
120
|
"""
|
|
99
121
|
Calculates the current worker state including status and available tasks.
|
|
100
122
|
"""
|
|
101
|
-
if self._current_load >= self._config.
|
|
123
|
+
if self._current_load >= self._config.MAX_CONCURRENT_TASKS:
|
|
102
124
|
return {"status": "busy", "supported_tasks": []}
|
|
103
125
|
|
|
104
126
|
supported_tasks = []
|
|
@@ -118,9 +140,36 @@ class Worker:
|
|
|
118
140
|
status = "idle" if supported_tasks else "busy"
|
|
119
141
|
return {"status": status, "supported_tasks": supported_tasks}
|
|
120
142
|
|
|
143
|
+
def _get_headers(self, orchestrator: dict[str, Any]) -> dict[str, str]:
|
|
144
|
+
"""Builds authentication headers for a specific orchestrator."""
|
|
145
|
+
token = orchestrator.get("token", self._config.WORKER_TOKEN)
|
|
146
|
+
return {"X-Worker-Token": token}
|
|
147
|
+
|
|
148
|
+
def _get_next_orchestrator(self) -> dict[str, Any] | None:
|
|
149
|
+
"""
|
|
150
|
+
Selects the next orchestrator using a smooth weighted round-robin algorithm.
|
|
151
|
+
"""
|
|
152
|
+
if not self._config.ORCHESTRATORS:
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
# The orchestrator with the highest current_weight is selected.
|
|
156
|
+
selected_orchestrator = None
|
|
157
|
+
highest_weight = -1
|
|
158
|
+
|
|
159
|
+
for o in self._config.ORCHESTRATORS:
|
|
160
|
+
o["current_weight"] += o["weight"]
|
|
161
|
+
if o["current_weight"] > highest_weight:
|
|
162
|
+
highest_weight = o["current_weight"]
|
|
163
|
+
selected_orchestrator = o
|
|
164
|
+
|
|
165
|
+
if selected_orchestrator:
|
|
166
|
+
selected_orchestrator["current_weight"] -= self._total_orchestrator_weight
|
|
167
|
+
|
|
168
|
+
return selected_orchestrator
|
|
169
|
+
|
|
121
170
|
async def _debounced_heartbeat_sender(self):
|
|
122
171
|
"""Waits for the debounce delay then sends a heartbeat."""
|
|
123
|
-
await sleep(self._config.
|
|
172
|
+
await sleep(self._config.HEARTBEAT_DEBOUNCE_DELAY)
|
|
124
173
|
await self._send_heartbeats_to_all()
|
|
125
174
|
|
|
126
175
|
def _schedule_heartbeat_debounce(self):
|
|
@@ -131,86 +180,166 @@ class Worker:
|
|
|
131
180
|
# Schedule the new debounced call.
|
|
132
181
|
self._debounce_task = create_task(self._debounced_heartbeat_sender())
|
|
133
182
|
|
|
134
|
-
async def _poll_for_tasks(self,
|
|
183
|
+
async def _poll_for_tasks(self, orchestrator: dict[str, Any]):
|
|
135
184
|
"""Polls a specific Orchestrator for new tasks."""
|
|
136
|
-
url = f"{
|
|
185
|
+
url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}/tasks/next"
|
|
137
186
|
try:
|
|
138
187
|
if not self._http_session:
|
|
139
188
|
return
|
|
140
|
-
timeout = ClientTimeout(total=self._config.
|
|
141
|
-
|
|
189
|
+
timeout = ClientTimeout(total=self._config.TASK_POLL_TIMEOUT + 5)
|
|
190
|
+
headers = self._get_headers(orchestrator)
|
|
191
|
+
async with self._http_session.get(url, headers=headers, timeout=timeout) as resp:
|
|
142
192
|
if resp.status == 200:
|
|
143
193
|
task_data = await resp.json()
|
|
144
|
-
task_data["
|
|
194
|
+
task_data["orchestrator"] = orchestrator
|
|
145
195
|
|
|
146
196
|
self._current_load += 1
|
|
147
|
-
task_handler_info
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
self._current_load_by_type[task_type_for_limit] += 1
|
|
197
|
+
if (task_handler_info := self._task_handlers.get(task_data["type"])) and (
|
|
198
|
+
task_type_for_limit := task_handler_info.get("type")
|
|
199
|
+
):
|
|
200
|
+
self._current_load_by_type[task_type_for_limit] += 1
|
|
152
201
|
self._schedule_heartbeat_debounce()
|
|
153
202
|
|
|
154
203
|
task = create_task(self._process_task(task_data))
|
|
155
204
|
self._active_tasks[task_data["task_id"]] = task
|
|
156
205
|
elif resp.status != 204:
|
|
157
|
-
await sleep(self._config.
|
|
206
|
+
await sleep(self._config.TASK_POLL_ERROR_DELAY)
|
|
158
207
|
except (AsyncTimeoutError, ClientError) as e:
|
|
159
208
|
logger.error(f"Error polling for tasks: {e}")
|
|
160
|
-
await sleep(self._config.
|
|
209
|
+
await sleep(self._config.TASK_POLL_ERROR_DELAY)
|
|
161
210
|
|
|
162
211
|
async def _start_polling(self):
|
|
163
|
-
print("Waiting for registration")
|
|
164
212
|
"""The main loop for polling tasks."""
|
|
165
213
|
await self._registered_event.wait()
|
|
166
|
-
|
|
214
|
+
|
|
167
215
|
while not self._shutdown_event.is_set():
|
|
168
216
|
if self._get_current_state()["status"] == "busy":
|
|
169
|
-
await sleep(self._config.
|
|
217
|
+
await sleep(self._config.IDLE_POLL_DELAY)
|
|
170
218
|
continue
|
|
171
219
|
|
|
172
|
-
if self._config.
|
|
173
|
-
orchestrator
|
|
174
|
-
|
|
175
|
-
self._round_robin_index = (self._round_robin_index + 1) % len(self._config.orchestrators)
|
|
220
|
+
if self._config.MULTI_ORCHESTRATOR_MODE == "ROUND_ROBIN":
|
|
221
|
+
if orchestrator := self._get_next_orchestrator():
|
|
222
|
+
await self._poll_for_tasks(orchestrator)
|
|
176
223
|
else:
|
|
177
|
-
for orchestrator in self._config.
|
|
224
|
+
for orchestrator in self._config.ORCHESTRATORS:
|
|
178
225
|
if self._get_current_state()["status"] == "busy":
|
|
179
226
|
break
|
|
180
|
-
await self._poll_for_tasks(orchestrator
|
|
227
|
+
await self._poll_for_tasks(orchestrator)
|
|
181
228
|
|
|
182
229
|
if self._current_load == 0:
|
|
183
|
-
await sleep(self._config.
|
|
230
|
+
await sleep(self._config.IDLE_POLL_DELAY)
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
def _prepare_task_params(handler: Callable, params: dict[str, Any]) -> Any:
|
|
234
|
+
"""
|
|
235
|
+
Inspects the handler's signature to validate and instantiate params.
|
|
236
|
+
Supports dict, dataclasses, and optional pydantic models.
|
|
237
|
+
"""
|
|
238
|
+
sig = signature(handler)
|
|
239
|
+
params_annotation = sig.parameters.get("params").annotation
|
|
240
|
+
|
|
241
|
+
if params_annotation is sig.empty or params_annotation is dict:
|
|
242
|
+
return params
|
|
243
|
+
|
|
244
|
+
# Pydantic Model Validation
|
|
245
|
+
if _PYDANTIC_INSTALLED and isinstance(params_annotation, type) and issubclass(params_annotation, BaseModel):
|
|
246
|
+
try:
|
|
247
|
+
return params_annotation.model_validate(params)
|
|
248
|
+
except ValidationError as e:
|
|
249
|
+
raise ParamValidationError(str(e)) from e
|
|
250
|
+
|
|
251
|
+
# Dataclass Instantiation
|
|
252
|
+
if isinstance(params_annotation, type) and is_dataclass(params_annotation):
|
|
253
|
+
try:
|
|
254
|
+
# Filter unknown fields to prevent TypeError on dataclass instantiation
|
|
255
|
+
known_fields = {f.name for f in params_annotation.__dataclass_fields__.values()}
|
|
256
|
+
filtered_params = {k: v for k, v in params.items() if k in known_fields}
|
|
257
|
+
|
|
258
|
+
# Explicitly check for missing required fields
|
|
259
|
+
required_fields = [
|
|
260
|
+
f.name
|
|
261
|
+
for f in params_annotation.__dataclass_fields__.values()
|
|
262
|
+
if f.default is Parameter.empty and f.default_factory is Parameter.empty
|
|
263
|
+
]
|
|
264
|
+
|
|
265
|
+
if missing_fields := [f for f in required_fields if f not in filtered_params]:
|
|
266
|
+
raise ParamValidationError(f"Missing required fields for dataclass: {', '.join(missing_fields)}")
|
|
267
|
+
|
|
268
|
+
return params_annotation(**filtered_params)
|
|
269
|
+
except (TypeError, ValueError) as e:
|
|
270
|
+
# TypeError for missing/extra args, ValueError from __post_init__
|
|
271
|
+
raise ParamValidationError(str(e)) from e
|
|
272
|
+
|
|
273
|
+
return params
|
|
274
|
+
|
|
275
|
+
def _prepare_dependencies(self, handler: Callable, task_id: str) -> dict[str, Any]:
|
|
276
|
+
"""Injects dependencies based on type hints."""
|
|
277
|
+
deps = {}
|
|
278
|
+
task_dir = join(self._config.TASK_FILES_DIR, task_id)
|
|
279
|
+
# Always create the object, but directory is lazy
|
|
280
|
+
task_files = TaskFiles(task_dir)
|
|
281
|
+
|
|
282
|
+
sig = signature(handler)
|
|
283
|
+
for name, param in sig.parameters.items():
|
|
284
|
+
if param.annotation is TaskFiles:
|
|
285
|
+
deps[name] = task_files
|
|
286
|
+
|
|
287
|
+
return deps
|
|
184
288
|
|
|
185
289
|
async def _process_task(self, task_data: dict[str, Any]):
|
|
186
290
|
"""Executes the task logic."""
|
|
187
291
|
task_id, job_id, task_name = task_data["task_id"], task_data["job_id"], task_data["type"]
|
|
188
|
-
params,
|
|
292
|
+
params, orchestrator = task_data.get("params", {}), task_data["orchestrator"]
|
|
189
293
|
|
|
190
294
|
result: dict[str, Any] = {}
|
|
191
295
|
handler_data = self._task_handlers.get(task_name)
|
|
192
296
|
task_type_for_limit = handler_data.get("type") if handler_data else None
|
|
193
297
|
|
|
298
|
+
result_sent = False # Flag to track if result has been sent
|
|
299
|
+
|
|
194
300
|
try:
|
|
195
|
-
if handler_data:
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
301
|
+
if not handler_data:
|
|
302
|
+
message = f"Unsupported task: {task_name}"
|
|
303
|
+
logger.warning(message)
|
|
304
|
+
result = {"status": "failure", "error": {"code": PERMANENT_ERROR, "message": message}}
|
|
305
|
+
payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
|
|
306
|
+
await self._send_result(payload, orchestrator)
|
|
307
|
+
result_sent = True # Mark result as sent
|
|
308
|
+
return
|
|
309
|
+
|
|
310
|
+
params = await self._s3_manager.process_params(params, task_id)
|
|
311
|
+
validated_params = self._prepare_task_params(handler_data["func"], params)
|
|
312
|
+
deps = self._prepare_dependencies(handler_data["func"], task_id)
|
|
313
|
+
|
|
314
|
+
result = await handler_data["func"](
|
|
315
|
+
validated_params,
|
|
316
|
+
task_id=task_id,
|
|
317
|
+
job_id=job_id,
|
|
318
|
+
priority=task_data.get("priority", 0),
|
|
319
|
+
send_progress=self.send_progress,
|
|
320
|
+
add_to_hot_cache=self.add_to_hot_cache,
|
|
321
|
+
remove_from_hot_cache=self.remove_from_hot_cache,
|
|
322
|
+
**deps,
|
|
323
|
+
)
|
|
324
|
+
result = await self._s3_manager.process_result(result)
|
|
325
|
+
except ParamValidationError as e:
|
|
326
|
+
logger.error(f"Task {task_id} failed validation: {e}")
|
|
327
|
+
result = {"status": "failure", "error": {"code": INVALID_INPUT_ERROR, "message": str(e)}}
|
|
207
328
|
except CancelledError:
|
|
329
|
+
logger.info(f"Task {task_id} was cancelled.")
|
|
208
330
|
result = {"status": "cancelled"}
|
|
331
|
+
# We must re-raise the exception to be handled by the outer gather
|
|
332
|
+
raise
|
|
209
333
|
except Exception as e:
|
|
210
|
-
|
|
334
|
+
logger.exception(f"An unexpected error occurred while processing task {task_id}:")
|
|
335
|
+
result = {"status": "failure", "error": {"code": TRANSIENT_ERROR, "message": str(e)}}
|
|
211
336
|
finally:
|
|
212
|
-
|
|
213
|
-
await self.
|
|
337
|
+
# Cleanup task workspace
|
|
338
|
+
await self._s3_manager.cleanup(task_id)
|
|
339
|
+
|
|
340
|
+
if not result_sent: # Only send if not already sent
|
|
341
|
+
payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
|
|
342
|
+
await self._send_result(payload, orchestrator)
|
|
214
343
|
self._active_tasks.pop(task_id, None)
|
|
215
344
|
|
|
216
345
|
self._current_load -= 1
|
|
@@ -218,14 +347,15 @@ class Worker:
|
|
|
218
347
|
self._current_load_by_type[task_type_for_limit] -= 1
|
|
219
348
|
self._schedule_heartbeat_debounce()
|
|
220
349
|
|
|
221
|
-
async def _send_result(self, payload: dict[str, Any],
|
|
350
|
+
async def _send_result(self, payload: dict[str, Any], orchestrator: dict[str, Any]):
|
|
222
351
|
"""Sends the result to a specific orchestrator."""
|
|
223
|
-
url = f"{
|
|
224
|
-
delay = self._config.
|
|
225
|
-
|
|
352
|
+
url = f"{orchestrator['url']}/_worker/tasks/result"
|
|
353
|
+
delay = self._config.RESULT_RETRY_INITIAL_DELAY
|
|
354
|
+
headers = self._get_headers(orchestrator)
|
|
355
|
+
for i in range(self._config.RESULT_MAX_RETRIES):
|
|
226
356
|
try:
|
|
227
357
|
if self._http_session and not self._http_session.closed:
|
|
228
|
-
async with self._http_session.post(url, json=payload, headers=
|
|
358
|
+
async with self._http_session.post(url, json=payload, headers=headers) as resp:
|
|
229
359
|
if resp.status == 200:
|
|
230
360
|
return
|
|
231
361
|
except ClientError as e:
|
|
@@ -233,43 +363,44 @@ class Worker:
|
|
|
233
363
|
await sleep(delay * (2**i))
|
|
234
364
|
|
|
235
365
|
async def _manage_orchestrator_communications(self):
|
|
236
|
-
print("Registering worker")
|
|
237
366
|
"""Registers the worker and sends heartbeats."""
|
|
238
367
|
await self._register_with_all_orchestrators()
|
|
239
|
-
|
|
368
|
+
|
|
240
369
|
self._registered_event.set()
|
|
241
|
-
if self._config.
|
|
370
|
+
if self._config.ENABLE_WEBSOCKETS:
|
|
242
371
|
create_task(self._start_websocket_manager())
|
|
243
372
|
|
|
244
373
|
while not self._shutdown_event.is_set():
|
|
245
374
|
await self._send_heartbeats_to_all()
|
|
246
|
-
await sleep(self._config.
|
|
375
|
+
await sleep(self._config.HEARTBEAT_INTERVAL)
|
|
247
376
|
|
|
248
377
|
async def _register_with_all_orchestrators(self):
|
|
249
378
|
"""Registers the worker with all orchestrators."""
|
|
250
379
|
state = self._get_current_state()
|
|
251
380
|
payload = {
|
|
252
|
-
"worker_id": self._config.
|
|
253
|
-
"worker_type": self._config.
|
|
381
|
+
"worker_id": self._config.WORKER_ID,
|
|
382
|
+
"worker_type": self._config.WORKER_TYPE,
|
|
254
383
|
"supported_tasks": state["supported_tasks"],
|
|
255
|
-
"max_concurrent_tasks": self._config.
|
|
256
|
-
"
|
|
257
|
-
"
|
|
258
|
-
"
|
|
259
|
-
"
|
|
384
|
+
"max_concurrent_tasks": self._config.MAX_CONCURRENT_TASKS,
|
|
385
|
+
"cost_per_skill": self._config.COST_PER_SKILL,
|
|
386
|
+
"installed_models": self._config.INSTALLED_MODELS,
|
|
387
|
+
"hostname": self._config.HOSTNAME,
|
|
388
|
+
"ip_address": self._config.IP_ADDRESS,
|
|
389
|
+
"resources": self._config.RESOURCES,
|
|
260
390
|
}
|
|
261
|
-
for orchestrator in self._config.
|
|
391
|
+
for orchestrator in self._config.ORCHESTRATORS:
|
|
262
392
|
url = f"{orchestrator['url']}/_worker/workers/register"
|
|
263
393
|
try:
|
|
264
394
|
if self._http_session:
|
|
265
|
-
async with self._http_session.post(
|
|
395
|
+
async with self._http_session.post(
|
|
396
|
+
url, json=payload, headers=self._get_headers(orchestrator)
|
|
397
|
+
) as resp:
|
|
266
398
|
if resp.status >= 400:
|
|
267
399
|
logger.error(f"Error registering with {orchestrator['url']}: {resp.status}")
|
|
268
400
|
except ClientError as e:
|
|
269
401
|
logger.error(f"Error registering with orchestrator {orchestrator['url']}: {e}")
|
|
270
402
|
|
|
271
403
|
async def _send_heartbeats_to_all(self):
|
|
272
|
-
print("Sending heartbeats")
|
|
273
404
|
"""Sends heartbeat messages to all orchestrators."""
|
|
274
405
|
state = self._get_current_state()
|
|
275
406
|
payload = {
|
|
@@ -287,27 +418,27 @@ class Worker:
|
|
|
287
418
|
if hot_skills:
|
|
288
419
|
payload["hot_skills"] = hot_skills
|
|
289
420
|
|
|
290
|
-
async def _send_single(
|
|
291
|
-
url = f"{
|
|
421
|
+
async def _send_single(orchestrator: dict[str, Any]):
|
|
422
|
+
url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}"
|
|
423
|
+
headers = self._get_headers(orchestrator)
|
|
292
424
|
try:
|
|
293
425
|
if self._http_session and not self._http_session.closed:
|
|
294
|
-
async with self._http_session.patch(url, json=payload, headers=
|
|
426
|
+
async with self._http_session.patch(url, json=payload, headers=headers) as resp:
|
|
295
427
|
if resp.status >= 400:
|
|
296
|
-
logger.warning(f"Heartbeat to {
|
|
428
|
+
logger.warning(f"Heartbeat to {orchestrator['url']} failed with status: {resp.status}")
|
|
297
429
|
except ClientError as e:
|
|
298
|
-
logger.error(f"Error sending heartbeat to orchestrator {
|
|
430
|
+
logger.error(f"Error sending heartbeat to orchestrator {orchestrator['url']}: {e}")
|
|
299
431
|
|
|
300
|
-
await gather(*[_send_single(o
|
|
432
|
+
await gather(*[_send_single(o) for o in self._config.ORCHESTRATORS])
|
|
301
433
|
|
|
302
434
|
async def main(self):
|
|
303
|
-
print("Main started")
|
|
304
435
|
"""The main asynchronous function."""
|
|
305
436
|
self._validate_config() # Validate config now that all tasks are registered
|
|
306
437
|
if not self._http_session:
|
|
307
438
|
self._http_session = ClientSession()
|
|
308
|
-
|
|
439
|
+
|
|
309
440
|
comm_task = create_task(self._manage_orchestrator_communications())
|
|
310
|
-
|
|
441
|
+
|
|
311
442
|
polling_task = create_task(self._start_polling())
|
|
312
443
|
await self._shutdown_event.wait()
|
|
313
444
|
|
|
@@ -327,14 +458,17 @@ class Worker:
|
|
|
327
458
|
run(self.main())
|
|
328
459
|
except KeyboardInterrupt:
|
|
329
460
|
self._shutdown_event.set()
|
|
330
|
-
run(sleep(1.5))
|
|
331
461
|
|
|
332
462
|
async def _run_health_check_server(self):
|
|
333
463
|
app = web.Application()
|
|
334
|
-
|
|
464
|
+
|
|
465
|
+
async def health_handler(_):
|
|
466
|
+
return web.Response(text="OK")
|
|
467
|
+
|
|
468
|
+
app.router.add_get("/health", health_handler)
|
|
335
469
|
runner = web.AppRunner(app)
|
|
336
470
|
await runner.setup()
|
|
337
|
-
site = web.TCPSite(runner, "0.0.0.0", self._config.
|
|
471
|
+
site = web.TCPSite(runner, "0.0.0.0", self._config.WORKER_PORT)
|
|
338
472
|
await site.start()
|
|
339
473
|
await self._shutdown_event.wait()
|
|
340
474
|
await runner.cleanup()
|
|
@@ -347,17 +481,16 @@ class Worker:
|
|
|
347
481
|
run(_main_wrapper())
|
|
348
482
|
except KeyboardInterrupt:
|
|
349
483
|
self._shutdown_event.set()
|
|
350
|
-
run(sleep(1.5))
|
|
351
484
|
|
|
352
485
|
# WebSocket methods omitted for brevity as they are not relevant to the changes
|
|
353
486
|
async def _start_websocket_manager(self):
|
|
354
487
|
"""Manages the WebSocket connection to the orchestrator."""
|
|
355
488
|
while not self._shutdown_event.is_set():
|
|
356
|
-
for orchestrator in self._config.
|
|
489
|
+
for orchestrator in self._config.ORCHESTRATORS:
|
|
357
490
|
ws_url = orchestrator["url"].replace("http", "ws", 1) + "/_worker/ws"
|
|
358
491
|
try:
|
|
359
492
|
if self._http_session:
|
|
360
|
-
async with self._http_session.ws_connect(ws_url, headers=self.
|
|
493
|
+
async with self._http_session.ws_connect(ws_url, headers=self._get_headers(orchestrator)) as ws:
|
|
361
494
|
self._ws_connection = ws
|
|
362
495
|
logger.info(f"WebSocket connection established to {ws_url}")
|
|
363
496
|
await self._listen_for_commands()
|
|
@@ -367,7 +500,7 @@ class Worker:
|
|
|
367
500
|
self._ws_connection = None
|
|
368
501
|
logger.info(f"WebSocket connection to {ws_url} closed.")
|
|
369
502
|
await sleep(5) # Reconnection delay
|
|
370
|
-
if not self._config.
|
|
503
|
+
if not self._config.ORCHESTRATORS:
|
|
371
504
|
await sleep(5)
|
|
372
505
|
|
|
373
506
|
async def _listen_for_commands(self):
|