avtomatika-worker 1.0b2__py3-none-any.whl → 1.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,93 @@
1
+ from asyncio import sleep
2
+ from logging import getLogger
3
+ from typing import Any
4
+
5
+ from aiohttp import ClientError, ClientSession, ClientTimeout, ClientWebSocketResponse
6
+
7
+ from .constants import AUTH_HEADER_WORKER
8
+
9
+ logger = getLogger(__name__)
10
+
11
+
12
+ class OrchestratorClient:
13
+ """
14
+ Dedicated client for communicating with a single Avtomatika Orchestrator instance.
15
+ Handles HTTP requests, retries, and authentication.
16
+ """
17
+
18
+ def __init__(self, session: ClientSession, base_url: str, worker_id: str, token: str):
19
+ self.session = session
20
+ self.base_url = base_url.rstrip("/")
21
+ self.worker_id = worker_id
22
+ self.token = token
23
+ self._headers = {AUTH_HEADER_WORKER: self.token}
24
+
25
+ async def register(self, payload: dict[str, Any]) -> bool:
26
+ """Registers the worker with the orchestrator."""
27
+ url = f"{self.base_url}/_worker/workers/register"
28
+ try:
29
+ async with self.session.post(url, json=payload, headers=self._headers) as resp:
30
+ if resp.status >= 400:
31
+ logger.error(f"Error registering with {self.base_url}: {resp.status}")
32
+ return False
33
+ return True
34
+ except ClientError as e:
35
+ logger.error(f"Error registering with orchestrator {self.base_url}: {e}")
36
+ return False
37
+
38
+ async def poll_task(self, timeout: float) -> dict[str, Any] | None:
39
+ """Polls for the next available task."""
40
+ url = f"{self.base_url}/_worker/workers/{self.worker_id}/tasks/next"
41
+ client_timeout = ClientTimeout(total=timeout + 5)
42
+ try:
43
+ async with self.session.get(url, headers=self._headers, timeout=client_timeout) as resp:
44
+ if resp.status == 200:
45
+ return await resp.json()
46
+ elif resp.status != 204:
47
+ logger.warning(f"Unexpected status from {self.base_url} during poll: {resp.status}")
48
+ except ClientError as e:
49
+ logger.error(f"Error polling for tasks from {self.base_url}: {e}")
50
+ except Exception as e:
51
+ logger.exception(f"Unexpected error polling from {self.base_url}: {e}")
52
+ return None
53
+
54
+ async def send_heartbeat(self, payload: dict[str, Any]) -> bool:
55
+ """Sends a heartbeat message to update worker state."""
56
+ url = f"{self.base_url}/_worker/workers/{self.worker_id}"
57
+ try:
58
+ async with self.session.patch(url, json=payload, headers=self._headers) as resp:
59
+ if resp.status >= 400:
60
+ logger.warning(f"Heartbeat to {self.base_url} failed with status: {resp.status}")
61
+ return False
62
+ return True
63
+ except ClientError as e:
64
+ logger.error(f"Error sending heartbeat to orchestrator {self.base_url}: {e}")
65
+ return False
66
+
67
+ async def send_result(self, payload: dict[str, Any], max_retries: int, initial_delay: float) -> bool:
68
+ """Sends task result with retries and exponential backoff."""
69
+ url = f"{self.base_url}/_worker/tasks/result"
70
+ delay = initial_delay
71
+ for i in range(max_retries):
72
+ try:
73
+ async with self.session.post(url, json=payload, headers=self._headers) as resp:
74
+ if resp.status == 200:
75
+ return True
76
+ logger.error(f"Error sending result to {self.base_url}: {resp.status}")
77
+ except ClientError as e:
78
+ logger.error(f"Error sending result to {self.base_url}: {e}")
79
+
80
+ if i < max_retries - 1:
81
+ await sleep(delay * (2**i))
82
+ return False
83
+
84
+ async def connect_websocket(self) -> ClientWebSocketResponse | None:
85
+ """Establishes a WebSocket connection for real-time commands."""
86
+ ws_url = self.base_url.replace("http", "ws", 1) + "/_worker/ws"
87
+ try:
88
+ ws = await self.session.ws_connect(ws_url, headers=self._headers)
89
+ logger.info(f"WebSocket connection established to {ws_url}")
90
+ return ws
91
+ except Exception as e:
92
+ logger.warning(f"WebSocket connection to {ws_url} failed: {e}")
93
+ return None
@@ -10,7 +10,7 @@ class WorkerConfig:
10
10
  Reads parameters from environment variables and provides default values.
11
11
  """
12
12
 
13
- def __init__(self):
13
+ def __init__(self) -> None:
14
14
  # --- Basic worker information ---
15
15
  self.WORKER_ID: str = getenv("WORKER_ID", f"worker-{uuid4()}")
16
16
  self.WORKER_TYPE: str = getenv("WORKER_TYPE", "generic-cpu-worker")
@@ -54,6 +54,7 @@ class WorkerConfig:
54
54
  self.S3_ACCESS_KEY: str | None = getenv("S3_ACCESS_KEY")
55
55
  self.S3_SECRET_KEY: str | None = getenv("S3_SECRET_KEY")
56
56
  self.S3_DEFAULT_BUCKET: str = getenv("S3_DEFAULT_BUCKET", "avtomatika-payloads")
57
+ self.S3_REGION: str = getenv("S3_REGION", "us-east-1")
57
58
 
58
59
  # --- Tuning parameters ---
59
60
  self.HEARTBEAT_INTERVAL: float = float(getenv("HEARTBEAT_INTERVAL", "15"))
@@ -70,6 +71,18 @@ class WorkerConfig:
70
71
  self.ENABLE_WEBSOCKETS: bool = getenv("WORKER_ENABLE_WEBSOCKETS", "false").lower() == "true"
71
72
  self.MULTI_ORCHESTRATOR_MODE: str = getenv("MULTI_ORCHESTRATOR_MODE", "FAILOVER")
72
73
 
74
+ def validate(self) -> None:
75
+ """Validates critical configuration parameters."""
76
+ if self.WORKER_TOKEN == "your-secret-worker-token":
77
+ print("Warning: WORKER_TOKEN is set to the default value. Tasks might fail authentication.")
78
+
79
+ if not self.ORCHESTRATORS:
80
+ raise ValueError("No orchestrators configured.")
81
+
82
+ for o in self.ORCHESTRATORS:
83
+ if not o.get("url"):
84
+ raise ValueError("Orchestrator configuration missing URL.")
85
+
73
86
  def _get_orchestrators_config(self) -> list[dict[str, Any]]:
74
87
  """
