avtomatika 1.0b7__py3-none-any.whl → 1.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
avtomatika/engine.py CHANGED
@@ -19,10 +19,12 @@ from .app_keys import (
19
19
  HTTP_SESSION_KEY,
20
20
  REPUTATION_CALCULATOR_KEY,
21
21
  REPUTATION_CALCULATOR_TASK_KEY,
22
+ S3_SERVICE_KEY,
22
23
  SCHEDULER_KEY,
23
24
  SCHEDULER_TASK_KEY,
24
25
  WATCHER_KEY,
25
26
  WATCHER_TASK_KEY,
27
+ WORKER_SERVICE_KEY,
26
28
  WS_MANAGER_KEY,
27
29
  )
28
30
  from .blueprint import StateMachineBlueprint
@@ -37,7 +39,9 @@ from .history.base import HistoryStorageBase
37
39
  from .history.noop import NoOpHistoryStorage
38
40
  from .logging_config import setup_logging
39
41
  from .reputation import ReputationCalculator
42
+ from .s3 import S3Service
40
43
  from .scheduler import Scheduler
44
+ from .services.worker_service import WorkerService
41
45
  from .storage.base import StorageBackend
42
46
  from .telemetry import setup_telemetry
43
47
  from .utils.webhook_sender import WebhookPayload, WebhookSender
@@ -54,7 +58,7 @@ def json_dumps(obj: Any) -> str:
54
58
  return dumps(obj).decode("utf-8")
55
59
 
56
60
 
57
- def json_response(data: Any, **kwargs: Any) -> web.Response:
61
+ def json_response(data, **kwargs: Any) -> web.Response:
58
62
  return web.json_response(data, dumps=json_dumps, **kwargs)
59
63
 
60
64
 
@@ -69,8 +73,13 @@ class OrchestratorEngine:
69
73
  self.ws_manager = WebSocketManager()
70
74
  self.app = web.Application(middlewares=[compression_middleware])
71
75
  self.app[ENGINE_KEY] = self
76
+ self.worker_service = None
72
77
  self._setup_done = False
73
78
 
79
+ from rxon import HttpListener
80
+
81
+ self.rxon_listener = HttpListener(self.app)
82
+
74
83
  def register_blueprint(self, blueprint: StateMachineBlueprint) -> None:
75
84
  if self._setup_done:
76
85
  raise RuntimeError("Cannot register blueprints after engine setup.")
@@ -140,7 +149,75 @@ class OrchestratorEngine:
140
149
  )
141
150
  self.history_storage = NoOpHistoryStorage()
142
151
 
152
+ async def handle_rxon_message(self, message_type: str, payload: Any, context: dict) -> Any:
153
+ """Core handler for RXON protocol messages via any listener."""
154
+ from rxon.security import extract_cert_identity
155
+
156
+ from .security import verify_worker_auth
157
+
158
+ request = context.get("raw_request")
159
+ token = context.get("token")
160
+ cert_identity = extract_cert_identity(request) if request else None
161
+
162
+ worker_id_hint = context.get("worker_id_hint")
163
+
164
+ if not worker_id_hint:
165
+ if message_type == "poll" and isinstance(payload, str):
166
+ worker_id_hint = payload
167
+ elif isinstance(payload, dict) and "worker_id" in payload:
168
+ worker_id_hint = payload["worker_id"]
169
+ elif hasattr(payload, "worker_id"):
170
+ worker_id_hint = payload.worker_id
171
+
172
+ try:
173
+ auth_worker_id = await verify_worker_auth(self.storage, self.config, token, cert_identity, worker_id_hint)
174
+ except PermissionError as e:
175
+ raise web.HTTPUnauthorized(text=str(e)) from e
176
+ except ValueError as e:
177
+ raise web.HTTPBadRequest(text=str(e)) from e
178
+
179
+ if message_type == "register":
180
+ return await self.worker_service.register_worker(payload)
181
+
182
+ elif message_type == "poll":
183
+ return await self.worker_service.get_next_task(auth_worker_id)
184
+
185
+ elif message_type == "result":
186
+ return await self.worker_service.process_task_result(payload, auth_worker_id)
187
+
188
+ elif message_type == "heartbeat":
189
+ return await self.worker_service.update_worker_heartbeat(auth_worker_id, payload)
190
+
191
+ elif message_type == "sts_token":
192
+ if cert_identity is None:
193
+ raise web.HTTPForbidden(text="Unauthorized: mTLS certificate required to issue access token.")
194
+ return await self.worker_service.issue_access_token(auth_worker_id)
195
+
196
+ elif message_type == "websocket":
197
+ ws = payload
198
+ await self.ws_manager.register(auth_worker_id, ws)
199
+ try:
200
+ from aiohttp import WSMsgType
201
+
202
+ async for msg in ws:
203
+ if msg.type == WSMsgType.TEXT:
204
+ try:
205
+ data = msg.json()
206
+ await self.ws_manager.handle_message(auth_worker_id, data)
207
+ except Exception as e:
208
+ logger.error(f"Error processing WebSocket message from {auth_worker_id}: {e}")
209
+ elif msg.type == WSMsgType.ERROR:
210
+ break
211
+ finally:
212
+ await self.ws_manager.unregister(auth_worker_id)
213
+ return None
214
+
143
215
  async def on_startup(self, app: web.Application) -> None:
216
+ # Fail Fast: Check Storage Connection
217
+ if not await self.storage.ping():
218
+ logger.critical("Failed to connect to Storage Backend (Redis). Exiting.")
219
+ raise RuntimeError("Storage Backend is unavailable.")
220
+
144
221
  try:
145
222
  from opentelemetry.instrumentation.aiohttp_client import (
146
223
  AioHttpClientInstrumentor,
@@ -152,6 +229,8 @@ class OrchestratorEngine:
152
229
  "opentelemetry-instrumentation-aiohttp-client not found. AIOHTTP client instrumentation is disabled."
153
230
  )
154
231
  await self._setup_history_storage()
232
+ # Start history background worker
233
+ await self.history_storage.start()
155
234
 
156
235
  # Load client configs if the path is provided
157
236
  if self.config.CLIENTS_CONFIG_PATH:
@@ -188,6 +267,7 @@ class OrchestratorEngine:
188
267
 
189
268
  app[HTTP_SESSION_KEY] = ClientSession()
190
269
  self.webhook_sender = WebhookSender(app[HTTP_SESSION_KEY])
270
+ self.webhook_sender.start()
191
271
  self.dispatcher = Dispatcher(self.storage, self.config)
192
272
  app[DISPATCHER_KEY] = self.dispatcher
193
273
  app[EXECUTOR_KEY] = JobExecutor(self, self.history_storage)
@@ -196,6 +276,10 @@ class OrchestratorEngine:
196
276
  app[HEALTH_CHECKER_KEY] = HealthChecker(self)
197
277
  app[SCHEDULER_KEY] = Scheduler(self)
198
278
  app[WS_MANAGER_KEY] = self.ws_manager
279
+ app[S3_SERVICE_KEY] = S3Service(self.config, self.history_storage)
280
+
281
+ self.worker_service = WorkerService(self.storage, self.history_storage, self.config, self)
282
+ app[WORKER_SERVICE_KEY] = self.worker_service
199
283
 
200
284
  app[EXECUTOR_TASK_KEY] = create_task(app[EXECUTOR_KEY].run())
201
285
  app[WATCHER_TASK_KEY] = create_task(app[WATCHER_KEY].run())
@@ -203,8 +287,12 @@ class OrchestratorEngine:
203
287
  app[HEALTH_CHECKER_TASK_KEY] = create_task(app[HEALTH_CHECKER_KEY].run())
204
288
  app[SCHEDULER_TASK_KEY] = create_task(app[SCHEDULER_KEY].run())
205
289
 
290
+ await self.rxon_listener.start(self.handle_rxon_message)
291
+
206
292
  async def on_shutdown(self, app: web.Application) -> None:
207
293
  logger.info("Shutdown sequence started.")
294
+ await self.rxon_listener.stop()
295
+
208
296
  app[EXECUTOR_KEY].stop()
