avtomatika 1.0b8__py3-none-any.whl → 1.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
avtomatika/engine.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from asyncio import TimeoutError as AsyncTimeoutError
2
2
  from asyncio import create_task, gather, get_running_loop, wait_for
3
3
  from logging import getLogger
4
- from typing import Any
4
+ from typing import Any, Optional
5
5
  from uuid import uuid4
6
6
 
7
7
  from aiohttp import ClientSession, web
@@ -24,6 +24,7 @@ from .app_keys import (
24
24
  SCHEDULER_TASK_KEY,
25
25
  WATCHER_KEY,
26
26
  WATCHER_TASK_KEY,
27
+ WORKER_SERVICE_KEY,
27
28
  WS_MANAGER_KEY,
28
29
  )
29
30
  from .blueprint import StateMachineBlueprint
@@ -40,6 +41,7 @@ from .logging_config import setup_logging
40
41
  from .reputation import ReputationCalculator
41
42
  from .s3 import S3Service
42
43
  from .scheduler import Scheduler
44
+ from .services.worker_service import WorkerService
43
45
  from .storage.base import StorageBackend
44
46
  from .telemetry import setup_telemetry
45
47
  from .utils.webhook_sender import WebhookPayload, WebhookSender
@@ -68,10 +70,19 @@ class OrchestratorEngine:
68
70
  self.config = config
69
71
  self.blueprints: dict[str, StateMachineBlueprint] = {}
70
72
  self.history_storage: HistoryStorageBase = NoOpHistoryStorage()
71
- self.ws_manager = WebSocketManager()
73
+ self.ws_manager = WebSocketManager(self.storage)
72
74
  self.app = web.Application(middlewares=[compression_middleware])
73
75
  self.app[ENGINE_KEY] = self
76
+ self.worker_service: Optional[WorkerService] = None
74
77
  self._setup_done = False
78
+ self.webhook_sender: WebhookSender
79
+ self.dispatcher: Dispatcher
80
+ self.runner: web.AppRunner
81
+ self.site: web.TCPSite
82
+
83
+ from rxon import HttpListener
84
+
85
+ self.rxon_listener = HttpListener(self.app)
75
86
 
76
87
  def register_blueprint(self, blueprint: StateMachineBlueprint) -> None:
77
88
  if self._setup_done:
@@ -142,8 +153,74 @@ class OrchestratorEngine:
142
153
  )
143
154
  self.history_storage = NoOpHistoryStorage()
144
155
 
156
+ async def handle_rxon_message(self, message_type: str, payload: Any, context: dict) -> Any:
157
+ """Core handler for RXON protocol messages via any listener."""
158
+ from rxon.security import extract_cert_identity
159
+
160
+ from .security import verify_worker_auth
161
+
162
+ request = context.get("raw_request")
163
+ token = context.get("token")
164
+ cert_identity = extract_cert_identity(request) if request else None
165
+
166
+ worker_id_hint = context.get("worker_id_hint")
167
+
168
+ if not worker_id_hint:
169
+ if message_type == "poll" and isinstance(payload, str):
170
+ worker_id_hint = payload
171
+ elif isinstance(payload, dict) and "worker_id" in payload:
172
+ worker_id_hint = payload["worker_id"]
173
+ elif hasattr(payload, "worker_id"):
174
+ worker_id_hint = payload.worker_id
175
+
176
+ try:
177
+ auth_worker_id = await verify_worker_auth(self.storage, self.config, token, cert_identity, worker_id_hint)
178
+ except PermissionError as e:
179
+ raise web.HTTPUnauthorized(text=str(e)) from e
180
+ except ValueError as e:
181
+ raise web.HTTPBadRequest(text=str(e)) from e
182
+
183
+ if self.worker_service is None:
184
+ raise web.HTTPInternalServerError(text="WorkerService is not initialized.")
185
+
186
+ if message_type == "register":
187
+ return await self.worker_service.register_worker(payload)
188
+
189
+ elif message_type == "poll":
190
+ return await self.worker_service.get_next_task(auth_worker_id)
191
+
192
+ elif message_type == "result":
193
+ return await self.worker_service.process_task_result(payload, auth_worker_id)
194
+
195
+ elif message_type == "heartbeat":
196
+ return await self.worker_service.update_worker_heartbeat(auth_worker_id, payload)
197
+
198
+ elif message_type == "sts_token":
199
+ if cert_identity is None:
200
+ raise web.HTTPForbidden(text="Unauthorized: mTLS certificate required to issue access token.")
201
+ return await self.worker_service.issue_access_token(auth_worker_id)
202
+
203
+ elif message_type == "websocket":
204
+ ws = payload
205
+ await self.ws_manager.register(auth_worker_id, ws)
206
+ try:
207
+ from aiohttp import WSMsgType
208
+
209
+ async for msg in ws:
210
+ if msg.type == WSMsgType.TEXT:
211
+ try:
212
+ data = msg.json()
213
+ await self.ws_manager.handle_message(auth_worker_id, data)
214
+ except Exception as e:
215
+ logger.error(f"Error processing WebSocket message from {auth_worker_id}: {e}")
216
+ elif msg.type == WSMsgType.ERROR:
217
+ break
218
+ finally:
219
+ await self.ws_manager.unregister(auth_worker_id)
220
+ return None
221
+
145
222
  async def on_startup(self, app: web.Application) -> None:
146
- # 1. Fail Fast: Check Storage Connection
223
+ # Fail Fast: Check Storage Connection
147
224
  if not await self.storage.ping():
148
225
  logger.critical("Failed to connect to Storage Backend (Redis). Exiting.")
149
226
  raise RuntimeError("Storage Backend is unavailable.")
@@ -208,14 +285,21 @@ class OrchestratorEngine:
208
285
  app[WS_MANAGER_KEY] = self.ws_manager
209
286
  app[S3_SERVICE_KEY] = S3Service(self.config, self.history_storage)
210
287
 
288
+ self.worker_service = WorkerService(self.storage, self.history_storage, self.config, self)
289
+ app[WORKER_SERVICE_KEY] = self.worker_service
290
+
211
291
  app[EXECUTOR_TASK_KEY] = create_task(app[EXECUTOR_KEY].run())
212
292
  app[WATCHER_TASK_KEY] = create_task(app[WATCHER_KEY].run())
213
293
  app[REPUTATION_CALCULATOR_TASK_KEY] = create_task(app[REPUTATION_CALCULATOR_KEY].run())
214
294
  app[HEALTH_CHECKER_TASK_KEY] = create_task(app[HEALTH_CHECKER_KEY].run())
215
295
  app[SCHEDULER_TASK_KEY] = create_task(app[SCHEDULER_KEY].run())
216
296
 
297
+ await self.rxon_listener.start(self.handle_rxon_message)
298
+
217
299
  async def on_shutdown(self, app: web.Application) -> None:
218
300
  logger.info("Shutdown sequence started.")
301
+ await self.rxon_listener.stop()
302
+
219
303
  app[EXECUTOR_KEY].stop()
220
304
  app[WATCHER_KEY].stop()
221
305
  app[REPUTATION_CALCULATOR_KEY].stop()
@@ -274,6 +358,8 @@ class OrchestratorEngine:
274
358
  blueprint_name: str,
275
359
  initial_data: dict[str, Any],
276
360
  source: str = "internal",
361
+ tracing_context: dict[str, str] | None = None,
362
+ data_metadata: dict[str, Any] | None = None,
277
363
  ) -> str:
278
364
  """Creates a job directly, bypassing the HTTP API layer.
279
365
  Useful for internal schedulers and triggers.
@@ -297,8 +383,9 @@ class OrchestratorEngine:
297
383
  "initial_data": initial_data,
298
384
  "state_history": {},
299
385
  "status": JOB_STATUS_PENDING,
300
- "tracing_context": {},
386
+ "tracing_context": tracing_context or {},
301
387
  "client_config": client_config,
388
+ "data_metadata": data_metadata or {},
302
389
  }
303
390
  await self.storage.save_job_state(job_id, job_state)
304
391
  await self.storage.enqueue_job(job_id)
@@ -374,19 +461,44 @@ class OrchestratorEngine:
374
461
 
375
462
  def run(self) -> None:
376
463
  self.setup()
464
+ ssl_context = None
465
+ if self.config.TLS_ENABLED:
466
+ from rxon.security import create_server_ssl_context
467
+
468
+ ssl_context = create_server_ssl_context(
469
+ cert_path=self.config.TLS_CERT_PATH,
470
+ key_path=self.config.TLS_KEY_PATH,
471
+ ca_path=self.config.TLS_CA_PATH,
472
+ require_client_cert=self.config.TLS_REQUIRE_CLIENT_CERT,
473
+ )
474
+ print(f"TLS enabled. mTLS required: {self.config.TLS_REQUIRE_CLIENT_CERT}")
475
+
377
476
  print(
378
477
  f"Starting OrchestratorEngine API server on {self.config.API_HOST}:{self.config.API_PORT} in blocking mode."
379
478
  )
380
- web.run_app(self.app, host=self.config.API_HOST, port=self.config.API_PORT)
479
+ web.run_app(self.app, host=self.config.API_HOST, port=self.config.API_PORT, ssl_context=ssl_context)
381
480
 
382
481
  async def start(self):
383
482
  """Starts the orchestrator engine non-blockingly."""
384
483
  self.setup()
385
484
  self.runner = web.AppRunner(self.app)
386
485
  await self.runner.setup()
387
- self.site = web.TCPSite(self.runner, self.config.API_HOST, self.config.API_PORT)
486
+
487
+ ssl_context = None
488
+ if self.config.TLS_ENABLED:
489
+ from rxon.security import create_server_ssl_context
490
+
491
+ ssl_context = create_server_ssl_context(
492
+ cert_path=self.config.TLS_CERT_PATH,
493
+ key_path=self.config.TLS_KEY_PATH,
494
+ ca_path=self.config.TLS_CA_PATH,
495
+ require_client_cert=self.config.TLS_REQUIRE_CLIENT_CERT,
496
+ )
497
+
498
+ self.site = web.TCPSite(self.runner, self.config.API_HOST, self.config.API_PORT, ssl_context=ssl_context)
388
499
  await self.site.start()
389
- print(f"OrchestratorEngine API server running on http://{self.config.API_HOST}:{self.config.API_PORT}")
500
+ protocol = "https" if self.config.TLS_ENABLED else "http"
501
+ print(f"OrchestratorEngine API server running on {protocol}://{self.config.API_HOST}:{self.config.API_PORT}")
390
502
 
391
503
  async def stop(self):
392
504
  """Stops the orchestrator engine."""
avtomatika/executor.py CHANGED
@@ -48,6 +48,16 @@ except ImportError:
48
48
  TraceContextTextMapPropagator = NoOpTraceContextTextMapPropagator() # Instantiate the class
49
49
 
50
50
  from .app_keys import S3_SERVICE_KEY
51
+ from .constants import (
52
+ JOB_STATUS_ERROR,
53
+ JOB_STATUS_FAILED,
54
+ JOB_STATUS_FINISHED,
55
+ JOB_STATUS_PENDING,
56
+ JOB_STATUS_QUARANTINED,
57
+ JOB_STATUS_RUNNING,
58
+ JOB_STATUS_WAITING_FOR_PARALLEL,
59
+ JOB_STATUS_WAITING_FOR_WORKER,
60
+ )
51
61
  from .context import ActionFactory
52
62
  from .data_types import ClientConfig, JobContext
53
63
  from .history.base import HistoryStorageBase
@@ -59,7 +69,7 @@ logger = getLogger(
59
69
  __name__
60
70
  ) # Re-declare logger after potential redefinition in except block if opentelemetry was missing
61
71
 
62
- TERMINAL_STATES = {"finished", "failed", "error", "quarantined"}
72
+ TERMINAL_STATES = {JOB_STATUS_FINISHED, JOB_STATUS_FAILED, JOB_STATUS_ERROR, JOB_STATUS_QUARANTINED}
63
73
 
64
74
 
65
75
  class JobExecutor:
@@ -263,7 +273,7 @@ class JobExecutor:
263
273
  # When transitioning to a new state, reset the retry counter.
264
274
  job_state["retry_count"] = 0
265
275
  job_state["current_state"] = next_state
266
- job_state["status"] = "running"
276
+ job_state["status"] = JOB_STATUS_RUNNING
267
277
  await self.storage.save_job_state(job_id, job_state)
268
278
 
269
279
  if next_state not in TERMINAL_STATES:
@@ -280,11 +290,7 @@ class JobExecutor:
280
290
  create_task(task_files.cleanup())
281
291
 
282
292
  await self._check_and_resume_parent(job_state)
283
- # Send webhook for finished/failed jobs
284
- event_type = "job_finished" if next_state == "finished" else "job_failed"
285
- # Since _check_and_resume_parent is for sub-jobs, we only send webhook if it's a top-level job
286
- # or if the user explicitly requested it for sub-jobs (by providing webhook_url).
287
- # The current logic stores webhook_url in job_state, so we just check it.
293
+ event_type = "job_finished" if next_state == JOB_STATUS_FINISHED else "job_failed"
288
294
  await self.engine.send_job_webhook(job_state, event_type)
289
295
 
290
296
  async def _handle_dispatch(
@@ -313,21 +319,15 @@ class JobExecutor:
313
319
  logger.info(f"Job {job_id} is now paused, awaiting human approval.")
314
320
  else:
315
321
  logger.info(f"Job {job_id} dispatching task: {task_info}")
316
-
317
322
  now = monotonic()
318
- # Safely get timeout, falling back to the global config if not provided in the task.
319
- # This prevents TypeErrors if 'timeout_seconds' is missing.
320
323
  timeout_seconds = task_info.get("timeout_seconds") or self.engine.config.WORKER_TIMEOUT_SECONDS
321
324
  timeout_at = now + timeout_seconds
322
-
323
- # Set status to waiting and add to watch list *before* dispatching
324
- job_state["status"] = "waiting_for_worker"
325
+ job_state["status"] = JOB_STATUS_WAITING_FOR_WORKER
325
326
  job_state["task_dispatched_at"] = now
326
327
  job_state["current_task_info"] = task_info # Save for retries
327
328
  job_state["current_task_transitions"] = task_info.get("transitions", {})
328
329
  await self.storage.save_job_state(job_id, job_state)
329
330
  await self.storage.add_job_to_watch(job_id, timeout_at)
330
-
331
331
  await self.dispatcher.dispatch(job_state, task_info)
332
332
 
333
333
  async def _handle_run_blueprint(
@@ -355,7 +355,7 @@ class JobExecutor:
355
355
  "blueprint_name": sub_blueprint_info["blueprint_name"],
356
356
  "current_state": "start",
357
357
  "initial_data": sub_blueprint_info["initial_data"],
358
- "status": "pending",
358
+ "status": JOB_STATUS_PENDING,
359
359
  "parent_job_id": parent_job_id,
360
360
  }
361
361
  await self.storage.save_job_state(child_job_id, child_job_state)
@@ -388,7 +388,7 @@ class JobExecutor:
388
388
  branch_task_ids = [str(uuid4()) for _ in tasks_to_dispatch]
389
389
 
390
390
  # Update job state for parallel execution
391
- job_state["status"] = "waiting_for_parallel_tasks"
391
+ job_state["status"] = JOB_STATUS_WAITING_FOR_PARALLEL
392
392
  job_state["aggregation_target"] = aggregate_into
393
393
  job_state["active_branches"] = branch_task_ids
394
394
  job_state["aggregation_results"] = {}
@@ -466,7 +466,7 @@ class JobExecutor:
466
466
  logger.critical(
467
467
  f"Job {job_id} has failed handler execution {max_retries + 1} times. Moving to quarantine.",
468
468
  )
469
- job_state["status"] = "quarantined"
469
+ job_state["status"] = JOB_STATUS_QUARANTINED
470
470
  job_state["error_message"] = str(error)
471
471
  await self.storage.save_job_state(job_id, job_state)
472
472
  await self.storage.quarantine_job(job_id)
@@ -499,7 +499,7 @@ class JobExecutor:
499
499
  return
500
500
 
501
501
  # Determine the outcome of the child job to select the correct transition.
502
- child_outcome = "success" if child_job_state["current_state"] == "finished" else "failure"
502
+ child_outcome = "success" if child_job_state["current_state"] == JOB_STATUS_FINISHED else "failure"
503
503
  transitions = parent_job_state.get("current_task_transitions", {})
504
504
  next_state = transitions.get(child_outcome, "failed")
505
505
 
@@ -514,7 +514,7 @@ class JobExecutor:
514
514
 
515
515
  # Update the parent job to its new state and re-enqueue it.
516
516
  parent_job_state["current_state"] = next_state
517
- parent_job_state["status"] = "running"
517
+ parent_job_state["status"] = JOB_STATUS_RUNNING
518
518
  await self.storage.save_job_state(parent_job_id, parent_job_state)
519
519
  await self.storage.enqueue_job(parent_job_id)
520
520
 
@@ -1,6 +1,7 @@
1
1
  from datetime import datetime
2
2
  from logging import DEBUG, Formatter, StreamHandler, getLogger
3
3
  from sys import stdout
4
+ from typing import Any, Literal, Optional
4
5
  from zoneinfo import ZoneInfo
5
6
 
6
7
  from pythonjsonlogger import json
@@ -9,14 +10,22 @@ from pythonjsonlogger import json
9
10
  class TimezoneFormatter(Formatter):
10
11
  """Formatter that respects a custom timezone."""
11
12
 
12
- def __init__(self, fmt=None, datefmt=None, style="%", validate=True, *, tz_name="UTC"):
13
+ def __init__(
14
+ self,
15
+ fmt: Optional[str] = None,
16
+ datefmt: Optional[str] = None,
17
+ style: Literal["%", "{", "$"] = "%",
18
+ validate: bool = True,
19
+ *,
20
+ tz_name: str = "UTC",
21
+ ) -> None:
13
22
  super().__init__(fmt, datefmt, style, validate)
14
23
  self.tz = ZoneInfo(tz_name)
15
24
 
16
- def converter(self, timestamp):
25
+ def converter(self, timestamp: float) -> datetime: # type: ignore[override]
17
26
  return datetime.fromtimestamp(timestamp, self.tz)
18
27
 
19
- def formatTime(self, record, datefmt=None):
28
+ def formatTime(self, record: Any, datefmt: Optional[str] = None) -> str:
20
29
  dt = self.converter(record.created)
21
30
  if datefmt:
22
31
  s = dt.strftime(datefmt)
@@ -28,14 +37,14 @@ class TimezoneFormatter(Formatter):
28
37
  return s
29
38
 
30
39
 
31
- class TimezoneJsonFormatter(json.JsonFormatter):
40
+ class TimezoneJsonFormatter(json.JsonFormatter): # type: ignore[name-defined]
32
41
  """JSON Formatter that respects a custom timezone."""
33
42
 
34
- def __init__(self, *args, tz_name="UTC", **kwargs):
43
+ def __init__(self, *args: Any, tz_name: str = "UTC", **kwargs: Any) -> None:
35
44
  super().__init__(*args, **kwargs)
36
45
  self.tz = ZoneInfo(tz_name)
37
46
 
38
- def formatTime(self, record, datefmt=None):
47
+ def formatTime(self, record: Any, datefmt: Optional[str] = None) -> str:
39
48
  # Override formatTime to use timezone-aware datetime
40
49
  dt = datetime.fromtimestamp(record.created, self.tz)
41
50
  if datefmt:
@@ -44,7 +53,7 @@ class TimezoneJsonFormatter(json.JsonFormatter):
44
53
  return dt.isoformat()
45
54
 
46
55
 
47
- def setup_logging(log_level: str = "INFO", log_format: str = "json", tz_name: str = "UTC"):
56
+ def setup_logging(log_level: str = "INFO", log_format: str = "json", tz_name: str = "UTC") -> None:
48
57
  """Configures structured logging for the entire application."""
49
58
  logger = getLogger("avtomatika")
50
59
  logger.setLevel(log_level)
avtomatika/s3.py CHANGED
@@ -3,13 +3,15 @@ from logging import getLogger
3
3
  from os import sep, walk
4
4
  from pathlib import Path
5
5
  from shutil import rmtree
6
- from typing import Any, Tuple
6
+ from typing import Any
7
7
 
8
8
  from aiofiles import open as aiopen
9
9
  from obstore import delete_async, get_async, put_async
10
10
  from obstore import list as obstore_list
11
11
  from obstore.store import S3Store
12
12
  from orjson import dumps, loads
13
+ from rxon.blob import calculate_config_hash, parse_uri
14
+ from rxon.exceptions import IntegrityError
13
15
 
14
16
  from .config import Config
15
17
  from .history.base import HistoryStorageBase
@@ -56,40 +58,50 @@ class TaskFiles:
56
58
  clean_name = filename.split("/")[-1] if "://" in filename else filename.lstrip("/")
57
59
  return self.local_dir / clean_name
58
60
 
59
- def _parse_s3_uri(self, uri: str) -> Tuple[str, str, bool]:
60
- """
61
- Parses s3://bucket/key into (bucket, key, is_directory).
62
- is_directory is True if uri ends with '/'.
61
+ async def _download_single_file(
62
+ self,
63
+ key: str,
64
+ local_path: Path,
65
+ expected_size: int | None = None,
66
+ expected_hash: str | None = None,
67
+ ) -> dict[str, Any]:
68
+ """Downloads a single file safely using semaphore and streaming.
69
+ Returns metadata (size, etag).
63
70
  """