75
88
  Loads orchestrator configuration from the ORCHESTRATORS_CONFIG environment variable.
@@ -0,0 +1,22 @@
1
+ """
2
+ Centralized constants for the Avtomatika protocol (Worker SDK).
3
+ These should match the constants in the core `avtomatika` package.
4
+ """
5
+
6
+ # --- Auth Headers ---
7
+ AUTH_HEADER_CLIENT = "X-Avtomatika-Token"
8
+ AUTH_HEADER_WORKER = "X-Worker-Token"
9
+
10
+ # --- Error Codes ---
11
+ ERROR_CODE_TRANSIENT = "TRANSIENT_ERROR"
12
+ ERROR_CODE_PERMANENT = "PERMANENT_ERROR"
13
+ ERROR_CODE_INVALID_INPUT = "INVALID_INPUT_ERROR"
14
+
15
+ # --- Task Statuses ---
16
+ TASK_STATUS_SUCCESS = "success"
17
+ TASK_STATUS_FAILURE = "failure"
18
+ TASK_STATUS_CANCELLED = "cancelled"
19
+ TASK_STATUS_NEEDS_REVIEW = "needs_review" # Example of a common custom status
20
+
21
+ # --- Commands (WebSocket) ---
22
+ COMMAND_CANCEL_TASK = "cancel_task"
File without changes
avtomatika_worker/s3.py CHANGED
@@ -1,36 +1,61 @@
1
- from asyncio import gather, to_thread
1
+ from asyncio import Semaphore, gather, to_thread
2
+ from logging import getLogger
2
3
  from os import walk
3
4
  from os.path import basename, dirname, join, relpath
4
5
  from shutil import rmtree
5
- from typing import Any
6
+ from typing import Any, cast
6
7
  from urllib.parse import urlparse
7
8
 
8
- from aioboto3 import Session
9
+ import obstore
10
+ from aiofiles import open as aio_open
9
11
  from aiofiles.os import makedirs
10
12
  from aiofiles.ospath import exists, isdir
11
- from botocore.client import Config
13
+ from obstore.store import S3Store
12
14
 
13
15
  from .config import WorkerConfig
14
16
 
17
+ logger = getLogger(__name__)
18
+
19
+ # Limit concurrent S3 operations to avoid "Too many open files"
20
+ MAX_S3_CONCURRENCY = 50
21
+
15
22
 
16
23
  class S3Manager:
17
- """Handles S3 payload offloading."""
24
+ """Handles S3 payload offloading using obstore (high-performance async S3 client)."""
18
25
 
19
26
  def __init__(self, config: WorkerConfig):
20
27
  self._config = config
21
- self._session = Session()
28
+ self._stores: dict[str, S3Store] = {}
29
+ self._semaphore = Semaphore(MAX_S3_CONCURRENCY)
30
+
31
+ def _get_store(self, bucket_name: str) -> S3Store:
32
+ """Creates or returns a cached S3Store for a specific bucket."""
33
+ if bucket_name in self._stores:
34
+ return self._stores[bucket_name]
22
35
 
