avtomatika-worker 1.0b1__py3-none-any.whl → 1.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,9 +2,10 @@
2
2
 
3
3
  from importlib.metadata import PackageNotFoundError, version
4
4
 
5
+ from .task_files import TaskFiles
5
6
  from .worker import Worker
6
7
 
7
- __all__ = ["Worker"]
8
+ __all__ = ["Worker", "TaskFiles"]
8
9
 
9
10
  try:
10
11
  __version__ = version("avtomatika-worker")
@@ -0,0 +1,93 @@
1
+ from asyncio import sleep
2
+ from logging import getLogger
3
+ from typing import Any
4
+
5
+ from aiohttp import ClientError, ClientSession, ClientTimeout, ClientWebSocketResponse
6
+
7
+ from .constants import AUTH_HEADER_WORKER
8
+
9
+ logger = getLogger(__name__)
10
+
11
+
12
+ class OrchestratorClient:
13
+ """
14
+ Dedicated client for communicating with a single Avtomatika Orchestrator instance.
15
+ Handles HTTP requests, retries, and authentication.
16
+ """
17
+
18
+ def __init__(self, session: ClientSession, base_url: str, worker_id: str, token: str):
19
+ self.session = session
20
+ self.base_url = base_url.rstrip("/")
21
+ self.worker_id = worker_id
22
+ self.token = token
23
+ self._headers = {AUTH_HEADER_WORKER: self.token}
24
+
25
+ async def register(self, payload: dict[str, Any]) -> bool:
26
+ """Registers the worker with the orchestrator."""
27
+ url = f"{self.base_url}/_worker/workers/register"
28
+ try:
29
+ async with self.session.post(url, json=payload, headers=self._headers) as resp:
30
+ if resp.status >= 400:
31
+ logger.error(f"Error registering with {self.base_url}: {resp.status}")
32
+ return False
33
+ return True
34
+ except ClientError as e:
35
+ logger.error(f"Error registering with orchestrator {self.base_url}: {e}")
36
+ return False
37
+
38
+ async def poll_task(self, timeout: float) -> dict[str, Any] | None:
39
+ """Polls for the next available task."""
40
+ url = f"{self.base_url}/_worker/workers/{self.worker_id}/tasks/next"
41
+ client_timeout = ClientTimeout(total=timeout + 5)
42
+ try:
43
+ async with self.session.get(url, headers=self._headers, timeout=client_timeout) as resp:
44
+ if resp.status == 200:
45
+ return await resp.json()
46
+ elif resp.status != 204:
47
+ logger.warning(f"Unexpected status from {self.base_url} during poll: {resp.status}")
48
+ except ClientError as e:
49
+ logger.error(f"Error polling for tasks from {self.base_url}: {e}")
50
+ except Exception as e:
51
+ logger.exception(f"Unexpected error polling from {self.base_url}: {e}")
52
+ return None
53
+
54
+ async def send_heartbeat(self, payload: dict[str, Any]) -> bool:
55
+ """Sends a heartbeat message to update worker state."""
56
+ url = f"{self.base_url}/_worker/workers/{self.worker_id}"
57
+ try:
58
+ async with self.session.patch(url, json=payload, headers=self._headers) as resp:
59
+ if resp.status >= 400:
60
+ logger.warning(f"Heartbeat to {self.base_url} failed with status: {resp.status}")
61
+ return False
62
+ return True
63
+ except ClientError as e:
64
+ logger.error(f"Error sending heartbeat to orchestrator {self.base_url}: {e}")
65
+ return False
66
+
67
+ async def send_result(self, payload: dict[str, Any], max_retries: int, initial_delay: float) -> bool:
68
+ """Sends task result with retries and exponential backoff."""
69
+ url = f"{self.base_url}/_worker/tasks/result"
70
+ delay = initial_delay
71
+ for i in range(max_retries):
72
+ try:
73
+ async with self.session.post(url, json=payload, headers=self._headers) as resp:
74
+ if resp.status == 200:
75
+ return True
76
+ logger.error(f"Error sending result to {self.base_url}: {resp.status}")
77
+ except ClientError as e:
78
+ logger.error(f"Error sending result to {self.base_url}: {e}")
79
+
80
+ if i < max_retries - 1:
81
+ await sleep(delay * (2**i))
82
+ return False
83
+
84
+ async def connect_websocket(self) -> ClientWebSocketResponse | None:
85
+ """Establishes a WebSocket connection for real-time commands."""
86
+ ws_url = self.base_url.replace("http", "ws", 1) + "/_worker/ws"
87
+ try:
88
+ ws = await self.session.ws_connect(ws_url, headers=self._headers)
89
+ logger.info(f"WebSocket connection established to {ws_url}")
90
+ return ws
91
+ except Exception as e:
92
+ logger.warning(f"WebSocket connection to {ws_url} failed: {e}")
93
+ return None
@@ -10,7 +10,7 @@ class WorkerConfig:
10
10
  Reads parameters from environment variables and provides default values.
11
11
  """
12
12
 
13
- def __init__(self):
13
+ def __init__(self) -> None:
14
14
  # --- Basic worker information ---
15
15
  self.WORKER_ID: str = getenv("WORKER_ID", f"worker-{uuid4()}")
16
16
  self.WORKER_TYPE: str = getenv("WORKER_TYPE", "generic-cpu-worker")
@@ -49,11 +49,12 @@ class WorkerConfig:
49
49
  )
50
50
 
51
51
  # --- S3 Settings for payload offloading ---
52
- self.WORKER_PAYLOAD_DIR: str = getenv("WORKER_PAYLOAD_DIR", "/tmp/payloads")
52
+ self.TASK_FILES_DIR: str = getenv("TASK_FILES_DIR", "/tmp/payloads")
53
53
  self.S3_ENDPOINT_URL: str | None = getenv("S3_ENDPOINT_URL")
54
54
  self.S3_ACCESS_KEY: str | None = getenv("S3_ACCESS_KEY")
55
55
  self.S3_SECRET_KEY: str | None = getenv("S3_SECRET_KEY")
56
56
  self.S3_DEFAULT_BUCKET: str = getenv("S3_DEFAULT_BUCKET", "avtomatika-payloads")
57
+ self.S3_REGION: str = getenv("S3_REGION", "us-east-1")
57
58
 
58
59
  # --- Tuning parameters ---
59
60
  self.HEARTBEAT_INTERVAL: float = float(getenv("HEARTBEAT_INTERVAL", "15"))
@@ -70,13 +71,24 @@ class WorkerConfig:
70
71
  self.ENABLE_WEBSOCKETS: bool = getenv("WORKER_ENABLE_WEBSOCKETS", "false").lower() == "true"
71
72
  self.MULTI_ORCHESTRATOR_MODE: str = getenv("MULTI_ORCHESTRATOR_MODE", "FAILOVER")
72
73
 
74
+ def validate(self) -> None:
75
+ """Validates critical configuration parameters."""
76
+ if self.WORKER_TOKEN == "your-secret-worker-token":
77
+ print("Warning: WORKER_TOKEN is set to the default value. Tasks might fail authentication.")
78
+
79
+ if not self.ORCHESTRATORS:
80
+ raise ValueError("No orchestrators configured.")
81
+
82
+ for o in self.ORCHESTRATORS:
83
+ if not o.get("url"):
84
+ raise ValueError("Orchestrator configuration missing URL.")
85
+
73
86
  def _get_orchestrators_config(self) -> list[dict[str, Any]]:
74
87
  """