64
- is_dir = uri.endswith("/")
65
-
66
- if not uri.startswith("s3://"):
67
- key = f"{self._s3_prefix}{uri.lstrip('/')}"
68
- return self._bucket, key, is_dir
69
-
70
- parts = uri[5:].split("/", 1)
71
- bucket = parts[0]
72
- key = parts[1] if len(parts) > 1 else ""
73
- return bucket, key, is_dir
74
-
75
- async def _download_single_file(self, key: str, local_path: Path) -> None:
76
- """Downloads a single file safely using semaphore and streaming to avoid OOM."""
77
71
  if not local_path.parent.exists():
78
72
  await to_thread(local_path.parent.mkdir, parents=True, exist_ok=True)
79
73
 
80
74
  async with self._semaphore:
81
75
  response = await get_async(self._store, key)
76
+ meta = response.meta
77
+ file_size = meta.size
78
+ etag = meta.e_tag.strip('"') if meta.e_tag else None
79
+
80
+ if expected_size is not None and file_size != expected_size:
81
+ raise IntegrityError(f"File size mismatch for {key}: expected {expected_size}, got {file_size}")
82
+
83
+ if expected_hash is not None and etag and expected_hash != etag:
84
+ raise IntegrityError(f"Integrity mismatch for {key}: expected ETag {expected_hash}, got {etag}")
85
+
82
86
  stream = response.stream()