23
- def _get_client_args(self) -> dict[str, Any]:
24
- """Returns standard arguments for S3 client creation."""
25
- return {
26
- "service_name": "s3",
27
- "endpoint_url": self._config.S3_ENDPOINT_URL,
36
+ config_kwargs = {
28
37
  "aws_access_key_id": self._config.S3_ACCESS_KEY,
29
38
  "aws_secret_access_key": self._config.S3_SECRET_KEY,
30
- "config": Config(signature_version="s3v4"),
39
+ "region": "us-east-1", # Default region if not specified, required by some clients
31
40
  }
32
41
 
33
- async def cleanup(self, task_id: str):
42
+ if self._config.S3_ENDPOINT_URL:
43
+ config_kwargs["endpoint"] = self._config.S3_ENDPOINT_URL
44
+ if self._config.S3_ENDPOINT_URL.startswith("http://"):
45
+ config_kwargs["allow_http"] = "true"
46
+
47
+ # Filter out None values
48
+ config_kwargs = {k: v for k, v in config_kwargs.items() if v is not None}
49
+
50
+ try:
51
+ store = S3Store(bucket_name, **config_kwargs)
52
+ self._stores[bucket_name] = store
53
+ return store
54
+ except Exception as e:
55
+ logger.error(f"Failed to create S3Store for bucket {bucket_name}: {e}")
56
+ raise
57
+
58
+ async def cleanup(self, task_id: str) -> None:
34
59
  """Removes the task-specific payload directory."""
35
60
  task_dir = join(self._config.TASK_FILES_DIR, task_id)
36
61
  if await exists(task_dir):
@@ -38,54 +63,83 @@ class S3Manager:
38
63
 
39
64
  async def _process_s3_uri(self, uri: str, task_id: str) -> str:
40
65
  """Downloads a file or a folder (if uri ends with /) from S3 and returns the local path."""
41
- parsed_url = urlparse(uri)
42
- bucket_name = parsed_url.netloc
43
- object_key = parsed_url.path.lstrip("/")
66
+ try:
67
+ parsed_url = urlparse(uri)
68
+ bucket_name = parsed_url.netloc
69
+ object_key = parsed_url.path.lstrip("/")
70
+ store = self._get_store(bucket_name)
71
+
72
+ # Use task-specific directory for isolation
73
+ local_dir_root = join(self._config.TASK_FILES_DIR, task_id)
74
+ await makedirs(local_dir_root, exist_ok=True)
44
75
 
45
- # Use task-specific directory for isolation
46
- local_dir_root = join(self._config.TASK_FILES_DIR, task_id)
47
- await makedirs(local_dir_root, exist_ok=True)
76
+ logger.info(f"Starting download from S3: {uri}")
48
77
 
49
- async with self._session.client(**self._get_client_args()) as s3:
50
78
  # Handle folder download (prefix)
51
79
  if uri.endswith("/"):
52
80
  folder_name = object_key.rstrip("/").split("/")[-1]
53
81
  local_folder_path = join(local_dir_root, folder_name)
54
82
 
55
- paginator = s3.get_paginator("list_objects_v2")
56
- tasks = []
57
- async for page in paginator.paginate(Bucket=bucket_name, Prefix=object_key):
58
- for obj in page.get("Contents", []):
59
- key = obj["Key"]
60
- if key.endswith("/"):
61
- continue
83
+ # List objects with prefix
84
+ # obstore.list returns an async iterator of ObjectMeta
85
+ files_to_download = []
86
+
87
+ # Note: obstore.list returns an async iterator.
88
+ async for obj in obstore.list(store, prefix=object_key):
89
+ key = obj.key
90
+
91
+ if key.endswith("/"):
92
+ continue
62
93
 
63
- # Calculate relative path inside the folder
64
- rel_path = key[len(object_key) :]
65
- local_file_path = join(local_folder_path, rel_path)
94
+ # Calculate relative path inside the folder
95
+ rel_path = key[len(object_key) :]
96
+ local_file_path = join(local_folder_path, rel_path)
66
97
 
67
- await makedirs(dirname(local_file_path), exist_ok=True)
68
- tasks.append(s3.download_file(bucket_name, key, local_file_path))
98
+ await makedirs(dirname(local_file_path), exist_ok=True)
99
+ files_to_download.append((key, local_file_path))
69
100
 
70
- if tasks:
71
- await gather(*tasks)
101
+ async def _download_file(key: str, path: str) -> None:
102
+ async with self._semaphore:
103
+ result = await obstore.get(store, key)
104
+ async with aio_open(path, "wb") as f:
105
+ async for chunk in result.stream():
106
+ await f.write(chunk)
107
+
108
+ # Execute downloads in parallel
109
+ if files_to_download:
110
+ await gather(*[_download_file(k, p) for k, p in files_to_download])
111
+
112
+ logger.info(f"Successfully downloaded folder from S3: {uri} ({len(files_to_download)} files)")
72
113
  return local_folder_path
73
114
 
74
115
  # Handle single file download
75
116
  local_path = join(local_dir_root, basename(object_key))
76
- await s3.download_file(bucket_name, object_key, local_path)
117
+
118
+ result = await obstore.get(store, object_key)
119
+ async with aio_open(local_path, "wb") as f:
120
+ async for chunk in result.stream():
121
+ await f.write(chunk)
122
+
123
+ logger.info(f"Successfully downloaded file from S3: {uri} -> {local_path}")
77
124
  return local_path
78
125
 
126
+ except Exception as e:
127
+ # Catching generic Exception because obstore might raise different errors.
128
+ logger.exception(f"Error during download of {uri}: {e}")
129
+ raise
130
+
79
131
  async def _upload_to_s3(self, local_path: str) -> str:
80
132
  """Uploads a file or a folder to S3 and returns the S3 URI."""
81
133
  bucket_name = self._config.S3_DEFAULT_BUCKET
134
+ store = self._get_store(bucket_name)
135
+
136
+ logger.info(f"Starting upload to S3 from local path: {local_path}")
82
137
 
83
- async with self._session.client(**self._get_client_args()) as s3:
138
+ try:
84
139
  # Handle folder upload
85
140
  if await isdir(local_path):
86
141
  folder_name = basename(local_path.rstrip("/"))
87
142
  s3_prefix = f"{folder_name}/"
88
- tasks = []
89
143
 
90
144
  # Use to_thread to avoid blocking event loop during file walk
91
145
  def _get_files_to_upload():
@@ -99,18 +153,33 @@ class S3Manager:
99
153
 
100
154
  files_list = await to_thread(_get_files_to_upload)
101
155
 
102
- for full_path, key in files_list:
103
- tasks.append(s3.upload_file(full_path, bucket_name, key))
156
+ async def _upload_file(path: str, key: str) -> None:
157
+ async with self._semaphore:
158
+ # obstore.put accepts bytes or file-like objects.
159
+ # Since we are in async, reading small files is fine.
160
+ with open(path, "rb") as f:
161
+ await obstore.put(store, key, f)
104
162
 
105
- if tasks:
106
- await gather(*tasks)
163
+ if files_list:
164
+ # Upload in parallel
165
+ await gather(*[_upload_file(f, k) for f, k in files_list])
107
166
 
108
- return f"s3://{bucket_name}/{s3_prefix}"
167
+ s3_uri = f"s3://{bucket_name}/{s3_prefix}"
168
+ logger.info(f"Successfully uploaded folder to S3: {local_path} -> {s3_uri} ({len(files_list)} files)")
169
+ return s3_uri
109
170
 
110
171
  # Handle single file upload
111
172
  object_key = basename(local_path)
112
- await s3.upload_file(local_path, bucket_name, object_key)
113
- return f"s3://{bucket_name}/{object_key}"
173
+ with open(local_path, "rb") as f:
174
+ await obstore.put(store, object_key, f)
175
+
176
+ s3_uri = f"s3://{bucket_name}/{object_key}"
177
+ logger.info(f"Successfully uploaded file to S3: {local_path} -> {s3_uri}")
178
+ return s3_uri
179
+
180
+ except Exception as e:
181
+ logger.exception(f"Error during upload of {local_path}: {e}")
182
+ raise
114
183
 
115
184
  async def process_params(self, params: dict[str, Any], task_id: str) -> dict[str, Any]:
116
185
  """Recursively searches for S3 URIs in params and downloads the files."""
@@ -124,7 +193,7 @@ class S3Manager:
124
193
  return {k: await _process(v) for k, v in item.items()}
125
194
  return [await _process(i) for i in item] if isinstance(item, list) else item
126
195
 
127
- return await _process(params)
196
+ return cast(dict[str, Any], await _process(params))
128
197
 
129
198
  async def process_result(self, result: dict[str, Any]) -> dict[str, Any]:
130
199
  """Recursively searches for local file paths in the result and uploads them to S3."""
@@ -138,4 +207,4 @@ class S3Manager:
138
207
  return {k: await _process(v) for k, v in item.items()}
139
208
  return [await _process(i) for i in item] if isinstance(item, list) else item
140
209
 
141
- return await _process(result)
210
+ return cast(dict[str, Any], await _process(result))
@@ -1,8 +1,21 @@
1
- # Error codes for worker task results
2
- TRANSIENT_ERROR = "TRANSIENT_ERROR"
3
- PERMANENT_ERROR = "PERMANENT_ERROR"
4
- INVALID_INPUT_ERROR = "INVALID_INPUT_ERROR"
1
+ from .constants import (
2
+ ERROR_CODE_INVALID_INPUT as INVALID_INPUT_ERROR,
3
+ )
4
+ from .constants import (
5
+ ERROR_CODE_PERMANENT as PERMANENT_ERROR,
6
+ )
7
+ from .constants import (
8
+ ERROR_CODE_TRANSIENT as TRANSIENT_ERROR,
9
+ )
5
10
 
6
11
 
7
12
  class ParamValidationError(Exception):
8
13
  """Custom exception for parameter validation errors."""
14
+
15
+
16
+ __all__ = [
17
+ "INVALID_INPUT_ERROR",
18
+ "PERMANENT_ERROR",
19
+ "TRANSIENT_ERROR",
20
+ "ParamValidationError",
21
+ ]
@@ -1,5 +1,4 @@
1
1
  from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
2
- from asyncio import TimeoutError as AsyncTimeoutError
3
2
  from dataclasses import is_dataclass
4
3
  from inspect import Parameter, signature
5
4
  from json import JSONDecodeError
@@ -7,12 +6,21 @@ from logging import getLogger
7
6
  from os.path import join
8
7
  from typing import Any, Callable
9
8
 
10
- from aiohttp import ClientError, ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType, web
9
+ from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType, web
11
10
 
11
+ from .client import OrchestratorClient
12
12
  from .config import WorkerConfig
13
+ from .constants import (
14
+ COMMAND_CANCEL_TASK,
15
+ ERROR_CODE_INVALID_INPUT,
16
+ ERROR_CODE_PERMANENT,
17
+ ERROR_CODE_TRANSIENT,
18
+ TASK_STATUS_CANCELLED,
19
+ TASK_STATUS_FAILURE,
20
+ )
13
21
  from .s3 import S3Manager
14
22
  from .task_files import TaskFiles
15
- from .types import INVALID_INPUT_ERROR, PERMANENT_ERROR, TRANSIENT_ERROR, ParamValidationError
23
+ from .types import ParamValidationError
16
24
 
17
25
  try:
18
26
  from pydantic import BaseModel, ValidationError
@@ -45,7 +53,7 @@ class Worker:
45
53
  self._s3_manager = S3Manager(self._config)
46
54
  self._config.WORKER_TYPE = worker_type # Allow overriding worker_type
47
55
  if max_concurrent_tasks is not None:
48
- self._config.max_concurrent_tasks = max_concurrent_tasks
56
+ self._config.MAX_CONCURRENT_TASKS = max_concurrent_tasks
49
57
 
50
58
  self._task_type_limits = task_type_limits or {}
51
59
  self._task_handlers: dict[str, dict[str, Any]] = {}
@@ -59,10 +67,8 @@ class Worker:
59
67
  self._http_session = http_session
60
68
  self._session_is_managed_externally = http_session is not None
61
69
  self._ws_connection: ClientWebSocketResponse | None = None
62
- # Removed: self._headers = {"X-Worker-Token": self._config.WORKER_TOKEN}
63
70
  self._shutdown_event = Event()
64
71
  self._registered_event = Event()
65
- self._round_robin_index = 0
66
72
  self._debounce_task: Task | None = None
67
73
 
68
74
  # --- Weighted Round-Robin State ---
@@ -72,7 +78,28 @@ class Worker:
72
78
  o["current_weight"] = 0
73
79
  self._total_orchestrator_weight += o.get("weight", 1)
74
80
 
75
- def _validate_config(self):
81
+ self._clients: list[tuple[dict[str, Any], OrchestratorClient]] = []
82
+ if self._http_session:
83
+ self._init_clients()
84
+
85
+ def _init_clients(self):
86
+ """Initializes OrchestratorClient instances for each configured orchestrator."""
87
+ if not self._http_session:
88
+ return
89
+ self._clients = [
90
+ (
91
+ o,
92
+ OrchestratorClient(
93
+ session=self._http_session,
94
+ base_url=o["url"],
95
+ worker_id=self._config.WORKER_ID,
96
+ token=o.get("token", self._config.WORKER_TOKEN),
97
+ ),
98
+ )
99
+ for o in self._config.ORCHESTRATORS
100
+ ]
101
+
102
+ def _validate_task_types(self):
76
103
  """Checks for unused task type limits and warns the user."""
77
104
  registered_task_types = {
78
105
  handler_data["type"] for handler_data in self._task_handlers.values() if handler_data["type"]
@@ -140,32 +167,31 @@ class Worker:
140
167
  status = "idle" if supported_tasks else "busy"
141
168
  return {"status": status, "supported_tasks": supported_tasks}
142
169
 
143
- def _get_headers(self, orchestrator: dict[str, Any]) -> dict[str, str]:
144
- """Builds authentication headers for a specific orchestrator."""
145
- token = orchestrator.get("token", self._config.WORKER_TOKEN)
146
- return {"X-Worker-Token": token}
147
-
148
- def _get_next_orchestrator(self) -> dict[str, Any] | None:
170
+ def _get_next_client(self) -> OrchestratorClient | None:
149
171
  """
150
- Selects the next orchestrator using a smooth weighted round-robin algorithm.
172
+ Selects the next orchestrator client using a smooth weighted round-robin algorithm.
151
173
  """
152
- if not self._config.ORCHESTRATORS:
174
+ if not self._clients:
153
175
  return None
154
176
 
155
177
  # The orchestrator with the highest current_weight is selected.
156
- selected_orchestrator = None
178
+ selected_client = None
157
179
  highest_weight = -1
158
180
 
159
- for o in self._config.ORCHESTRATORS:
181
+ for o, client in self._clients:
160
182
  o["current_weight"] += o["weight"]
161
183
  if o["current_weight"] > highest_weight:
162
184
  highest_weight = o["current_weight"]
163
- selected_orchestrator = o
185
+ selected_client = client
164
186
 
165
- if selected_orchestrator:
166
- selected_orchestrator["current_weight"] -= self._total_orchestrator_weight
187
+ if selected_client:
188
+ # Find the config for the selected client to decrement its weight
189
+ for o, client in self._clients:
190
+ if client == selected_client:
191
+ o["current_weight"] -= self._total_orchestrator_weight
192
+ break
167
193
 
168
- return selected_orchestrator
194
+ return selected_client
169
195
 
170
196
  async def _debounced_heartbeat_sender(self):
171
197
  """Waits for the debounce delay then sends a heartbeat."""
@@ -180,33 +206,27 @@ class Worker:
180
206
  # Schedule the new debounced call.
181
207
  self._debounce_task = create_task(self._debounced_heartbeat_sender())
182
208
 
183
- async def _poll_for_tasks(self, orchestrator: dict[str, Any]):
209
+ async def _poll_for_tasks(self, client: OrchestratorClient):
184
210
  """Polls a specific Orchestrator for new tasks."""
185
- url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}/tasks/next"
186
- try:
187
- if not self._http_session:
188
- return
189
- timeout = ClientTimeout(total=self._config.TASK_POLL_TIMEOUT + 5)
190
- headers = self._get_headers(orchestrator)
191
- async with self._http_session.get(url, headers=headers, timeout=timeout) as resp:
192
- if resp.status == 200:
193
- task_data = await resp.json()
194
- task_data["orchestrator"] = orchestrator
195
-
196
- self._current_load += 1
197
- if (task_handler_info := self._task_handlers.get(task_data["type"])) and (
198
- task_type_for_limit := task_handler_info.get("type")
199
- ):
200
- self._current_load_by_type[task_type_for_limit] += 1
201
- self._schedule_heartbeat_debounce()
202
-
203
- task = create_task(self._process_task(task_data))
204
- self._active_tasks[task_data["task_id"]] = task
205
- elif resp.status != 204:
206
- await sleep(self._config.TASK_POLL_ERROR_DELAY)
207
- except (AsyncTimeoutError, ClientError) as e:
208
- logger.error(f"Error polling for tasks: {e}")
209
- await sleep(self._config.TASK_POLL_ERROR_DELAY)
211
+ task_data = await client.poll_task(timeout=self._config.TASK_POLL_TIMEOUT)
212
+ if task_data:
213
+ task_data["client"] = client
214
+
215
+ self._current_load += 1
216
+ if (task_handler_info := self._task_handlers.get(task_data["type"])) and (
217
+ task_type_for_limit := task_handler_info.get("type")
218
+ ):
219
+ self._current_load_by_type[task_type_for_limit] += 1
220
+ self._schedule_heartbeat_debounce()
221
+
222
+ task = create_task(self._process_task(task_data))
223
+ self._active_tasks[task_data["task_id"]] = task
224
+ else:
225
+ # If no task but it was a 204 or error, the client already handled/logged it.
226
+ # We might want a short sleep here if it was an error, but client.poll_task
227
+ # doesn't distinguish between 204 and error currently.
228
+ # However, the previous logic only slept on status != 204.
229
+ pass
210
230
 
211
231
  async def _start_polling(self):
212
232
  """The main loop for polling tasks."""
@@ -218,13 +238,13 @@ class Worker:
218
238
  continue
219
239
 
220
240
  if self._config.MULTI_ORCHESTRATOR_MODE == "ROUND_ROBIN":
221
- if orchestrator := self._get_next_orchestrator():
222
- await self._poll_for_tasks(orchestrator)
241
+ if client := self._get_next_client():
242
+ await self._poll_for_tasks(client)
223
243
  else:
224
- for orchestrator in self._config.ORCHESTRATORS:
244
+ for _, client in self._clients:
225
245
  if self._get_current_state()["status"] == "busy":
226
246
  break
227
- await self._poll_for_tasks(orchestrator)
247
+ await self._poll_for_tasks(client)
228
248
 
229
249
  if self._current_load == 0:
230
250
  await sleep(self._config.IDLE_POLL_DELAY)
@@ -289,7 +309,7 @@ class Worker:
289
309
  async def _process_task(self, task_data: dict[str, Any]):
290
310
  """Executes the task logic."""
291
311
  task_id, job_id, task_name = task_data["task_id"], task_data["job_id"], task_data["type"]
292
- params, orchestrator = task_data.get("params", {}), task_data["orchestrator"]
312
+ params, client = task_data.get("params", {}), task_data["client"]
293
313
 
294
314
  result: dict[str, Any] = {}
295
315
  handler_data = self._task_handlers.get(task_name)
@@ -301,9 +321,11 @@ class Worker:
301
321
  if not handler_data:
302
322
  message = f"Unsupported task: {task_name}"
303
323
  logger.warning(message)
304
- result = {"status": "failure", "error": {"code": PERMANENT_ERROR, "message": message}}
324
+ result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_PERMANENT, "message": message}}
305
325
  payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
