avtomatika 1.0b4__py3-none-any.whl → 1.0b6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
avtomatika/executor.py CHANGED
@@ -3,7 +3,7 @@ from inspect import signature
3
3
  from logging import getLogger
4
4
  from time import monotonic
5
5
  from types import SimpleNamespace
6
- from typing import TYPE_CHECKING, Any, Dict
6
+ from typing import TYPE_CHECKING, Any
7
7
  from uuid import uuid4
8
8
 
9
9
  # Conditional import for OpenTelemetry
@@ -73,158 +73,163 @@ class JobExecutor:
73
73
  self.history_storage = history_storage
74
74
  self.dispatcher = engine.dispatcher
75
75
  self._running = False
76
+ self._processing_messages: set[str] = set()
76
77
 
77
- async def _process_job(self, job_id: str):
78
- """The core logic for processing a single job dequeued from storage.
79
-
80
- This function orchestrates finding the correct blueprint and handler,
81
- executing the handler, and then processing the action (e.g., transition,
82
- dispatch) that the handler requested.
83
- """
84
- start_time = monotonic()
85
- job_state = await self.storage.get_job_state(job_id)
86
- if not job_state:
87
- logger.error(f"Job {job_id} not found in storage, cannot process.")
78
+ async def _process_job(self, job_id: str, message_id: str):
79
+ """The core logic for processing a single job dequeued from storage."""
80
+ if message_id in self._processing_messages:
88
81
  return
89
82
 
90
- if job_state.get("status") in TERMINAL_STATES:
91
- logger.warning(f"Job {job_id} is already in a terminal state '{job_state['status']}', skipping.")
92
- return
93
-
94
- # Ensure retry_count is initialized.
95
- if "retry_count" not in job_state:
96
- job_state["retry_count"] = 0
97
-
98
- await self.history_storage.log_job_event(
99
- {
100
- "job_id": job_id,
101
- "state": job_state.get("current_state"),
102
- "event_type": "state_started",
103
- "attempt_number": job_state.get("retry_count", 0) + 1,
104
- "context_snapshot": job_state,
105
- },
106
- )
107
-
108
- # Set up distributed tracing context.
109
- parent_context = TraceContextTextMapPropagator().extract(
110
- carrier=job_state.get("tracing_context", {}),
111
- )
83
+ self._processing_messages.add(message_id)
84
+ try:
85
+ start_time = monotonic()
86
+ job_state = await self.storage.get_job_state(job_id)
87
+ if not job_state:
88
+ logger.error(f"Job {job_id} not found in storage, cannot process.")
89
+ return
112
90
 
113
- with tracer.start_as_current_span(
114
- f"JobExecutor:{job_state['blueprint_name']}:{job_state['current_state']}",
115
- context=parent_context,
116
- ) as span:
117
- span.set_attribute("job.id", job_id)
118
- span.set_attribute("job.current_state", job_state["current_state"])
119
-
120
- # Inject the current tracing context back into the job state for propagation.
121
- tracing_context: Dict[str, str] = {}
122
- inject(tracing_context)
123
- job_state["tracing_context"] = tracing_context
124
-
125
- blueprint = self.engine.blueprints.get(job_state["blueprint_name"])
126
- if not blueprint:
127
- # This is a critical, non-retriable error.
128
- duration_ms = int((monotonic() - start_time) * 1000)
129
- await self._handle_failure(
130
- job_state,
131
- RuntimeError(
132
- f"Blueprint '{job_state['blueprint_name']}' not found",
133
- ),
134
- duration_ms,
135
- )
91
+ if job_state.get("status") in TERMINAL_STATES:
92
+ logger.warning(f"Job {job_id} is already in a terminal state '{job_state['status']}', skipping.")
136
93
  return
137
94
 