83
87
  async with aiopen(local_path, "wb") as f:
84
88
  async for chunk in stream:
85
89
  await f.write(chunk)
86
90
 
87
- async def download(self, name_or_uri: str, local_name: str | None = None) -> Path:
91
+ return {"size": file_size, "etag": etag}
92
+
93
+ async def download(
94
+ self,
95
+ name_or_uri: str,
96
+ local_name: str | None = None,
97
+ verify_meta: dict[str, Any] | None = None,
98
+ ) -> Path:
88
99
  """
89
100
  Downloads a file or directory (recursively).
90
101
  If URI ends with '/', it treats it as a directory.
91
102
  """
92
- bucket, key, is_dir = self._parse_s3_uri(name_or_uri)
103
+ bucket, key, is_dir = parse_uri(name_or_uri, self._bucket, self._s3_prefix)
104
+ verify_meta = verify_meta or {}
93
105
 
94
106
  if local_name:
95
107
  target_path = self.path(local_name)
@@ -112,22 +124,42 @@ class TaskFiles:
112
124
  tasks.append(self._download_single_file(s3_key, local_file_path))
113
125
 
114
126
  if tasks:
115
- await gather(*tasks)
116
-
117
- await self._log_event("download_dir", f"s3://{bucket}/{key}", str(target_path))
127
+ results = await gather(*tasks)
128
+ total_size = sum(r["size"] for r in results)
129
+ await self._log_event(
130
+ "download_dir",
131
+ f"s3://{bucket}/{key}",
132
+ str(target_path),
133
+ metadata={"total_size": total_size, "file_count": len(results)},
134
+ )
135
+ else:
136
+ await self._log_event(
137
+ "download_dir",
138
+ f"s3://{bucket}/{key}",
139
+ str(target_path),
140
+ metadata={"total_size": 0, "file_count": 0},
141
+ )
118
142
  return target_path