306
- await self._send_result(payload, orchestrator)
326
+ await client.send_result(
327
+ payload, self._config.RESULT_MAX_RETRIES, self._config.RESULT_RETRY_INITIAL_DELAY
328
+ )
307
329
  result_sent = True # Mark result as sent
308
330
  return
309
331
 
@@ -324,22 +346,24 @@ class Worker:
324
346
  result = await self._s3_manager.process_result(result)
325
347
  except ParamValidationError as e:
326
348
  logger.error(f"Task {task_id} failed validation: {e}")
327
- result = {"status": "failure", "error": {"code": INVALID_INPUT_ERROR, "message": str(e)}}
349
+ result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_INVALID_INPUT, "message": str(e)}}
328
350
  except CancelledError:
329
351
  logger.info(f"Task {task_id} was cancelled.")
330
- result = {"status": "cancelled"}
352
+ result = {"status": TASK_STATUS_CANCELLED}
331
353
  # We must re-raise the exception to be handled by the outer gather
332
354
  raise
333
355
  except Exception as e:
334
356
  logger.exception(f"An unexpected error occurred while processing task {task_id}:")
335
- result = {"status": "failure", "error": {"code": TRANSIENT_ERROR, "message": str(e)}}
357
+ result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_TRANSIENT, "message": str(e)}}
336
358
  finally:
337
359
  # Cleanup task workspace
338
360
  await self._s3_manager.cleanup(task_id)
339
361
 
340
362
  if not result_sent: # Only send if not already sent
341
363
  payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
342
- await self._send_result(payload, orchestrator)
364
+ await client.send_result(
365
+ payload, self._config.RESULT_MAX_RETRIES, self._config.RESULT_RETRY_INITIAL_DELAY
366
+ )
343
367
  self._active_tasks.pop(task_id, None)
344
368
 
345
369
  self._current_load -= 1
@@ -347,21 +371,6 @@ class Worker:
347
371
  self._current_load_by_type[task_type_for_limit] -= 1
348
372
  self._schedule_heartbeat_debounce()
349
373
 
350
- async def _send_result(self, payload: dict[str, Any], orchestrator: dict[str, Any]):
351
- """Sends the result to a specific orchestrator."""
352
- url = f"{orchestrator['url']}/_worker/tasks/result"
353
- delay = self._config.RESULT_RETRY_INITIAL_DELAY
354
- headers = self._get_headers(orchestrator)
355
- for i in range(self._config.RESULT_MAX_RETRIES):
356
- try:
357
- if self._http_session and not self._http_session.closed:
358
- async with self._http_session.post(url, json=payload, headers=headers) as resp:
359
- if resp.status == 200:
360
- return
361
- except ClientError as e:
362
- logger.error(f"Error sending result: {e}")
363
- await sleep(delay * (2**i))
364
-
365
374
  async def _manage_orchestrator_communications(self):