75
88
  Loads orchestrator configuration from the ORCHESTRATORS_CONFIG environment variable.
76
89
  For backward compatibility, if it is not set, it uses ORCHESTRATOR_URL.
77
90
  """
78
- orchestrators_json = getenv("ORCHESTRATORS_CONFIG")
79
- if orchestrators_json:
91
+ if orchestrators_json := getenv("ORCHESTRATORS_CONFIG"):
80
92
  try:
81
93
  orchestrators = loads(orchestrators_json)
82
94
  if getenv("ORCHESTRATOR_URL"):
@@ -94,23 +106,23 @@ class WorkerConfig:
94
106
  orchestrator_url = getenv("ORCHESTRATOR_URL", "http://localhost:8080")
95
107
  return [{"url": orchestrator_url, "priority": 1, "weight": 1}]
96
108
 
97
- def _get_gpu_info(self) -> dict[str, Any] | None:
109
+ @staticmethod
110
+ def _get_gpu_info() -> dict[str, Any] | None:
98
111
  """Collects GPU information from environment variables.
99
112
  Returns None if GPU is not configured.
100
113
  """
101
- gpu_model = getenv("GPU_MODEL")
102
- if not gpu_model:
114
+ if gpu_model := getenv("GPU_MODEL"):
115
+ return {
116
+ "model": gpu_model,
117
+ "vram_gb": int(getenv("GPU_VRAM_GB", "0")),
118
+ }
119
+ else:
103
120
  return None
104
121
 
105
- return {
106
- "model": gpu_model,
107
- "vram_gb": int(getenv("GPU_VRAM_GB", "0")),
108
- }
109
-
110
- def _load_json_from_env(self, key: str, default: Any) -> Any:
122
+ @staticmethod
123
+ def _load_json_from_env(key: str, default: Any) -> Any:
111
124
  """Safely loads a JSON string from an environment variable."""
112
- value = getenv(key)
113
- if value:
125
+ if value := getenv(key):
114
126
  try:
115
127
  return loads(value)
116
128
  except JSONDecodeError:
@@ -0,0 +1,22 @@
1
+ """
2
+ Centralized constants for the Avtomatika protocol (Worker SDK).
3
+ These should match the constants in the core `avtomatika` package.
4
+ """
5
+
6
+ # --- Auth Headers ---
7
+ AUTH_HEADER_CLIENT = "X-Avtomatika-Token"
8
+ AUTH_HEADER_WORKER = "X-Worker-Token"
9
+
10
+ # --- Error Codes ---
11
+ ERROR_CODE_TRANSIENT = "TRANSIENT_ERROR"
12
+ ERROR_CODE_PERMANENT = "PERMANENT_ERROR"
13
+ ERROR_CODE_INVALID_INPUT = "INVALID_INPUT_ERROR"
14
+
15
+ # --- Task Statuses ---
16
+ TASK_STATUS_SUCCESS = "success"
17
+ TASK_STATUS_FAILURE = "failure"
18
+ TASK_STATUS_CANCELLED = "cancelled"
19
+ TASK_STATUS_NEEDS_REVIEW = "needs_review" # Example of a common custom status
20
+
21
+ # --- Commands (WebSocket) ---
22
+ COMMAND_CANCEL_TASK = "cancel_task"
File without changes
avtomatika_worker/s3.py CHANGED
@@ -1,62 +1,199 @@
1
- import asyncio
2
- import os
3
- from typing import Any
1
+ from asyncio import Semaphore, gather, to_thread
2
+ from logging import getLogger
3
+ from os import walk
4
+ from os.path import basename, dirname, join, relpath
5
+ from shutil import rmtree
6
+ from typing import Any, cast
4
7
  from urllib.parse import urlparse
5
8
 
6
- import boto3
7
- from botocore.client import Config
9
+ import obstore
10
+ from aiofiles import open as aio_open
11
+ from aiofiles.os import makedirs
12
+ from aiofiles.ospath import exists, isdir
13
+ from obstore.store import S3Store
8
14
 
9
15
  from .config import WorkerConfig
10
16
 
17
+ logger = getLogger(__name__)
18
+
19
+ # Limit concurrent S3 operations to avoid "Too many open files"
20
+ MAX_S3_CONCURRENCY = 50
21
+
11
22
 
12
23
  class S3Manager:
13
- """Handles S3 payload offloading."""
24
+ """Handles S3 payload offloading using obstore (high-performance async S3 client)."""
14
25
 
15
26
  def __init__(self, config: WorkerConfig):
16
27
  self._config = config
17
- self._s3 = boto3.client(
18
- "s3",
19
- endpoint_url=self._config.S3_ENDPOINT_URL,
20
- aws_access_key_id=self._config.S3_ACCESS_KEY,
21
- aws_secret_access_key=self._config.S3_SECRET_KEY,
22
- config=Config(signature_version="s3v4"),
23
- )
24
-
25
- async def _process_s3_uri(self, uri: str) -> str:
26
- """Downloads a file from S3 and returns the local path."""
27
- parsed_url = urlparse(uri)
28
- bucket_name = parsed_url.netloc
29
- object_key = parsed_url.path.lstrip("/")
30
- local_dir = self._config.WORKER_PAYLOAD_DIR
31
- os.makedirs(local_dir, exist_ok=True)
32
- local_path = os.path.join(local_dir, os.path.basename(object_key))
33
-
34
- await asyncio.to_thread(self._s3.download_file, bucket_name, object_key, local_path)
35
- return local_path
28
+ self._stores: dict[str, S3Store] = {}
29
+ self._semaphore = Semaphore(MAX_S3_CONCURRENCY)
30
+
31
+ def _get_store(self, bucket_name: str) -> S3Store:
32
+ """Creates or returns a cached S3Store for a specific bucket."""
33
+ if bucket_name in self._stores:
34
+ return self._stores[bucket_name]
35
+
36
+ config_kwargs = {
37
+ "aws_access_key_id": self._config.S3_ACCESS_KEY,
38
+ "aws_secret_access_key": self._config.S3_SECRET_KEY,
39
+ "region": "us-east-1", # Default region if not specified, required by some clients
40
+ }
41
+
42
+ if self._config.S3_ENDPOINT_URL:
43
+ config_kwargs["endpoint"] = self._config.S3_ENDPOINT_URL
44
+ if self._config.S3_ENDPOINT_URL.startswith("http://"):
45
+ config_kwargs["allow_http"] = "true"
46
+
47
+ # Filter out None values
48
+ config_kwargs = {k: v for k, v in config_kwargs.items() if v is not None}
49
+
50
+ try:
51
+ store = S3Store(bucket_name, **config_kwargs)
52
+ self._stores[bucket_name] = store
53
+ return store
54
+ except Exception as e:
55
+ logger.error(f"Failed to create S3Store for bucket {bucket_name}: {e}")
56
+ raise
57
+
58
+ async def cleanup(self, task_id: str) -> None:
59
+ """Removes the task-specific payload directory."""
60
+ task_dir = join(self._config.TASK_FILES_DIR, task_id)
61
+ if await exists(task_dir):
62
+ await to_thread(lambda: rmtree(task_dir, ignore_errors=True))
63
+
64
+ async def _process_s3_uri(self, uri: str, task_id: str) -> str:
65
+ """Downloads a file or a folder (if uri ends with /) from S3 and returns the local path."""
66
+ try:
67
+ parsed_url = urlparse(uri)
68
+ bucket_name = parsed_url.netloc
69
+ object_key = parsed_url.path.lstrip("/")
70
+ store = self._get_store(bucket_name)
71
+
72
+ # Use task-specific directory for isolation
73
+ local_dir_root = join(self._config.TASK_FILES_DIR, task_id)
74
+ await makedirs(local_dir_root, exist_ok=True)
75
+
76
+ logger.info(f"Starting download from S3: {uri}")
77
+
78
+ # Handle folder download (prefix)
79
+ if uri.endswith("/"):
80
+ folder_name = object_key.rstrip("/").split("/")[-1]
81
+ local_folder_path = join(local_dir_root, folder_name)
82
+
83
+ # List objects with prefix
84
+ # obstore.list returns an async iterator of ObjectMeta
85
+ files_to_download = []
86
+
87
+ # Note: obstore.list returns an async iterator.
88
+ async for obj in obstore.list(store, prefix=object_key):
89
+ key = obj.key
90
+
91
+ if key.endswith("/"):
92
+ continue
93
+
94
+ # Calculate relative path inside the folder
95
+ rel_path = key[len(object_key) :]
96
+ local_file_path = join(local_folder_path, rel_path)
97
+
98
+ await makedirs(dirname(local_file_path), exist_ok=True)
99
+ files_to_download.append((key, local_file_path))
100
+
101
+ async def _download_file(key: str, path: str) -> None:
102
+ async with self._semaphore:
103
+ result = await obstore.get(store, key)
104
+ async with aio_open(path, "wb") as f:
105
+ async for chunk in result.stream():
106
+ await f.write(chunk)
107
+
108
+ # Execute downloads in parallel
109
+ if files_to_download:
110
+ await gather(*[_download_file(k, p) for k, p in files_to_download])
111
+
112
+ logger.info(f"Successfully downloaded folder from S3: {uri} ({len(files_to_download)} files)")
113
+ return local_folder_path
114
+
115
+ # Handle single file download
116
+ local_path = join(local_dir_root, basename(object_key))
117
+
118
+ result = await obstore.get(store, object_key)
119
+ async with aio_open(local_path, "wb") as f:
120
+ async for chunk in result.stream():
121
+ await f.write(chunk)
122
+
123
+ logger.info(f"Successfully downloaded file from S3: {uri} -> {local_path}")
124
+ return local_path
125
+
126
+ except Exception as e:
127
+ # Catching generic Exception because obstore might raise different errors.
128
+ logger.exception(f"Error during download of {uri}: {e}")
129
+ raise
36
130
 
37
131
  async def _upload_to_s3(self, local_path: str) -> str:
38
- """Uploads a file to S3 and returns the S3 URI."""
132
+ """Uploads a file or a folder to S3 and returns the S3 URI."""
39
133
  bucket_name = self._config.S3_DEFAULT_BUCKET
40
- object_key = os.path.basename(local_path)
134
+ store = self._get_store(bucket_name)
135
+
136
+ logger.info(f"Starting upload to S3 from local path: {local_path}")
137
+
138
+ try:
139
+ # Handle folder upload
140
+ if await isdir(local_path):
141
+ folder_name = basename(local_path.rstrip("/"))
142
+ s3_prefix = f"{folder_name}/"
143
+
144
+ # Use to_thread to avoid blocking event loop during file walk
145
+ def _get_files_to_upload():
146
+ files_to_upload = []
147
+ for root, _, files in walk(local_path):
148
+ for file in files:
149
+ f_path = join(root, file)
150
+ rel = relpath(f_path, local_path)
151
+ files_to_upload.append((f_path, f"{s3_prefix}{rel}"))
152
+ return files_to_upload
153
+
154
+ files_list = await to_thread(_get_files_to_upload)
155
+
156
+ async def _upload_file(path: str, key: str) -> None:
157
+ async with self._semaphore:
158
+ # obstore.put accepts bytes or file-like objects.
159
+ # Since we are in async, reading small files is fine.
160
+ with open(path, "rb") as f:
161
+ await obstore.put(store, key, f)
162
+
163
+ if files_list:
164
+ # Upload in parallel
165
+ await gather(*[_upload_file(f, k) for f, k in files_list])
166
+
167
+ s3_uri = f"s3://{bucket_name}/{s3_prefix}"
168
+ logger.info(f"Successfully uploaded folder to S3: {local_path} -> {s3_uri} ({len(files_list)} files)")
169
+ return s3_uri
170
+
171
+ # Handle single file upload
172
+ object_key = basename(local_path)
173
+ with open(local_path, "rb") as f:
174
+ await obstore.put(store, object_key, f)
175
+
176
+ s3_uri = f"s3://{bucket_name}/{object_key}"
177
+ logger.info(f"Successfully uploaded file to S3: {local_path} -> {s3_uri}")
178
+ return s3_uri
41
179
 
42
- await asyncio.to_thread(self._s3.upload_file, local_path, bucket_name, object_key)
43
- return f"s3://{bucket_name}/{object_key}"
180
+ except Exception as e:
181
+ logger.exception(f"Error during upload of {local_path}: {e}")
182
+ raise
44
183
 
45
- async def process_params(self, params: dict[str, Any]) -> dict[str, Any]:
184
+ async def process_params(self, params: dict[str, Any], task_id: str) -> dict[str, Any]:
46
185
  """Recursively searches for S3 URIs in params and downloads the files."""
47
186
  if not self._config.S3_ENDPOINT_URL:
48
187
  return params
49
188
 
50
189
  async def _process(item: Any) -> Any:
51
190
  if isinstance(item, str) and item.startswith("s3://"):
52
- return await self._process_s3_uri(item)
191
+ return await self._process_s3_uri(item, task_id)
53
192
  if isinstance(item, dict):
54
193
  return {k: await _process(v) for k, v in item.items()}
55
- if isinstance(item, list):
56
- return [await _process(i) for i in item]
57
- return item
194
+ return [await _process(i) for i in item] if isinstance(item, list) else item
58
195
 
59
- return await _process(params)
196
+ return cast(dict[str, Any], await _process(params))
60
197
 
61
198
  async def process_result(self, result: dict[str, Any]) -> dict[str, Any]:
62
199
  """Recursively searches for local file paths in the result and uploads them to S3."""
@@ -64,12 +201,10 @@ class S3Manager:
64
201
  return result
65
202
 
66
203
  async def _process(item: Any) -> Any:
67
- if isinstance(item, str) and os.path.exists(item) and item.startswith(self._config.WORKER_PAYLOAD_DIR):
68
- return await self._upload_to_s3(item)
204
+ if isinstance(item, str) and item.startswith(self._config.TASK_FILES_DIR):
205
+ return await self._upload_to_s3(item) if await exists(item) else item
69
206
  if isinstance(item, dict):
70
207
  return {k: await _process(v) for k, v in item.items()}
71
- if isinstance(item, list):
72
- return [await _process(i) for i in item]
73
- return item
208
+ return [await _process(i) for i in item] if isinstance(item, list) else item
74
209
 
75
- return await _process(result)
210
+ return cast(dict[str, Any], await _process(result))
@@ -0,0 +1,97 @@
1
+ from contextlib import asynccontextmanager
2
+ from os.path import dirname, join
3
+ from typing import AsyncGenerator
4
+
5
+ from aiofiles import open as aiopen
6
+ from aiofiles.os import listdir, makedirs
7
+ from aiofiles.ospath import exists as aio_exists
8
+
9
+
10
+ class TaskFiles:
11
+ """
12
+ A helper class for managing task-specific files.
13
+ Provides asynchronous lazy directory creation and high-level file operations
14
+ within an isolated workspace for each task.
15
+ """
16
+
17
+ def __init__(self, task_dir: str):
18
+ """
19
+ Initializes TaskFiles with a specific task directory.
20
+ The directory is not created until needed.
21
+ """
22
+ self._task_dir = task_dir
23
+
24
+ async def get_root(self) -> str:
25
+ """
26
+ Asynchronously returns the root directory for the task.
27
+ Creates the directory on disk if it doesn't exist.
28
+ """
29
+ await makedirs(self._task_dir, exist_ok=True)
30
+ return self._task_dir
31
+
32
+ async def path_to(self, filename: str) -> str:
33
+ """
34
+ Asynchronously returns an absolute path for a file within the task directory.
35
+ Guarantees that the task root directory exists.
36
+ """
37
+ root = await self.get_root()
38
+ return join(root, filename)
39
+
40
+ @asynccontextmanager
41
+ async def open(self, filename: str, mode: str = "r") -> AsyncGenerator:
42
+ """
43
+ An asynchronous context manager to open a file within the task directory.
44
+ Automatically creates the task root and any necessary subdirectories.
45
+
46
+ Args:
47
+ filename: Name or relative path of the file.
48
+ mode: File opening mode (e.g., 'r', 'w', 'a', 'rb', 'wb').
49
+ """
50
+ path = await self.path_to(filename)
51
+ # Ensure directory for the file itself exists if filename contains subdirectories
52
+ file_dir = dirname(path)
53
+ if file_dir != self._task_dir:
54
+ await makedirs(file_dir, exist_ok=True)
55
+
56
+ async with aiopen(path, mode) as f:
57
+ yield f
58
+
59
+ async def read(self, filename: str, mode: str = "r") -> str | bytes:
60
+ """
61
+ Asynchronously reads the entire content of a file.
62
+
63
+ Args:
64
+ filename: Name of the file to read.
65
+ mode: Mode to open the file in (defaults to 'r').
66
+ """
67
+ async with self.open(filename, mode) as f:
68
+ return await f.read()
69
+
70
+ async def write(self, filename: str, data: str | bytes, mode: str = "w") -> None:
71
+ """
72
+ Asynchronously writes data to a file. Creates or overwrites the file by default.
73
+
74
+ Args:
75
+ filename: Name of the file to write.
76
+ data: Content to write (string or bytes).
77
+ mode: Mode to open the file in (defaults to 'w').
78
+ """
79
+ async with self.open(filename, mode) as f:
80
+ await f.write(data)
81
+
82
+ async def list(self) -> list[str]:
83
+ """
84
+ Asynchronously lists all file and directory names within the task root.
85
+ """
86
+ root = await self.get_root()
87
+ return await listdir(root)
88
+
89
+ async def exists(self, filename: str) -> bool:
90
+ """
91
+ Asynchronously checks if a specific file or directory exists in the task root.
92
+ """
93
+ path = join(self._task_dir, filename)
94
+ return await aio_exists(path)
95
+
96
+ def __repr__(self):
97
+ return f"<TaskFiles root='{self._task_dir}'>"
@@ -1,8 +1,21 @@
1
- # Error codes for worker task results
2
- TRANSIENT_ERROR = "TRANSIENT_ERROR"
3
- PERMANENT_ERROR = "PERMANENT_ERROR"
4
- INVALID_INPUT_ERROR = "INVALID_INPUT_ERROR"
1
+ from .constants import (
2
+ ERROR_CODE_INVALID_INPUT as INVALID_INPUT_ERROR,
3
+ )
4
+ from .constants import (
5
+ ERROR_CODE_PERMANENT as PERMANENT_ERROR,
6
+ )
7
+ from .constants import (
8
+ ERROR_CODE_TRANSIENT as TRANSIENT_ERROR,
9
+ )
5
10
 
6
11
 
7
12
  class ParamValidationError(Exception):
8
13
  """Custom exception for parameter validation errors."""
14
+
15
+
16
+ __all__ = [
17
+ "INVALID_INPUT_ERROR",
18
+ "PERMANENT_ERROR",
19
+ "TRANSIENT_ERROR",
20
+ "ParamValidationError",
21
+ ]
@@ -1,16 +1,26 @@
1
1
  from asyncio import CancelledError, Event, Task, create_task, gather, run, sleep
2
- from asyncio import TimeoutError as AsyncTimeoutError
3
2
  from dataclasses import is_dataclass
4
3
  from inspect import Parameter, signature
5
4
  from json import JSONDecodeError
6
5
  from logging import getLogger
6
+ from os.path import join
7
7
  from typing import Any, Callable
8
8
 
9
- from aiohttp import ClientError, ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType, web
9
+ from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType, web
10
10
 
11
+ from .client import OrchestratorClient
11
12
  from .config import WorkerConfig
13
+ from .constants import (
14
+ COMMAND_CANCEL_TASK,
15
+ ERROR_CODE_INVALID_INPUT,
16
+ ERROR_CODE_PERMANENT,
17
+ ERROR_CODE_TRANSIENT,
18
+ TASK_STATUS_CANCELLED,
19
+ TASK_STATUS_FAILURE,
20
+ )
12
21
  from .s3 import S3Manager
13
- from .types import INVALID_INPUT_ERROR, PERMANENT_ERROR, TRANSIENT_ERROR, ParamValidationError
22
+ from .task_files import TaskFiles
23
+ from .types import ParamValidationError
14
24
 
15
25
  try:
16
26
  from pydantic import BaseModel, ValidationError
@@ -43,7 +53,7 @@ class Worker:
43
53
  self._s3_manager = S3Manager(self._config)
44
54
  self._config.WORKER_TYPE = worker_type # Allow overriding worker_type
45
55
  if max_concurrent_tasks is not None:
46
- self._config.max_concurrent_tasks = max_concurrent_tasks
56
+ self._config.MAX_CONCURRENT_TASKS = max_concurrent_tasks
47
57
 
48
58
  self._task_type_limits = task_type_limits or {}
49
59
  self._task_handlers: dict[str, dict[str, Any]] = {}
@@ -57,10 +67,8 @@ class Worker:
57
67
  self._http_session = http_session
58
68
  self._session_is_managed_externally = http_session is not None
59
69
  self._ws_connection: ClientWebSocketResponse | None = None
60
- # Removed: self._headers = {"X-Worker-Token": self._config.WORKER_TOKEN}
61
70
  self._shutdown_event = Event()
62
71
  self._registered_event = Event()
63
- self._round_robin_index = 0
64
72
  self._debounce_task: Task | None = None
65
73
 
66
74
  # --- Weighted Round-Robin State ---
@@ -70,7 +78,28 @@ class Worker:
70
78
  o["current_weight"] = 0
71
79
  self._total_orchestrator_weight += o.get("weight", 1)
72
80
 
73
- def _validate_config(self):
81
+ self._clients: list[tuple[dict[str, Any], OrchestratorClient]] = []
82
+ if self._http_session:
83
+ self._init_clients()
84
+
85
+ def _init_clients(self):
86
+ """Initializes OrchestratorClient instances for each configured orchestrator."""
87
+ if not self._http_session:
88
+ return
89
+ self._clients = [
90
+ (
91
+ o,
92
+ OrchestratorClient(
93
+ session=self._http_session,
94
+ base_url=o["url"],
95
+ worker_id=self._config.WORKER_ID,
96
+ token=o.get("token", self._config.WORKER_TOKEN),
97
+ ),
98
+ )
99
+ for o in self._config.ORCHESTRATORS
100
+ ]
101
+
102
+ def _validate_task_types(self):
74
103
  """Checks for unused task type limits and warns the user."""
75
104
  registered_task_types = {
76
105
  handler_data["type"] for handler_data in self._task_handlers.values() if handler_data["type"]
@@ -138,32 +167,31 @@ class Worker:
138
167
  status = "idle" if supported_tasks else "busy"
139
168
  return {"status": status, "supported_tasks": supported_tasks}
140
169
 
141
- def _get_headers(self, orchestrator: dict[str, Any]) -> dict[str, str]:
142
- """Builds authentication headers for a specific orchestrator."""
143
- token = orchestrator.get("token", self._config.WORKER_TOKEN)
144
- return {"X-Worker-Token": token}
145
-
146
- def _get_next_orchestrator(self) -> dict[str, Any] | None:
170
+ def _get_next_client(self) -> OrchestratorClient | None:
147
171
  """