119
143
  else:
120
144
  logger.debug(f"Downloading s3://{bucket}/{key} -> {target_path}")
121
- await self._download_single_file(key, target_path)
122
- await self._log_event("download", f"s3://{bucket}/{key}", str(target_path))
145
+ meta = await self._download_single_file(
146
+ key,
147
+ target_path,
148
+ expected_size=verify_meta.get("size"),
149
+ expected_hash=verify_meta.get("hash"),
150
+ )
151
+ await self._log_event("download", f"s3://{bucket}/{key}", str(target_path), metadata=meta)
123
152
  return target_path
124
153
 
125
- async def _upload_single_file(self, local_path: Path, s3_key: str) -> None:
126
- """Uploads a single file safely using semaphore."""
154
+ async def _upload_single_file(self, local_path: Path, s3_key: str) -> dict[str, Any]:
155
+ """Uploads a single file safely using semaphore. Returns S3 metadata."""
127
156
  async with self._semaphore:
157
+ file_size = local_path.stat().st_size
128
158
  async with aiopen(local_path, "rb") as f:
129
159
  content = await f.read()
130
- await put_async(self._store, s3_key, content)
160
+ result = await put_async(self._store, s3_key, content)
161
+ etag = result.e_tag.strip('"') if result.e_tag else None
162
+ return {"size": file_size, "etag": etag}
131
163
 
