avtomatika 1.0b7__py3-none-any.whl → 1.0b8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
avtomatika/s3.py ADDED
@@ -0,0 +1,323 @@
1
+ from asyncio import Semaphore, gather, to_thread
2
+ from logging import getLogger
3
+ from os import sep, walk
4
+ from pathlib import Path
5
+ from shutil import rmtree
6
+ from typing import Any, Tuple
7
+
8
+ from aiofiles import open as aiopen
9
+ from obstore import delete_async, get_async, put_async
10
+ from obstore import list as obstore_list
11
+ from obstore.store import S3Store
12
+ from orjson import dumps, loads
13
+
14
+ from .config import Config
15
+ from .history.base import HistoryStorageBase
16
+
17
+ logger = getLogger(__name__)
18
+
19
+ try:
20
+ HAS_S3_LIBS = True
21
+ except ImportError:
22
+ HAS_S3_LIBS = False
23
+ S3Store = Any
24
+
25
+
26
+ class TaskFiles:
27
+ """
28
+ Manages files for a specific job, ensuring full compatibility with avtomatika-worker.
29
+ Supports recursive directory download/upload and non-blocking I/O.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ store: "S3Store",
35
+ bucket: str,
36
+ job_id: str,
37
+ base_local_dir: str | Path,
38
+ semaphore: Semaphore,
39
+ history: HistoryStorageBase | None = None,
40
+ ):
41
+ self._store = store
42
+ self._bucket = bucket
43
+ self._job_id = job_id
44
+ self._history = history
45
+ self._s3_prefix = f"jobs/{job_id}/"
46
+ self.local_dir = Path(base_local_dir) / job_id
47
+ self._semaphore = semaphore
48
+
49
+ def _ensure_local_dir(self) -> None:
50
+ if not self.local_dir.exists():
51
+ self.local_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ def path(self, filename: str) -> Path:
54
+ """Returns local path for a filename, ensuring the directory exists."""
55
+ self._ensure_local_dir()
56
+ clean_name = filename.split("/")[-1] if "://" in filename else filename.lstrip("/")
57
+ return self.local_dir / clean_name
58
+
59
+ def _parse_s3_uri(self, uri: str) -> Tuple[str, str, bool]:
60
+ """
61
+ Parses s3://bucket/key into (bucket, key, is_directory).
62
+ is_directory is True if uri ends with '/'.
63
+ """
64
+ is_dir = uri.endswith("/")
65
+
66
+ if not uri.startswith("s3://"):
67
+ key = f"{self._s3_prefix}{uri.lstrip('/')}"
68
+ return self._bucket, key, is_dir
69
+
70
+ parts = uri[5:].split("/", 1)
71
+ bucket = parts[0]
72
+ key = parts[1] if len(parts) > 1 else ""
73
+ return bucket, key, is_dir
74
+
75
+ async def _download_single_file(self, key: str, local_path: Path) -> None:
76
+ """Downloads a single file safely using semaphore and streaming to avoid OOM."""
77
+ if not local_path.parent.exists():
78
+ await to_thread(local_path.parent.mkdir, parents=True, exist_ok=True)
79
+
80
+ async with self._semaphore:
81
+ response = await get_async(self._store, key)
82
+ stream = response.stream()
83
+ async with aiopen(local_path, "wb") as f:
84
+ async for chunk in stream:
85
+ await f.write(chunk)
86
+
87
+ async def download(self, name_or_uri: str, local_name: str | None = None) -> Path:
88
+ """
89
+ Downloads a file or directory (recursively).
90
+ If URI ends with '/', it treats it as a directory.
91
+ """
92
+ bucket, key, is_dir = self._parse_s3_uri(name_or_uri)
93
+
94
+ if local_name:
95
+ target_path = self.path(local_name)
96
+ else:
97
+ suffix = key.replace(self._s3_prefix, "", 1) if key.startswith(self._s3_prefix) else key.split("/")[-1]
98
+ target_path = self.local_dir / suffix
99
+
100
+ if is_dir:
101
+ logger.info(f"Recursive download: s3://{bucket}/{key} -> {target_path}")
102
+ entries = await to_thread(lambda: list(obstore_list(self._store, prefix=key)))
103
+
104
+ tasks = []
105
+ for entry in entries:
106
+ s3_key = entry["path"]
107
+ rel_path = s3_key[len(key) :]
108
+ if not rel_path:
109
+ continue
110
+
111
+ local_file_path = target_path / rel_path
112
+ tasks.append(self._download_single_file(s3_key, local_file_path))
113
+
114
+ if tasks:
115
+ await gather(*tasks)
116
+
117
+ await self._log_event("download_dir", f"s3://{bucket}/{key}", str(target_path))
118
+ return target_path
119
+ else:
120
+ logger.debug(f"Downloading s3://{bucket}/{key} -> {target_path}")
121
+ await self._download_single_file(key, target_path)
122
+ await self._log_event("download", f"s3://{bucket}/{key}", str(target_path))
123
+ return target_path
124
+
125
+ async def _upload_single_file(self, local_path: Path, s3_key: str) -> None:
126
+ """Uploads a single file safely using semaphore."""
127
+ async with self._semaphore:
128
+ async with aiopen(local_path, "rb") as f:
129
+ content = await f.read()
130
+ await put_async(self._store, s3_key, content)
131
+
132
+ async def upload(self, local_name: str, remote_name: str | None = None) -> str:
133
+ """
134
+ Uploads a file or directory recursively.
135
+ If local_name points to a directory, it uploads all contents.
136
+ """
137
+ local_path = self.path(local_name)
138
+
139
+ if local_path.is_dir():
140
+ base_remote = (remote_name or local_name).lstrip("/")
141
+ if not base_remote.endswith("/"):
142
+ base_remote += "/"
143
+
144
+ target_prefix = f"{self._s3_prefix}{base_remote}"
145
+ logger.info(f"Recursive upload: {local_path} -> s3://{self._bucket}/{target_prefix}")
146
+
147
+ def collect_files():
148
+ files_to_upload = []
149
+ for root, _, files in walk(local_path):
150
+ for file in files:
151
+ abs_path = Path(root) / file
152
+ rel_path = abs_path.relative_to(local_path)
153
+ s3_key = f"{target_prefix}{str(rel_path).replace(sep, '/')}"
154
+ files_to_upload.append((abs_path, s3_key))
155
+ return files_to_upload
156
+
157
+ files_map = await to_thread(collect_files)
158
+
159
+ tasks = [self._upload_single_file(lp, k) for lp, k in files_map]
160
+ if tasks:
161
+ await gather(*tasks)
162
+
163
+ uri = f"s3://{self._bucket}/{target_prefix}"
164
+ await self._log_event("upload_dir", uri, str(local_path))
165
+ return uri
166
+
167
+ elif local_path.exists():
168
+ target_key = f"{self._s3_prefix}{(remote_name or local_name).lstrip('/')}"
169
+ logger.debug(f"Uploading {local_path} -> s3://{self._bucket}/{target_key}")
170
+
171
+ await self._upload_single_file(local_path, target_key)
172
+
173
+ uri = f"s3://{self._bucket}/{target_key}"
174
+ await self._log_event("upload", uri, str(local_path))
175
+ return uri
176
+ else:
177
+ raise FileNotFoundError(f"Local file/dir not found: {local_path}")
178
+
179
+ async def read_text(self, name_or_uri: str) -> str:
180
+ bucket, key, _ = self._parse_s3_uri(name_or_uri)
181
+ filename = key.split("/")[-1]
182
+ local_path = self.path(filename)
183
+
184
+ if not local_path.exists():
185
+ await self.download(name_or_uri)
186
+
187
+ async with aiopen(local_path, "r", encoding="utf-8") as f:
188
+ return await f.read()
189
+
190
+ async def read_json(self, name_or_uri: str) -> Any:
191
+ bucket, key, _ = self._parse_s3_uri(name_or_uri)
192
+ filename = key.split("/")[-1]
193
+ local_path = self.path(filename)
194
+
195
+ if not local_path.exists():
196
+ await self.download(name_or_uri)
197
+
198
+ async with aiopen(local_path, "rb") as f:
199
+ content = await f.read()
200
+ return loads(content)
201
+
202
+ async def write_json(self, filename: str, data: Any, upload: bool = True) -> str:
203
+ """Writes JSON locally (binary mode) and optionally uploads to S3."""
204
+ local_path = self.path(filename)
205
+ json_bytes = dumps(data)
206
+
207
+ async with aiopen(local_path, "wb") as f:
208
+ await f.write(json_bytes)
209
+
210
+ if upload:
211
+ return await self.upload(filename)
212
+ return f"file://{local_path}"
213
+
214
+ async def write_text(self, filename: str, text: str, upload: bool = True) -> Path:
215
+ local_path = self.path(filename)
216
+ async with aiopen(local_path, "w", encoding="utf-8") as f:
217
+ await f.write(text)
218
+
219
+ if upload:
220
+ await self.upload(filename)
221
+
222
+ return local_path
223
+
224
+ async def cleanup(self) -> None:
225
+ """Full cleanup of S3 prefix and local job directory."""
226
+ logger.info(f"Cleanup for job {self._job_id}...")
227
+ try:
228
+ entries = await to_thread(lambda: list(obstore_list(self._store, prefix=self._s3_prefix)))
229
+ paths_to_delete = [entry["path"] for entry in entries]
230
+ if paths_to_delete:
231
+ await delete_async(self._store, paths_to_delete)
232
+ except Exception as e:
233
+ logger.error(f"S3 cleanup error: {e}")
234
+
235
+ if self.local_dir.exists():
236
+ await to_thread(rmtree, self.local_dir)
237
+
238
+ async def _log_event(self, operation: str, file_uri: str, local_path: str) -> None:
239
+ if not self._history:
240
+ return
241
+
242
+ try:
243
+ await self._history.log_job_event(
244
+ {
245
+ "job_id": self._job_id,
246
+ "event_type": "s3_operation",
247
+ "state": "running",
248
+ "context_snapshot": {
249
+ "operation": operation,
250
+ "s3_uri": file_uri,
251
+ "local_path": str(local_path),
252
+ },
253
+ }
254
+ )
255
+ except Exception as e:
256
+ logger.warning(f"Failed to log S3 event: {e}")
257
+
258
+
259
+ class S3Service:
260
+ """
261
+ Central service for S3 operations.
262
+ Initializes the Store and provides TaskFiles instances.
263
+ """
264
+
265
+ def __init__(self, config: Config, history: HistoryStorageBase | None = None):
266
+ self.config = config
267
+ self._history = history
268
+ self._store: S3Store | None = None
269
+ self._semaphore: Semaphore | None = None
270
+
271
+ self._config_present = bool(config.S3_ENDPOINT_URL and config.S3_ACCESS_KEY and config.S3_SECRET_KEY)
272
+
273
+ if self._config_present:
274
+ if HAS_S3_LIBS:
275
+ self._enabled = True
276
+ self._initialize_store()
277
+ else:
278
+ logger.error(
279
+ "S3 configuration found, but 'avtomatika[s3]' extra dependencies are not installed. "
280
+ "S3 support will be disabled. Install with: pip install 'avtomatika[s3]'"
281
+ )
282
+ self._enabled = False
283
+ else:
284
+ self._enabled = False
285
+ if any([config.S3_ENDPOINT_URL, config.S3_ACCESS_KEY, config.S3_SECRET_KEY]):
286
+ logger.warning("Partial S3 configuration found. S3 support disabled.")
287
+
288
+ def _initialize_store(self) -> None:
289
+ try:
290
+ self._store = S3Store(
291
+ bucket=self.config.S3_DEFAULT_BUCKET,
292
+ access_key_id=self.config.S3_ACCESS_KEY,
293
+ secret_access_key=self.config.S3_SECRET_KEY,
294
+ region=self.config.S3_REGION,
295
+ endpoint=self.config.S3_ENDPOINT_URL,
296
+ allow_http="http://" in self.config.S3_ENDPOINT_URL,
297
+ force_path_style=True,
298
+ )
299
+ self._semaphore = Semaphore(self.config.S3_MAX_CONCURRENCY)
300
+ logger.info(
301
+ f"S3Service initialized (Endpoint: {self.config.S3_ENDPOINT_URL}, "
302
+ f"Bucket: {self.config.S3_DEFAULT_BUCKET}, "
303
+ f"Max Concurrency: {self.config.S3_MAX_CONCURRENCY})"
304
+ )
305
+ except Exception as e:
306
+ logger.error(f"Failed to initialize S3 Store: {e}")
307
+ self._enabled = False
308
+
309
+ def get_task_files(self, job_id: str) -> TaskFiles | None:
310
+ if not self._enabled or not self._store or not self._semaphore:
311
+ return None
312
+
313
+ return TaskFiles(
314
+ self._store,
315
+ self.config.S3_DEFAULT_BUCKET,
316
+ job_id,
317
+ self.config.TASK_FILES_DIR,
318
+ self._semaphore,
319
+ self._history,
320
+ )
321
+
322
+ async def close(self) -> None:
323
+ pass
@@ -142,6 +142,37 @@ class StorageBackend(ABC):
142
142
  """