148
- Selects the next orchestrator using a smooth weighted round-robin algorithm.
172
+ Selects the next orchestrator client using a smooth weighted round-robin algorithm.
149
173
  """
150
- if not self._config.ORCHESTRATORS:
174
+ if not self._clients:
151
175
  return None
152
176
 
153
177
  # The orchestrator with the highest current_weight is selected.
154
- selected_orchestrator = None
178
+ selected_client = None
155
179
  highest_weight = -1
156
180
 
157
- for o in self._config.ORCHESTRATORS:
181
+ for o, client in self._clients:
158
182
  o["current_weight"] += o["weight"]
159
183
  if o["current_weight"] > highest_weight:
160
184
  highest_weight = o["current_weight"]
161
- selected_orchestrator = o
185
+ selected_client = client
162
186
 
163
- if selected_orchestrator:
164
- selected_orchestrator["current_weight"] -= self._total_orchestrator_weight
187
+ if selected_client:
188
+ # Find the config for the selected client to decrement its weight
189
+ for o, client in self._clients:
190
+ if client == selected_client:
191
+ o["current_weight"] -= self._total_orchestrator_weight
192
+ break
165
193
 
166
- return selected_orchestrator
194
+ return selected_client
167
195
 
168
196
  async def _debounced_heartbeat_sender(self):
169
197
  """Waits for the debounce delay then sends a heartbeat."""
@@ -178,34 +206,27 @@ class Worker:
178
206
  # Schedule the new debounced call.
179
207
  self._debounce_task = create_task(self._debounced_heartbeat_sender())
180
208
 
181
- async def _poll_for_tasks(self, orchestrator: dict[str, Any]):
209
+ async def _poll_for_tasks(self, client: OrchestratorClient):
182
210
  """Polls a specific Orchestrator for new tasks."""
183
- url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}/tasks/next"
184
- try:
185
- if not self._http_session:
186
- return
187
- timeout = ClientTimeout(total=self._config.TASK_POLL_TIMEOUT + 5)
188
- headers = self._get_headers(orchestrator)
189
- async with self._http_session.get(url, headers=headers, timeout=timeout) as resp:
190
- if resp.status == 200:
191
- task_data = await resp.json()
192
- task_data["orchestrator"] = orchestrator
193
-
194
- self._current_load += 1
195
- task_handler_info = self._task_handlers.get(task_data["type"])
196
- if task_handler_info:
197
- task_type_for_limit = task_handler_info.get("type")
198
- if task_type_for_limit:
199
- self._current_load_by_type[task_type_for_limit] += 1
200
- self._schedule_heartbeat_debounce()
201
-
202
- task = create_task(self._process_task(task_data))
203
- self._active_tasks[task_data["task_id"]] = task
204
- elif resp.status != 204:
205
- await sleep(self._config.TASK_POLL_ERROR_DELAY)
206
- except (AsyncTimeoutError, ClientError) as e:
207
- logger.error(f"Error polling for tasks: {e}")
208
- await sleep(self._config.TASK_POLL_ERROR_DELAY)
211
+ task_data = await client.poll_task(timeout=self._config.TASK_POLL_TIMEOUT)
212
+ if task_data:
213
+ task_data["client"] = client
214
+
215
+ self._current_load += 1
216
+ if (task_handler_info := self._task_handlers.get(task_data["type"])) and (
217
+ task_type_for_limit := task_handler_info.get("type")
218
+ ):
219
+ self._current_load_by_type[task_type_for_limit] += 1
220
+ self._schedule_heartbeat_debounce()
221
+
222
+ task = create_task(self._process_task(task_data))
223
+ self._active_tasks[task_data["task_id"]] = task
224
+ else:
225
+ # If no task but it was a 204 or error, the client already handled/logged it.
226
+ # We might want a short sleep here if it was an error, but client.poll_task
227
+ # doesn't distinguish between 204 and error currently.
228
+ # However, the previous logic only slept on status != 204.
229
+ pass
209
230
 
210
231
  async def _start_polling(self):
211
232
  """The main loop for polling tasks."""
@@ -217,19 +238,19 @@ class Worker:
217
238
  continue
218
239
 
219
240
  if self._config.MULTI_ORCHESTRATOR_MODE == "ROUND_ROBIN":
220
- orchestrator = self._get_next_orchestrator()
221
- if orchestrator:
222
- await self._poll_for_tasks(orchestrator)
241
+ if client := self._get_next_client():
242
+ await self._poll_for_tasks(client)
223
243
  else:
224
- for orchestrator in self._config.ORCHESTRATORS:
244
+ for _, client in self._clients:
225
245
  if self._get_current_state()["status"] == "busy":
226
246
  break
227
- await self._poll_for_tasks(orchestrator)
247
+ await self._poll_for_tasks(client)
228
248
 
229
249
  if self._current_load == 0:
230
250
  await sleep(self._config.IDLE_POLL_DELAY)
231
251
 
232
- def _prepare_task_params(self, handler: Callable, params: dict[str, Any]) -> Any:
252
+ @staticmethod
253
+ def _prepare_task_params(handler: Callable, params: dict[str, Any]) -> Any:
233
254
  """