138
- # Prepare the context and action factory for the handler.
139
- action_factory = ActionFactory(job_id)
140
- client_config_dict = job_state.get("client_config", {})
141
- client_config = ClientConfig(
142
- token=client_config_dict.get("token", ""),
143
- plan=client_config_dict.get("plan", "unknown"),
144
- params=client_config_dict.get("params", {}),
145
- )
146
- context = JobContext(
147
- job_id=job_id,
148
- current_state=job_state["current_state"],
149
- initial_data=job_state["initial_data"],
150
- state_history=job_state.get("state_history", {}),
151
- client=client_config,
152
- actions=action_factory,
153
- data_stores=SimpleNamespace(**blueprint.data_stores),
154
- tracing_context=tracing_context,
155
- aggregation_results=job_state.get("aggregation_results"),
95
+ # Ensure retry_count is initialized.
96
+ if "retry_count" not in job_state:
97
+ job_state["retry_count"] = 0
98
+
99
+ await self.history_storage.log_job_event(
100
+ {
101
+ "job_id": job_id,
102
+ "state": job_state.get("current_state"),
103
+ "event_type": "state_started",
104
+ "attempt_number": job_state.get("retry_count", 0) + 1,
105
+ "context_snapshot": job_state,
106
+ },
156
107
  )
157
108
 
158
- try:
159
- # Find and execute the appropriate handler for the current state.
160
- # It's important to check for aggregator handlers first for states
161
- # that are targets of parallel execution.
162
- is_aggregator_state = job_state.get("aggregation_target") == job_state.get("current_state")
163
- if is_aggregator_state and job_state.get("current_state") in blueprint.aggregator_handlers:
164
- handler = blueprint.aggregator_handlers[job_state["current_state"]]
165
- else:
166
- handler = blueprint.find_handler(context.current_state, context)
167
-
168
- # Build arguments for the handler dynamically.
169
- handler_signature = signature(handler)
170
- params_to_inject = {}
109
+ # Set up distributed tracing context.
110
+ parent_context = TraceContextTextMapPropagator().extract(
111
+ carrier=job_state.get("tracing_context", {}),
112
+ )
171
113
 
172
- if "context" in handler_signature.parameters:
173
- params_to_inject["context"] = context
174
- if "actions" in handler_signature.parameters:
175
- params_to_inject["actions"] = action_factory
176
- else:
177
- # New injection logic with prioritized lookup.
178
- context_as_dict = context._asdict()
179
- for param_name in handler_signature.parameters:
180
- # Look in JobContext fields first.
181
- if param_name in context_as_dict:
182
- params_to_inject[param_name] = context_as_dict[param_name]
183
- # Then look in state_history (data from previous steps/workers).
184
- elif param_name in context.state_history:
185
- params_to_inject[param_name] = context.state_history[param_name]
186
- # Finally, look in the initial data the job was created with.
187
- elif param_name in context.initial_data:
188
- params_to_inject[param_name] = context.initial_data[param_name]
189
-
190
- await handler(**params_to_inject)
191
-
192
- duration_ms = int((monotonic() - start_time) * 1000)
193
-
194
- # Process the single action requested by the handler.
195
- if action_factory.next_state:
196
- await self._handle_transition(
197
- job_state,
198
- action_factory.next_state,
199
- duration_ms,
200
- )
201
- elif action_factory.task_to_dispatch:
202
- await self._handle_dispatch(
203
- job_state,
204
- action_factory.task_to_dispatch,
205
- duration_ms,
206
- )
207
- elif action_factory.parallel_tasks_to_dispatch:
208
- await self._handle_parallel_dispatch(
114
+ with tracer.start_as_current_span(
115
+ f"JobExecutor:{job_state['blueprint_name']}:{job_state['current_state']}",
116
+ context=parent_context,
117
+ ) as span:
118
+ span.set_attribute("job.id", job_id)
119
+ span.set_attribute("job.current_state", job_state["current_state"])
120
+
121
+ # Inject the current tracing context back into the job state for propagation.
122
+ tracing_context: dict[str, str] = {}
123
+ inject(tracing_context)
124
+ job_state["tracing_context"] = tracing_context
125
+
126
+ blueprint = self.engine.blueprints.get(job_state["blueprint_name"])
127
+ if not blueprint:
128
+ # This is a critical, non-retriable error.
129
+ duration_ms = int((monotonic() - start_time) * 1000)
130
+ await self._handle_failure(
209
131
  job_state,
210
- action_factory.parallel_tasks_to_dispatch,
211
- duration_ms,
212
- )
213
- elif action_factory.sub_blueprint_to_run:
214
- await self._handle_run_blueprint(
215
- job_state,
216
- action_factory.sub_blueprint_to_run,
132
+ RuntimeError(
133
+ f"Blueprint '{job_state['blueprint_name']}' not found",
134
+ ),
217
135
  duration_ms,
218
136
  )
137
+ return
138
+
139
+ # Prepare the context and action factory for the handler.
140
+ action_factory = ActionFactory(job_id)
141
+ client_config_dict = job_state.get("client_config", {})
142
+ client_config = ClientConfig(
143
+ token=client_config_dict.get("token", ""),
144
+ plan=client_config_dict.get("plan", "unknown"),
145
+ params=client_config_dict.get("params", {}),
146
+ )
147
+ context = JobContext(
148
+ job_id=job_id,
149
+ current_state=job_state["current_state"],
150
+ initial_data=job_state["initial_data"],
151
+ state_history=job_state.get("state_history", {}),
152
+ client=client_config,
153
+ actions=action_factory,
154
+ data_stores=SimpleNamespace(**blueprint.data_stores),
155
+ tracing_context=tracing_context,
156
+ aggregation_results=job_state.get("aggregation_results"),
157
+ )
219
158
 
220
- except Exception as e:
221
- # This catches errors within the handler's execution.
222
- duration_ms = int((monotonic() - start_time) * 1000)
223
- await self._handle_failure(job_state, e, duration_ms)
159
+ try:
160
+ # Find and execute the appropriate handler for the current state.
161
+ # It's important to check for aggregator handlers first for states
162
+ # that are targets of parallel execution.
163
+ is_aggregator_state = job_state.get("aggregation_target") == job_state.get("current_state")
164
+ if is_aggregator_state and job_state.get("current_state") in blueprint.aggregator_handlers:
165
+ handler = blueprint.aggregator_handlers[job_state["current_state"]]
166
+ else:
167
+ handler = blueprint.find_handler(context.current_state, context)
168
+
169
+ # Build arguments for the handler dynamically.
170
+ handler_signature = signature(handler)
171
+ params_to_inject = {}
172
+
173
+ if "context" in handler_signature.parameters:
174
+ params_to_inject["context"] = context
175
+ if "actions" in handler_signature.parameters:
176
+ params_to_inject["actions"] = action_factory
177
+ else:
178
+ # New injection logic with prioritized lookup.
179
+ context_as_dict = context._asdict()
180
+ for param_name in handler_signature.parameters:
181
+ # Look in JobContext fields first.
182
+ if param_name in context_as_dict:
183
+ params_to_inject[param_name] = context_as_dict[param_name]
184
+ # Then look in state_history (data from previous steps/workers).
185
+ elif param_name in context.state_history:
186
+ params_to_inject[param_name] = context.state_history[param_name]
187
+ # Finally, look in the initial data the job was created with.
188
+ elif param_name in context.initial_data:
189
+ params_to_inject[param_name] = context.initial_data[param_name]
190
+
191
+ await handler(**params_to_inject)
192
+
193
+ duration_ms = int((monotonic() - start_time) * 1000)
194
+
195
+ # Process the single action requested by the handler.
196
+ if action_factory.next_state:
197
+ await self._handle_transition(
198
+ job_state,
199
+ action_factory.next_state,
200
+ duration_ms,
201
+ )
202
+ elif action_factory.task_to_dispatch:
203
+ await self._handle_dispatch(
204
+ job_state,
205
+ action_factory.task_to_dispatch,
206
+ duration_ms,
207
+ )
208
+ elif action_factory.parallel_tasks_to_dispatch:
209
+ await self._handle_parallel_dispatch(
210
+ job_state,
211
+ action_factory.parallel_tasks_to_dispatch,
212
+ duration_ms,
213
+ )
214
+ elif action_factory.sub_blueprint_to_run:
215
+ await self._handle_run_blueprint(
216
+ job_state,
217
+ action_factory.sub_blueprint_to_run,
218
+ duration_ms,
219
+ )
220
+
221
+ except Exception as e:
222
+ # This catches errors within the handler's execution.
223
+ duration_ms = int((monotonic() - start_time) * 1000)
224
+ await self._handle_failure(job_state, e, duration_ms)
225
+ finally:
226
+ await self.storage.ack_job(message_id)
227
+ if message_id in self._processing_messages:
228
+ self._processing_messages.remove(message_id)
224
229
 
225
230
  async def _handle_transition(
226
231
  self,
227
- job_state: Dict[str, Any],
232
+ job_state: dict[str, Any],
228
233
  next_state: str,
229
234
  duration_ms: int,
230
235
  ):
@@ -258,8 +263,8 @@ class JobExecutor:
258
263
 
259
264
  async def _handle_dispatch(
260
265
  self,
261
- job_state: Dict[str, Any],
262
- task_info: Dict[str, Any],
266
+ job_state: dict[str, Any],
267
+ task_info: dict[str, Any],
263
268
  duration_ms: int,
264
269
  ):
265
270
  job_id = job_state["id"]
@@ -302,8 +307,8 @@ class JobExecutor:
302
307
 
303
308
  async def _handle_run_blueprint(
304
309
  self,
305
- parent_job_state: Dict[str, Any],
306
- sub_blueprint_info: Dict[str, Any],
310
+ parent_job_state: dict[str, Any],
311
+ sub_blueprint_info: dict[str, Any],
307
312
  duration_ms: int,
308
313
  ):
309
314
  parent_job_id = parent_job_state["id"]
@@ -342,8 +347,8 @@ class JobExecutor:
342
347
 
343
348
  async def _handle_parallel_dispatch(
344
349
  self,
345
- job_state: Dict[str, Any],
346
- parallel_info: Dict[str, Any],
350
+ job_state: dict[str, Any],
351
+ parallel_info: dict[str, Any],
347
352
  duration_ms: int,
348
353
  ):
349
354
  job_id = job_state["id"]
@@ -390,7 +395,7 @@ class JobExecutor:
390
395
 
391
396
  async def _handle_failure(
392
397
  self,
393
- job_state: Dict[str, Any],
398
+ job_state: dict[str, Any],
394
399
  error: Exception,
395
400
  duration_ms: int,
396
401
  ):
@@ -448,7 +453,7 @@ class JobExecutor:
448
453
  {metrics.LABEL_BLUEPRINT: job_state.get("blueprint_name", "unknown")},
449
454
  )
450
455
 
451
- async def _check_and_resume_parent(self, child_job_state: Dict[str, Any]):
456
+ async def _check_and_resume_parent(self, child_job_state: dict[str, Any]):
452
457
  """Checks if a completed job was a sub-job. If so, it resumes the parent
453
458
  job, passing the success/failure outcome of the child.
454
459
  """
@@ -501,14 +506,29 @@ class JobExecutor:
501
506
  logger.exception("Unhandled exception in job processing task")
502
507
 
503
508
  async def run(self):
509
+ import asyncio
510
+
504
511
  logger.info("JobExecutor started.")
505
512
  self._running = True
513
+ semaphore = asyncio.Semaphore(self.engine.config.EXECUTOR_MAX_CONCURRENT_JOBS)
514
+
506
515
  while self._running:
507
516
  try:
508
- job_id = await self.storage.dequeue_job()
509
- if job_id:
510
- task = create_task(self._process_job(job_id))
517
+ # Wait for an available slot before fetching a new job
518
+ await semaphore.acquire()
519
+
520
+ result = await self.storage.dequeue_job()
521
+ if result:
522
+ job_id, message_id = result
523
+ task = create_task(self._process_job(job_id, message_id))
511
524
  task.add_done_callback(self._handle_task_completion)
525
+ # Release the semaphore slot when the task is done
526
+ task.add_done_callback(lambda _: semaphore.release())
527
+ else:
528
+ # No job found, release the slot and wait a bit
529
+ semaphore.release()
530
+ # Prevent busy loop if storage returns None immediately
531
+ await sleep(0.1)
512
532
  except CancelledError:
513
533
  break
514
534
  except Exception:
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Any, Dict, List
2
+ from typing import Any
3
3
 
4
4
 
5
5
  class HistoryStorageBase(ABC):
@@ -13,29 +13,29 @@ class HistoryStorageBase(ABC):
13
13
  raise NotImplementedError
14
14
 
15
15
  @abstractmethod
16
- async def log_job_event(self, event_data: Dict[str, Any]):
16
+ async def log_job_event(self, event_data: dict[str, Any]):
17
17
  """Logs an event related to the job lifecycle."""
18
18
  raise NotImplementedError
19
19
 
20
20
  @abstractmethod
21
- async def log_worker_event(self, event_data: Dict[str, Any]):
21
+ async def log_worker_event(self, event_data: dict[str, Any]):
22
22
  """Logs an event related to the worker lifecycle."""
23
23
  raise NotImplementedError
24
24
 
25
25
  @abstractmethod
26
- async def get_job_history(self, job_id: str) -> List[Dict[str, Any]]:
26
+ async def get_job_history(self, job_id: str) -> list[dict[str, Any]]:
27
27
  """Gets the full history for the specified job."""
28
28
  raise NotImplementedError
29
29
 
30
30
  @abstractmethod
31
- async def get_jobs(self, limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
31
+ async def get_jobs(self, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
32
32
  """Gets a paginated list of recent jobs.
33
33
  Primarily returns the last event for each job.
34
34
  """
35
35
  raise NotImplementedError
36
36
 
37
37
  @abstractmethod
38
- async def get_job_summary(self) -> Dict[str, int]:
38
+ async def get_job_summary(self) -> dict[str, int]:
39
39
  """Returns a summary of job statuses.
40
40
  Example: {'running': 10, 'completed': 50, 'failed': 5}
41
41
  """
@@ -46,6 +46,6 @@ class HistoryStorageBase(ABC):
46
46
  self,
47
47
  worker_id: str,
48
48
  since_days: int,
49
- ) -> List[Dict[str, Any]]:
49
+ ) -> list[dict[str, Any]]:
50
50
  """Gets the event history for a specific worker for the last N days."""
51
51
  raise NotImplementedError
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List
1
+ from typing import Any
2
2
 
3
3
  from .base import HistoryStorageBase
4
4
 
@@ -12,27 +12,27 @@ class NoOpHistoryStorage(HistoryStorageBase):
12
12
  # Do nothing
13
13
  pass
14
14
 
15
- async def log_job_event(self, event_data: Dict[str, Any]):
15
+ async def log_job_event(self, event_data: dict[str, Any]):
16
16
  # Do nothing
17
17
  pass
18
18
 
19
- async def log_worker_event(self, event_data: Dict[str, Any]):
19
+ async def log_worker_event(self, event_data: dict[str, Any]):
20
20
  # Do nothing
21
21
  pass
22
22
 
23
- async def get_job_history(self, job_id: str) -> List[Dict[str, Any]]:
23
+ async def get_job_history(self, job_id: str) -> list[dict[str, Any]]:
24
24
  # Always return an empty list
25
25
  return []
26
26
 
27
- async def get_jobs(self, limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
27
+ async def get_jobs(self, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
28
28
  return []
29
29
 
30
- async def get_job_summary(self) -> Dict[str, int]:
30
+ async def get_job_summary(self) -> dict[str, int]:
31
31
  return {}
32
32
 
33
33
  async def get_worker_history(
34
34
  self,
35
35
  worker_id: str,
36
36
  since_days: int,
37
- ) -> List[Dict[str, Any]]:
37
+ ) -> list[dict[str, Any]]:
38
38
  return []