209
297
  app[WATCHER_KEY].stop()
210
298
  app[REPUTATION_CALCULATOR_KEY].stop()
@@ -220,6 +308,13 @@ class OrchestratorEngine:
220
308
  logger.info("Closing WebSocket connections...")
221
309
  await self.ws_manager.close_all()
222
310
 
311
+ logger.info("Stopping WebhookSender...")
312
+ await self.webhook_sender.stop()
313
+
314
+ if S3_SERVICE_KEY in app:
315
+ logger.info("Closing S3 Service...")
316
+ await app[S3_SERVICE_KEY].close()
317
+
223
318
  logger.info("Cancelling background tasks...")
224
319
  app[HEALTH_CHECKER_TASK_KEY].cancel()
225
320
  app[WATCHER_TASK_KEY].cancel()
@@ -256,6 +351,7 @@ class OrchestratorEngine:
256
351
  blueprint_name: str,
257
352
  initial_data: dict[str, Any],
258
353
  source: str = "internal",
354
+ tracing_context: dict[str, str] | None = None,
259
355
  ) -> str:
260
356
  """Creates a job directly, bypassing the HTTP API layer.
261
357
  Useful for internal schedulers and triggers.
@@ -279,7 +375,7 @@ class OrchestratorEngine:
279
375
  "initial_data": initial_data,
280
376
  "state_history": {},
281
377
  "status": JOB_STATUS_PENDING,
282
- "tracing_context": {},
378
+ "tracing_context": tracing_context or {},
283
379
  "client_config": client_config,
284
380
  }
285
381
  await self.storage.save_job_state(job_id, job_state)
@@ -352,23 +448,48 @@ class OrchestratorEngine:
352
448
  )
353
449
 
354
450
  # Run in background to not block the main flow
355
- create_task(self.webhook_sender.send(webhook_url, payload))
451
+ await self.webhook_sender.send(webhook_url, payload)
356
452
 
357
453
  def run(self) -> None:
358
454
  self.setup()
455
+ ssl_context = None
456
+ if self.config.TLS_ENABLED:
457
+ from rxon.security import create_server_ssl_context
458
+
459
+ ssl_context = create_server_ssl_context(
460
+ cert_path=self.config.TLS_CERT_PATH,
461
+ key_path=self.config.TLS_KEY_PATH,
462
+ ca_path=self.config.TLS_CA_PATH,
463
+ require_client_cert=self.config.TLS_REQUIRE_CLIENT_CERT,
464
+ )
465
+ print(f"TLS enabled. mTLS required: {self.config.TLS_REQUIRE_CLIENT_CERT}")
466
+
359
467
  print(
360
468
  f"Starting OrchestratorEngine API server on {self.config.API_HOST}:{self.config.API_PORT} in blocking mode."
361
469
  )
362
- web.run_app(self.app, host=self.config.API_HOST, port=self.config.API_PORT)
470
+ web.run_app(self.app, host=self.config.API_HOST, port=self.config.API_PORT, ssl_context=ssl_context)
363
471
 
364
472
  async def start(self):
365
473
  """Starts the orchestrator engine non-blockingly."""
366
474
  self.setup()
367
475
  self.runner = web.AppRunner(self.app)
368
476
  await self.runner.setup()
369
- self.site = web.TCPSite(self.runner, self.config.API_HOST, self.config.API_PORT)
477
+
478
+ ssl_context = None
479
+ if self.config.TLS_ENABLED:
480
+ from rxon.security import create_server_ssl_context
481
+
482
+ ssl_context = create_server_ssl_context(
483
+ cert_path=self.config.TLS_CERT_PATH,
484
+ key_path=self.config.TLS_KEY_PATH,
485
+ ca_path=self.config.TLS_CA_PATH,
486
+ require_client_cert=self.config.TLS_REQUIRE_CLIENT_CERT,
487
+ )
488
+
489
+ self.site = web.TCPSite(self.runner, self.config.API_HOST, self.config.API_PORT, ssl_context=ssl_context)
370
490
  await self.site.start()
371
- print(f"OrchestratorEngine API server running on http://{self.config.API_HOST}:{self.config.API_PORT}")
491
+ protocol = "https" if self.config.TLS_ENABLED else "http"
492
+ print(f"OrchestratorEngine API server running on {protocol}://{self.config.API_HOST}:{self.config.API_PORT}")
372
493
 
373
494
  async def stop(self):
374
495
  """Stops the orchestrator engine."""
avtomatika/executor.py CHANGED
@@ -47,6 +47,17 @@ except ImportError:
47
47
  inject = NoOpPropagate().inject
48
48
  TraceContextTextMapPropagator = NoOpTraceContextTextMapPropagator() # Instantiate the class
49
49
 
50
+ from .app_keys import S3_SERVICE_KEY
51
+ from .constants import (
52
+ JOB_STATUS_ERROR,
53
+ JOB_STATUS_FAILED,
54
+ JOB_STATUS_FINISHED,
55
+ JOB_STATUS_PENDING,
56
+ JOB_STATUS_QUARANTINED,
57
+ JOB_STATUS_RUNNING,
58
+ JOB_STATUS_WAITING_FOR_PARALLEL,
59
+ JOB_STATUS_WAITING_FOR_WORKER,
60
+ )
50
61
  from .context import ActionFactory
51
62
  from .data_types import ClientConfig, JobContext
52
63
  from .history.base import HistoryStorageBase
@@ -58,7 +69,7 @@ logger = getLogger(
58
69
  __name__
59
70
  ) # Re-declare logger after potential redefinition in except block if opentelemetry was missing
60
71
 
61
- TERMINAL_STATES = {"finished", "failed", "error", "quarantined"}
72
+ TERMINAL_STATES = {JOB_STATUS_FINISHED, JOB_STATUS_FAILED, JOB_STATUS_ERROR, JOB_STATUS_QUARANTINED}
62
73
 
63
74
 
64
75
  class JobExecutor:
@@ -74,7 +85,7 @@ class JobExecutor:
74
85
  self._running = False
75
86
  self._processing_messages: set[str] = set()
76
87
 
77
- async def _process_job(self, job_id: str, message_id: str):
88
+ async def _process_job(self, job_id: str, message_id: str) -> None:
78
89
  """The core logic for processing a single job dequeued from storage."""
79
90
  if message_id in self._processing_messages:
80
91
  return
@@ -143,6 +154,11 @@ class JobExecutor:
143
154
  plan=client_config_dict.get("plan", "unknown"),
144
155
  params=client_config_dict.get("params", {}),
145
156
  )
157
+
158
+ # Get TaskFiles if S3 service is available
159
+ s3_service = self.engine.app.get(S3_SERVICE_KEY)
160
+ task_files = s3_service.get_task_files(job_id) if s3_service else None
161
+
146
162
  context = JobContext(
147
163
  job_id=job_id,
148
164
  current_state=job_state["current_state"],
@@ -153,6 +169,7 @@ class JobExecutor:
153
169
  data_stores=SimpleNamespace(**blueprint.data_stores),
154
170
  tracing_context=tracing_context,
155
171
  aggregation_results=job_state.get("aggregation_results"),
172
+ task_files=task_files,
156
173
  )
157
174
 
158
175
  try:
@@ -173,12 +190,17 @@ class JobExecutor:
173
190
  params_to_inject["context"] = context
174
191
  if "actions" in param_names:
175
192
  params_to_inject["actions"] = action_factory
193
+ if "task_files" in param_names:
194
+ params_to_inject["task_files"] = task_files
176
195
  else:
177
196
  # New injection logic with prioritized lookup.
178
197
  context_as_dict = context._asdict()
179
198
  for param_name in param_names:
199
+ # Direct injection of task_files
200
+ if param_name == "task_files":
201
+ params_to_inject[param_name] = task_files
180
202
  # Look in JobContext fields first.
181
- if param_name in context_as_dict:
203
+ elif param_name in context_as_dict:
182
204
  params_to_inject[param_name] = context_as_dict[param_name]
183
205
  # Then look in state_history (data from previous steps/workers).
184
206
  elif param_name in context.state_history:
@@ -251,19 +273,24 @@ class JobExecutor:
251
273
  # When transitioning to a new state, reset the retry counter.
252
274
  job_state["retry_count"] = 0
253
275
  job_state["current_state"] = next_state
254
- job_state["status"] = "running"
276
+ job_state["status"] = JOB_STATUS_RUNNING
255
277
  await self.storage.save_job_state(job_id, job_state)
256
278
 
257
279
  if next_state not in TERMINAL_STATES:
258
280
  await self.storage.enqueue_job(job_id)
259
281
  else:
260
282
  logger.info(f"Job {job_id} reached terminal state {next_state}")
283
+
284
+ # Clean up S3 files if service is available
285
+ s3_service = self.engine.app.get(S3_SERVICE_KEY)
286
+ if s3_service:
287
+ task_files = s3_service.get_task_files(job_id)
288
+ if task_files:
289
+ # Run cleanup in background to not block response
290
+ create_task(task_files.cleanup())
291
+
261
292
  await self._check_and_resume_parent(job_state)
262
- # Send webhook for finished/failed jobs
263
- event_type = "job_finished" if next_state == "finished" else "job_failed"
264
- # Since _check_and_resume_parent is for sub-jobs, we only send webhook if it's a top-level job
265
- # or if the user explicitly requested it for sub-jobs (by providing webhook_url).
266
- # The current logic stores webhook_url in job_state, so we just check it.
293
+ event_type = "job_finished" if next_state == JOB_STATUS_FINISHED else "job_failed"
267
294
  await self.engine.send_job_webhook(job_state, event_type)
268
295
 
269
296
  async def _handle_dispatch(
@@ -292,21 +319,15 @@ class JobExecutor:
292
319
  logger.info(f"Job {job_id} is now paused, awaiting human approval.")
293
320
  else:
294
321
  logger.info(f"Job {job_id} dispatching task: {task_info}")
295
-
296
322
  now = monotonic()
297
- # Safely get timeout, falling back to the global config if not provided in the task.
298
- # This prevents TypeErrors if 'timeout_seconds' is missing.
299
323
  timeout_seconds = task_info.get("timeout_seconds") or self.engine.config.WORKER_TIMEOUT_SECONDS
300
324
  timeout_at = now + timeout_seconds
301
-
302
- # Set status to waiting and add to watch list *before* dispatching
303
- job_state["status"] = "waiting_for_worker"
325
+ job_state["status"] = JOB_STATUS_WAITING_FOR_WORKER
304
326
  job_state["task_dispatched_at"] = now
305
327
  job_state["current_task_info"] = task_info # Save for retries
306
328
  job_state["current_task_transitions"] = task_info.get("transitions", {})
307
329
  await self.storage.save_job_state(job_id, job_state)
308
330
  await self.storage.add_job_to_watch(job_id, timeout_at)
309
-
310
331
  await self.dispatcher.dispatch(job_state, task_info)
311
332
 
312
333
  async def _handle_run_blueprint(
@@ -334,7 +355,7 @@ class JobExecutor:
334
355
  "blueprint_name": sub_blueprint_info["blueprint_name"],
335
356
  "current_state": "start",
336
357
  "initial_data": sub_blueprint_info["initial_data"],
337
- "status": "pending",
358
+ "status": JOB_STATUS_PENDING,
338
359
  "parent_job_id": parent_job_id,
339
360
  }
340
361
  await self.storage.save_job_state(child_job_id, child_job_state)
@@ -367,7 +388,7 @@ class JobExecutor:
367
388
  branch_task_ids = [str(uuid4()) for _ in tasks_to_dispatch]
368
389
 
369
390
  # Update job state for parallel execution
370
- job_state["status"] = "waiting_for_parallel_tasks"
391
+ job_state["status"] = JOB_STATUS_WAITING_FOR_PARALLEL
371
392
  job_state["aggregation_target"] = aggregate_into
372
393
  job_state["active_branches"] = branch_task_ids
373
394
  job_state["aggregation_results"] = {}
@@ -445,7 +466,7 @@ class JobExecutor:
445
466
  logger.critical(
446
467
  f"Job {job_id} has failed handler execution {max_retries + 1} times. Moving to quarantine.",
447
468
  )
448
- job_state["status"] = "quarantined"
469
+ job_state["status"] = JOB_STATUS_QUARANTINED
449
470
  job_state["error_message"] = str(error)
450
471
  await self.storage.save_job_state(job_id, job_state)
451
472
  await self.storage.quarantine_job(job_id)
@@ -478,7 +499,7 @@ class JobExecutor:
478
499
  return
479
500
 
480
501
  # Determine the outcome of the child job to select the correct transition.
481
- child_outcome = "success" if child_job_state["current_state"] == "finished" else "failure"
502
+ child_outcome = "success" if child_job_state["current_state"] == JOB_STATUS_FINISHED else "failure"
482
503
  transitions = parent_job_state.get("current_task_transitions", {})
483
504
  next_state = transitions.get(child_outcome, "failed")
484
505
 
@@ -493,7 +514,7 @@ class JobExecutor:
493
514
 
494
515
  # Update the parent job to its new state and re-enqueue it.
495
516
  parent_job_state["current_state"] = next_state
496
- parent_job_state["status"] = "running"
517
+ parent_job_state["status"] = JOB_STATUS_RUNNING
497
518
  await self.storage.save_job_state(parent_job_id, parent_job_state)
498
519
  await self.storage.enqueue_job(parent_job_id)
499
520
 
@@ -522,7 +543,10 @@ class JobExecutor:
522
543
  # Wait for an available slot before fetching a new job
523
544
  await semaphore.acquire()
524
545
 
525
- result = await self.storage.dequeue_job()
546
+ # Block for a configured time waiting for a job
547
+ block_time = self.engine.config.REDIS_STREAM_BLOCK_MS
548
+ result = await self.storage.dequeue_job(block=block_time if block_time > 0 else None)
549
+
526
550
  if result:
527
551
  job_id, message_id = result
528
552
  task = create_task(self._process_job(job_id, message_id))
@@ -530,14 +554,18 @@ class JobExecutor:
530
554
  # Release the semaphore slot when the task is done
531
555
  task.add_done_callback(lambda _: semaphore.release())
532
556
  else:
533
- # No job found, release the slot and wait a bit
557
+ # Timeout reached, release slot and loop again
534
558
  semaphore.release()
535
- # Prevent busy loop if storage returns None immediately
536
- await sleep(0.1)
559
+ # Prevent busy loop if blocking is disabled (e.g. in tests) or failed
560
+ if block_time <= 0:
561
+ await sleep(0.1)
562
+
537
563
  except CancelledError:
538
564
  break
539
565
  except Exception:
540
566
  logger.exception("Error in JobExecutor main loop.")
567
+ # If an error occurred (e.g. Redis connection lost), sleep briefly to avoid log spam
568
+ semaphore.release()
541
569
  await sleep(1)
542
570
  logger.info("JobExecutor stopped.")
543
571
 
@@ -20,19 +20,37 @@ logger = getLogger(__name__)
20
20
 
21
21
 
22
22
  class HealthChecker:
23
- def __init__(self, engine: "OrchestratorEngine"):
23
+ def __init__(self, engine: "OrchestratorEngine", interval_seconds: int = 600):
24
+ self.engine = engine
25
+ self.storage = engine.storage
26
+ self.interval_seconds = interval_seconds
24
27
  self._running = False
28
+ from uuid import uuid4
29
+
30
+ self._instance_id = str(uuid4())
25
31
 
26
32
  async def run(self):
27
- logger.info("HealthChecker is now passive and will not perform active checks.")
33
+ logger.info(f"HealthChecker started (Active Index Cleanup, Instance ID: {self._instance_id}).")
28
34
  self._running = True
29
35
  while self._running:
30
36
  try:
31
- # Sleep for a long time, as this checker is passive.
32
- # The loop exists to allow for a clean shutdown.
33
- await sleep(3600)
37
+ # Use distributed lock to ensure only one instance cleans up
38
+ if await self.storage.acquire_lock(
39
+ "global_health_check_lock", self._instance_id, self.interval_seconds - 5
40
+ ):
41
+ try:
42
+ await self.storage.cleanup_expired_workers()
43
+ finally:
44
+ # We don't release the lock immediately to prevent other instances from
45
+ # running the same task if the interval is small.
46
+ pass
47
+
48
+ await sleep(self.interval_seconds)
34
49
  except CancelledError:
35
50
  break
51
+ except Exception:
52
+ logger.exception("Error in HealthChecker main loop.")
53
+ await sleep(60)
36
54
  logger.info("HealthChecker stopped.")
37
55
 
38
56
  def stop(self):
@@ -1,25 +1,79 @@
1
+ import asyncio
2
+ import contextlib
1
3
  from abc import ABC, abstractmethod
4
+ from logging import getLogger
2
5
  from typing import Any
3
6
 
7
+ logger = getLogger(__name__)
8
+
4
9
 
5
10
  class HistoryStorageBase(ABC):
6
11
  """Abstract base class for a history store.