234
255
  Inspects the handler's signature to validate and instantiate params.
235
256
  Supports dict, dataclasses, and optional pydantic models.
@@ -261,8 +282,7 @@ class Worker:
261
282
  if f.default is Parameter.empty and f.default_factory is Parameter.empty
262
283
  ]
263
284
 
264
- missing_fields = [f for f in required_fields if f not in filtered_params]
265
- if missing_fields:
285
+ if missing_fields := [f for f in required_fields if f not in filtered_params]:
266
286
  raise ParamValidationError(f"Missing required fields for dataclass: {', '.join(missing_fields)}")
267
287
 
268
288
  return params_annotation(**filtered_params)
@@ -272,10 +292,24 @@ class Worker:
272
292
 
273
293
  return params
274
294
 
295
+ def _prepare_dependencies(self, handler: Callable, task_id: str) -> dict[str, Any]:
296
+ """Injects dependencies based on type hints."""
297
+ deps = {}
298
+ task_dir = join(self._config.TASK_FILES_DIR, task_id)
299
+ # Always create the object, but directory is lazy
300
+ task_files = TaskFiles(task_dir)
301
+
302
+ sig = signature(handler)
303
+ for name, param in sig.parameters.items():
304
+ if param.annotation is TaskFiles:
305
+ deps[name] = task_files
306
+
307
+ return deps
308
+
275
309
  async def _process_task(self, task_data: dict[str, Any]):
