avtomatika-worker 1.0a2__py3-none-any.whl → 1.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,9 +2,10 @@
2
2
 
3
3
  from importlib.metadata import PackageNotFoundError, version
4
4
 
5
+ from .task_files import TaskFiles
5
6
  from .worker import Worker
6
7
 
7
- __all__ = ["Worker"]
8
+ __all__ = ["Worker", "TaskFiles"]
8
9
 
9
10
  try:
10
11
  __version__ = version("avtomatika-worker")
@@ -12,94 +12,104 @@ class WorkerConfig:
12
12
 
13
13
  def __init__(self):
14
14
  # --- Basic worker information ---
15
- self.worker_id: str = getenv("WORKER_ID", f"worker-{uuid4()}")
16
- self.worker_type: str = getenv("WORKER_TYPE", "generic-cpu-worker")
17
- self.worker_port: int = int(getenv("WORKER_PORT", "8083"))
18
- self.hostname: str = gethostname()
15
+ self.WORKER_ID: str = getenv("WORKER_ID", f"worker-{uuid4()}")
16
+ self.WORKER_TYPE: str = getenv("WORKER_TYPE", "generic-cpu-worker")
17
+ self.WORKER_PORT: int = int(getenv("WORKER_PORT", "8083"))
18
+ self.HOSTNAME: str = gethostname()
19
19
  try:
20
- self.ip_address: str = gethostbyname(self.hostname)
20
+ self.IP_ADDRESS: str = gethostbyname(self.HOSTNAME)
21
21
  except gaierror:
22
- self.ip_address: str = "127.0.0.1"
22
+ self.IP_ADDRESS: str = "127.0.0.1"
23
23
 
24
24
  # --- Orchestrator settings ---
25
- self.orchestrators: list[dict[str, Any]] = self._get_orchestrators_config()
25
+ self.ORCHESTRATORS: list[dict[str, Any]] = self._get_orchestrators_config()
26
26
 
27
27
  # --- Security ---
