avtomatika-worker 1.0b1__py3-none-any.whl → 1.0b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- avtomatika_worker/__init__.py +2 -1
- avtomatika_worker/client.py +93 -0
- avtomatika_worker/config.py +27 -15
- avtomatika_worker/constants.py +22 -0
- avtomatika_worker/py.typed +0 -0
- avtomatika_worker/s3.py +176 -41
- avtomatika_worker/task_files.py +97 -0
- avtomatika_worker/types.py +17 -4
- avtomatika_worker/worker.py +125 -117
- {avtomatika_worker-1.0b1.dist-info → avtomatika_worker-1.0b3.dist-info}/METADATA +90 -15
- avtomatika_worker-1.0b3.dist-info/RECORD +14 -0
- avtomatika_worker-1.0b1.dist-info/RECORD +0 -10
- {avtomatika_worker-1.0b1.dist-info → avtomatika_worker-1.0b3.dist-info}/WHEEL +0 -0
- {avtomatika_worker-1.0b1.dist-info → avtomatika_worker-1.0b3.dist-info}/licenses/LICENSE +0 -0
- {avtomatika_worker-1.0b1.dist-info → avtomatika_worker-1.0b3.dist-info}/top_level.txt +0 -0
avtomatika_worker/__init__.py
CHANGED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from asyncio import sleep
|
|
2
|
+
from logging import getLogger
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from aiohttp import ClientError, ClientSession, ClientTimeout, ClientWebSocketResponse
|
|
6
|
+
|
|
7
|
+
from .constants import AUTH_HEADER_WORKER
|
|
8
|
+
|
|
9
|
+
logger = getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OrchestratorClient:
|
|
13
|
+
"""
|
|
14
|
+
Dedicated client for communicating with a single Avtomatika Orchestrator instance.
|
|
15
|
+
Handles HTTP requests, retries, and authentication.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, session: ClientSession, base_url: str, worker_id: str, token: str):
|
|
19
|
+
self.session = session
|
|
20
|
+
self.base_url = base_url.rstrip("/")
|
|
21
|
+
self.worker_id = worker_id
|
|
22
|
+
self.token = token
|
|
23
|
+
self._headers = {AUTH_HEADER_WORKER: self.token}
|
|
24
|
+
|
|
25
|
+
async def register(self, payload: dict[str, Any]) -> bool:
|
|
26
|
+
"""Registers the worker with the orchestrator."""
|
|
27
|
+
url = f"{self.base_url}/_worker/workers/register"
|
|
28
|
+
try:
|
|
29
|
+
async with self.session.post(url, json=payload, headers=self._headers) as resp:
|
|
30
|
+
if resp.status >= 400:
|
|
31
|
+
logger.error(f"Error registering with {self.base_url}: {resp.status}")
|
|
32
|
+
return False
|
|
33
|
+
return True
|
|
34
|
+
except ClientError as e:
|
|
35
|
+
logger.error(f"Error registering with orchestrator {self.base_url}: {e}")
|
|
36
|
+
return False
|
|
37
|
+
|
|
38
|
+
async def poll_task(self, timeout: float) -> dict[str, Any] | None:
|
|
39
|
+
"""Polls for the next available task."""
|
|
40
|
+
url = f"{self.base_url}/_worker/workers/{self.worker_id}/tasks/next"
|
|
41
|
+
client_timeout = ClientTimeout(total=timeout + 5)
|
|
42
|
+
try:
|
|
43
|
+
async with self.session.get(url, headers=self._headers, timeout=client_timeout) as resp:
|
|
44
|
+
if resp.status == 200:
|
|
45
|
+
return await resp.json()
|
|
46
|
+
elif resp.status != 204:
|
|
47
|
+
logger.warning(f"Unexpected status from {self.base_url} during poll: {resp.status}")
|
|
48
|
+
except ClientError as e:
|
|
49
|
+
logger.error(f"Error polling for tasks from {self.base_url}: {e}")
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.exception(f"Unexpected error polling from {self.base_url}: {e}")
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
async def send_heartbeat(self, payload: dict[str, Any]) -> bool:
|
|
55
|
+
"""Sends a heartbeat message to update worker state."""
|
|
56
|
+
url = f"{self.base_url}/_worker/workers/{self.worker_id}"
|
|
57
|
+
try:
|
|
58
|
+
async with self.session.patch(url, json=payload, headers=self._headers) as resp:
|
|
59
|
+
if resp.status >= 400:
|
|
60
|
+
logger.warning(f"Heartbeat to {self.base_url} failed with status: {resp.status}")
|
|
61
|
+
return False
|
|
62
|
+
return True
|
|
63
|
+
except ClientError as e:
|
|
64
|
+
logger.error(f"Error sending heartbeat to orchestrator {self.base_url}: {e}")
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
async def send_result(self, payload: dict[str, Any], max_retries: int, initial_delay: float) -> bool:
|
|
68
|
+
"""Sends task result with retries and exponential backoff."""
|
|
69
|
+
url = f"{self.base_url}/_worker/tasks/result"
|
|
70
|
+
delay = initial_delay
|
|
71
|
+
for i in range(max_retries):
|
|
72
|
+
try:
|
|
73
|
+
async with self.session.post(url, json=payload, headers=self._headers) as resp:
|
|
74
|
+
if resp.status == 200:
|
|
75
|
+
return True
|
|
76
|
+
logger.error(f"Error sending result to {self.base_url}: {resp.status}")
|
|
77
|
+
except ClientError as e:
|
|
78
|
+
logger.error(f"Error sending result to {self.base_url}: {e}")
|
|
79
|
+
|
|
80
|
+
if i < max_retries - 1:
|
|
81
|
+
await sleep(delay * (2**i))
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
async def connect_websocket(self) -> ClientWebSocketResponse | None:
|
|
85
|
+
"""Establishes a WebSocket connection for real-time commands."""
|
|
86
|
+
ws_url = self.base_url.replace("http", "ws", 1) + "/_worker/ws"
|
|
87
|
+
try:
|
|
88
|
+
ws = await self.session.ws_connect(ws_url, headers=self._headers)
|
|
89
|
+
logger.info(f"WebSocket connection established to {ws_url}")
|
|
90
|
+
return ws
|
|
91
|
+
except Exception as e:
|
|
92
|
+
logger.warning(f"WebSocket connection to {ws_url} failed: {e}")
|
|
93
|
+
return None
|
avtomatika_worker/config.py
CHANGED
|
@@ -10,7 +10,7 @@ class WorkerConfig:
|
|
|
10
10
|
Reads parameters from environment variables and provides default values.
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
|
-
def __init__(self):
|
|
13
|
+
def __init__(self) -> None:
|
|
14
14
|
# --- Basic worker information ---
|
|
15
15
|
self.WORKER_ID: str = getenv("WORKER_ID", f"worker-{uuid4()}")
|
|
16
16
|
self.WORKER_TYPE: str = getenv("WORKER_TYPE", "generic-cpu-worker")
|
|
@@ -49,11 +49,12 @@ class WorkerConfig:
|
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
# --- S3 Settings for payload offloading ---
|
|
52
|
-
self.
|
|
52
|
+
self.TASK_FILES_DIR: str = getenv("TASK_FILES_DIR", "/tmp/payloads")
|
|
53
53
|
self.S3_ENDPOINT_URL: str | None = getenv("S3_ENDPOINT_URL")
|
|
54
54
|
self.S3_ACCESS_KEY: str | None = getenv("S3_ACCESS_KEY")
|
|
55
55
|
self.S3_SECRET_KEY: str | None = getenv("S3_SECRET_KEY")
|
|
56
56
|
self.S3_DEFAULT_BUCKET: str = getenv("S3_DEFAULT_BUCKET", "avtomatika-payloads")
|
|
57
|
+
self.S3_REGION: str = getenv("S3_REGION", "us-east-1")
|
|
57
58
|
|
|
58
59
|
# --- Tuning parameters ---
|
|
59
60
|
self.HEARTBEAT_INTERVAL: float = float(getenv("HEARTBEAT_INTERVAL", "15"))
|
|
@@ -70,13 +71,24 @@ class WorkerConfig:
|
|
|
70
71
|
self.ENABLE_WEBSOCKETS: bool = getenv("WORKER_ENABLE_WEBSOCKETS", "false").lower() == "true"
|
|
71
72
|
self.MULTI_ORCHESTRATOR_MODE: str = getenv("MULTI_ORCHESTRATOR_MODE", "FAILOVER")
|
|
72
73
|
|
|
74
|
+
def validate(self) -> None:
|
|
75
|
+
"""Validates critical configuration parameters."""
|
|
76
|
+
if self.WORKER_TOKEN == "your-secret-worker-token":
|
|
77
|
+
print("Warning: WORKER_TOKEN is set to the default value. Tasks might fail authentication.")
|
|
78
|
+
|
|
79
|
+
if not self.ORCHESTRATORS:
|
|
80
|
+
raise ValueError("No orchestrators configured.")
|
|
81
|
+
|
|
82
|
+
for o in self.ORCHESTRATORS:
|
|
83
|
+
if not o.get("url"):
|
|
84
|
+
raise ValueError("Orchestrator configuration missing URL.")
|
|
85
|
+
|
|
73
86
|
def _get_orchestrators_config(self) -> list[dict[str, Any]]:
|
|
74
87
|
"""
|
|
75
88
|
Loads orchestrator configuration from the ORCHESTRATORS_CONFIG environment variable.
|
|
76
89
|
For backward compatibility, if it is not set, it uses ORCHESTRATOR_URL.
|
|
77
90
|
"""
|
|
78
|
-
orchestrators_json
|
|
79
|
-
if orchestrators_json:
|
|
91
|
+
if orchestrators_json := getenv("ORCHESTRATORS_CONFIG"):
|
|
80
92
|
try:
|
|
81
93
|
orchestrators = loads(orchestrators_json)
|
|
82
94
|
if getenv("ORCHESTRATOR_URL"):
|
|
@@ -94,23 +106,23 @@ class WorkerConfig:
|
|
|
94
106
|
orchestrator_url = getenv("ORCHESTRATOR_URL", "http://localhost:8080")
|
|
95
107
|
return [{"url": orchestrator_url, "priority": 1, "weight": 1}]
|
|
96
108
|
|
|
97
|
-
|
|
109
|
+
@staticmethod
|
|
110
|
+
def _get_gpu_info() -> dict[str, Any] | None:
|
|
98
111
|
"""Collects GPU information from environment variables.
|
|
99
112
|
Returns None if GPU is not configured.
|
|
100
113
|
"""
|
|
101
|
-
gpu_model
|
|
102
|
-
|
|
114
|
+
if gpu_model := getenv("GPU_MODEL"):
|
|
115
|
+
return {
|
|
116
|
+
"model": gpu_model,
|
|
117
|
+
"vram_gb": int(getenv("GPU_VRAM_GB", "0")),
|
|
118
|
+
}
|
|
119
|
+
else:
|
|
103
120
|
return None
|
|
104
121
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
"vram_gb": int(getenv("GPU_VRAM_GB", "0")),
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
def _load_json_from_env(self, key: str, default: Any) -> Any:
|
|
122
|
+
@staticmethod
|
|
123
|
+
def _load_json_from_env(key: str, default: Any) -> Any:
|
|
111
124
|
"""Safely loads a JSON string from an environment variable."""
|
|
112
|
-
value
|
|
113
|
-
if value:
|
|
125
|
+
if value := getenv(key):
|
|
114
126
|
try:
|
|
115
127
|
return loads(value)
|
|
116
128
|
except JSONDecodeError:
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Centralized constants for the Avtomatika protocol (Worker SDK).
|
|
3
|
+
These should match the constants in the core `avtomatika` package.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
# --- Auth Headers ---
|
|
7
|
+
AUTH_HEADER_CLIENT = "X-Avtomatika-Token"
|
|
8
|
+
AUTH_HEADER_WORKER = "X-Worker-Token"
|
|
9
|
+
|
|
10
|
+
# --- Error Codes ---
|
|
11
|
+
ERROR_CODE_TRANSIENT = "TRANSIENT_ERROR"
|
|
12
|
+
ERROR_CODE_PERMANENT = "PERMANENT_ERROR"
|
|
13
|
+
ERROR_CODE_INVALID_INPUT = "INVALID_INPUT_ERROR"
|
|
14
|
+
|
|
15
|
+
# --- Task Statuses ---
|
|
16
|
+
TASK_STATUS_SUCCESS = "success"
|
|
17
|
+
TASK_STATUS_FAILURE = "failure"
|
|
18
|
+
TASK_STATUS_CANCELLED = "cancelled"
|
|
19
|
+
TASK_STATUS_NEEDS_REVIEW = "needs_review" # Example of a common custom status
|
|
20
|
+
|
|
21
|
+
# --- Commands (WebSocket) ---
|
|
22
|
+
COMMAND_CANCEL_TASK = "cancel_task"
|
|
File without changes
|
avtomatika_worker/s3.py
CHANGED
|
@@ -1,62 +1,199 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
from
|
|
1
|
+
from asyncio import Semaphore, gather, to_thread
|
|
2
|
+
from logging import getLogger
|
|
3
|
+
from os import walk
|
|
4
|
+
from os.path import basename, dirname, join, relpath
|
|
5
|
+
from shutil import rmtree
|
|
6
|
+
from typing import Any, cast
|
|
4
7
|
from urllib.parse import urlparse
|
|
5
8
|
|
|
6
|
-
import
|
|
7
|
-
from
|
|
9
|
+
import obstore
|
|
10
|
+
from aiofiles import open as aio_open
|
|
11
|
+
from aiofiles.os import makedirs
|
|
12
|
+
from aiofiles.ospath import exists, isdir
|
|
13
|
+
from obstore.store import S3Store
|
|
8
14
|
|
|
9
15
|
from .config import WorkerConfig
|
|
10
16
|
|
|
17
|
+
logger = getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# Limit concurrent S3 operations to avoid "Too many open files"
|
|
20
|
+
MAX_S3_CONCURRENCY = 50
|
|
21
|
+
|
|
11
22
|
|
|
12
23
|
class S3Manager:
|
|
13
|
-
"""Handles S3 payload offloading."""
|
|
24
|
+
"""Handles S3 payload offloading using obstore (high-performance async S3 client)."""
|
|
14
25
|
|
|
15
26
|
def __init__(self, config: WorkerConfig):
|
|
16
27
|
self._config = config
|
|
17
|
-
self.
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
28
|
+
self._stores: dict[str, S3Store] = {}
|
|
29
|
+
self._semaphore = Semaphore(MAX_S3_CONCURRENCY)
|
|
30
|
+
|
|
31
|
+
def _get_store(self, bucket_name: str) -> S3Store:
|
|
32
|
+
"""Creates or returns a cached S3Store for a specific bucket."""
|
|
33
|
+
if bucket_name in self._stores:
|
|
34
|
+
return self._stores[bucket_name]
|
|
35
|
+
|
|
36
|
+
config_kwargs = {
|
|
37
|
+
"aws_access_key_id": self._config.S3_ACCESS_KEY,
|
|
38
|
+
"aws_secret_access_key": self._config.S3_SECRET_KEY,
|
|
39
|
+
"region": "us-east-1", # Default region if not specified, required by some clients
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
if self._config.S3_ENDPOINT_URL:
|
|
43
|
+
config_kwargs["endpoint"] = self._config.S3_ENDPOINT_URL
|
|
44
|
+
if self._config.S3_ENDPOINT_URL.startswith("http://"):
|
|
45
|
+
config_kwargs["allow_http"] = "true"
|
|
46
|
+
|
|
47
|
+
# Filter out None values
|
|
48
|
+
config_kwargs = {k: v for k, v in config_kwargs.items() if v is not None}
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
store = S3Store(bucket_name, **config_kwargs)
|
|
52
|
+
self._stores[bucket_name] = store
|
|
53
|
+
return store
|
|
54
|
+
except Exception as e:
|
|
55
|
+
logger.error(f"Failed to create S3Store for bucket {bucket_name}: {e}")
|
|
56
|
+
raise
|
|
57
|
+
|
|
58
|
+
async def cleanup(self, task_id: str) -> None:
|
|
59
|
+
"""Removes the task-specific payload directory."""
|
|
60
|
+
task_dir = join(self._config.TASK_FILES_DIR, task_id)
|
|
61
|
+
if await exists(task_dir):
|
|
62
|
+
await to_thread(lambda: rmtree(task_dir, ignore_errors=True))
|
|
63
|
+
|
|
64
|
+
async def _process_s3_uri(self, uri: str, task_id: str) -> str:
|
|
65
|
+
"""Downloads a file or a folder (if uri ends with /) from S3 and returns the local path."""
|
|
66
|
+
try:
|
|
67
|
+
parsed_url = urlparse(uri)
|
|
68
|
+
bucket_name = parsed_url.netloc
|
|
69
|
+
object_key = parsed_url.path.lstrip("/")
|
|
70
|
+
store = self._get_store(bucket_name)
|
|
71
|
+
|
|
72
|
+
# Use task-specific directory for isolation
|
|
73
|
+
local_dir_root = join(self._config.TASK_FILES_DIR, task_id)
|
|
74
|
+
await makedirs(local_dir_root, exist_ok=True)
|
|
75
|
+
|
|
76
|
+
logger.info(f"Starting download from S3: {uri}")
|
|
77
|
+
|
|
78
|
+
# Handle folder download (prefix)
|
|
79
|
+
if uri.endswith("/"):
|
|
80
|
+
folder_name = object_key.rstrip("/").split("/")[-1]
|
|
81
|
+
local_folder_path = join(local_dir_root, folder_name)
|
|
82
|
+
|
|
83
|
+
# List objects with prefix
|
|
84
|
+
# obstore.list returns an async iterator of ObjectMeta
|
|
85
|
+
files_to_download = []
|
|
86
|
+
|
|
87
|
+
# Note: obstore.list returns an async iterator.
|
|
88
|
+
async for obj in obstore.list(store, prefix=object_key):
|
|
89
|
+
key = obj.key
|
|
90
|
+
|
|
91
|
+
if key.endswith("/"):
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
# Calculate relative path inside the folder
|
|
95
|
+
rel_path = key[len(object_key) :]
|
|
96
|
+
local_file_path = join(local_folder_path, rel_path)
|
|
97
|
+
|
|
98
|
+
await makedirs(dirname(local_file_path), exist_ok=True)
|
|
99
|
+
files_to_download.append((key, local_file_path))
|
|
100
|
+
|
|
101
|
+
async def _download_file(key: str, path: str) -> None:
|
|
102
|
+
async with self._semaphore:
|
|
103
|
+
result = await obstore.get(store, key)
|
|
104
|
+
async with aio_open(path, "wb") as f:
|
|
105
|
+
async for chunk in result.stream():
|
|
106
|
+
await f.write(chunk)
|
|
107
|
+
|
|
108
|
+
# Execute downloads in parallel
|
|
109
|
+
if files_to_download:
|
|
110
|
+
await gather(*[_download_file(k, p) for k, p in files_to_download])
|
|
111
|
+
|
|
112
|
+
logger.info(f"Successfully downloaded folder from S3: {uri} ({len(files_to_download)} files)")
|
|
113
|
+
return local_folder_path
|
|
114
|
+
|
|
115
|
+
# Handle single file download
|
|
116
|
+
local_path = join(local_dir_root, basename(object_key))
|
|
117
|
+
|
|
118
|
+
result = await obstore.get(store, object_key)
|
|
119
|
+
async with aio_open(local_path, "wb") as f:
|
|
120
|
+
async for chunk in result.stream():
|
|
121
|
+
await f.write(chunk)
|
|
122
|
+
|
|
123
|
+
logger.info(f"Successfully downloaded file from S3: {uri} -> {local_path}")
|
|
124
|
+
return local_path
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
# Catching generic Exception because obstore might raise different errors.
|
|
128
|
+
logger.exception(f"Error during download of {uri}: {e}")
|
|
129
|
+
raise
|
|
36
130
|
|
|
37
131
|
async def _upload_to_s3(self, local_path: str) -> str:
|
|
38
|
-
"""Uploads a file to S3 and returns the S3 URI."""
|
|
132
|
+
"""Uploads a file or a folder to S3 and returns the S3 URI."""
|
|
39
133
|
bucket_name = self._config.S3_DEFAULT_BUCKET
|
|
40
|
-
|
|
134
|
+
store = self._get_store(bucket_name)
|
|
135
|
+
|
|
136
|
+
logger.info(f"Starting upload to S3 from local path: {local_path}")
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
# Handle folder upload
|
|
140
|
+
if await isdir(local_path):
|
|
141
|
+
folder_name = basename(local_path.rstrip("/"))
|
|
142
|
+
s3_prefix = f"{folder_name}/"
|
|
143
|
+
|
|
144
|
+
# Use to_thread to avoid blocking event loop during file walk
|
|
145
|
+
def _get_files_to_upload():
|
|
146
|
+
files_to_upload = []
|
|
147
|
+
for root, _, files in walk(local_path):
|
|
148
|
+
for file in files:
|
|
149
|
+
f_path = join(root, file)
|
|
150
|
+
rel = relpath(f_path, local_path)
|
|
151
|
+
files_to_upload.append((f_path, f"{s3_prefix}{rel}"))
|
|
152
|
+
return files_to_upload
|
|
153
|
+
|
|
154
|
+
files_list = await to_thread(_get_files_to_upload)
|
|
155
|
+
|
|
156
|
+
async def _upload_file(path: str, key: str) -> None:
|
|
157
|
+
async with self._semaphore:
|
|
158
|
+
# obstore.put accepts bytes or file-like objects.
|
|
159
|
+
# Since we are in async, reading small files is fine.
|
|
160
|
+
with open(path, "rb") as f:
|
|
161
|
+
await obstore.put(store, key, f)
|
|
162
|
+
|
|
163
|
+
if files_list:
|
|
164
|
+
# Upload in parallel
|
|
165
|
+
await gather(*[_upload_file(f, k) for f, k in files_list])
|
|
166
|
+
|
|
167
|
+
s3_uri = f"s3://{bucket_name}/{s3_prefix}"
|
|
168
|
+
logger.info(f"Successfully uploaded folder to S3: {local_path} -> {s3_uri} ({len(files_list)} files)")
|
|
169
|
+
return s3_uri
|
|
170
|
+
|
|
171
|
+
# Handle single file upload
|
|
172
|
+
object_key = basename(local_path)
|
|
173
|
+
with open(local_path, "rb") as f:
|
|
174
|
+
await obstore.put(store, object_key, f)
|
|
175
|
+
|
|
176
|
+
s3_uri = f"s3://{bucket_name}/{object_key}"
|
|
177
|
+
logger.info(f"Successfully uploaded file to S3: {local_path} -> {s3_uri}")
|
|
178
|
+
return s3_uri
|
|
41
179
|
|
|
42
|
-
|
|
43
|
-
|
|
180
|
+
except Exception as e:
|
|
181
|
+
logger.exception(f"Error during upload of {local_path}: {e}")
|
|
182
|
+
raise
|
|
44
183
|
|
|
45
|
-
async def process_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
|
184
|
+
async def process_params(self, params: dict[str, Any], task_id: str) -> dict[str, Any]:
|
|
46
185
|
"""Recursively searches for S3 URIs in params and downloads the files."""
|
|
47
186
|
if not self._config.S3_ENDPOINT_URL:
|
|
48
187
|
return params
|
|
49
188
|
|
|
50
189
|
async def _process(item: Any) -> Any:
|
|
51
190
|
if isinstance(item, str) and item.startswith("s3://"):
|
|
52
|
-
return await self._process_s3_uri(item)
|
|
191
|
+
return await self._process_s3_uri(item, task_id)
|
|
53
192
|
if isinstance(item, dict):
|
|
54
193
|
return {k: await _process(v) for k, v in item.items()}
|
|
55
|
-
if isinstance(item, list)
|
|
56
|
-
return [await _process(i) for i in item]
|
|
57
|
-
return item
|
|
194
|
+
return [await _process(i) for i in item] if isinstance(item, list) else item
|
|
58
195
|
|
|
59
|
-
return await _process(params)
|
|
196
|
+
return cast(dict[str, Any], await _process(params))
|
|
60
197
|
|
|
61
198
|
async def process_result(self, result: dict[str, Any]) -> dict[str, Any]:
|
|
62
199
|
"""Recursively searches for local file paths in the result and uploads them to S3."""
|
|
@@ -64,12 +201,10 @@ class S3Manager:
|
|
|
64
201
|
return result
|
|
65
202
|
|
|
66
203
|
async def _process(item: Any) -> Any:
|
|
67
|
-
if isinstance(item, str) and
|
|
68
|
-
return await self._upload_to_s3(item)
|
|
204
|
+
if isinstance(item, str) and item.startswith(self._config.TASK_FILES_DIR):
|
|
205
|
+
return await self._upload_to_s3(item) if await exists(item) else item
|
|
69
206
|
if isinstance(item, dict):
|
|
70
207
|
return {k: await _process(v) for k, v in item.items()}
|
|
71
|
-
if isinstance(item, list)
|
|
72
|
-
return [await _process(i) for i in item]
|
|
73
|
-
return item
|
|
208
|
+
return [await _process(i) for i in item] if isinstance(item, list) else item
|
|
74
209
|
|
|
75
|
-
return await _process(result)
|
|
210
|
+
return cast(dict[str, Any], await _process(result))
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from contextlib import asynccontextmanager
|
|
2
|
+
from os.path import dirname, join
|
|
3
|
+
from typing import AsyncGenerator
|
|
4
|
+
|
|
5
|
+
from aiofiles import open as aiopen
|
|
6
|
+
from aiofiles.os import listdir, makedirs
|
|
7
|
+
from aiofiles.ospath import exists as aio_exists
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TaskFiles:
|
|
11
|
+
"""
|
|
12
|
+
A helper class for managing task-specific files.
|
|
13
|
+
Provides asynchronous lazy directory creation and high-level file operations
|
|
14
|
+
within an isolated workspace for each task.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, task_dir: str):
|
|
18
|
+
"""
|
|
19
|
+
Initializes TaskFiles with a specific task directory.
|
|
20
|
+
The directory is not created until needed.
|
|
21
|
+
"""
|
|
22
|
+
self._task_dir = task_dir
|
|
23
|
+
|
|
24
|
+
async def get_root(self) -> str:
|
|
25
|
+
"""
|
|
26
|
+
Asynchronously returns the root directory for the task.
|
|
27
|
+
Creates the directory on disk if it doesn't exist.
|
|
28
|
+
"""
|
|
29
|
+
await makedirs(self._task_dir, exist_ok=True)
|
|
30
|
+
return self._task_dir
|
|
31
|
+
|
|
32
|
+
async def path_to(self, filename: str) -> str:
|
|
33
|
+
"""
|
|
34
|
+
Asynchronously returns an absolute path for a file within the task directory.
|
|
35
|
+
Guarantees that the task root directory exists.
|
|
36
|
+
"""
|
|
37
|
+
root = await self.get_root()
|
|
38
|
+
return join(root, filename)
|
|
39
|
+
|
|
40
|
+
@asynccontextmanager
|
|
41
|
+
async def open(self, filename: str, mode: str = "r") -> AsyncGenerator:
|
|
42
|
+
"""
|
|
43
|
+
An asynchronous context manager to open a file within the task directory.
|
|
44
|
+
Automatically creates the task root and any necessary subdirectories.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
filename: Name or relative path of the file.
|
|
48
|
+
mode: File opening mode (e.g., 'r', 'w', 'a', 'rb', 'wb').
|
|
49
|
+
"""
|
|
50
|
+
path = await self.path_to(filename)
|
|
51
|
+
# Ensure directory for the file itself exists if filename contains subdirectories
|
|
52
|
+
file_dir = dirname(path)
|
|
53
|
+
if file_dir != self._task_dir:
|
|
54
|
+
await makedirs(file_dir, exist_ok=True)
|
|
55
|
+
|
|
56
|
+
async with aiopen(path, mode) as f:
|
|
57
|
+
yield f
|
|
58
|
+
|
|
59
|
+
async def read(self, filename: str, mode: str = "r") -> str | bytes:
|
|
60
|
+
"""
|
|
61
|
+
Asynchronously reads the entire content of a file.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
filename: Name of the file to read.
|
|
65
|
+
mode: Mode to open the file in (defaults to 'r').
|
|
66
|
+
"""
|
|
67
|
+
async with self.open(filename, mode) as f:
|
|
68
|
+
return await f.read()
|
|
69
|
+
|
|
70
|
+
async def write(self, filename: str, data: str | bytes, mode: str = "w") -> None:
|
|
71
|
+
"""
|
|
72
|
+
Asynchronously writes data to a file. Creates or overwrites the file by default.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
filename: Name of the file to write.
|
|
76
|
+
data: Content to write (string or bytes).
|
|
77
|
+
mode: Mode to open the file in (defaults to 'w').
|
|
78
|
+
"""
|
|
79
|
+
async with self.open(filename, mode) as f:
|
|
80
|
+
await f.write(data)
|
|
81
|
+
|
|
82
|
+
async def list(self) -> list[str]:
|
|
83
|
+
"""
|
|
84
|
+
Asynchronously lists all file and directory names within the task root.
|
|
85
|
+
"""
|
|
86
|
+
root = await self.get_root()
|
|
87
|
+
return await listdir(root)
|
|
88
|
+
|
|
89
|
+
async def exists(self, filename: str) -> bool:
|
|
90
|
+
"""
|
|
91
|
+
Asynchronously checks if a specific file or directory exists in the task root.
|
|
92
|
+
"""
|
|
93
|
+
path = join(self._task_dir, filename)
|
|
94
|
+
return await aio_exists(path)
|
|
95
|
+
|
|
96
|
+
def __repr__(self):
|
|
97
|
+
return f"<TaskFiles root='{self._task_dir}'>"
|
avtomatika_worker/types.py
CHANGED
|
@@ -1,8 +1,21 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
from .constants import (
|
|
2
|
+
ERROR_CODE_INVALID_INPUT as INVALID_INPUT_ERROR,
|
|
3
|
+
)
|
|
4
|
+
from .constants import (
|
|
5
|
+
ERROR_CODE_PERMANENT as PERMANENT_ERROR,
|
|
6
|
+
)
|
|
7
|
+
from .constants import (
|
|
8
|
+
ERROR_CODE_TRANSIENT as TRANSIENT_ERROR,
|
|
9
|
+
)
|
|
5
10
|
|
|
6
11
|
|
|
7
12
|
class ParamValidationError(Exception):
|
|
8
13
|
"""Custom exception for parameter validation errors."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"INVALID_INPUT_ERROR",
|
|
18
|
+
"PERMANENT_ERROR",
|
|
19
|
+
"TRANSIENT_ERROR",
|
|
20
|
+
"ParamValidationError",
|
|
21
|
+
]
|
avtomatika_worker/worker.py
CHANGED
|
@@ -1,16 +1,26 @@
|
|
|
1
1
|
from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
|
|
2
|
-
from asyncio import TimeoutError as AsyncTimeoutError
|
|
3
2
|
from dataclasses import is_dataclass
|
|
4
3
|
from inspect import Parameter, signature
|
|
5
4
|
from json import JSONDecodeError
|
|
6
5
|
from logging import getLogger
|
|
6
|
+
from os.path import join
|
|
7
7
|
from typing import Any, Callable
|
|
8
8
|
|
|
9
|
-
from aiohttp import
|
|
9
|
+
from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType, web
|
|
10
10
|
|
|
11
|
+
from .client import OrchestratorClient
|
|
11
12
|
from .config import WorkerConfig
|
|
13
|
+
from .constants import (
|
|
14
|
+
COMMAND_CANCEL_TASK,
|
|
15
|
+
ERROR_CODE_INVALID_INPUT,
|
|
16
|
+
ERROR_CODE_PERMANENT,
|
|
17
|
+
ERROR_CODE_TRANSIENT,
|
|
18
|
+
TASK_STATUS_CANCELLED,
|
|
19
|
+
TASK_STATUS_FAILURE,
|
|
20
|
+
)
|
|
12
21
|
from .s3 import S3Manager
|
|
13
|
-
from .
|
|
22
|
+
from .task_files import TaskFiles
|
|
23
|
+
from .types import ParamValidationError
|
|
14
24
|
|
|
15
25
|
try:
|
|
16
26
|
from pydantic import BaseModel, ValidationError
|
|
@@ -43,7 +53,7 @@ class Worker:
|
|
|
43
53
|
self._s3_manager = S3Manager(self._config)
|
|
44
54
|
self._config.WORKER_TYPE = worker_type # Allow overriding worker_type
|
|
45
55
|
if max_concurrent_tasks is not None:
|
|
46
|
-
self._config.
|
|
56
|
+
self._config.MAX_CONCURRENT_TASKS = max_concurrent_tasks
|
|
47
57
|
|
|
48
58
|
self._task_type_limits = task_type_limits or {}
|
|
49
59
|
self._task_handlers: dict[str, dict[str, Any]] = {}
|
|
@@ -57,10 +67,8 @@ class Worker:
|
|
|
57
67
|
self._http_session = http_session
|
|
58
68
|
self._session_is_managed_externally = http_session is not None
|
|
59
69
|
self._ws_connection: ClientWebSocketResponse | None = None
|
|
60
|
-
# Removed: self._headers = {"X-Worker-Token": self._config.WORKER_TOKEN}
|
|
61
70
|
self._shutdown_event = Event()
|
|
62
71
|
self._registered_event = Event()
|
|
63
|
-
self._round_robin_index = 0
|
|
64
72
|
self._debounce_task: Task | None = None
|
|
65
73
|
|
|
66
74
|
# --- Weighted Round-Robin State ---
|
|
@@ -70,7 +78,28 @@ class Worker:
|
|
|
70
78
|
o["current_weight"] = 0
|
|
71
79
|
self._total_orchestrator_weight += o.get("weight", 1)
|
|
72
80
|
|
|
73
|
-
|
|
81
|
+
self._clients: list[tuple[dict[str, Any], OrchestratorClient]] = []
|
|
82
|
+
if self._http_session:
|
|
83
|
+
self._init_clients()
|
|
84
|
+
|
|
85
|
+
def _init_clients(self):
|
|
86
|
+
"""Initializes OrchestratorClient instances for each configured orchestrator."""
|
|
87
|
+
if not self._http_session:
|
|
88
|
+
return
|
|
89
|
+
self._clients = [
|
|
90
|
+
(
|
|
91
|
+
o,
|
|
92
|
+
OrchestratorClient(
|
|
93
|
+
session=self._http_session,
|
|
94
|
+
base_url=o["url"],
|
|
95
|
+
worker_id=self._config.WORKER_ID,
|
|
96
|
+
token=o.get("token", self._config.WORKER_TOKEN),
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
for o in self._config.ORCHESTRATORS
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
def _validate_task_types(self):
|
|
74
103
|
"""Checks for unused task type limits and warns the user."""
|
|
75
104
|
registered_task_types = {
|
|
76
105
|
handler_data["type"] for handler_data in self._task_handlers.values() if handler_data["type"]
|
|
@@ -138,32 +167,31 @@ class Worker:
|
|
|
138
167
|
status = "idle" if supported_tasks else "busy"
|
|
139
168
|
return {"status": status, "supported_tasks": supported_tasks}
|
|
140
169
|
|
|
141
|
-
def
|
|
142
|
-
"""Builds authentication headers for a specific orchestrator."""
|
|
143
|
-
token = orchestrator.get("token", self._config.WORKER_TOKEN)
|
|
144
|
-
return {"X-Worker-Token": token}
|
|
145
|
-
|
|
146
|
-
def _get_next_orchestrator(self) -> dict[str, Any] | None:
|
|
170
|
+
def _get_next_client(self) -> OrchestratorClient | None:
|
|
147
171
|
"""
|
|
148
|
-
Selects the next orchestrator using a smooth weighted round-robin algorithm.
|
|
172
|
+
Selects the next orchestrator client using a smooth weighted round-robin algorithm.
|
|
149
173
|
"""
|
|
150
|
-
if not self.
|
|
174
|
+
if not self._clients:
|
|
151
175
|
return None
|
|
152
176
|
|
|
153
177
|
# The orchestrator with the highest current_weight is selected.
|
|
154
|
-
|
|
178
|
+
selected_client = None
|
|
155
179
|
highest_weight = -1
|
|
156
180
|
|
|
157
|
-
for o in self.
|
|
181
|
+
for o, client in self._clients:
|
|
158
182
|
o["current_weight"] += o["weight"]
|
|
159
183
|
if o["current_weight"] > highest_weight:
|
|
160
184
|
highest_weight = o["current_weight"]
|
|
161
|
-
|
|
185
|
+
selected_client = client
|
|
162
186
|
|
|
163
|
-
if
|
|
164
|
-
|
|
187
|
+
if selected_client:
|
|
188
|
+
# Find the config for the selected client to decrement its weight
|
|
189
|
+
for o, client in self._clients:
|
|
190
|
+
if client == selected_client:
|
|
191
|
+
o["current_weight"] -= self._total_orchestrator_weight
|
|
192
|
+
break
|
|
165
193
|
|
|
166
|
-
return
|
|
194
|
+
return selected_client
|
|
167
195
|
|
|
168
196
|
async def _debounced_heartbeat_sender(self):
|
|
169
197
|
"""Waits for the debounce delay then sends a heartbeat."""
|
|
@@ -178,34 +206,27 @@ class Worker:
|
|
|
178
206
|
# Schedule the new debounced call.
|
|
179
207
|
self._debounce_task = create_task(self._debounced_heartbeat_sender())
|
|
180
208
|
|
|
181
|
-
async def _poll_for_tasks(self,
|
|
209
|
+
async def _poll_for_tasks(self, client: OrchestratorClient):
|
|
182
210
|
"""Polls a specific Orchestrator for new tasks."""
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
task = create_task(self._process_task(task_data))
|
|
203
|
-
self._active_tasks[task_data["task_id"]] = task
|
|
204
|
-
elif resp.status != 204:
|
|
205
|
-
await sleep(self._config.TASK_POLL_ERROR_DELAY)
|
|
206
|
-
except (AsyncTimeoutError, ClientError) as e:
|
|
207
|
-
logger.error(f"Error polling for tasks: {e}")
|
|
208
|
-
await sleep(self._config.TASK_POLL_ERROR_DELAY)
|
|
211
|
+
task_data = await client.poll_task(timeout=self._config.TASK_POLL_TIMEOUT)
|
|
212
|
+
if task_data:
|
|
213
|
+
task_data["client"] = client
|
|
214
|
+
|
|
215
|
+
self._current_load += 1
|
|
216
|
+
if (task_handler_info := self._task_handlers.get(task_data["type"])) and (
|
|
217
|
+
task_type_for_limit := task_handler_info.get("type")
|
|
218
|
+
):
|
|
219
|
+
self._current_load_by_type[task_type_for_limit] += 1
|
|
220
|
+
self._schedule_heartbeat_debounce()
|
|
221
|
+
|
|
222
|
+
task = create_task(self._process_task(task_data))
|
|
223
|
+
self._active_tasks[task_data["task_id"]] = task
|
|
224
|
+
else:
|
|
225
|
+
# If no task but it was a 204 or error, the client already handled/logged it.
|
|
226
|
+
# We might want a short sleep here if it was an error, but client.poll_task
|
|
227
|
+
# doesn't distinguish between 204 and error currently.
|
|
228
|
+
# However, the previous logic only slept on status != 204.
|
|
229
|
+
pass
|
|
209
230
|
|
|
210
231
|
async def _start_polling(self):
|
|
211
232
|
"""The main loop for polling tasks."""
|
|
@@ -217,19 +238,19 @@ class Worker:
|
|
|
217
238
|
continue
|
|
218
239
|
|
|
219
240
|
if self._config.MULTI_ORCHESTRATOR_MODE == "ROUND_ROBIN":
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
await self._poll_for_tasks(orchestrator)
|
|
241
|
+
if client := self._get_next_client():
|
|
242
|
+
await self._poll_for_tasks(client)
|
|
223
243
|
else:
|
|
224
|
-
for
|
|
244
|
+
for _, client in self._clients:
|
|
225
245
|
if self._get_current_state()["status"] == "busy":
|
|
226
246
|
break
|
|
227
|
-
await self._poll_for_tasks(
|
|
247
|
+
await self._poll_for_tasks(client)
|
|
228
248
|
|
|
229
249
|
if self._current_load == 0:
|
|
230
250
|
await sleep(self._config.IDLE_POLL_DELAY)
|
|
231
251
|
|
|
232
|
-
|
|
252
|
+
@staticmethod
|
|
253
|
+
def _prepare_task_params(handler: Callable, params: dict[str, Any]) -> Any:
|
|
233
254
|
"""
|
|
234
255
|
Inspects the handler's signature to validate and instantiate params.
|
|
235
256
|
Supports dict, dataclasses, and optional pydantic models.
|
|
@@ -261,8 +282,7 @@ class Worker:
|
|
|
261
282
|
if f.default is Parameter.empty and f.default_factory is Parameter.empty
|
|
262
283
|
]
|
|
263
284
|
|
|
264
|
-
missing_fields
|
|
265
|
-
if missing_fields:
|
|
285
|
+
if missing_fields := [f for f in required_fields if f not in filtered_params]:
|
|
266
286
|
raise ParamValidationError(f"Missing required fields for dataclass: {', '.join(missing_fields)}")
|
|
267
287
|
|
|
268
288
|
return params_annotation(**filtered_params)
|
|
@@ -272,10 +292,24 @@ class Worker:
|
|
|
272
292
|
|
|
273
293
|
return params
|
|
274
294
|
|
|
295
|
+
def _prepare_dependencies(self, handler: Callable, task_id: str) -> dict[str, Any]:
|
|
296
|
+
"""Injects dependencies based on type hints."""
|
|
297
|
+
deps = {}
|
|
298
|
+
task_dir = join(self._config.TASK_FILES_DIR, task_id)
|
|
299
|
+
# Always create the object, but directory is lazy
|
|
300
|
+
task_files = TaskFiles(task_dir)
|
|
301
|
+
|
|
302
|
+
sig = signature(handler)
|
|
303
|
+
for name, param in sig.parameters.items():
|
|
304
|
+
if param.annotation is TaskFiles:
|
|
305
|
+
deps[name] = task_files
|
|
306
|
+
|
|
307
|
+
return deps
|
|
308
|
+
|
|
275
309
|
async def _process_task(self, task_data: dict[str, Any]):
|
|
276
310
|
"""Executes the task logic."""
|
|
277
311
|
task_id, job_id, task_name = task_data["task_id"], task_data["job_id"], task_data["type"]
|
|
278
|
-
params,
|
|
312
|
+
params, client = task_data.get("params", {}), task_data["client"]
|
|
279
313
|
|
|
280
314
|
result: dict[str, Any] = {}
|
|
281
315
|
handler_data = self._task_handlers.get(task_name)
|
|
@@ -287,14 +321,17 @@ class Worker:
|
|
|
287
321
|
if not handler_data:
|
|
288
322
|
message = f"Unsupported task: {task_name}"
|
|
289
323
|
logger.warning(message)
|
|
290
|
-
result = {"status":
|
|
324
|
+
result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_PERMANENT, "message": message}}
|
|
291
325
|
payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
|
|
292
|
-
await
|
|
326
|
+
await client.send_result(
|
|
327
|
+
payload, self._config.RESULT_MAX_RETRIES, self._config.RESULT_RETRY_INITIAL_DELAY
|
|
328
|
+
)
|
|
293
329
|
result_sent = True # Mark result as sent
|
|
294
330
|
return
|
|
295
331
|
|
|
296
|
-
params = await self._s3_manager.process_params(params)
|
|
332
|
+
params = await self._s3_manager.process_params(params, task_id)
|
|
297
333
|
validated_params = self._prepare_task_params(handler_data["func"], params)
|
|
334
|
+
deps = self._prepare_dependencies(handler_data["func"], task_id)
|
|
298
335
|
|
|
299
336
|
result = await handler_data["func"](
|
|
300
337
|
validated_params,
|
|
@@ -304,23 +341,29 @@ class Worker:
|
|
|
304
341
|
send_progress=self.send_progress,
|
|
305
342
|
add_to_hot_cache=self.add_to_hot_cache,
|
|
306
343
|
remove_from_hot_cache=self.remove_from_hot_cache,
|
|
344
|
+
**deps,
|
|
307
345
|
)
|
|
308
346
|
result = await self._s3_manager.process_result(result)
|
|
309
347
|
except ParamValidationError as e:
|
|
310
348
|
logger.error(f"Task {task_id} failed validation: {e}")
|
|
311
|
-
result = {"status":
|
|
349
|
+
result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_INVALID_INPUT, "message": str(e)}}
|
|
312
350
|
except CancelledError:
|
|
313
351
|
logger.info(f"Task {task_id} was cancelled.")
|
|
314
|
-
result = {"status":
|
|
352
|
+
result = {"status": TASK_STATUS_CANCELLED}
|
|
315
353
|
# We must re-raise the exception to be handled by the outer gather
|
|
316
354
|
raise
|
|
317
355
|
except Exception as e:
|
|
318
356
|
logger.exception(f"An unexpected error occurred while processing task {task_id}:")
|
|
319
|
-
result = {"status":
|
|
357
|
+
result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_TRANSIENT, "message": str(e)}}
|
|
320
358
|
finally:
|
|
359
|
+
# Cleanup task workspace
|
|
360
|
+
await self._s3_manager.cleanup(task_id)
|
|
361
|
+
|
|
321
362
|
if not result_sent: # Only send if not already sent
|
|
322
363
|
payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
|
|
323
|
-
await
|
|
364
|
+
await client.send_result(
|
|
365
|
+
payload, self._config.RESULT_MAX_RETRIES, self._config.RESULT_RETRY_INITIAL_DELAY
|
|
366
|
+
)
|
|
324
367
|
self._active_tasks.pop(task_id, None)
|
|
325
368
|
|
|
326
369
|
self._current_load -= 1
|
|
@@ -328,21 +371,6 @@ class Worker:
|
|
|
328
371
|
self._current_load_by_type[task_type_for_limit] -= 1
|
|
329
372
|
self._schedule_heartbeat_debounce()
|
|
330
373
|
|
|
331
|
-
async def _send_result(self, payload: dict[str, Any], orchestrator: dict[str, Any]):
|
|
332
|
-
"""Sends the result to a specific orchestrator."""
|
|
333
|
-
url = f"{orchestrator['url']}/_worker/tasks/result"
|
|
334
|
-
delay = self._config.RESULT_RETRY_INITIAL_DELAY
|
|
335
|
-
headers = self._get_headers(orchestrator)
|
|
336
|
-
for i in range(self._config.RESULT_MAX_RETRIES):
|
|
337
|
-
try:
|
|
338
|
-
if self._http_session and not self._http_session.closed:
|
|
339
|
-
async with self._http_session.post(url, json=payload, headers=headers) as resp:
|
|
340
|
-
if resp.status == 200:
|
|
341
|
-
return
|
|
342
|
-
except ClientError as e:
|
|
343
|
-
logger.error(f"Error sending result: {e}")
|
|
344
|
-
await sleep(delay * (2**i))
|
|
345
|
-
|
|
346
374
|
async def _manage_orchestrator_communications(self):
|
|
347
375
|
"""Registers the worker and sends heartbeats."""
|
|
348
376
|
await self._register_with_all_orchestrators()
|
|
@@ -369,17 +397,7 @@ class Worker:
|
|
|
369
397
|
"ip_address": self._config.IP_ADDRESS,
|
|
370
398
|
"resources": self._config.RESOURCES,
|
|
371
399
|
}
|
|
372
|
-
for
|
|
373
|
-
url = f"{orchestrator['url']}/_worker/workers/register"
|
|
374
|
-
try:
|
|
375
|
-
if self._http_session:
|
|
376
|
-
async with self._http_session.post(
|
|
377
|
-
url, json=payload, headers=self._get_headers(orchestrator)
|
|
378
|
-
) as resp:
|
|
379
|
-
if resp.status >= 400:
|
|
380
|
-
logger.error(f"Error registering with {orchestrator['url']}: {resp.status}")
|
|
381
|
-
except ClientError as e:
|
|
382
|
-
logger.error(f"Error registering with orchestrator {orchestrator['url']}: {e}")
|
|
400
|
+
await gather(*[client.register(payload) for _, client in self._clients])
|
|
383
401
|
|
|
384
402
|
async def _send_heartbeats_to_all(self):
|
|
385
403
|
"""Sends heartbeat messages to all orchestrators."""
|
|
@@ -399,24 +417,15 @@ class Worker:
|
|
|
399
417
|
if hot_skills:
|
|
400
418
|
payload["hot_skills"] = hot_skills
|
|
401
419
|
|
|
402
|
-
|
|
403
|
-
url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}"
|
|
404
|
-
headers = self._get_headers(orchestrator)
|
|
405
|
-
try:
|
|
406
|
-
if self._http_session and not self._http_session.closed:
|
|
407
|
-
async with self._http_session.patch(url, json=payload, headers=headers) as resp:
|
|
408
|
-
if resp.status >= 400:
|
|
409
|
-
logger.warning(f"Heartbeat to {orchestrator['url']} failed with status: {resp.status}")
|
|
410
|
-
except ClientError as e:
|
|
411
|
-
logger.error(f"Error sending heartbeat to orchestrator {orchestrator['url']}: {e}")
|
|
412
|
-
|
|
413
|
-
await gather(*[_send_single(o) for o in self._config.ORCHESTRATORS])
|
|
420
|
+
await gather(*[client.send_heartbeat(payload) for _, client in self._clients])
|
|
414
421
|
|
|
415
422
|
async def main(self):
|
|
416
423
|
"""The main asynchronous function."""
|
|
417
|
-
self.
|
|
424
|
+
self._config.validate()
|
|
425
|
+
self._validate_task_types() # Validate config now that all tasks are registered
|
|
418
426
|
if not self._http_session:
|
|
419
427
|
self._http_session = ClientSession()
|
|
428
|
+
self._init_clients()
|
|
420
429
|
|
|
421
430
|
comm_task = create_task(self._manage_orchestrator_communications())
|
|
422
431
|
|
|
@@ -442,7 +451,11 @@ class Worker:
|
|
|
442
451
|
|
|
443
452
|
async def _run_health_check_server(self):
|
|
444
453
|
app = web.Application()
|
|
445
|
-
|
|
454
|
+
|
|
455
|
+
async def health_handler(_):
|
|
456
|
+
return web.Response(text="OK")
|
|
457
|
+
|
|
458
|
+
app.router.add_get("/health", health_handler)
|
|
446
459
|
runner = web.AppRunner(app)
|
|
447
460
|
await runner.setup()
|
|
448
461
|
site = web.TCPSite(runner, "0.0.0.0", self._config.WORKER_PORT)
|
|
@@ -459,25 +472,20 @@ class Worker:
|
|
|
459
472
|
except KeyboardInterrupt:
|
|
460
473
|
self._shutdown_event.set()
|
|
461
474
|
|
|
462
|
-
# WebSocket methods omitted for brevity as they are not relevant to the changes
|
|
463
475
|
async def _start_websocket_manager(self):
|
|
464
476
|
"""Manages the WebSocket connection to the orchestrator."""
|
|
465
477
|
while not self._shutdown_event.is_set():
|
|
466
|
-
|
|
467
|
-
|
|
478
|
+
# In multi-orchestrator mode, we currently only connect to the first one available
|
|
479
|
+
for _, client in self._clients:
|
|
468
480
|
try:
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
await self._listen_for_commands()
|
|
474
|
-
except (ClientError, AsyncTimeoutError) as e:
|
|
475
|
-
logger.warning(f"WebSocket connection to {ws_url} failed: {e}")
|
|
481
|
+
ws = await client.connect_websocket()
|
|
482
|
+
if ws:
|
|
483
|
+
self._ws_connection = ws
|
|
484
|
+
await self._listen_for_commands()
|
|
476
485
|
finally:
|
|
477
486
|
self._ws_connection = None
|
|
478
|
-
logger.info(f"WebSocket connection to {ws_url} closed.")
|
|
479
487
|
await sleep(5) # Reconnection delay
|
|
480
|
-
if not self.
|
|
488
|
+
if not self._clients:
|
|
481
489
|
await sleep(5)
|
|
482
490
|
|
|
483
491
|
async def _listen_for_commands(self):
|
|
@@ -490,7 +498,7 @@ class Worker:
|
|
|
490
498
|
if msg.type == WSMsgType.TEXT:
|
|
491
499
|
try:
|
|
492
500
|
command = msg.json()
|
|
493
|
-
if command.get("type") ==
|
|
501
|
+
if command.get("type") == COMMAND_CANCEL_TASK:
|
|
494
502
|
task_id = command.get("task_id")
|
|
495
503
|
if task_id in self._active_tasks:
|
|
496
504
|
self._active_tasks[task_id].cancel()
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: avtomatika-worker
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.0b3
|
|
4
4
|
Summary: Worker SDK for the Avtomatika orchestrator.
|
|
5
|
-
Project-URL: Homepage, https://github.com/
|
|
6
|
-
Project-URL: Bug Tracker, https://github.com/
|
|
5
|
+
Project-URL: Homepage, https://github.com/avtomatika-ai/avtomatika-worker
|
|
6
|
+
Project-URL: Bug Tracker, https://github.com/avtomatika-ai/avtomatika-worker/issues
|
|
7
7
|
Classifier: Development Status :: 4 - Beta
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
|
9
9
|
Classifier: License :: OSI Approved :: MIT License
|
|
@@ -13,22 +13,22 @@ Description-Content-Type: text/markdown
|
|
|
13
13
|
License-File: LICENSE
|
|
14
14
|
Requires-Dist: aiohttp~=3.13.2
|
|
15
15
|
Requires-Dist: python-json-logger~=4.0.0
|
|
16
|
-
Requires-Dist:
|
|
16
|
+
Requires-Dist: obstore>=0.1
|
|
17
|
+
Requires-Dist: aiofiles~=25.1.0
|
|
17
18
|
Provides-Extra: test
|
|
18
19
|
Requires-Dist: pytest; extra == "test"
|
|
19
20
|
Requires-Dist: pytest-asyncio; extra == "test"
|
|
20
21
|
Requires-Dist: aioresponses; extra == "test"
|
|
21
22
|
Requires-Dist: pytest-mock; extra == "test"
|
|
22
23
|
Requires-Dist: pydantic; extra == "test"
|
|
23
|
-
Requires-Dist:
|
|
24
|
-
Requires-Dist: aiofiles; extra == "test"
|
|
24
|
+
Requires-Dist: types-aiofiles; extra == "test"
|
|
25
25
|
Provides-Extra: pydantic
|
|
26
26
|
Requires-Dist: pydantic; extra == "pydantic"
|
|
27
27
|
Dynamic: license-file
|
|
28
28
|
|
|
29
29
|
# Avtomatika Worker SDK
|
|
30
30
|
|
|
31
|
-
This is
|
|
31
|
+
This is the official SDK for creating workers compatible with the **[Avtomatika Orchestrator](https://github.com/avtomatika-ai/avtomatika)**. It implements the **[RCA Protocol](https://github.com/avtomatika-ai/rca)**, handling all communication complexity (polling, heartbeats, S3 offloading) so you can focus on writing your business logic.
|
|
32
32
|
|
|
33
33
|
## Installation
|
|
34
34
|
|
|
@@ -434,18 +434,92 @@ The `ORCHESTRATORS_CONFIG` variable must contain a JSON string. Each object in t
|
|
|
434
434
|
|
|
435
435
|
|
|
436
436
|
|
|
437
|
-
### 5. Handling Large Files (S3 Payload Offloading)
|
|
438
437
|
|
|
439
|
-
The SDK supports working with large files "out of the box" via S3-compatible storage.
|
|
440
438
|
|
|
441
|
-
|
|
442
|
-
- **Automatic Upload**: If your handler returns a local file path in `data` (located within the `WORKER_PAYLOAD_DIR` directory), the SDK will automatically upload this file to S3 and replace the path with an `s3://` URI in the final result.
|
|
439
|
+
### 5. File System Helper (TaskFiles)
|
|
443
440
|
|
|
444
|
-
|
|
441
|
+
To simplify working with temporary files and paths, the SDK provides a `TaskFiles` helper class. It automatically manages directory creation within the isolated task folder and provides an asynchronous interface for file operations. Just add an argument typed as `TaskFiles` to your handler:
|
|
445
442
|
|
|
446
|
-
|
|
443
|
+
```python
|
|
444
|
+
from avtomatika_worker import Worker, TaskFiles
|
|
445
|
+
|
|
446
|
+
@worker.task("generate_report")
|
|
447
|
+
async def generate_report(params: dict, files: TaskFiles, **kwargs):
|
|
448
|
+
# 1. Easy read/write
|
|
449
|
+
await files.write("data.json", '{"status": "ok"}')
|
|
450
|
+
content = await files.read("data.json")
|
|
451
|
+
|
|
452
|
+
# 2. Get path (directory is created automatically)
|
|
453
|
+
output_path = await files.path_to("report.pdf")
|
|
454
|
+
|
|
455
|
+
# 3. Check and list files
|
|
456
|
+
if await files.exists("input.jpg"):
|
|
457
|
+
file_list = await files.list()
|
|
458
|
+
|
|
459
|
+
return {"data": {"report": output_path}}
|
|
460
|
+
```
|
|
461
|
+
|
|
462
|
+
**Available Methods (all asynchronous):**
|
|
463
|
+
- `await path_to(name)` — returns the full path to a file (ensures the task directory exists).
|
|
464
|
+
- `await read(name, mode='r')` — reads the entire file.
|
|
465
|
+
- `await write(name, data, mode='w')` — writes data to a file.
|
|
466
|
+
- `await list()` — lists filenames in the task directory.
|
|
467
|
+
- `await exists(name)` — checks if a file exists.
|
|
468
|
+
- `async with open(name, mode)` — async context manager for advanced usage.
|
|
469
|
+
|
|
470
|
+
> **Note: Automatic Cleanup**
|
|
471
|
+
>
|
|
472
|
+
> The SDK automatically deletes the entire task directory (including everything created via `TaskFiles`) immediately after the task completes and the result is sent.
|
|
473
|
+
|
|
474
|
+
### 6. Handling Large Files (S3 Payload Offloading)
|
|
475
|
+
|
|
476
|
+
The SDK supports working with large files "out of the box" via S3-compatible storage, using the high-performance **`obstore`** library (Rust-based).
|
|
477
|
+
|
|
478
|
+
- **Automatic Download**: If a value in `params` is a URI of the form `s3://...`, the SDK will automatically download the file to the local disk and replace the URI in `params` with the local path. **If the URI ends with `/` (e.g., `s3://bucket/data/`), the SDK treats it as a folder prefix and recursively downloads all matching objects into a local directory.**
|
|
479
|
+
- **Automatic Upload**: If your handler returns a local file path in `data` (located within the `TASK_FILES_DIR` directory), the SDK will automatically upload this file to S3 and replace the path with an `s3://` URI in the final result. **If the path is a directory, the SDK recursively uploads all files within it.**
|
|
480
|
+
|
|
481
|
+
This functionality is transparent to your code.
|
|
482
|
+
|
|
483
|
+
#### S3 Example
|
|
484
|
+
|
|
485
|
+
Suppose the orchestrator sends a task with `{"input_image": "s3://my-bucket/photo.jpg"}`:
|
|
486
|
+
|
|
487
|
+
```python
|
|
488
|
+
import os
|
|
489
|
+
from avtomatika_worker import Worker, TaskFiles
|
|
490
|
+
|
|
491
|
+
worker = Worker(worker_type="image-worker")
|
|
492
|
+
|
|
493
|
+
@worker.task("process_image")
|
|
494
|
+
async def handle_image(params: dict, files: TaskFiles, **kwargs):
|
|
495
|
+
# SDK has already downloaded the file.
|
|
496
|
+
# 'input_image' now contains a local path like '/tmp/payloads/task-id/photo.jpg'
|
|
497
|
+
local_input = params["input_image"]
|
|
498
|
+
local_output = await files.path_to("processed.png")
|
|
499
|
+
|
|
500
|
+
# Your logic here (using local files)
|
|
501
|
+
# ... image processing ...
|
|
502
|
+
|
|
503
|
+
# Return the local path of the result.
|
|
504
|
+
# The SDK will upload it back to S3 automatically.
|
|
505
|
+
return {
|
|
506
|
+
"status": "success",
|
|
507
|
+
"data": {
|
|
508
|
+
"output_image": local_output
|
|
509
|
+
}
|
|
510
|
+
}
|
|
511
|
+
```
|
|
512
|
+
|
|
513
|
+
This only requires configuring environment variables for S3 access (see Full Configuration Reference).
|
|
514
|
+
|
|
515
|
+
> **Important: S3 Consistency**
|
|
516
|
+
>
|
|
517
|
+
> The SDK **does not validate** that the Worker and Orchestrator share the same storage backend. You must ensure that:
|
|
518
|
+
> 1. The Worker can reach the `S3_ENDPOINT_URL` used by the Orchestrator.
|
|
519
|
+
> 2. The Worker's credentials allow reading from the buckets referenced in the incoming `s3://` URIs.
|
|
520
|
+
> 3. The Worker's credentials allow writing to the `S3_DEFAULT_BUCKET`.
|
|
447
521
|
|
|
448
|
-
|
|
522
|
+
### 7. WebSocket Support
|
|
449
523
|
|
|
450
524
|
## Advanced Features
|
|
451
525
|
|
|
@@ -522,11 +596,12 @@ The worker is fully configured via environment variables.
|
|
|
522
596
|
| `TASK_POLL_TIMEOUT` | The timeout in seconds for polling for new tasks. | `30` |
|
|
523
597
|
| `TASK_POLL_ERROR_DELAY` | The delay in seconds before retrying after a polling error. | `5.0` |
|
|
524
598
|
| `IDLE_POLL_DELAY` | The delay in seconds between polls when the worker is idle. | `0.01` |
|
|
525
|
-
| `
|
|
599
|
+
| `TASK_FILES_DIR` | The directory for temporarily storing files when working with S3. | `/tmp/payloads` |
|
|
526
600
|
| `S3_ENDPOINT_URL` | The URL of the S3-compatible storage. | - |
|
|
527
601
|
| `S3_ACCESS_KEY` | The access key for S3. | - |
|
|
528
602
|
| `S3_SECRET_KEY` | The secret key for S3. | - |
|
|
529
603
|
| `S3_DEFAULT_BUCKET` | The default bucket name for uploading results. | `avtomatika-payloads` |
|
|
604
|
+
| `S3_REGION` | The region for S3 storage (required by some providers). | `us-east-1` |
|
|
530
605
|
|
|
531
606
|
## Development
|
|
532
607
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
avtomatika_worker/__init__.py,sha256=y_s5KlsgFu7guemZfjLVQ3Jzq7DyLG168-maVGwWRC4,334
|
|
2
|
+
avtomatika_worker/client.py,sha256=mkvwrMY8tAaZN_lwMSxWHmAoWsDemD-WiKSeH5fM6GI,4173
|
|
3
|
+
avtomatika_worker/config.py,sha256=NaAhufpwyG6CsHW-cXmqR3MfGp_5SdDZ_vEhmmV8G3g,5819
|
|
4
|
+
avtomatika_worker/constants.py,sha256=DfGR_YkW9rbioCorKpNGfZ0i_0iGgMq2swyJhVl9nNA,669
|
|
5
|
+
avtomatika_worker/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
+
avtomatika_worker/s3.py,sha256=sAuXBp__XTVhyNwK9gsxy1Jm_udJM6ypssGTZ00pa6U,8847
|
|
7
|
+
avtomatika_worker/task_files.py,sha256=ucjBuI78UmtMvfucTzDTNJ1g0KJaRIwyshRNTipIZSU,3351
|
|
8
|
+
avtomatika_worker/types.py,sha256=dSNsHgqV6hZhOt4eUK2PDWB6lrrwCA5_T_iIBI_wTZ0,442
|
|
9
|
+
avtomatika_worker/worker.py,sha256=XSRfLO-W0J6WG128Iu-rL_w3-PqmsWQMUElVLi3Z1gk,21904
|
|
10
|
+
avtomatika_worker-1.0b3.dist-info/licenses/LICENSE,sha256=tqCjw9Y1vbU-hLcWi__7wQstLbt2T1XWPdbQYqCxuWY,1072
|
|
11
|
+
avtomatika_worker-1.0b3.dist-info/METADATA,sha256=yiEtJuMv5WHYHfScna7cF5QAvAUhMCJdUDENHvrMRFY,29601
|
|
12
|
+
avtomatika_worker-1.0b3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
13
|
+
avtomatika_worker-1.0b3.dist-info/top_level.txt,sha256=d3b5BUeUrHM1Cn-cbStz-hpucikEBlPOvtcmQ_j3qAs,18
|
|
14
|
+
avtomatika_worker-1.0b3.dist-info/RECORD,,
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
avtomatika_worker/__init__.py,sha256=j0up34aVy7xyI67xg04TVbXSSSKGdO49vsBKhtH_D0M,287
|
|
2
|
-
avtomatika_worker/config.py,sha256=k1p1Njh7CVWU1PaJDqA6jsf9GXoH0iM4og_o0V0gPHI,5260
|
|
3
|
-
avtomatika_worker/s3.py,sha256=7aC7k90kLUEwBVLWvvcLapaF6gIph_7b3XGXYdtCiNU,2895
|
|
4
|
-
avtomatika_worker/types.py,sha256=MqXaX0NUatYDna3GgBWj73-WOT1EfaX1ei4i7eUsZR0,255
|
|
5
|
-
avtomatika_worker/worker.py,sha256=afePnSQp_aFAb6qDk6HAuI81PlwIZiK52fI6QY6aQQE,23234
|
|
6
|
-
avtomatika_worker-1.0b1.dist-info/licenses/LICENSE,sha256=tqCjw9Y1vbU-hLcWi__7wQstLbt2T1XWPdbQYqCxuWY,1072
|
|
7
|
-
avtomatika_worker-1.0b1.dist-info/METADATA,sha256=IMgLX7vO1jq_iQtm4TSU-XEf3hUxoViMfMzQNHfbjgg,26326
|
|
8
|
-
avtomatika_worker-1.0b1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
-
avtomatika_worker-1.0b1.dist-info/top_level.txt,sha256=d3b5BUeUrHM1Cn-cbStz-hpucikEBlPOvtcmQ_j3qAs,18
|
|
10
|
-
avtomatika_worker-1.0b1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|