143
143
  raise NotImplementedError
144
144
 
145
+ @abstractmethod
146
+ async def get_active_worker_ids(self) -> list[str]:
147
+ """Returns a list of IDs for all currently active workers.
148
+
149
+ :return: A list of worker ID strings.
150
+ """
151
+ raise NotImplementedError
152
+
153
+ @abstractmethod
154
+ async def cleanup_expired_workers(self) -> None:
155
+ """Maintenance task to clean up internal indexes from expired worker entries."""
156
+ raise NotImplementedError
157
+
158
+ @abstractmethod
159
+ async def get_workers(self, worker_ids: list[str]) -> list[dict[str, Any]]:
160
+ """Bulk retrieves worker info for a list of IDs.
161
+
162
+ :param worker_ids: List of worker identifiers.
163
+ :return: List of worker info dictionaries.
164
+ """
165
+ raise NotImplementedError
166
+
167
+ @abstractmethod
168
+ async def find_workers_for_task(self, task_type: str) -> list[str]:
169
+ """Finds idle workers that support the given task.
170
+
171
+ :param task_type: The type of task to find workers for.
172
+ :return: A list of worker IDs that are idle and support the task.
173
+ """
174
+ raise NotImplementedError
175
+
145
176
  @abstractmethod
146
177
  async def add_job_to_watch(self, job_id: str, timeout_at: float) -> None:
147
178
  """Add a job to the list for timeout tracking.
@@ -152,9 +183,10 @@ class StorageBackend(ABC):
152
183
  raise NotImplementedError
153
184
 
154
185
  @abstractmethod
155
- async def get_timed_out_jobs(self) -> list[str]:
186
+ async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
156
187
  """Get a list of job IDs that are overdue and remove them from the tracking list.
157
188
 
189
+ :param limit: Maximum number of jobs to retrieve.
158
190
  :return: A list of overdue job IDs.
159
191
  """
160
192
  raise NotImplementedError
@@ -165,9 +197,10 @@ class StorageBackend(ABC):
165
197
  raise NotImplementedError
166
198
 
167
199
  @abstractmethod
168
- async def dequeue_job(self) -> tuple[str, str] | None:
200
+ async def dequeue_job(self, block: int | None = None) -> tuple[str, str] | None:
169
201
  """Retrieve a job ID and its message ID from the execution queue.
170
202
 
203
+ :param block: Milliseconds to block if no message is available. None for non-blocking.
171
204
  :return: A tuple of (job_id, message_id) or None if the timeout has expired.
172
205
  """
173
206
  raise NotImplementedError
@@ -250,7 +283,7 @@ class StorageBackend(ABC):
250
283
  raise NotImplementedError
251
284
 
252
285
  @abstractmethod
253
- async def set_worker_token(self, worker_id: str, token: str):
286
+ async def set_worker_token(self, worker_id: str, token: str) -> None:
254
287
  """Saves an individual token for a specific worker."""
255
288
  raise NotImplementedError
256
289
 
@@ -265,7 +298,7 @@ class StorageBackend(ABC):
265
298
  raise NotImplementedError
266
299
 
267
300
  @abstractmethod
268
- async def flush_all(self):
301
+ async def flush_all(self) -> None:
269
302
  """Completely clears the storage. Used mainly for tests."""
270
303
  raise NotImplementedError
271
304
 
@@ -312,3 +345,11 @@ class StorageBackend(ABC):
312
345
  :return: True if the lock was successfully released, False otherwise.
313
346
  """