132
164
  async def upload(self, local_name: str, remote_name: str | None = None) -> str:
133
165
  """
@@ -158,26 +190,30 @@ class TaskFiles:
158
190
 
159
191
  tasks = [self._upload_single_file(lp, k) for lp, k in files_map]
160
192
  if tasks:
161
- await gather(*tasks)
193
+ results = await gather(*tasks)
194
+ total_size = sum(r["size"] for r in results)
195
+ metadata = {"total_size": total_size, "file_count": len(results)}
196
+ else:
197
+ metadata = {"total_size": 0, "file_count": 0}
162
198
 
163
199
  uri = f"s3://{self._bucket}/{target_prefix}"
164
- await self._log_event("upload_dir", uri, str(local_path))
200
+ await self._log_event("upload_dir", uri, str(local_path), metadata=metadata)
165
201
  return uri
166
202
 
167
203
  elif local_path.exists():
168
204
  target_key = f"{self._s3_prefix}{(remote_name or local_name).lstrip('/')}"
169
205
  logger.debug(f"Uploading {local_path} -> s3://{self._bucket}/{target_key}")
170
206
 
171
- await self._upload_single_file(local_path, target_key)
207
+ meta = await self._upload_single_file(local_path, target_key)
172
208
 
173
209
  uri = f"s3://{self._bucket}/{target_key}"
174
- await self._log_event("upload", uri, str(local_path))
210
+ await self._log_event("upload", uri, str(local_path), metadata=meta)
175
211
  return uri
176
212
  else:
177
213
  raise FileNotFoundError(f"Local file/dir not found: {local_path}")
178
214
 
179
215
  async def read_text(self, name_or_uri: str) -> str:
180
- bucket, key, _ = self._parse_s3_uri(name_or_uri)
216
+ bucket, key, _ = parse_uri(name_or_uri, self._bucket, self._s3_prefix)
181
217
  filename = key.split("/")[-1]
182
218
  local_path = self.path(filename)
183
219
 
@@ -188,7 +224,7 @@ class TaskFiles:
188
224
  return await f.read()
189
225
 
190
226
  async def read_json(self, name_or_uri: str) -> Any:
191
- bucket, key, _ = self._parse_s3_uri(name_or_uri)
227
+ bucket, key, _ = parse_uri(name_or_uri, self._bucket, self._s3_prefix)
192
228
  filename = key.split("/")[-1]
193
229
  local_path = self.path(filename)
194
230
 
@@ -235,21 +271,31 @@ class TaskFiles:
235
271
  if self.local_dir.exists():
236
272
  await to_thread(rmtree, self.local_dir)
237
273
 
238
- async def _log_event(self, operation: str, file_uri: str, local_path: str) -> None:
274
+ async def _log_event(
275
+ self,
276
+ operation: str,
277
+ file_uri: str,
278
+ local_path: str,
279
+ metadata: dict[str, Any] | None = None,
280
+ ) -> None:
239
281
  if not self._history:
240
282
  return
241
283
 
242
284
  try:
285
+ context_snapshot = {
286
+ "operation": operation,
287
+ "s3_uri": file_uri,
288
+ "local_path": str(local_path),
289
+ }
290
+ if metadata:
291
+ context_snapshot.update(metadata)
292
+
243
293
  await self._history.log_job_event(
244
294
  {
245
295
  "job_id": self._job_id,
246
296
  "event_type": "s3_operation",
247
297
  "state": "running",
248
- "context_snapshot": {
249
- "operation": operation,
250
- "s3_uri": file_uri,
251
- "local_path": str(local_path),
252
- },
298
+ "context_snapshot": context_snapshot,
253
299
  }
254
300
  )
255
301
  except Exception as e:
@@ -306,6 +352,16 @@ class S3Service:
306
352
  logger.error(f"Failed to initialize S3 Store: {e}")
307
353
  self._enabled = False
308
354
 
355
+ def get_config_hash(self) -> str | None:
356
+ """Returns a hash of the current S3 configuration for consistency checks."""
357
+ if not self._enabled:
358
+ return None
359
+ return calculate_config_hash(
360
+ self.config.S3_ENDPOINT_URL,
361
+ self.config.S3_ACCESS_KEY,
362
+ self.config.S3_DEFAULT_BUCKET,
363
+ )
364
+
309
365
  def get_task_files(self, job_id: str) -> TaskFiles | None:
310
366
  if not self._enabled or not self._store or not self._semaphore:
311
367
  return None
@@ -22,14 +22,17 @@ def load_schedules_from_file(file_path: str) -> list[ScheduledJobConfig]:
22
22
 
23
23
  schedules = []
24
24
  for name, config in data.items():
25
- # Skip sections that might be metadata (though TOML structure usually implies all top-level keys are jobs)
26
25
  if not isinstance(config, dict):
27
26
  continue
28
27
 
28
+ blueprint = config.get("blueprint")
29
+ if not isinstance(blueprint, str):
30
+ raise ValueError(f"Schedule '{name}' is missing a 'blueprint' name.")
31
+
29
32
  schedules.append(
30
33
  ScheduledJobConfig(
31
34
  name=name,
32
- blueprint=config.get("blueprint"),
35
+ blueprint=blueprint,
33
36
  input_data=config.get("input_data", {}),
34
37
  interval_seconds=config.get("interval_seconds"),
35
38
  daily_at=config.get("daily_at"),