7
- Defines the interface for logging job and worker events.
12
+ Implements buffered asynchronous logging to avoid blocking the main loop.
8
13
  """
9
14
 
15
+ def __init__(self):
16
+ self._queue: asyncio.Queue[tuple[str, dict[str, Any]]] = asyncio.Queue(maxsize=5000)
17
+ self._worker_task: asyncio.Task | None = None
18
+
19
+ async def start(self) -> None:
20
+ """Starts the background worker for writing logs."""
21
+ if not self._worker_task:
22
+ self._worker_task = asyncio.create_task(self._worker())
23
+ logger.info("HistoryStorage background worker started.")
24
+
25
+ async def close(self) -> None:
26
+ """Stops the background worker and closes resources."""
27
+ if self._worker_task:
28
+ self._worker_task.cancel()
29
+ with contextlib.suppress(asyncio.CancelledError):
30
+ await self._worker_task
31
+ self._worker_task = None
32
+ logger.info("HistoryStorage background worker stopped.")
33
+
10
34
  @abstractmethod
11
- async def initialize(self):
35
+ async def initialize(self) -> None:
12
36
  """Performs initialization, e.g., creating tables in the DB."""
13
37
  raise NotImplementedError
14
38
 
39
+ async def log_job_event(self, event_data: dict[str, Any]) -> None:
40
+ """Queues a job event for logging."""
41
+ try:
42
+ self._queue.put_nowait(("job", event_data))
43
+ except asyncio.QueueFull:
44
+ logger.warning("History queue full! Dropping job event.")
45
+
46
+ async def log_worker_event(self, event_data: dict[str, Any]) -> None:
47
+ """Queues a worker event for logging."""
48
+ try:
49
+ self._queue.put_nowait(("worker", event_data))
50
+ except asyncio.QueueFull:
51
+ logger.warning("History queue full! Dropping worker event.")
52
+
53
+ async def _worker(self) -> None:
54
+ while True:
55
+ try:
56
+ kind, data = await self._queue.get()
57
+ try:
58
+ if kind == "job":
59
+ await self._persist_job_event(data)
60
+ elif kind == "worker":
61
+ await self._persist_worker_event(data)
62
+ except Exception as e:
63
+ logger.error(f"Error persisting history event: {e}")
64
+ finally:
65
+ self._queue.task_done()
66
+ except asyncio.CancelledError:
67
+ break
68
+
15
69
  @abstractmethod
16
- async def log_job_event(self, event_data: dict[str, Any]):
17
- """Logs an event related to the job lifecycle."""
70
+ async def _persist_job_event(self, event_data: dict[str, Any]) -> None:
71
+ """Actual implementation of writing a job event to storage."""
18
72
  raise NotImplementedError
19
73
 
20
74
  @abstractmethod
21
- async def log_worker_event(self, event_data: dict[str, Any]):
22
- """Logs an event related to the worker lifecycle."""
75
+ async def _persist_worker_event(self, event_data: dict[str, Any]) -> None:
76
+ """Actual implementation of writing a worker event to storage."""
23
77
  raise NotImplementedError
24
78
 
25
79
  @abstractmethod
@@ -8,20 +8,31 @@ class NoOpHistoryStorage(HistoryStorageBase):
8
8
  Used when history storage is not configured.
9
9
  """