276
310
  """Executes the task logic."""
277
311
  task_id, job_id, task_name = task_data["task_id"], task_data["job_id"], task_data["type"]
278
- params, orchestrator = task_data.get("params", {}), task_data["orchestrator"]
312
+ params, client = task_data.get("params", {}), task_data["client"]
279
313
 
280
314
  result: dict[str, Any] = {}
281
315
  handler_data = self._task_handlers.get(task_name)
@@ -287,14 +321,17 @@ class Worker:
287
321
  if not handler_data:
288
322
  message = f"Unsupported task: {task_name}"
289
323
  logger.warning(message)
290
- result = {"status": "failure", "error": {"code": PERMANENT_ERROR, "message": message}}
324
+ result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_PERMANENT, "message": message}}
291
325
  payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
292
- await self._send_result(payload, orchestrator)
326
+ await client.send_result(
327
+ payload, self._config.RESULT_MAX_RETRIES, self._config.RESULT_RETRY_INITIAL_DELAY
328
+ )
293
329
  result_sent = True # Mark result as sent
294
330
  return
295
331
 
296
- params = await self._s3_manager.process_params(params)
332
+ params = await self._s3_manager.process_params(params, task_id)
297
333
  validated_params = self._prepare_task_params(handler_data["func"], params)
334
+ deps = self._prepare_dependencies(handler_data["func"], task_id)
298
335
 