366
375
  """Registers the worker and sends heartbeats."""
367
376
  await self._register_with_all_orchestrators()
@@ -388,17 +397,7 @@ class Worker:
388
397
  "ip_address": self._config.IP_ADDRESS,
389
398
  "resources": self._config.RESOURCES,
390
399
  }
391
- for orchestrator in self._config.ORCHESTRATORS:
392
- url = f"{orchestrator['url']}/_worker/workers/register"
393
- try:
394
- if self._http_session:
395
- async with self._http_session.post(
396
- url, json=payload, headers=self._get_headers(orchestrator)
397
- ) as resp:
398
- if resp.status >= 400:
399
- logger.error(f"Error registering with {orchestrator['url']}: {resp.status}")
400
- except ClientError as e:
401
- logger.error(f"Error registering with orchestrator {orchestrator['url']}: {e}")
400
+ await gather(*[client.register(payload) for _, client in self._clients])
402
401
 
403
402
  async def _send_heartbeats_to_all(self):
404
403
  """Sends heartbeat messages to all orchestrators."""
@@ -418,24 +417,15 @@ class Worker:
418
417
  if hot_skills:
419
418
  payload["hot_skills"] = hot_skills
420
419
 
421
- async def _send_single(orchestrator: dict[str, Any]):
422
- url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}"
423
- headers = self._get_headers(orchestrator)
424
- try:
425
- if self._http_session and not self._http_session.closed:
426
- async with self._http_session.patch(url, json=payload, headers=headers) as resp:
427
- if resp.status >= 400:
428
- logger.warning(f"Heartbeat to {orchestrator['url']} failed with status: {resp.status}")
429
- except ClientError as e:
430
- logger.error(f"Error sending heartbeat to orchestrator {orchestrator['url']}: {e}")
431
-
432
- await gather(*[_send_single(o) for o in self._config.ORCHESTRATORS])
420
+ await gather(*[client.send_heartbeat(payload) for _, client in self._clients])
433
421
 