10
10
 
11
- async def initialize(self):
12
- # Do nothing
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ async def start(self) -> None:
15
+ pass
16
+
17
+ async def close(self) -> None:
18
+ pass
19
+
20
+ async def initialize(self) -> None:
21
+ pass
22
+
23
+ async def log_job_event(self, event_data: dict[str, Any]) -> None:
24
+ pass
25
+
26
+ async def log_worker_event(self, event_data: dict[str, Any]) -> None:
13
27
  pass
14
28
 
15
- async def log_job_event(self, event_data: dict[str, Any]):
16
- # Do nothing
29
+ async def _persist_job_event(self, event_data: dict[str, Any]) -> None:
17
30
  pass
18
31
 
19
- async def log_worker_event(self, event_data: dict[str, Any]):
20
- # Do nothing
32
+ async def _persist_worker_event(self, event_data: dict[str, Any]) -> None:
21
33
  pass
22
34
 
23
35
  async def get_job_history(self, job_id: str) -> list[dict[str, Any]]:
24
- # Always return an empty list
25
36
  return []
26
37
 
27
38
  async def get_jobs(self, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
@@ -46,19 +46,20 @@ class PostgresHistoryStorage(HistoryStorageBase, ABC):
46
46
  """Implementation of the history store based on asyncpg for PostgreSQL."""
47
47
 
48
48
  def __init__(self, dsn: str, tz_name: str = "UTC"):
49
+ super().__init__()
49
50
  self._dsn = dsn
50
51
  self._pool: Pool | None = None
51
52
  self.tz_name = tz_name
52
53
  self.tz = ZoneInfo(tz_name)
53
54
 
54
- async def _setup_connection(self, conn: Connection):
55
+ async def _setup_connection(self, conn: Connection) -> None:
55
56
  """Configures the connection session with the correct timezone."""
56
57
  try:
57
58
  await conn.execute(f"SET TIME ZONE '{self.tz_name}'")
58
59
  except PostgresError as e:
59
60
  logger.error(f"Failed to set timezone '{self.tz_name}' for PG connection: {e}")
60
61
 
61
- async def initialize(self):
62
+ async def initialize(self) -> None:
62
63
  """Initializes the connection pool to PostgreSQL and creates tables."""
63
64
  try:
64
65
  # We use init parameter to configure each new connection in the pool
@@ -75,13 +76,14 @@ class PostgresHistoryStorage(HistoryStorageBase, ABC):
75
76
  logger.error(f"Failed to initialize PostgreSQL history storage: {e}")
76
77
  raise
77
78
 
78
- async def close(self):
79
- """Closes the connection pool."""
79
+ async def close(self) -> None:
80
+ """Closes the connection pool and background worker."""
81
+ await super().close()
80
82
  if self._pool:
81
83
  await self._pool.close()
82
84
  logger.info("PostgreSQL history storage connection pool closed.")
83
85
 
84
- async def log_job_event(self, event_data: dict[str, Any]):
86
+ async def _persist_job_event(self, event_data: dict[str, Any]) -> None:
85
87
  """Logs a job lifecycle event to PostgreSQL."""
86
88
  if not self._pool:
87
89
  raise RuntimeError("History storage is not initialized.")
@@ -117,7 +119,7 @@ class PostgresHistoryStorage(HistoryStorageBase, ABC):
117
119
  except PostgresError as e:
118
120
  logger.error(f"Failed to log job event to PostgreSQL: {e}")
119
121
 
120
- async def log_worker_event(self, event_data: dict[str, Any]):
122
+ async def _persist_worker_event(self, event_data: dict[str, Any]) -> None:
121
123
  """Logs a worker lifecycle event to PostgreSQL."""
122
124
  if not self._pool:
123
125
  raise RuntimeError("History storage is not initialized.")