299
336
  result = await handler_data["func"](
300
337
  validated_params,
@@ -304,23 +341,29 @@ class Worker:
304
341
  send_progress=self.send_progress,
305
342
  add_to_hot_cache=self.add_to_hot_cache,
306
343
  remove_from_hot_cache=self.remove_from_hot_cache,
344
+ **deps,
307
345
  )
308
346
  result = await self._s3_manager.process_result(result)
309
347
  except ParamValidationError as e:
310
348
  logger.error(f"Task {task_id} failed validation: {e}")
311
- result = {"status": "failure", "error": {"code": INVALID_INPUT_ERROR, "message": str(e)}}
349
+ result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_INVALID_INPUT, "message": str(e)}}
312
350
  except CancelledError:
313
351
  logger.info(f"Task {task_id} was cancelled.")
314
- result = {"status": "cancelled"}
352
+ result = {"status": TASK_STATUS_CANCELLED}
315
353
  # We must re-raise the exception to be handled by the outer gather
316
354
  raise
317
355
  except Exception as e:
318
356
  logger.exception(f"An unexpected error occurred while processing task {task_id}:")
319
- result = {"status": "failure", "error": {"code": TRANSIENT_ERROR, "message": str(e)}}
357
+ result = {"status": TASK_STATUS_FAILURE, "error": {"code": ERROR_CODE_TRANSIENT, "message": str(e)}}
320
358
  finally:
359
+ # Cleanup task workspace
360
+ await self._s3_manager.cleanup(task_id)
361
+
321
362
  if not result_sent: # Only send if not already sent
322
363
  payload = {"job_id": job_id, "task_id": task_id, "worker_id": self._config.WORKER_ID, "result": result}
323
- await self._send_result(payload, orchestrator)
364
+ await client.send_result(
365
+ payload, self._config.RESULT_MAX_RETRIES, self._config.RESULT_RETRY_INITIAL_DELAY
366
+ )
324
367
  self._active_tasks.pop(task_id, None)
325
368
 
326
369
  self._current_load -= 1
@@ -328,21 +371,6 @@ class Worker:
328
371
  self._current_load_by_type[task_type_for_limit] -= 1
329
372
  self._schedule_heartbeat_debounce()
330
373
 
331
- async def _send_result(self, payload: dict[str, Any], orchestrator: dict[str, Any]):
332
- """Sends the result to a specific orchestrator."""
333
- url = f"{orchestrator['url']}/_worker/tasks/result"
334
- delay = self._config.RESULT_RETRY_INITIAL_DELAY
335
- headers = self._get_headers(orchestrator)
336
- for i in range(self._config.RESULT_MAX_RETRIES):
337
- try:
338
- if self._http_session and not self._http_session.closed:
339
- async with self._http_session.post(url, json=payload, headers=headers) as resp:
340
- if resp.status == 200:
341
- return
342
- except ClientError as e:
343
- logger.error(f"Error sending result: {e}")
344
- await sleep(delay * (2**i))
345
-
346
374
  async def _manage_orchestrator_communications(self):
347
375
  """Registers the worker and sends heartbeats."""
348
376
  await self._register_with_all_orchestrators()
@@ -369,17 +397,7 @@ class Worker:
369
397
  "ip_address": self._config.IP_ADDRESS,
370
398
  "resources": self._config.RESOURCES,
371
399
  }
372
- for orchestrator in self._config.ORCHESTRATORS:
373
- url = f"{orchestrator['url']}/_worker/workers/register"
374
- try:
375
- if self._http_session:
376
- async with self._http_session.post(
377
- url, json=payload, headers=self._get_headers(orchestrator)
378
- ) as resp:
379
- if resp.status >= 400:
380
- logger.error(f"Error registering with {orchestrator['url']}: {resp.status}")
381
- except ClientError as e:
382
- logger.error(f"Error registering with orchestrator {orchestrator['url']}: {e}")
400
+ await gather(*[client.register(payload) for _, client in self._clients])
383
401
 
384
402
  async def _send_heartbeats_to_all(self):
385
403
  """Sends heartbeat messages to all orchestrators."""
@@ -399,24 +417,15 @@ class Worker:
399
417
  if hot_skills:
400
418
  payload["hot_skills"] = hot_skills
401
419
 
402
- async def _send_single(orchestrator: dict[str, Any]):
403
- url = f"{orchestrator['url']}/_worker/workers/{self._config.WORKER_ID}"
404
- headers = self._get_headers(orchestrator)
405
- try:
406
- if self._http_session and not self._http_session.closed:
407
- async with self._http_session.patch(url, json=payload, headers=headers) as resp:
408
- if resp.status >= 400:
409
- logger.warning(f"Heartbeat to {orchestrator['url']} failed with status: {resp.status}")
410
- except ClientError as e:
411
- logger.error(f"Error sending heartbeat to orchestrator {orchestrator['url']}: {e}")
412
-
413
- await gather(*[_send_single(o) for o in self._config.ORCHESTRATORS])
420
+ await gather(*[client.send_heartbeat(payload) for _, client in self._clients])
414
421
 
415
422
  async def main(self):
416
423
  """The main asynchronous function."""
417
- self._validate_config() # Validate config now that all tasks are registered
424
+ self._config.validate()
425
+ self._validate_task_types() # Validate config now that all tasks are registered
418
426
  if not self._http_session:
419
427
  self._http_session = ClientSession()
428
+ self._init_clients()
420
429
 
421
430
  comm_task = create_task(self._manage_orchestrator_communications())
422
431
 
@@ -442,7 +451,11 @@ class Worker:
442
451
 
443
452
  async def _run_health_check_server(self):
444
453
  app = web.Application()