434
422
  async def main(self):
435
423
  """The main asynchronous function."""
436
- self._validate_config() # Validate config now that all tasks are registered
424
+ self._config.validate()
425
+ self._validate_task_types() # Validate config now that all tasks are registered
437
426
  if not self._http_session:
438
427
  self._http_session = ClientSession()
428
+ self._init_clients()
439
429
 
440
430
  comm_task = create_task(self._manage_orchestrator_communications())
441
431
 
@@ -482,25 +472,20 @@ class Worker:
482
472
  except KeyboardInterrupt:
483
473
  self._shutdown_event.set()
484
474
 
485
- # WebSocket methods omitted for brevity as they are not relevant to the changes
486
475
  async def _start_websocket_manager(self):
487
476
  """Manages the WebSocket connection to the orchestrator."""
488
477
  while not self._shutdown_event.is_set():
489
- for orchestrator in self._config.ORCHESTRATORS:
490
- ws_url = orchestrator["url"].replace("http", "ws", 1) + "/_worker/ws"
478
+ # In multi-orchestrator mode, we currently only connect to the first one available
479
+ for _, client in self._clients:
491
480
  try:
492
- if self._http_session:
493
- async with self._http_session.ws_connect(ws_url, headers=self._get_headers(orchestrator)) as ws:
494
- self._ws_connection = ws
495
- logger.info(f"WebSocket connection established to {ws_url}")
496
- await self._listen_for_commands()
497
- except (ClientError, AsyncTimeoutError) as e:
498
- logger.warning(f"WebSocket connection to {ws_url} failed: {e}")
481
+ ws = await client.connect_websocket()
482
+ if ws:
483
+ self._ws_connection = ws
484
+ await self._listen_for_commands()
499
485
  finally:
500
486
  self._ws_connection = None
501
- logger.info(f"WebSocket connection to {ws_url} closed.")
502
487
  await sleep(5) # Reconnection delay
503
- if not self._config.ORCHESTRATORS:
488
+ if not self._clients:
504
489
  await sleep(5)
505
490
 
506
491
  async def _listen_for_commands(self):
@@ -513,7 +498,7 @@ class Worker:
513
498
  if msg.type == WSMsgType.TEXT:
514
499
  try:
515
500
  command = msg.json()
516
- if command.get("type") == "cancel_task":
501
+ if command.get("type") == COMMAND_CANCEL_TASK:
517
502
  task_id = command.get("task_id")
518
503
  if task_id in self._active_tasks:
519
504
  self._active_tasks[task_id].cancel()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: avtomatika-worker
3
- Version: 1.0b2
3
+ Version: 1.0b3
4
4
  Summary: Worker SDK for the Avtomatika orchestrator.
5
5
  Project-URL: Homepage, https://github.com/avtomatika-ai/avtomatika-worker
6
6
  Project-URL: Bug Tracker, https://github.com/avtomatika-ai/avtomatika-worker/issues
@@ -13,7 +13,7 @@ Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  Requires-Dist: aiohttp~=3.13.2
15
15
  Requires-Dist: python-json-logger~=4.0.0
16
- Requires-Dist: aioboto3~=15.5.0
16
+ Requires-Dist: obstore>=0.1
17
17
  Requires-Dist: aiofiles~=25.1.0
18
18
  Provides-Extra: test
19
19
  Requires-Dist: pytest; extra == "test"
@@ -21,13 +21,14 @@ Requires-Dist: pytest-asyncio; extra == "test"
21
21
  Requires-Dist: aioresponses; extra == "test"
22
22
  Requires-Dist: pytest-mock; extra == "test"
23
23
  Requires-Dist: pydantic; extra == "test"
24
+ Requires-Dist: types-aiofiles; extra == "test"
24
25
  Provides-Extra: pydantic
25
26
  Requires-Dist: pydantic; extra == "pydantic"
26
27
  Dynamic: license-file
27
28
 
28
29
  # Avtomatika Worker SDK
29
30
 
