avtomatika 1.0b8__py3-none-any.whl → 1.0b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- avtomatika/api/handlers.py +5 -257
- avtomatika/api/routes.py +42 -63
- avtomatika/api.html +1 -1
- avtomatika/app_keys.py +1 -0
- avtomatika/blueprint.py +3 -2
- avtomatika/config.py +8 -0
- avtomatika/constants.py +75 -25
- avtomatika/data_types.py +2 -22
- avtomatika/dispatcher.py +4 -0
- avtomatika/engine.py +119 -7
- avtomatika/executor.py +19 -19
- avtomatika/logging_config.py +16 -7
- avtomatika/s3.py +96 -40
- avtomatika/scheduler_config_loader.py +5 -2
- avtomatika/security.py +56 -74
- avtomatika/services/__init__.py +0 -0
- avtomatika/services/worker_service.py +267 -0
- avtomatika/storage/base.py +10 -0
- avtomatika/storage/memory.py +15 -4
- avtomatika/storage/redis.py +42 -11
- avtomatika/telemetry.py +8 -7
- avtomatika/utils/webhook_sender.py +3 -3
- avtomatika/watcher.py +4 -2
- avtomatika/ws_manager.py +16 -8
- {avtomatika-1.0b8.dist-info → avtomatika-1.0b10.dist-info}/METADATA +47 -15
- avtomatika-1.0b10.dist-info/RECORD +48 -0
- {avtomatika-1.0b8.dist-info → avtomatika-1.0b10.dist-info}/WHEEL +1 -1
- avtomatika-1.0b8.dist-info/RECORD +0 -46
- {avtomatika-1.0b8.dist-info → avtomatika-1.0b10.dist-info}/licenses/LICENSE +0 -0
- {avtomatika-1.0b8.dist-info → avtomatika-1.0b10.dist-info}/top_level.txt +0 -0
avtomatika/engine.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from asyncio import TimeoutError as AsyncTimeoutError
|
|
2
2
|
from asyncio import create_task, gather, get_running_loop, wait_for
|
|
3
3
|
from logging import getLogger
|
|
4
|
-
from typing import Any
|
|
4
|
+
from typing import Any, Optional
|
|
5
5
|
from uuid import uuid4
|
|
6
6
|
|
|
7
7
|
from aiohttp import ClientSession, web
|
|
@@ -24,6 +24,7 @@ from .app_keys import (
|
|
|
24
24
|
SCHEDULER_TASK_KEY,
|
|
25
25
|
WATCHER_KEY,
|
|
26
26
|
WATCHER_TASK_KEY,
|
|
27
|
+
WORKER_SERVICE_KEY,
|
|
27
28
|
WS_MANAGER_KEY,
|
|
28
29
|
)
|
|
29
30
|
from .blueprint import StateMachineBlueprint
|
|
@@ -40,6 +41,7 @@ from .logging_config import setup_logging
|
|
|
40
41
|
from .reputation import ReputationCalculator
|
|
41
42
|
from .s3 import S3Service
|
|
42
43
|
from .scheduler import Scheduler
|
|
44
|
+
from .services.worker_service import WorkerService
|
|
43
45
|
from .storage.base import StorageBackend
|
|
44
46
|
from .telemetry import setup_telemetry
|
|
45
47
|
from .utils.webhook_sender import WebhookPayload, WebhookSender
|
|
@@ -68,10 +70,19 @@ class OrchestratorEngine:
|
|
|
68
70
|
self.config = config
|
|
69
71
|
self.blueprints: dict[str, StateMachineBlueprint] = {}
|
|
70
72
|
self.history_storage: HistoryStorageBase = NoOpHistoryStorage()
|
|
71
|
-
self.ws_manager = WebSocketManager()
|
|
73
|
+
self.ws_manager = WebSocketManager(self.storage)
|
|
72
74
|
self.app = web.Application(middlewares=[compression_middleware])
|
|
73
75
|
self.app[ENGINE_KEY] = self
|
|
76
|
+
self.worker_service: Optional[WorkerService] = None
|
|
74
77
|
self._setup_done = False
|
|
78
|
+
self.webhook_sender: WebhookSender
|
|
79
|
+
self.dispatcher: Dispatcher
|
|
80
|
+
self.runner: web.AppRunner
|
|
81
|
+
self.site: web.TCPSite
|
|
82
|
+
|
|
83
|
+
from rxon import HttpListener
|
|
84
|
+
|
|
85
|
+
self.rxon_listener = HttpListener(self.app)
|
|
75
86
|
|
|
76
87
|
def register_blueprint(self, blueprint: StateMachineBlueprint) -> None:
|
|
77
88
|
if self._setup_done:
|
|
@@ -142,8 +153,74 @@ class OrchestratorEngine:
|
|
|
142
153
|
)
|
|
143
154
|
self.history_storage = NoOpHistoryStorage()
|
|
144
155
|
|
|
156
|
+
async def handle_rxon_message(self, message_type: str, payload: Any, context: dict) -> Any:
|
|
157
|
+
"""Core handler for RXON protocol messages via any listener."""
|
|
158
|
+
from rxon.security import extract_cert_identity
|
|
159
|
+
|
|
160
|
+
from .security import verify_worker_auth
|
|
161
|
+
|
|
162
|
+
request = context.get("raw_request")
|
|
163
|
+
token = context.get("token")
|
|
164
|
+
cert_identity = extract_cert_identity(request) if request else None
|
|
165
|
+
|
|
166
|
+
worker_id_hint = context.get("worker_id_hint")
|
|
167
|
+
|
|
168
|
+
if not worker_id_hint:
|
|
169
|
+
if message_type == "poll" and isinstance(payload, str):
|
|
170
|
+
worker_id_hint = payload
|
|
171
|
+
elif isinstance(payload, dict) and "worker_id" in payload:
|
|
172
|
+
worker_id_hint = payload["worker_id"]
|
|
173
|
+
elif hasattr(payload, "worker_id"):
|
|
174
|
+
worker_id_hint = payload.worker_id
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
auth_worker_id = await verify_worker_auth(self.storage, self.config, token, cert_identity, worker_id_hint)
|
|
178
|
+
except PermissionError as e:
|
|
179
|
+
raise web.HTTPUnauthorized(text=str(e)) from e
|
|
180
|
+
except ValueError as e:
|
|
181
|
+
raise web.HTTPBadRequest(text=str(e)) from e
|
|
182
|
+
|
|
183
|
+
if self.worker_service is None:
|
|
184
|
+
raise web.HTTPInternalServerError(text="WorkerService is not initialized.")
|
|
185
|
+
|
|
186
|
+
if message_type == "register":
|
|
187
|
+
return await self.worker_service.register_worker(payload)
|
|
188
|
+
|
|
189
|
+
elif message_type == "poll":
|
|
190
|
+
return await self.worker_service.get_next_task(auth_worker_id)
|
|
191
|
+
|
|
192
|
+
elif message_type == "result":
|
|
193
|
+
return await self.worker_service.process_task_result(payload, auth_worker_id)
|
|
194
|
+
|
|
195
|
+
elif message_type == "heartbeat":
|
|
196
|
+
return await self.worker_service.update_worker_heartbeat(auth_worker_id, payload)
|
|
197
|
+
|
|
198
|
+
elif message_type == "sts_token":
|
|
199
|
+
if cert_identity is None:
|
|
200
|
+
raise web.HTTPForbidden(text="Unauthorized: mTLS certificate required to issue access token.")
|
|
201
|
+
return await self.worker_service.issue_access_token(auth_worker_id)
|
|
202
|
+
|
|
203
|
+
elif message_type == "websocket":
|
|
204
|
+
ws = payload
|
|
205
|
+
await self.ws_manager.register(auth_worker_id, ws)
|
|
206
|
+
try:
|
|
207
|
+
from aiohttp import WSMsgType
|
|
208
|
+
|
|
209
|
+
async for msg in ws:
|
|
210
|
+
if msg.type == WSMsgType.TEXT:
|
|
211
|
+
try:
|
|
212
|
+
data = msg.json()
|
|
213
|
+
await self.ws_manager.handle_message(auth_worker_id, data)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
logger.error(f"Error processing WebSocket message from {auth_worker_id}: {e}")
|
|
216
|
+
elif msg.type == WSMsgType.ERROR:
|
|
217
|
+
break
|
|
218
|
+
finally:
|
|
219
|
+
await self.ws_manager.unregister(auth_worker_id)
|
|
220
|
+
return None
|
|
221
|
+
|
|
145
222
|
async def on_startup(self, app: web.Application) -> None:
|
|
146
|
-
#
|
|
223
|
+
# Fail Fast: Check Storage Connection
|
|
147
224
|
if not await self.storage.ping():
|
|
148
225
|
logger.critical("Failed to connect to Storage Backend (Redis). Exiting.")
|
|
149
226
|
raise RuntimeError("Storage Backend is unavailable.")
|
|
@@ -208,14 +285,21 @@ class OrchestratorEngine:
|
|
|
208
285
|
app[WS_MANAGER_KEY] = self.ws_manager
|
|
209
286
|
app[S3_SERVICE_KEY] = S3Service(self.config, self.history_storage)
|
|
210
287
|
|
|
288
|
+
self.worker_service = WorkerService(self.storage, self.history_storage, self.config, self)
|
|
289
|
+
app[WORKER_SERVICE_KEY] = self.worker_service
|
|
290
|
+
|
|
211
291
|
app[EXECUTOR_TASK_KEY] = create_task(app[EXECUTOR_KEY].run())
|
|
212
292
|
app[WATCHER_TASK_KEY] = create_task(app[WATCHER_KEY].run())
|
|
213
293
|
app[REPUTATION_CALCULATOR_TASK_KEY] = create_task(app[REPUTATION_CALCULATOR_KEY].run())
|
|
214
294
|
app[HEALTH_CHECKER_TASK_KEY] = create_task(app[HEALTH_CHECKER_KEY].run())
|
|
215
295
|
app[SCHEDULER_TASK_KEY] = create_task(app[SCHEDULER_KEY].run())
|
|
216
296
|
|
|
297
|
+
await self.rxon_listener.start(self.handle_rxon_message)
|
|
298
|
+
|
|
217
299
|
async def on_shutdown(self, app: web.Application) -> None:
|
|
218
300
|
logger.info("Shutdown sequence started.")
|
|
301
|
+
await self.rxon_listener.stop()
|
|
302
|
+
|
|
219
303
|
app[EXECUTOR_KEY].stop()
|
|
220
304
|
app[WATCHER_KEY].stop()
|
|
221
305
|
app[REPUTATION_CALCULATOR_KEY].stop()
|
|
@@ -274,6 +358,8 @@ class OrchestratorEngine:
|
|
|
274
358
|
blueprint_name: str,
|
|
275
359
|
initial_data: dict[str, Any],
|
|
276
360
|
source: str = "internal",
|
|
361
|
+
tracing_context: dict[str, str] | None = None,
|
|
362
|
+
data_metadata: dict[str, Any] | None = None,
|
|
277
363
|
) -> str:
|
|
278
364
|
"""Creates a job directly, bypassing the HTTP API layer.
|
|
279
365
|
Useful for internal schedulers and triggers.
|
|
@@ -297,8 +383,9 @@ class OrchestratorEngine:
|
|
|
297
383
|
"initial_data": initial_data,
|
|
298
384
|
"state_history": {},
|
|
299
385
|
"status": JOB_STATUS_PENDING,
|
|
300
|
-
"tracing_context": {},
|
|
386
|
+
"tracing_context": tracing_context or {},
|
|
301
387
|
"client_config": client_config,
|
|
388
|
+
"data_metadata": data_metadata or {},
|
|
302
389
|
}
|
|
303
390
|
await self.storage.save_job_state(job_id, job_state)
|
|
304
391
|
await self.storage.enqueue_job(job_id)
|
|
@@ -374,19 +461,44 @@ class OrchestratorEngine:
|
|
|
374
461
|
|
|
375
462
|
def run(self) -> None:
|
|
376
463
|
self.setup()
|
|
464
|
+
ssl_context = None
|
|
465
|
+
if self.config.TLS_ENABLED:
|
|
466
|
+
from rxon.security import create_server_ssl_context
|
|
467
|
+
|
|
468
|
+
ssl_context = create_server_ssl_context(
|
|
469
|
+
cert_path=self.config.TLS_CERT_PATH,
|
|
470
|
+
key_path=self.config.TLS_KEY_PATH,
|
|
471
|
+
ca_path=self.config.TLS_CA_PATH,
|
|
472
|
+
require_client_cert=self.config.TLS_REQUIRE_CLIENT_CERT,
|
|
473
|
+
)
|
|
474
|
+
print(f"TLS enabled. mTLS required: {self.config.TLS_REQUIRE_CLIENT_CERT}")
|
|
475
|
+
|
|
377
476
|
print(
|
|
378
477
|
f"Starting OrchestratorEngine API server on {self.config.API_HOST}:{self.config.API_PORT} in blocking mode."
|
|
379
478
|
)
|
|
380
|
-
web.run_app(self.app, host=self.config.API_HOST, port=self.config.API_PORT)
|
|
479
|
+
web.run_app(self.app, host=self.config.API_HOST, port=self.config.API_PORT, ssl_context=ssl_context)
|
|
381
480
|
|
|
382
481
|
async def start(self):
|
|
383
482
|
"""Starts the orchestrator engine non-blockingly."""
|
|
384
483
|
self.setup()
|
|
385
484
|
self.runner = web.AppRunner(self.app)
|
|
386
485
|
await self.runner.setup()
|
|
387
|
-
|
|
486
|
+
|
|
487
|
+
ssl_context = None
|
|
488
|
+
if self.config.TLS_ENABLED:
|
|
489
|
+
from rxon.security import create_server_ssl_context
|
|
490
|
+
|
|
491
|
+
ssl_context = create_server_ssl_context(
|
|
492
|
+
cert_path=self.config.TLS_CERT_PATH,
|
|
493
|
+
key_path=self.config.TLS_KEY_PATH,
|
|
494
|
+
ca_path=self.config.TLS_CA_PATH,
|
|
495
|
+
require_client_cert=self.config.TLS_REQUIRE_CLIENT_CERT,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
self.site = web.TCPSite(self.runner, self.config.API_HOST, self.config.API_PORT, ssl_context=ssl_context)
|
|
388
499
|
await self.site.start()
|
|
389
|
-
|
|
500
|
+
protocol = "https" if self.config.TLS_ENABLED else "http"
|
|
501
|
+
print(f"OrchestratorEngine API server running on {protocol}://{self.config.API_HOST}:{self.config.API_PORT}")
|
|
390
502
|
|
|
391
503
|
async def stop(self):
|
|
392
504
|
"""Stops the orchestrator engine."""
|
avtomatika/executor.py
CHANGED
|
@@ -48,6 +48,16 @@ except ImportError:
|
|
|
48
48
|
TraceContextTextMapPropagator = NoOpTraceContextTextMapPropagator() # Instantiate the class
|
|
49
49
|
|
|
50
50
|
from .app_keys import S3_SERVICE_KEY
|
|
51
|
+
from .constants import (
|
|
52
|
+
JOB_STATUS_ERROR,
|
|
53
|
+
JOB_STATUS_FAILED,
|
|
54
|
+
JOB_STATUS_FINISHED,
|
|
55
|
+
JOB_STATUS_PENDING,
|
|
56
|
+
JOB_STATUS_QUARANTINED,
|
|
57
|
+
JOB_STATUS_RUNNING,
|
|
58
|
+
JOB_STATUS_WAITING_FOR_PARALLEL,
|
|
59
|
+
JOB_STATUS_WAITING_FOR_WORKER,
|
|
60
|
+
)
|
|
51
61
|
from .context import ActionFactory
|
|
52
62
|
from .data_types import ClientConfig, JobContext
|
|
53
63
|
from .history.base import HistoryStorageBase
|
|
@@ -59,7 +69,7 @@ logger = getLogger(
|
|
|
59
69
|
__name__
|
|
60
70
|
) # Re-declare logger after potential redefinition in except block if opentelemetry was missing
|
|
61
71
|
|
|
62
|
-
TERMINAL_STATES = {
|
|
72
|
+
TERMINAL_STATES = {JOB_STATUS_FINISHED, JOB_STATUS_FAILED, JOB_STATUS_ERROR, JOB_STATUS_QUARANTINED}
|
|
63
73
|
|
|
64
74
|
|
|
65
75
|
class JobExecutor:
|
|
@@ -263,7 +273,7 @@ class JobExecutor:
|
|
|
263
273
|
# When transitioning to a new state, reset the retry counter.
|
|
264
274
|
job_state["retry_count"] = 0
|
|
265
275
|
job_state["current_state"] = next_state
|
|
266
|
-
job_state["status"] =
|
|
276
|
+
job_state["status"] = JOB_STATUS_RUNNING
|
|
267
277
|
await self.storage.save_job_state(job_id, job_state)
|
|
268
278
|
|
|
269
279
|
if next_state not in TERMINAL_STATES:
|
|
@@ -280,11 +290,7 @@ class JobExecutor:
|
|
|
280
290
|
create_task(task_files.cleanup())
|
|
281
291
|
|
|
282
292
|
await self._check_and_resume_parent(job_state)
|
|
283
|
-
|
|
284
|
-
event_type = "job_finished" if next_state == "finished" else "job_failed"
|
|
285
|
-
# Since _check_and_resume_parent is for sub-jobs, we only send webhook if it's a top-level job
|
|
286
|
-
# or if the user explicitly requested it for sub-jobs (by providing webhook_url).
|
|
287
|
-
# The current logic stores webhook_url in job_state, so we just check it.
|
|
293
|
+
event_type = "job_finished" if next_state == JOB_STATUS_FINISHED else "job_failed"
|
|
288
294
|
await self.engine.send_job_webhook(job_state, event_type)
|
|
289
295
|
|
|
290
296
|
async def _handle_dispatch(
|
|
@@ -313,21 +319,15 @@ class JobExecutor:
|
|
|
313
319
|
logger.info(f"Job {job_id} is now paused, awaiting human approval.")
|
|
314
320
|
else:
|
|
315
321
|
logger.info(f"Job {job_id} dispatching task: {task_info}")
|
|
316
|
-
|
|
317
322
|
now = monotonic()
|
|
318
|
-
# Safely get timeout, falling back to the global config if not provided in the task.
|
|
319
|
-
# This prevents TypeErrors if 'timeout_seconds' is missing.
|
|
320
323
|
timeout_seconds = task_info.get("timeout_seconds") or self.engine.config.WORKER_TIMEOUT_SECONDS
|
|
321
324
|
timeout_at = now + timeout_seconds
|
|
322
|
-
|
|
323
|
-
# Set status to waiting and add to watch list *before* dispatching
|
|
324
|
-
job_state["status"] = "waiting_for_worker"
|
|
325
|
+
job_state["status"] = JOB_STATUS_WAITING_FOR_WORKER
|
|
325
326
|
job_state["task_dispatched_at"] = now
|
|
326
327
|
job_state["current_task_info"] = task_info # Save for retries
|
|
327
328
|
job_state["current_task_transitions"] = task_info.get("transitions", {})
|
|
328
329
|
await self.storage.save_job_state(job_id, job_state)
|
|
329
330
|
await self.storage.add_job_to_watch(job_id, timeout_at)
|
|
330
|
-
|
|
331
331
|
await self.dispatcher.dispatch(job_state, task_info)
|
|
332
332
|
|
|
333
333
|
async def _handle_run_blueprint(
|
|
@@ -355,7 +355,7 @@ class JobExecutor:
|
|
|
355
355
|
"blueprint_name": sub_blueprint_info["blueprint_name"],
|
|
356
356
|
"current_state": "start",
|
|
357
357
|
"initial_data": sub_blueprint_info["initial_data"],
|
|
358
|
-
"status":
|
|
358
|
+
"status": JOB_STATUS_PENDING,
|
|
359
359
|
"parent_job_id": parent_job_id,
|
|
360
360
|
}
|
|
361
361
|
await self.storage.save_job_state(child_job_id, child_job_state)
|
|
@@ -388,7 +388,7 @@ class JobExecutor:
|
|
|
388
388
|
branch_task_ids = [str(uuid4()) for _ in tasks_to_dispatch]
|
|
389
389
|
|
|
390
390
|
# Update job state for parallel execution
|
|
391
|
-
job_state["status"] =
|
|
391
|
+
job_state["status"] = JOB_STATUS_WAITING_FOR_PARALLEL
|
|
392
392
|
job_state["aggregation_target"] = aggregate_into
|
|
393
393
|
job_state["active_branches"] = branch_task_ids
|
|
394
394
|
job_state["aggregation_results"] = {}
|
|
@@ -466,7 +466,7 @@ class JobExecutor:
|
|
|
466
466
|
logger.critical(
|
|
467
467
|
f"Job {job_id} has failed handler execution {max_retries + 1} times. Moving to quarantine.",
|
|
468
468
|
)
|
|
469
|
-
job_state["status"] =
|
|
469
|
+
job_state["status"] = JOB_STATUS_QUARANTINED
|
|
470
470
|
job_state["error_message"] = str(error)
|
|
471
471
|
await self.storage.save_job_state(job_id, job_state)
|
|
472
472
|
await self.storage.quarantine_job(job_id)
|
|
@@ -499,7 +499,7 @@ class JobExecutor:
|
|
|
499
499
|
return
|
|
500
500
|
|
|
501
501
|
# Determine the outcome of the child job to select the correct transition.
|
|
502
|
-
child_outcome = "success" if child_job_state["current_state"] ==
|
|
502
|
+
child_outcome = "success" if child_job_state["current_state"] == JOB_STATUS_FINISHED else "failure"
|
|
503
503
|
transitions = parent_job_state.get("current_task_transitions", {})
|
|
504
504
|
next_state = transitions.get(child_outcome, "failed")
|
|
505
505
|
|
|
@@ -514,7 +514,7 @@ class JobExecutor:
|
|
|
514
514
|
|
|
515
515
|
# Update the parent job to its new state and re-enqueue it.
|
|
516
516
|
parent_job_state["current_state"] = next_state
|
|
517
|
-
parent_job_state["status"] =
|
|
517
|
+
parent_job_state["status"] = JOB_STATUS_RUNNING
|
|
518
518
|
await self.storage.save_job_state(parent_job_id, parent_job_state)
|
|
519
519
|
await self.storage.enqueue_job(parent_job_id)
|
|
520
520
|
|
avtomatika/logging_config.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
2
|
from logging import DEBUG, Formatter, StreamHandler, getLogger
|
|
3
3
|
from sys import stdout
|
|
4
|
+
from typing import Any, Literal, Optional
|
|
4
5
|
from zoneinfo import ZoneInfo
|
|
5
6
|
|
|
6
7
|
from pythonjsonlogger import json
|
|
@@ -9,14 +10,22 @@ from pythonjsonlogger import json
|
|
|
9
10
|
class TimezoneFormatter(Formatter):
|
|
10
11
|
"""Formatter that respects a custom timezone."""
|
|
11
12
|
|
|
12
|
-
def __init__(
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
fmt: Optional[str] = None,
|
|
16
|
+
datefmt: Optional[str] = None,
|
|
17
|
+
style: Literal["%", "{", "$"] = "%",
|
|
18
|
+
validate: bool = True,
|
|
19
|
+
*,
|
|
20
|
+
tz_name: str = "UTC",
|
|
21
|
+
) -> None:
|
|
13
22
|
super().__init__(fmt, datefmt, style, validate)
|
|
14
23
|
self.tz = ZoneInfo(tz_name)
|
|
15
24
|
|
|
16
|
-
def converter(self, timestamp):
|
|
25
|
+
def converter(self, timestamp: float) -> datetime: # type: ignore[override]
|
|
17
26
|
return datetime.fromtimestamp(timestamp, self.tz)
|
|
18
27
|
|
|
19
|
-
def formatTime(self, record, datefmt=None):
|
|
28
|
+
def formatTime(self, record: Any, datefmt: Optional[str] = None) -> str:
|
|
20
29
|
dt = self.converter(record.created)
|
|
21
30
|
if datefmt:
|
|
22
31
|
s = dt.strftime(datefmt)
|
|
@@ -28,14 +37,14 @@ class TimezoneFormatter(Formatter):
|
|
|
28
37
|
return s
|
|
29
38
|
|
|
30
39
|
|
|
31
|
-
class TimezoneJsonFormatter(json.JsonFormatter):
|
|
40
|
+
class TimezoneJsonFormatter(json.JsonFormatter): # type: ignore[name-defined]
|
|
32
41
|
"""JSON Formatter that respects a custom timezone."""
|
|
33
42
|
|
|
34
|
-
def __init__(self, *args, tz_name="UTC", **kwargs):
|
|
43
|
+
def __init__(self, *args: Any, tz_name: str = "UTC", **kwargs: Any) -> None:
|
|
35
44
|
super().__init__(*args, **kwargs)
|
|
36
45
|
self.tz = ZoneInfo(tz_name)
|
|
37
46
|
|
|
38
|
-
def formatTime(self, record, datefmt=None):
|
|
47
|
+
def formatTime(self, record: Any, datefmt: Optional[str] = None) -> str:
|
|
39
48
|
# Override formatTime to use timezone-aware datetime
|
|
40
49
|
dt = datetime.fromtimestamp(record.created, self.tz)
|
|
41
50
|
if datefmt:
|
|
@@ -44,7 +53,7 @@ class TimezoneJsonFormatter(json.JsonFormatter):
|
|
|
44
53
|
return dt.isoformat()
|
|
45
54
|
|
|
46
55
|
|
|
47
|
-
def setup_logging(log_level: str = "INFO", log_format: str = "json", tz_name: str = "UTC"):
|
|
56
|
+
def setup_logging(log_level: str = "INFO", log_format: str = "json", tz_name: str = "UTC") -> None:
|
|
48
57
|
"""Configures structured logging for the entire application."""
|
|
49
58
|
logger = getLogger("avtomatika")
|
|
50
59
|
logger.setLevel(log_level)
|
avtomatika/s3.py
CHANGED
|
@@ -3,13 +3,15 @@ from logging import getLogger
|
|
|
3
3
|
from os import sep, walk
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from shutil import rmtree
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from aiofiles import open as aiopen
|
|
9
9
|
from obstore import delete_async, get_async, put_async
|
|
10
10
|
from obstore import list as obstore_list
|
|
11
11
|
from obstore.store import S3Store
|
|
12
12
|
from orjson import dumps, loads
|
|
13
|
+
from rxon.blob import calculate_config_hash, parse_uri
|
|
14
|
+
from rxon.exceptions import IntegrityError
|
|
13
15
|
|
|
14
16
|
from .config import Config
|
|
15
17
|
from .history.base import HistoryStorageBase
|
|
@@ -56,40 +58,50 @@ class TaskFiles:
|
|
|
56
58
|
clean_name = filename.split("/")[-1] if "://" in filename else filename.lstrip("/")
|
|
57
59
|
return self.local_dir / clean_name
|
|
58
60
|
|
|
59
|
-
def
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
61
|
+
async def _download_single_file(
|
|
62
|
+
self,
|
|
63
|
+
key: str,
|
|
64
|
+
local_path: Path,
|
|
65
|
+
expected_size: int | None = None,
|
|
66
|
+
expected_hash: str | None = None,
|
|
67
|
+
) -> dict[str, Any]:
|
|
68
|
+
"""Downloads a single file safely using semaphore and streaming.
|
|
69
|
+
Returns metadata (size, etag).
|
|
63
70
|
"""
|
|
64
|
-
is_dir = uri.endswith("/")
|
|
65
|
-
|
|
66
|
-
if not uri.startswith("s3://"):
|
|
67
|
-
key = f"{self._s3_prefix}{uri.lstrip('/')}"
|
|
68
|
-
return self._bucket, key, is_dir
|
|
69
|
-
|
|
70
|
-
parts = uri[5:].split("/", 1)
|
|
71
|
-
bucket = parts[0]
|
|
72
|
-
key = parts[1] if len(parts) > 1 else ""
|
|
73
|
-
return bucket, key, is_dir
|
|
74
|
-
|
|
75
|
-
async def _download_single_file(self, key: str, local_path: Path) -> None:
|
|
76
|
-
"""Downloads a single file safely using semaphore and streaming to avoid OOM."""
|
|
77
71
|
if not local_path.parent.exists():
|
|
78
72
|
await to_thread(local_path.parent.mkdir, parents=True, exist_ok=True)
|
|
79
73
|
|
|
80
74
|
async with self._semaphore:
|
|
81
75
|
response = await get_async(self._store, key)
|
|
76
|
+
meta = response.meta
|
|
77
|
+
file_size = meta.size
|
|
78
|
+
etag = meta.e_tag.strip('"') if meta.e_tag else None
|
|
79
|
+
|
|
80
|
+
if expected_size is not None and file_size != expected_size:
|
|
81
|
+
raise IntegrityError(f"File size mismatch for {key}: expected {expected_size}, got {file_size}")
|
|
82
|
+
|
|
83
|
+
if expected_hash is not None and etag and expected_hash != etag:
|
|
84
|
+
raise IntegrityError(f"Integrity mismatch for {key}: expected ETag {expected_hash}, got {etag}")
|
|
85
|
+
|
|
82
86
|
stream = response.stream()
|
|
83
87
|
async with aiopen(local_path, "wb") as f:
|
|
84
88
|
async for chunk in stream:
|
|
85
89
|
await f.write(chunk)
|
|
86
90
|
|
|
87
|
-
|
|
91
|
+
return {"size": file_size, "etag": etag}
|
|
92
|
+
|
|
93
|
+
async def download(
|
|
94
|
+
self,
|
|
95
|
+
name_or_uri: str,
|
|
96
|
+
local_name: str | None = None,
|
|
97
|
+
verify_meta: dict[str, Any] | None = None,
|
|
98
|
+
) -> Path:
|
|
88
99
|
"""
|
|
89
100
|
Downloads a file or directory (recursively).
|
|
90
101
|
If URI ends with '/', it treats it as a directory.
|
|
91
102
|
"""
|
|
92
|
-
bucket, key, is_dir = self.
|
|
103
|
+
bucket, key, is_dir = parse_uri(name_or_uri, self._bucket, self._s3_prefix)
|
|
104
|
+
verify_meta = verify_meta or {}
|
|
93
105
|
|
|
94
106
|
if local_name:
|
|
95
107
|
target_path = self.path(local_name)
|
|
@@ -112,22 +124,42 @@ class TaskFiles:
|
|
|
112
124
|
tasks.append(self._download_single_file(s3_key, local_file_path))
|
|
113
125
|
|
|
114
126
|
if tasks:
|
|
115
|
-
await gather(*tasks)
|
|
116
|
-
|
|
117
|
-
|
|
127
|
+
results = await gather(*tasks)
|
|
128
|
+
total_size = sum(r["size"] for r in results)
|
|
129
|
+
await self._log_event(
|
|
130
|
+
"download_dir",
|
|
131
|
+
f"s3://{bucket}/{key}",
|
|
132
|
+
str(target_path),
|
|
133
|
+
metadata={"total_size": total_size, "file_count": len(results)},
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
await self._log_event(
|
|
137
|
+
"download_dir",
|
|
138
|
+
f"s3://{bucket}/{key}",
|
|
139
|
+
str(target_path),
|
|
140
|
+
metadata={"total_size": 0, "file_count": 0},
|
|
141
|
+
)
|
|
118
142
|
return target_path
|
|
119
143
|
else:
|
|
120
144
|
logger.debug(f"Downloading s3://{bucket}/{key} -> {target_path}")
|
|
121
|
-
await self._download_single_file(
|
|
122
|
-
|
|
145
|
+
meta = await self._download_single_file(
|
|
146
|
+
key,
|
|
147
|
+
target_path,
|
|
148
|
+
expected_size=verify_meta.get("size"),
|
|
149
|
+
expected_hash=verify_meta.get("hash"),
|
|
150
|
+
)
|
|
151
|
+
await self._log_event("download", f"s3://{bucket}/{key}", str(target_path), metadata=meta)
|
|
123
152
|
return target_path
|
|
124
153
|
|
|
125
|
-
async def _upload_single_file(self, local_path: Path, s3_key: str) ->
|
|
126
|
-
"""Uploads a single file safely using semaphore."""
|
|
154
|
+
async def _upload_single_file(self, local_path: Path, s3_key: str) -> dict[str, Any]:
|
|
155
|
+
"""Uploads a single file safely using semaphore. Returns S3 metadata."""
|
|
127
156
|
async with self._semaphore:
|
|
157
|
+
file_size = local_path.stat().st_size
|
|
128
158
|
async with aiopen(local_path, "rb") as f:
|
|
129
159
|
content = await f.read()
|
|
130
|
-
await put_async(self._store, s3_key, content)
|
|
160
|
+
result = await put_async(self._store, s3_key, content)
|
|
161
|
+
etag = result.e_tag.strip('"') if result.e_tag else None
|
|
162
|
+
return {"size": file_size, "etag": etag}
|
|
131
163
|
|
|
132
164
|
async def upload(self, local_name: str, remote_name: str | None = None) -> str:
|
|
133
165
|
"""
|
|
@@ -158,26 +190,30 @@ class TaskFiles:
|
|
|
158
190
|
|
|
159
191
|
tasks = [self._upload_single_file(lp, k) for lp, k in files_map]
|
|
160
192
|
if tasks:
|
|
161
|
-
await gather(*tasks)
|
|
193
|
+
results = await gather(*tasks)
|
|
194
|
+
total_size = sum(r["size"] for r in results)
|
|
195
|
+
metadata = {"total_size": total_size, "file_count": len(results)}
|
|
196
|
+
else:
|
|
197
|
+
metadata = {"total_size": 0, "file_count": 0}
|
|
162
198
|
|
|
163
199
|
uri = f"s3://{self._bucket}/{target_prefix}"
|
|
164
|
-
await self._log_event("upload_dir", uri, str(local_path))
|
|
200
|
+
await self._log_event("upload_dir", uri, str(local_path), metadata=metadata)
|
|
165
201
|
return uri
|
|
166
202
|
|
|
167
203
|
elif local_path.exists():
|
|
168
204
|
target_key = f"{self._s3_prefix}{(remote_name or local_name).lstrip('/')}"
|
|
169
205
|
logger.debug(f"Uploading {local_path} -> s3://{self._bucket}/{target_key}")
|
|
170
206
|
|
|
171
|
-
await self._upload_single_file(local_path, target_key)
|
|
207
|
+
meta = await self._upload_single_file(local_path, target_key)
|
|
172
208
|
|
|
173
209
|
uri = f"s3://{self._bucket}/{target_key}"
|
|
174
|
-
await self._log_event("upload", uri, str(local_path))
|
|
210
|
+
await self._log_event("upload", uri, str(local_path), metadata=meta)
|
|
175
211
|
return uri
|
|
176
212
|
else:
|
|
177
213
|
raise FileNotFoundError(f"Local file/dir not found: {local_path}")
|
|
178
214
|
|
|
179
215
|
async def read_text(self, name_or_uri: str) -> str:
|
|
180
|
-
bucket, key, _ = self.
|
|
216
|
+
bucket, key, _ = parse_uri(name_or_uri, self._bucket, self._s3_prefix)
|
|
181
217
|
filename = key.split("/")[-1]
|
|
182
218
|
local_path = self.path(filename)
|
|
183
219
|
|
|
@@ -188,7 +224,7 @@ class TaskFiles:
|
|
|
188
224
|
return await f.read()
|
|
189
225
|
|
|
190
226
|
async def read_json(self, name_or_uri: str) -> Any:
|
|
191
|
-
bucket, key, _ = self.
|
|
227
|
+
bucket, key, _ = parse_uri(name_or_uri, self._bucket, self._s3_prefix)
|
|
192
228
|
filename = key.split("/")[-1]
|
|
193
229
|
local_path = self.path(filename)
|
|
194
230
|
|
|
@@ -235,21 +271,31 @@ class TaskFiles:
|
|
|
235
271
|
if self.local_dir.exists():
|
|
236
272
|
await to_thread(rmtree, self.local_dir)
|
|
237
273
|
|
|
238
|
-
async def _log_event(
|
|
274
|
+
async def _log_event(
|
|
275
|
+
self,
|
|
276
|
+
operation: str,
|
|
277
|
+
file_uri: str,
|
|
278
|
+
local_path: str,
|
|
279
|
+
metadata: dict[str, Any] | None = None,
|
|
280
|
+
) -> None:
|
|
239
281
|
if not self._history:
|
|
240
282
|
return
|
|
241
283
|
|
|
242
284
|
try:
|
|
285
|
+
context_snapshot = {
|
|
286
|
+
"operation": operation,
|
|
287
|
+
"s3_uri": file_uri,
|
|
288
|
+
"local_path": str(local_path),
|
|
289
|
+
}
|
|
290
|
+
if metadata:
|
|
291
|
+
context_snapshot.update(metadata)
|
|
292
|
+
|
|
243
293
|
await self._history.log_job_event(
|
|
244
294
|
{
|
|
245
295
|
"job_id": self._job_id,
|
|
246
296
|
"event_type": "s3_operation",
|
|
247
297
|
"state": "running",
|
|
248
|
-
"context_snapshot":
|
|
249
|
-
"operation": operation,
|
|
250
|
-
"s3_uri": file_uri,
|
|
251
|
-
"local_path": str(local_path),
|
|
252
|
-
},
|
|
298
|
+
"context_snapshot": context_snapshot,
|
|
253
299
|
}
|
|
254
300
|
)
|
|
255
301
|
except Exception as e:
|
|
@@ -306,6 +352,16 @@ class S3Service:
|
|
|
306
352
|
logger.error(f"Failed to initialize S3 Store: {e}")
|
|
307
353
|
self._enabled = False
|
|
308
354
|
|
|
355
|
+
def get_config_hash(self) -> str | None:
|
|
356
|
+
"""Returns a hash of the current S3 configuration for consistency checks."""
|
|
357
|
+
if not self._enabled:
|
|
358
|
+
return None
|
|
359
|
+
return calculate_config_hash(
|
|
360
|
+
self.config.S3_ENDPOINT_URL,
|
|
361
|
+
self.config.S3_ACCESS_KEY,
|
|
362
|
+
self.config.S3_DEFAULT_BUCKET,
|
|
363
|
+
)
|
|
364
|
+
|
|
309
365
|
def get_task_files(self, job_id: str) -> TaskFiles | None:
|
|
310
366
|
if not self._enabled or not self._store or not self._semaphore:
|
|
311
367
|
return None
|
|
@@ -22,14 +22,17 @@ def load_schedules_from_file(file_path: str) -> list[ScheduledJobConfig]:
|
|
|
22
22
|
|
|
23
23
|
schedules = []
|
|
24
24
|
for name, config in data.items():
|
|
25
|
-
# Skip sections that might be metadata (though TOML structure usually implies all top-level keys are jobs)
|
|
26
25
|
if not isinstance(config, dict):
|
|
27
26
|
continue
|
|
28
27
|
|
|
28
|
+
blueprint = config.get("blueprint")
|
|
29
|
+
if not isinstance(blueprint, str):
|
|
30
|
+
raise ValueError(f"Schedule '{name}' is missing a 'blueprint' name.")
|
|
31
|
+
|
|
29
32
|
schedules.append(
|
|
30
33
|
ScheduledJobConfig(
|
|
31
34
|
name=name,
|
|
32
|
-
blueprint=
|
|
35
|
+
blueprint=blueprint,
|
|
33
36
|
input_data=config.get("input_data", {}),
|
|
34
37
|
interval_seconds=config.get("interval_seconds"),
|
|
35
38
|
daily_at=config.get("daily_at"),
|