28
- self.worker_token: str = getenv(
28
+ self.WORKER_TOKEN: str = getenv(
29
29
  "WORKER_INDIVIDUAL_TOKEN",
30
30
  getenv("WORKER_TOKEN", "your-secret-worker-token"),
31
31
  )
32
32
 
33
33
  # --- Resources and performance ---
34
- self.cost_per_second: float = float(getenv("WORKER_COST_PER_SECOND", "0.01"))
35
- self.max_concurrent_tasks: int = int(getenv("MAX_CONCURRENT_TASKS", "10"))
36
- self.resources: dict[str, Any] = {
34
+ self.COST_PER_SKILL: dict[str, float] = self._load_json_from_env("COST_PER_SKILL", default={})
35
+ self.MAX_CONCURRENT_TASKS: int = int(getenv("MAX_CONCURRENT_TASKS", "10"))
36
+ self.RESOURCES: dict[str, Any] = {
37
37
  "cpu_cores": int(getenv("CPU_CORES", "4")),
38
38
  "gpu_info": self._get_gpu_info(),
39
39
  }
40
40
 
41
41
  # --- Installed software and models (read as JSON strings) ---
42
- self.installed_software: dict[str, str] = self._load_json_from_env(
42
+ self.INSTALLED_SOFTWARE: dict[str, str] = self._load_json_from_env(
43
43
  "INSTALLED_SOFTWARE",
44
44
  default={"python": "3.9"},
45
45
  )
46
- self.installed_models: list[dict[str, str]] = self._load_json_from_env(
46
+ self.INSTALLED_MODELS: list[dict[str, str]] = self._load_json_from_env(
47
47
  "INSTALLED_MODELS",
48
48
  default=[],
49
49
  )
50
50
 
51
+ # --- S3 Settings for payload offloading ---
52
+ self.TASK_FILES_DIR: str = getenv("TASK_FILES_DIR", "/tmp/payloads")
53
+ self.S3_ENDPOINT_URL: str | None = getenv("S3_ENDPOINT_URL")
54
+ self.S3_ACCESS_KEY: str | None = getenv("S3_ACCESS_KEY")
55
+ self.S3_SECRET_KEY: str | None = getenv("S3_SECRET_KEY")
56
+ self.S3_DEFAULT_BUCKET: str = getenv("S3_DEFAULT_BUCKET", "avtomatika-payloads")
57
+
51
58
  # --- Tuning parameters ---
52
- self.heartbeat_interval: float = float(getenv("HEARTBEAT_INTERVAL", "15"))
53
- self.result_max_retries: int = int(getenv("RESULT_MAX_RETRIES", "5"))
54
- self.result_retry_initial_delay: float = float(
59
+ self.HEARTBEAT_INTERVAL: float = float(getenv("HEARTBEAT_INTERVAL", "15"))
60
+ self.RESULT_MAX_RETRIES: int = int(getenv("RESULT_MAX_RETRIES", "5"))
61
+ self.RESULT_RETRY_INITIAL_DELAY: float = float(
55
62
  getenv("RESULT_RETRY_INITIAL_DELAY", "1.0"),
56
63
  )
57
- self.heartbeat_debounce_delay: float = float(getenv("WORKER_HEARTBEAT_DEBOUNCE_DELAY", 0.1))
58
- self.task_poll_timeout: float = float(getenv("TASK_POLL_TIMEOUT", "30"))
59
- self.task_poll_error_delay: float = float(
64
+ self.HEARTBEAT_DEBOUNCE_DELAY: float = float(getenv("WORKER_HEARTBEAT_DEBOUNCE_DELAY", 0.1))
65
+ self.TASK_POLL_TIMEOUT: float = float(getenv("TASK_POLL_TIMEOUT", "30"))
66
+ self.TASK_POLL_ERROR_DELAY: float = float(
60
67
  getenv("TASK_POLL_ERROR_DELAY", "5.0"),
61
68
  )
62
- self.idle_poll_delay: float = float(getenv("IDLE_POLL_DELAY", "0.01"))
63
- self.enable_websockets: bool = getenv("WORKER_ENABLE_WEBSOCKETS", "false").lower() == "true"
64
- self.multi_orchestrator_mode: str = getenv("MULTI_ORCHESTRATOR_MODE", "FAILOVER")
69
+ self.IDLE_POLL_DELAY: float = float(getenv("IDLE_POLL_DELAY", "0.01"))
70
+ self.ENABLE_WEBSOCKETS: bool = getenv("WORKER_ENABLE_WEBSOCKETS", "false").lower() == "true"
71
+ self.MULTI_ORCHESTRATOR_MODE: str = getenv("MULTI_ORCHESTRATOR_MODE", "FAILOVER")
65
72
 
66
73
  def _get_orchestrators_config(self) -> list[dict[str, Any]]:
67
74
  """
68
75
  Loads orchestrator configuration from the ORCHESTRATORS_CONFIG environment variable.
69
76
  For backward compatibility, if it is not set, it uses ORCHESTRATOR_URL.
70
77
  """
71
- orchestrators_json = getenv("ORCHESTRATORS_CONFIG")
72
- if orchestrators_json:
78
+ if orchestrators_json := getenv("ORCHESTRATORS_CONFIG"):
73
79
  try:
74
80
  orchestrators = loads(orchestrators_json)
81
+ if getenv("ORCHESTRATOR_URL"):
82
+ print("Info: Both ORCHESTRATORS_CONFIG and ORCHESTRATOR_URL are set. Using ORCHESTRATORS_CONFIG.")
75
83
  for o in orchestrators:
76
84
  if "priority" not in o:
77
85
  o["priority"] = 10
86
+ if "weight" not in o:
87
+ o["weight"] = 1
78
88
  orchestrators.sort(key=lambda x: (x.get("priority", 10), x.get("url")))
79
89
  return orchestrators
80
90
  except JSONDecodeError:
81
91
  print("Warning: Could not decode JSON from ORCHESTRATORS_CONFIG. Falling back to default.")
82
92
 
83
93
  orchestrator_url = getenv("ORCHESTRATOR_URL", "http://localhost:8080")
84
- return [{"url": orchestrator_url, "priority": 1}]
94
+ return [{"url": orchestrator_url, "priority": 1, "weight": 1}]
85
95
 
86
- def _get_gpu_info(self) -> dict[str, Any] | None:
96
+ @staticmethod
97
+ def _get_gpu_info() -> dict[str, Any] | None:
87
98
  """Collects GPU information from environment variables.
88
99
  Returns None if GPU is not configured.
89
100
  """
90
- gpu_model = getenv("GPU_MODEL")
91
- if not gpu_model:
101
+ if gpu_model := getenv("GPU_MODEL"):
102
+ return {
103
+ "model": gpu_model,
104
+ "vram_gb": int(getenv("GPU_VRAM_GB", "0")),
105
+ }
106
+ else:
92
107
  return None
93
108
 
94
- return {
95
- "model": gpu_model,
96
- "vram_gb": int(getenv("GPU_VRAM_GB", "0")),
97
- }
98
-
99
- def _load_json_from_env(self, key: str, default: Any) -> Any:
109
+ @staticmethod
110
+ def _load_json_from_env(key: str, default: Any) -> Any:
100
111
  """Safely loads a JSON string from an environment variable."""
101
- value = getenv(key)
102
- if value:
112
+ if value := getenv(key):
103
113
  try:
104
114
  return loads(value)
105
115
  except JSONDecodeError:
@@ -0,0 +1,141 @@
1
+ from asyncio import gather, to_thread
2
+ from os import walk
3
+ from os.path import basename, dirname, join, relpath
4
+ from shutil import rmtree
5
+ from typing import Any
6
+ from urllib.parse import urlparse
7
+
8
+ from aioboto3 import Session
9
+ from aiofiles.os import makedirs
10
+ from aiofiles.ospath import exists, isdir
11
+ from botocore.client import Config
12
+
13
+ from .config import WorkerConfig
14
+
15
+
16
+ class S3Manager:
17
+ """Handles S3 payload offloading."""
18
+
19
+ def __init__(self, config: WorkerConfig):
20
+ self._config = config
21
+ self._session = Session()
22
+
23
+ def _get_client_args(self) -> dict[str, Any]:
24
+ """Returns standard arguments for S3 client creation."""
25
+ return {
26
+ "service_name": "s3",
27
+ "endpoint_url": self._config.S3_ENDPOINT_URL,
28
+ "aws_access_key_id": self._config.S3_ACCESS_KEY,
29
+ "aws_secret_access_key": self._config.S3_SECRET_KEY,
30
+ "config": Config(signature_version="s3v4"),
31
+ }
32
+
33
+ async def cleanup(self, task_id: str):
34
+ """Removes the task-specific payload directory."""
35
+ task_dir = join(self._config.TASK_FILES_DIR, task_id)
36
+ if await exists(task_dir):
37
+ await to_thread(lambda: rmtree(task_dir, ignore_errors=True))
38
+
39
+ async def _process_s3_uri(self, uri: str, task_id: str) -> str:
40
+ """Downloads a file or a folder (if uri ends with /) from S3 and returns the local path."""
41
+ parsed_url = urlparse(uri)
42
+ bucket_name = parsed_url.netloc
43
+ object_key = parsed_url.path.lstrip("/")
44
+
45
+ # Use task-specific directory for isolation
46
+ local_dir_root = join(self._config.TASK_FILES_DIR, task_id)
47
+ await makedirs(local_dir_root, exist_ok=True)
48
+
49
+ async with self._session.client(**self._get_client_args()) as s3:
50
+ # Handle folder download (prefix)
51
+ if uri.endswith("/"):
52
+ folder_name = object_key.rstrip("/").split("/")[-1]
53
+ local_folder_path = join(local_dir_root, folder_name)
54
+
55
+ paginator = s3.get_paginator("list_objects_v2")
56
+ tasks = []
57
+ async for page in paginator.paginate(Bucket=bucket_name, Prefix=object_key):
58
+ for obj in page.get("Contents", []):
59
+ key = obj["Key"]
60
+ if key.endswith("/"):
61
+ continue
62
+
63
+ # Calculate relative path inside the folder
64
+ rel_path = key[len(object_key) :]
65
+ local_file_path = join(local_folder_path, rel_path)
66
+
67
+ await makedirs(dirname(local_file_path), exist_ok=True)
68
+ tasks.append(s3.download_file(bucket_name, key, local_file_path))
69
+
70
+ if tasks:
71
+ await gather(*tasks)
72
+ return local_folder_path
73
+
74
+ # Handle single file download
75
+ local_path = join(local_dir_root, basename(object_key))
76
+ await s3.download_file(bucket_name, object_key, local_path)
77
+ return local_path
78
+
79
+ async def _upload_to_s3(self, local_path: str) -> str:
80
+ """Uploads a file or a folder to S3 and returns the S3 URI."""
81
+ bucket_name = self._config.S3_DEFAULT_BUCKET
82
+
83
+ async with self._session.client(**self._get_client_args()) as s3:
84
+ # Handle folder upload
85
+ if await isdir(local_path):
86
+ folder_name = basename(local_path.rstrip("/"))
87
+ s3_prefix = f"{folder_name}/"
88
+ tasks = []
89
+
90
+ # Use to_thread to avoid blocking event loop during file walk
91
+ def _get_files_to_upload():
92
+ files_to_upload = []
93
+ for root, _, files in walk(local_path):
94
+ for file in files:
95
+ f_path = join(root, file)
96
+ rel = relpath(f_path, local_path)
97
+ files_to_upload.append((f_path, f"{s3_prefix}{rel}"))
98
+ return files_to_upload
99
+
100
+ files_list = await to_thread(_get_files_to_upload)
101
+
102
+ for full_path, key in files_list:
103
+ tasks.append(s3.upload_file(full_path, bucket_name, key))
104
+
105
+ if tasks:
106
+ await gather(*tasks)
107
+
108
+ return f"s3://{bucket_name}/{s3_prefix}"
109
+
110
+ # Handle single file upload
111
+ object_key = basename(local_path)
112
+ await s3.upload_file(local_path, bucket_name, object_key)
113
+ return f"s3://{bucket_name}/{object_key}"
114
+
115
+ async def process_params(self, params: dict[str, Any], task_id: str) -> dict[str, Any]:
116
+ """Recursively searches for S3 URIs in params and downloads the files."""
117
+ if not self._config.S3_ENDPOINT_URL:
118
+ return params
119
+
120
+ async def _process(item: Any) -> Any:
121
+ if isinstance(item, str) and item.startswith("s3://"):
122
+ return await self._process_s3_uri(item, task_id)
123
+ if isinstance(item, dict):
124
+ return {k: await _process(v) for k, v in item.items()}
125
+ return [await _process(i) for i in item] if isinstance(item, list) else item
126
+
127
+ return await _process(params)
128
+
129
+ async def process_result(self, result: dict[str, Any]) -> dict[str, Any]:
130
+ """Recursively searches for local file paths in the result and uploads them to S3."""
131
+ if not self._config.S3_ENDPOINT_URL:
132
+ return result
133
+
134
+ async def _process(item: Any) -> Any:
135
+ if isinstance(item, str) and item.startswith(self._config.TASK_FILES_DIR):
136
+ return await self._upload_to_s3(item) if await exists(item) else item
137
+ if isinstance(item, dict):
138
+ return {k: await _process(v) for k, v in item.items()}
139
+ return [await _process(i) for i in item] if isinstance(item, list) else item
140
+
141
+ return await _process(result)
@@ -0,0 +1,97 @@
1
+ from contextlib import asynccontextmanager
2
+ from os.path import dirname, join
3
+ from typing import AsyncGenerator
4
+
5
+ from aiofiles import open as aiopen
6
+ from aiofiles.os import listdir, makedirs
7
+ from aiofiles.ospath import exists as aio_exists
8
+
9
+
10
+ class TaskFiles:
11
+ """
12
+ A helper class for managing task-specific files.
13
+ Provides asynchronous lazy directory creation and high-level file operations
14
+ within an isolated workspace for each task.
15
+ """
16
+
17
+ def __init__(self, task_dir: str):
18
+ """
19
+ Initializes TaskFiles with a specific task directory.
20
+ The directory is not created until needed.
21
+ """
22
+ self._task_dir = task_dir
23
+
24
+ async def get_root(self) -> str:
25
+ """
26
+ Asynchronously returns the root directory for the task.
27
+ Creates the directory on disk if it doesn't exist.
28
+ """
29
+ await makedirs(self._task_dir, exist_ok=True)
30
+ return self._task_dir
31
+
32
+ async def path_to(self, filename: str) -> str:
33
+ """
34
+ Asynchronously returns an absolute path for a file within the task directory.
35
+ Guarantees that the task root directory exists.
36
+ """
37
+ root = await self.get_root()
38
+ return join(root, filename)
39
+
40
+ @asynccontextmanager
41
+ async def open(self, filename: str, mode: str = "r") -> AsyncGenerator:
42
+ """
43
+ An asynchronous context manager to open a file within the task directory.
44
+ Automatically creates the task root and any necessary subdirectories.
45
+
46
+ Args:
47
+ filename: Name or relative path of the file.
48
+ mode: File opening mode (e.g., 'r', 'w', 'a', 'rb', 'wb').
49
+ """
50
+ path = await self.path_to(filename)
51
+ # Ensure directory for the file itself exists if filename contains subdirectories
52
+ file_dir = dirname(path)
53
+ if file_dir != self._task_dir:
54
+ await makedirs(file_dir, exist_ok=True)
55
+
56
+ async with aiopen(path, mode) as f:
57
+ yield f
58
+
59
+ async def read(self, filename: str, mode: str = "r") -> str | bytes:
60
+ """
61
+ Asynchronously reads the entire content of a file.
62
+
63
+ Args:
64
+ filename: Name of the file to read.
65
+ mode: Mode to open the file in (defaults to 'r').
66
+ """
67
+ async with self.open(filename, mode) as f:
68
+ return await f.read()
69
+
70
+ async def write(self, filename: str, data: str | bytes, mode: str = "w") -> None:
71
+ """
72
+ Asynchronously writes data to a file. Creates or overwrites the file by default.
73
+
74
+ Args:
75
+ filename: Name of the file to write.
76
+ data: Content to write (string or bytes).
77
+ mode: Mode to open the file in (defaults to 'w').
78
+ """
79
+ async with self.open(filename, mode) as f:
80
+ await f.write(data)
81
+
82
+ async def list(self) -> list[str]:
83
+ """
84
+ Asynchronously lists all file and directory names within the task root.
85
+ """
86
+ root = await self.get_root()
87
+ return await listdir(root)
88
+
89
+ async def exists(self, filename: str) -> bool:
90
+ """
91
+ Asynchronously checks if a specific file or directory exists in the task root.
92
+ """
93
+ path = join(self._task_dir, filename)
94
+ return await aio_exists(path)
95
+
96
+ def __repr__(self):
97
+ return f"<TaskFiles root='{self._task_dir}'>"
@@ -2,3 +2,7 @@
2
2
  TRANSIENT_ERROR = "TRANSIENT_ERROR"
3
3
  PERMANENT_ERROR = "PERMANENT_ERROR"
4
4
  INVALID_INPUT_ERROR = "INVALID_INPUT_ERROR"
5
+
6
+
7
+ class ParamValidationError(Exception):
8
+ """Custom exception for parameter validation errors."""