445
- app.router.add_get("/health", lambda r: web.Response(text="OK"))
454
+
455
+ async def health_handler(_):
456
+ return web.Response(text="OK")
457
+
458
+ app.router.add_get("/health", health_handler)
446
459
  runner = web.AppRunner(app)
447
460
  await runner.setup()
448
461
  site = web.TCPSite(runner, "0.0.0.0", self._config.WORKER_PORT)
@@ -459,25 +472,20 @@ class Worker:
459
472
  except KeyboardInterrupt:
460
473
  self._shutdown_event.set()
461
474
 
462
- # WebSocket methods omitted for brevity as they are not relevant to the changes
463
475
  async def _start_websocket_manager(self):
464
476
  """Manages the WebSocket connection to the orchestrator."""
465
477
  while not self._shutdown_event.is_set():
466
- for orchestrator in self._config.ORCHESTRATORS:
467
- ws_url = orchestrator["url"].replace("http", "ws", 1) + "/_worker/ws"
478
+ # In multi-orchestrator mode, we currently only connect to the first one available
479
+ for _, client in self._clients:
468
480
  try:
469
- if self._http_session:
470
- async with self._http_session.ws_connect(ws_url, headers=self._get_headers(orchestrator)) as ws:
471
- self._ws_connection = ws
472
- logger.info(f"WebSocket connection established to {ws_url}")
473
- await self._listen_for_commands()
474
- except (ClientError, AsyncTimeoutError) as e:
475
- logger.warning(f"WebSocket connection to {ws_url} failed: {e}")
481
+ ws = await client.connect_websocket()
482
+ if ws:
483
+ self._ws_connection = ws
484
+ await self._listen_for_commands()
476
485
  finally:
477
486
  self._ws_connection = None
478
- logger.info(f"WebSocket connection to {ws_url} closed.")
479
487
  await sleep(5) # Reconnection delay
480
- if not self._config.ORCHESTRATORS:
488
+ if not self._clients:
481
489
  await sleep(5)
482
490
 
483
491
  async def _listen_for_commands(self):
@@ -490,7 +498,7 @@ class Worker:
490
498
  if msg.type == WSMsgType.TEXT:
491
499
  try:
492
500
  command = msg.json()
493
- if command.get("type") == "cancel_task":
501
+ if command.get("type") == COMMAND_CANCEL_TASK:
494
502
  task_id = command.get("task_id")
495
503
  if task_id in self._active_tasks:
496
504
  self._active_tasks[task_id].cancel()
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: avtomatika-worker
3
- Version: 1.0b1
3
+ Version: 1.0b3
4
4
  Summary: Worker SDK for the Avtomatika orchestrator.
5
- Project-URL: Homepage, https://github.com/avtomatila-ai/avtomatika-worker
6
- Project-URL: Bug Tracker, https://github.com/avtomatila-ai/avtomatika-worker/issues
5
+ Project-URL: Homepage, https://github.com/avtomatika-ai/avtomatika-worker
6
+ Project-URL: Bug Tracker, https://github.com/avtomatika-ai/avtomatika-worker/issues
7
7
  Classifier: Development Status :: 4 - Beta
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Classifier: License :: OSI Approved :: MIT License
@@ -13,22 +13,22 @@ Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  Requires-Dist: aiohttp~=3.13.2
15
15
  Requires-Dist: python-json-logger~=4.0.0
16
- Requires-Dist: aioboto3~=13.0.0
16
+ Requires-Dist: obstore>=0.1
17
+ Requires-Dist: aiofiles~=25.1.0
17
18
  Provides-Extra: test
18
19
  Requires-Dist: pytest; extra == "test"
19
20
  Requires-Dist: pytest-asyncio; extra == "test"
20
21
  Requires-Dist: aioresponses; extra == "test"
21
22
  Requires-Dist: pytest-mock; extra == "test"
22
23
  Requires-Dist: pydantic; extra == "test"
23
- Requires-Dist: moto[server]; extra == "test"
24
- Requires-Dist: aiofiles; extra == "test"
24
+ Requires-Dist: types-aiofiles; extra == "test"
25
25
  Provides-Extra: pydantic
26
26
  Requires-Dist: pydantic; extra == "pydantic"
27
27
  Dynamic: license-file
28
28
 
29
29
  # Avtomatika Worker SDK
30
30
 