314
347
  raise NotImplementedError
348
+
349
+ @abstractmethod
350
+ async def ping(self) -> bool:
351
+ """Checks connection to the storage backend.
352
+
353
+ :return: True if storage is accessible, False otherwise.
354
+ """
355
+ raise NotImplementedError
@@ -154,6 +154,33 @@ class MemoryStorage(StorageBackend):
154
154
  )
155
155
  return active_workers
156
156
 
157
+ async def get_active_worker_ids(self) -> list[str]:
158
+ async with self._lock:
159
+ now = monotonic()
160
+ return [worker_id for worker_id, ttl in self._worker_ttls.items() if ttl > now]
161
+
162
+ async def cleanup_expired_workers(self) -> None:
163
+ async with self._lock:
164
+ await self._clean_expired()
165
+
166
+ async def get_workers(self, worker_ids: list[str]) -> list[dict[str, Any]]:
167
+ async with self._lock:
168
+ return [self._workers[wid] for wid in worker_ids if wid in self._workers]
169
+
170
+ async def find_workers_for_task(self, task_type: str) -> list[str]:
171
+ """Finds idle workers supporting the task (O(N) for memory storage)."""
172
+ async with self._lock:
173
+ now = monotonic()
174
+ candidates = []
175
+ for worker_id, info in self._workers.items():
176
+ if self._worker_ttls.get(worker_id, 0) <= now:
177
+ continue
178
+ if info.get("status", "idle") != "idle":
179
+ continue
180
+ if task_type in info.get("supported_tasks", []):
181
+ candidates.append(worker_id)
182
+ return candidates
183
+
157
184
  async def add_job_to_watch(self, job_id: str, timeout_at: float) -> None:
158
185
  async with self._lock:
159
186
  self._watched_jobs[job_id] = timeout_at
@@ -173,13 +200,21 @@ class MemoryStorage(StorageBackend):
173
200
  async def enqueue_job(self, job_id: str) -> None:
174
201
  await self._job_queue.put(job_id)
175
202
 
176
- async def dequeue_job(self) -> tuple[str, str] | None:
177
- """Waits indefinitely for a job ID from the queue and returns it.
178
- Returns a tuple of (job_id, message_id). In MemoryStorage, message_id is dummy.
203
+ async def dequeue_job(self, block: int | None = None) -> tuple[str, str] | None:
204
+ """Waits for a job ID from the queue.
205
+ If block is None, waits indefinitely.
206
+ If block is int, waits for that many milliseconds.
179
207
  """
180
- job_id = await self._job_queue.get()
181
- self._job_queue.task_done()
182
- return job_id, "memory-msg-id"
208
+ try:
209
+ if block is None:
210
+ job_id = await self._job_queue.get()
211
+ else:
212
+ job_id = await wait_for(self._job_queue.get(), timeout=block / 1000.0)
213
+
214
+ self._job_queue.task_done()
215
+ return job_id, "memory-msg-id"
216
+ except AsyncTimeoutError:
217
+ return None
183
218
 
184
219
  async def ack_job(self, message_id: str) -> None:
185
220
  """No-op for MemoryStorage as it doesn't support persistent streams."""
@@ -334,3 +369,6 @@ class MemoryStorage(StorageBackend):
334
369
  del self._locks[key]
335
370
  return True
336
371
  return False
372
+
373
+ async def ping(self) -> bool:
374
+ return True