30
- This is an SDK for creating workers compatible with the **Avtomatika** orchestrator. The SDK handles all the complexity of interacting with the orchestrator, allowing you to focus on writing your business logic.
31
+ This is the official SDK for creating workers compatible with the **[Avtomatika Orchestrator](https://github.com/avtomatika-ai/avtomatika)**. It implements the **[RCA Protocol](https://github.com/avtomatika-ai/rca)**, handling all communication complexity (polling, heartbeats, S3 offloading) so you can focus on writing your business logic.
31
32
 
32
33
  ## Installation
33
34
 
@@ -472,7 +473,7 @@ async def generate_report(params: dict, files: TaskFiles, **kwargs):
472
473
 
473
474
  ### 6. Handling Large Files (S3 Payload Offloading)
474
475
 
475
- The SDK supports working with large files "out of the box" via S3-compatible storage.
476
+ The SDK supports working with large files "out of the box" via S3-compatible storage, using the high-performance **`obstore`** library (Rust-based).
476
477
 
477
478
  - **Automatic Download**: If a value in `params` is a URI of the form `s3://...`, the SDK will automatically download the file to the local disk and replace the URI in `params` with the local path. **If the URI ends with `/` (e.g., `s3://bucket/data/`), the SDK treats it as a folder prefix and recursively downloads all matching objects into a local directory.**
478
479
  - **Automatic Upload**: If your handler returns a local file path in `data` (located within the `TASK_FILES_DIR` directory), the SDK will automatically upload this file to S3 and replace the path with an `s3://` URI in the final result. **If the path is a directory, the SDK recursively uploads all files within it.**
@@ -600,6 +601,7 @@ The worker is fully configured via environment variables.
600
601
  | `S3_ACCESS_KEY` | The access key for S3. | - |
601
602
  | `S3_SECRET_KEY` | The secret key for S3. | - |
602
603
  | `S3_DEFAULT_BUCKET` | The default bucket name for uploading results. | `avtomatika-payloads` |
604
+ | `S3_REGION` | The region for S3 storage (required by some providers). | `us-east-1` |
603
605
 
604
606
  ## Development
605
607
 
@@ -0,0 +1,14 @@
1
+ avtomatika_worker/__init__.py,sha256=y_s5KlsgFu7guemZfjLVQ3Jzq7DyLG168-maVGwWRC4,334
2
+ avtomatika_worker/client.py,sha256=mkvwrMY8tAaZN_lwMSxWHmAoWsDemD-WiKSeH5fM6GI,4173
3
+ avtomatika_worker/config.py,sha256=NaAhufpwyG6CsHW-cXmqR3MfGp_5SdDZ_vEhmmV8G3g,5819
4
+ avtomatika_worker/constants.py,sha256=DfGR_YkW9rbioCorKpNGfZ0i_0iGgMq2swyJhVl9nNA,669
5
+ avtomatika_worker/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ avtomatika_worker/s3.py,sha256=sAuXBp__XTVhyNwK9gsxy1Jm_udJM6ypssGTZ00pa6U,8847
7
+ avtomatika_worker/task_files.py,sha256=ucjBuI78UmtMvfucTzDTNJ1g0KJaRIwyshRNTipIZSU,3351
8
+ avtomatika_worker/types.py,sha256=dSNsHgqV6hZhOt4eUK2PDWB6lrrwCA5_T_iIBI_wTZ0,442
9
+ avtomatika_worker/worker.py,sha256=XSRfLO-W0J6WG128Iu-rL_w3-PqmsWQMUElVLi3Z1gk,21904
10
+ avtomatika_worker-1.0b3.dist-info/licenses/LICENSE,sha256=tqCjw9Y1vbU-hLcWi__7wQstLbt2T1XWPdbQYqCxuWY,1072
11
+ avtomatika_worker-1.0b3.dist-info/METADATA,sha256=yiEtJuMv5WHYHfScna7cF5QAvAUhMCJdUDENHvrMRFY,29601
12
+ avtomatika_worker-1.0b3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ avtomatika_worker-1.0b3.dist-info/top_level.txt,sha256=d3b5BUeUrHM1Cn-cbStz-hpucikEBlPOvtcmQ_j3qAs,18
14
+ avtomatika_worker-1.0b3.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- avtomatika_worker/__init__.py,sha256=y_s5KlsgFu7guemZfjLVQ3Jzq7DyLG168-maVGwWRC4,334
2
- avtomatika_worker/config.py,sha256=v-2XGIcCIMr9S2SPAVKOMTpU8QSLeUm-udNOKWSxjQQ,5247
3
- avtomatika_worker/s3.py,sha256=ySwEOrP2ZslJ-Mg4_9vyxsnRzX0LIe78FmP2nlq8n9s,5930
4
- avtomatika_worker/task_files.py,sha256=ucjBuI78UmtMvfucTzDTNJ1g0KJaRIwyshRNTipIZSU,3351
5
- avtomatika_worker/types.py,sha256=MqXaX0NUatYDna3GgBWj73-WOT1EfaX1ei4i7eUsZR0,255
6
- avtomatika_worker/worker.py,sha256=sghM9Y8wwB4uJv_MPlQ3noBejDiwh4MYTZySNwjcZ3w,23968
7
- avtomatika_worker-1.0b2.dist-info/licenses/LICENSE,sha256=tqCjw9Y1vbU-hLcWi__7wQstLbt2T1XWPdbQYqCxuWY,1072
8
- avtomatika_worker-1.0b2.dist-info/METADATA,sha256=QdZKdrT-HP-u94gP7NVUt_TWGWJX6LIek2llHM1xj8Q,29184
9
- avtomatika_worker-1.0b2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
- avtomatika_worker-1.0b2.dist-info/top_level.txt,sha256=d3b5BUeUrHM1Cn-cbStz-hpucikEBlPOvtcmQ_j3qAs,18
11
- avtomatika_worker-1.0b2.dist-info/RECORD,,