31
- This is an SDK for creating workers compatible with the **Avtomatika** orchestrator. The SDK handles all the complexity of interacting with the orchestrator, allowing you to focus on writing your business logic.
31
+ This is the official SDK for creating workers compatible with the **[Avtomatika Orchestrator](https://github.com/avtomatika-ai/avtomatika)**. It implements the **[RCA Protocol](https://github.com/avtomatika-ai/rca)**, handling all communication complexity (polling, heartbeats, S3 offloading) so you can focus on writing your business logic.
32
32
 
33
33
  ## Installation
34
34
 
@@ -434,18 +434,92 @@ The `ORCHESTRATORS_CONFIG` variable must contain a JSON string. Each object in t
434
434
 
435
435
 
436
436
 
437
- ### 5. Handling Large Files (S3 Payload Offloading)
438
437
 
439
- The SDK supports working with large files "out of the box" via S3-compatible storage.
440
438
 
441
- - **Automatic Download**: If a value in `params` is a URI of the form `s3://...`, the SDK will automatically download the file to the local disk and replace the URI in `params` with the local path.
442
- - **Automatic Upload**: If your handler returns a local file path in `data` (located within the `WORKER_PAYLOAD_DIR` directory), the SDK will automatically upload this file to S3 and replace the path with an `s3://` URI in the final result.
439
+ ### 5. File System Helper (TaskFiles)
443
440
 
444
- This functionality is transparent to your code and only requires configuring environment variables for S3 access.
441
+ To simplify working with temporary files and paths, the SDK provides a `TaskFiles` helper class. It automatically manages directory creation within the isolated task folder and provides an asynchronous interface for file operations. Just add an argument typed as `TaskFiles` to your handler:
445
442
 
446
- ### 6. WebSocket Support
443
+ ```python
444
+ from avtomatika_worker import Worker, TaskFiles
445
+
446
+ @worker.task("generate_report")
447
+ async def generate_report(params: dict, files: TaskFiles, **kwargs):
448
+ # 1. Easy read/write
449
+ await files.write("data.json", '{"status": "ok"}')
450
+ content = await files.read("data.json")
451
+
452
+ # 2. Get path (directory is created automatically)
453
+ output_path = await files.path_to("report.pdf")
454
+
455
+ # 3. Check and list files
456
+ if await files.exists("input.jpg"):
457
+ file_list = await files.list()
458
+
459
+ return {"data": {"report": output_path}}
460
+ ```
461
+
462
+ **Available Methods (all asynchronous):**
463
+ - `await path_to(name)` — returns the full path to a file (ensures the task directory exists).
464
+ - `await read(name, mode='r')` — reads the entire file.
465
+ - `await write(name, data, mode='w')` — writes data to a file.
466
+ - `await list()` — lists filenames in the task directory.
467
+ - `await exists(name)` — checks if a file exists.
468
+ - `async with open(name, mode)` — async context manager for advanced usage.
469
+
470
+ > **Note: Automatic Cleanup**
471
+ >
472
+ > The SDK automatically deletes the entire task directory (including everything created via `TaskFiles`) immediately after the task completes and the result is sent.
473
+
474
+ ### 6. Handling Large Files (S3 Payload Offloading)
475
+
476
+ The SDK supports working with large files "out of the box" via S3-compatible storage, using the high-performance **`obstore`** library (Rust-based).
477
+
478
+ - **Automatic Download**: If a value in `params` is a URI of the form `s3://...`, the SDK will automatically download the file to the local disk and replace the URI in `params` with the local path. **If the URI ends with `/` (e.g., `s3://bucket/data/`), the SDK treats it as a folder prefix and recursively downloads all matching objects into a local directory.**
479
+ - **Automatic Upload**: If your handler returns a local file path in `data` (located within the `TASK_FILES_DIR` directory), the SDK will automatically upload this file to S3 and replace the path with an `s3://` URI in the final result. **If the path is a directory, the SDK recursively uploads all files within it.**
480
+
481
+ This functionality is transparent to your code.
482
+
483
+ #### S3 Example
484
+
485
+ Suppose the orchestrator sends a task with `{"input_image": "s3://my-bucket/photo.jpg"}`:
486
+
487
+ ```python
488
+ import os
489
+ from avtomatika_worker import Worker, TaskFiles
490
+
491
+ worker = Worker(worker_type="image-worker")
492
+
493
+ @worker.task("process_image")
494
+ async def handle_image(params: dict, files: TaskFiles, **kwargs):
495
+ # SDK has already downloaded the file.
496
+ # 'input_image' now contains a local path like '/tmp/payloads/task-id/photo.jpg'
497
+ local_input = params["input_image"]
498
+ local_output = await files.path_to("processed.png")
499
+
500
+ # Your logic here (using local files)
501
+ # ... image processing ...
502
+
503
+ # Return the local path of the result.
504
+ # The SDK will upload it back to S3 automatically.
505
+ return {
506
+ "status": "success",
507
+ "data": {
508
+ "output_image": local_output
509
+ }
510
+ }
511
+ ```
512
+
513
+ This only requires configuring environment variables for S3 access (see Full Configuration Reference).
514
+
515
+ > **Important: S3 Consistency**
516
+ >
517
+ > The SDK **does not validate** that the Worker and Orchestrator share the same storage backend. You must ensure that:
518
+ > 1. The Worker can reach the `S3_ENDPOINT_URL` used by the Orchestrator.
519
+ > 2. The Worker's credentials allow reading from the buckets referenced in the incoming `s3://` URIs.
520
+ > 3. The Worker's credentials allow writing to the `S3_DEFAULT_BUCKET`.
447
521
 
448
- If enabled, the SDK establishes a persistent WebSocket connection with the orchestrator to receive real-time commands, such as canceling an ongoing task.
522
+ ### 7. WebSocket Support
449
523
 
450
524
  ## Advanced Features
451
525
 
@@ -522,11 +596,12 @@ The worker is fully configured via environment variables.
522
596
  | `TASK_POLL_TIMEOUT` | The timeout in seconds for polling for new tasks. | `30` |
523
597
  | `TASK_POLL_ERROR_DELAY` | The delay in seconds before retrying after a polling error. | `5.0` |
524
598
  | `IDLE_POLL_DELAY` | The delay in seconds between polls when the worker is idle. | `0.01` |
525
- | `WORKER_PAYLOAD_DIR` | The directory for temporarily storing files when working with S3. | `/tmp/payloads` |
599
+ | `TASK_FILES_DIR` | The directory for temporarily storing files when working with S3. | `/tmp/payloads` |
526
600
  | `S3_ENDPOINT_URL` | The URL of the S3-compatible storage. | - |
527
601
  | `S3_ACCESS_KEY` | The access key for S3. | - |
528
602
  | `S3_SECRET_KEY` | The secret key for S3. | - |
529
603
  | `S3_DEFAULT_BUCKET` | The default bucket name for uploading results. | `avtomatika-payloads` |
604
+ | `S3_REGION` | The region for S3 storage (required by some providers). | `us-east-1` |
530
605
 
531
606
  ## Development
532
607
 
@@ -0,0 +1,14 @@
1
+ avtomatika_worker/__init__.py,sha256=y_s5KlsgFu7guemZfjLVQ3Jzq7DyLG168-maVGwWRC4,334
2
+ avtomatika_worker/client.py,sha256=mkvwrMY8tAaZN_lwMSxWHmAoWsDemD-WiKSeH5fM6GI,4173
3
+ avtomatika_worker/config.py,sha256=NaAhufpwyG6CsHW-cXmqR3MfGp_5SdDZ_vEhmmV8G3g,5819
4
+ avtomatika_worker/constants.py,sha256=DfGR_YkW9rbioCorKpNGfZ0i_0iGgMq2swyJhVl9nNA,669
5
+ avtomatika_worker/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ avtomatika_worker/s3.py,sha256=sAuXBp__XTVhyNwK9gsxy1Jm_udJM6ypssGTZ00pa6U,8847
7
+ avtomatika_worker/task_files.py,sha256=ucjBuI78UmtMvfucTzDTNJ1g0KJaRIwyshRNTipIZSU,3351
8
+ avtomatika_worker/types.py,sha256=dSNsHgqV6hZhOt4eUK2PDWB6lrrwCA5_T_iIBI_wTZ0,442
9
+ avtomatika_worker/worker.py,sha256=XSRfLO-W0J6WG128Iu-rL_w3-PqmsWQMUElVLi3Z1gk,21904
10
+ avtomatika_worker-1.0b3.dist-info/licenses/LICENSE,sha256=tqCjw9Y1vbU-hLcWi__7wQstLbt2T1XWPdbQYqCxuWY,1072
11
+ avtomatika_worker-1.0b3.dist-info/METADATA,sha256=yiEtJuMv5WHYHfScna7cF5QAvAUhMCJdUDENHvrMRFY,29601
12
+ avtomatika_worker-1.0b3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ avtomatika_worker-1.0b3.dist-info/top_level.txt,sha256=d3b5BUeUrHM1Cn-cbStz-hpucikEBlPOvtcmQ_j3qAs,18
14
+ avtomatika_worker-1.0b3.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- avtomatika_worker/__init__.py,sha256=j0up34aVy7xyI67xg04TVbXSSSKGdO49vsBKhtH_D0M,287
2
- avtomatika_worker/config.py,sha256=k1p1Njh7CVWU1PaJDqA6jsf9GXoH0iM4og_o0V0gPHI,5260
3
- avtomatika_worker/s3.py,sha256=7aC7k90kLUEwBVLWvvcLapaF6gIph_7b3XGXYdtCiNU,2895
4
- avtomatika_worker/types.py,sha256=MqXaX0NUatYDna3GgBWj73-WOT1EfaX1ei4i7eUsZR0,255
5
- avtomatika_worker/worker.py,sha256=afePnSQp_aFAb6qDk6HAuI81PlwIZiK52fI6QY6aQQE,23234
6
- avtomatika_worker-1.0b1.dist-info/licenses/LICENSE,sha256=tqCjw9Y1vbU-hLcWi__7wQstLbt2T1XWPdbQYqCxuWY,1072
7
- avtomatika_worker-1.0b1.dist-info/METADATA,sha256=IMgLX7vO1jq_iQtm4TSU-XEf3hUxoViMfMzQNHfbjgg,26326
8
- avtomatika_worker-1.0b1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- avtomatika_worker-1.0b1.dist-info/top_level.txt,sha256=d3b5BUeUrHM1Cn-cbStz-hpucikEBlPOvtcmQ_j3qAs,18
10
- avtomatika_worker-1.0b1.dist-info/RECORD,,