avtomatika 1.0b6__py3-none-any.whl → 1.0b7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
avtomatika/app_keys.py ADDED
@@ -0,0 +1,32 @@
1
+ from asyncio import Task
2
+ from typing import TYPE_CHECKING
3
+
4
+ from aiohttp import ClientSession
5
+ from aiohttp.web import AppKey
6
+
7
+ if TYPE_CHECKING:
8
+ pass
9
+
10
+ # Application keys for storing components
11
+ # Using strings for types where possible to avoid circular imports during runtime,
12
+ # or specific imports where safe.
13
+
14
+ # Main Engine
15
+ ENGINE_KEY = AppKey("engine", "OrchestratorEngine")
16
+ HTTP_SESSION_KEY = AppKey("http_session", ClientSession)
17
+
18
+ # Core Components
19
+ DISPATCHER_KEY = AppKey("dispatcher", "Dispatcher")
20
+ EXECUTOR_KEY = AppKey("executor", "JobExecutor")
21
+ WATCHER_KEY = AppKey("watcher", "Watcher")
22
+ REPUTATION_CALCULATOR_KEY = AppKey("reputation_calculator", "ReputationCalculator")
23
+ HEALTH_CHECKER_KEY = AppKey("health_checker", "HealthChecker")
24
+ SCHEDULER_KEY = AppKey("scheduler", "Scheduler")
25
+ WS_MANAGER_KEY = AppKey("ws_manager", "WebSocketManager")
26
+
27
+ # Background Tasks
28
+ EXECUTOR_TASK_KEY = AppKey("executor_task", Task)
29
+ WATCHER_TASK_KEY = AppKey("watcher_task", Task)
30
+ REPUTATION_CALCULATOR_TASK_KEY = AppKey("reputation_calculator_task", Task)
31
+ HEALTH_CHECKER_TASK_KEY = AppKey("health_checker_task", Task)
32
+ SCHEDULER_TASK_KEY = AppKey("scheduler_task", Task)
avtomatika/blueprint.py CHANGED
@@ -52,7 +52,7 @@ def _parse_condition(condition_str: str) -> Condition:
52
52
 
53
53
 
54
54
  class ConditionalHandler:
55
- def __init__(self, blueprint, state: str, func: Callable, condition_str: str):
55
+ def __init__(self, blueprint: "StateMachineBlueprint", state: str, func: Callable, condition_str: str):
56
56
  self.blueprint = blueprint
57
57
  self.state = state
58
58
  self.func = func
@@ -115,7 +115,7 @@ class StateMachineBlueprint:
115
115
  name: str,
116
116
  api_endpoint: str | None = None,
117
117
  api_version: str | None = None,
118
- data_stores: Any = None,
118
+ data_stores: dict[str, Any] | None = None,
119
119
  ):
120
120
  """Initializes a new blueprint.
121
121
 
@@ -136,8 +136,9 @@ class StateMachineBlueprint:
136
136
  self.conditional_handlers: list[ConditionalHandler] = []
137
137
  self.start_state: str | None = None
138
138
  self.end_states: set[str] = set()
139
+ self._handler_params: dict[Callable, tuple[str, ...]] = {}
139
140
 
140
- def add_data_store(self, name: str, initial_data: dict[str, Any]):
141
+ def add_data_store(self, name: str, initial_data: dict[str, Any]) -> None:
141
142
  """Adds a named data store to the blueprint."""
142
143
  if name in self.data_stores:
143
144
  raise ValueError(f"Data store with name '{name}' already exists.")
@@ -157,10 +158,116 @@ class StateMachineBlueprint:
157
158
 
158
159
  return decorator
159
160
 
160
- def validate(self):
161
+ def validate(self) -> None:
161
162
  """Validates that the blueprint is configured correctly."""
162
163
  if self.start_state is None:
163
164
  raise ValueError(f"Blueprint '{self.name}' must have exactly one start state.")
165
+ self._analyze_handlers()
166
+ self.validate_integrity()
167
+
168
+ def validate_integrity(self) -> None:
169
+ """Checks for dangling transitions and unreachable states."""
170
+ transitions = self._get_all_transitions()
171
+ defined_states = (
172
+ set(self.handlers.keys())
173
+ | set(self.aggregator_handlers.keys())
174
+ | {ch.state for ch in self.conditional_handlers}
175
+ )
176
+
177
+ # 1. Check for dangling transitions
178
+ for source_state, targets in transitions.items():
179
+ for target_state in targets:
180
+ if target_state not in defined_states:
181
+ raise ValueError(
182
+ f"Blueprint '{self.name}' has a dangling transition: "
183
+ f"state '{source_state}' leads to non-existent state '{target_state}'."
184
+ )
185
+
186
+ # 2. Check for unreachable states
187
+ if self.start_state:
188
+ reachable = {self.start_state}
189
+ stack = [self.start_state]
190
+ while stack:
191
+ current = stack.pop()
192
+ for target in transitions.get(current, set()):
193
+ if target not in reachable:
194
+ reachable.add(target)
195
+ stack.append(target)
196
+
197
+ unreachable = defined_states - reachable
198
+ if unreachable:
199
+ raise ValueError(
200
+ f"Blueprint '{self.name}' has unreachable states: {', '.join(unreachable)}. "
201
+ "All states must be reachable from the start state."
202
+ )
203
+
204
+ def _get_all_transitions(self) -> dict[str, set[str]]:
205
+ """Parses handler source code to find all possible transitions."""
206
+ import ast
207
+ import inspect
208
+ import logging
209
+ import textwrap
210
+
211
+ logger = logging.getLogger(__name__)
212
+ transitions: dict[str, set[str]] = {}
213
+
214
+ all_handlers = (
215
+ list(self.handlers.items())
216
+ + list(self.aggregator_handlers.items())
217
+ + [(ch.state, ch.func) for ch in self.conditional_handlers]
218
+ )
219
+
220
+ for state, func in all_handlers:
221
+ if state not in transitions:
222
+ transitions[state] = set()
223
+ try:
224
+ source = textwrap.dedent(inspect.getsource(func))
225
+ tree = ast.parse(source)
226
+ for node in ast.walk(tree):
227
+ if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute)):
228
+ continue
229
+
230
+ # Handle actions.transition_to("state")
231
+ if node.func.attr == "transition_to" and node.args and isinstance(node.args[0], ast.Constant):
232
+ transitions[state].add(str(node.args[0].value))
233
+
234
+ # Handle actions.dispatch_task(..., transitions={"status": "state"})
235
+ # Also handles await_human_approval, run_blueprint which use the same 'transitions' kwarg
236
+ elif node.func.attr in ("dispatch_task", "await_human_approval", "run_blueprint"):
237
+ for keyword in node.keywords:
238
+ if keyword.arg == "transitions" and isinstance(keyword.value, ast.Dict):
239
+ for value_node in keyword.value.values:
240
+ if isinstance(value_node, ast.Constant):
241
+ transitions[state].add(str(value_node.value))
242
+
243
+ # Handle actions.dispatch_parallel(..., aggregate_into="state")
244
+ elif node.func.attr == "dispatch_parallel":
245
+ for keyword in node.keywords:
246
+ if keyword.arg == "aggregate_into" and isinstance(keyword.value, ast.Constant):
247
+ transitions[state].add(str(keyword.value.value))
248
+
249
+ except (TypeError, OSError, SyntaxError) as e:
250
+ logger.warning(f"Could not parse handler for state '{state}': {e}")
251
+
252
+ return transitions
253
+
254
+ def _analyze_handlers(self) -> None:
255
+ """Analyzes and caches parameters for all registered handlers."""
256
+ import inspect
257
+
258
+ all_funcs = (
259
+ list(self.handlers.values())
260
+ + list(self.aggregator_handlers.values())
261
+ + [ch.func for ch in self.conditional_handlers]
262
+ )
263
+
264
+ for func in all_funcs:
265
+ sig = inspect.signature(func)
266
+ self._handler_params[func] = tuple(sig.parameters.keys())
267
+
268
+ def get_handler_params(self, func: Callable) -> tuple[str, ...]:
269
+ """Returns the cached parameters for a handler function."""
270
+ return self._handler_params.get(func, ())
164
271
 
165
272
  def find_handler(self, state: str, context: Any) -> Callable:
166
273
  for handler in self.conditional_handlers:
@@ -173,60 +280,24 @@ class StateMachineBlueprint:
173
280
  )
174
281
 
175
282
  def render_graph(self, output_filename: str | None = None, output_format: str = "png"):
176
- import ast
177
- import inspect
178
- import logging
179
- import textwrap
180
-
181
283
  from graphviz import Digraph # type: ignore[import]
182
284
 
183
- logger = logging.getLogger(__name__)
184
-
185
285
  dot = Digraph(comment=f"State Machine for {self.name}")
186
286
  dot.attr("node", shape="box", style="rounded")
187
- all_handlers = list(self.handlers.items()) + [(ch.state, ch.func) for ch in self.conditional_handlers]
188
- states = set(self.handlers.keys())
189
- for handler_state, handler_func in all_handlers:
190
- try:
191
- source = textwrap.dedent(inspect.getsource(handler_func))
192
- tree = ast.parse(source)
193
- for node in ast.walk(tree):
194
- if isinstance(node, ast.Call) and isinstance(
195
- node.func,
196
- ast.Attribute,
197
- ):
198
- if node.func.attr == "transition_to" and node.args and isinstance(node.args[0], ast.Constant):
199
- target_state = str(node.args[0].value)
200
- states.add(target_state)
201
- dot.edge(handler_state, target_state, label="transition")
202
- elif node.func.attr == "dispatch_task":
203
- for keyword in node.keywords:
204
- if keyword.arg == "transitions" and isinstance(
205
- keyword.value,
206
- ast.Dict,
207
- ):
208
- for key_node, value_node in zip(
209
- keyword.value.keys,
210
- keyword.value.values,
211
- strict=False,
212
- ):
213
- if isinstance(
214
- key_node,
215
- ast.Constant,
216
- ) and isinstance(value_node, ast.Constant):
217
- key = str(key_node.value)
218
- value = str(value_node.value)
219
- states.add(value)
220
- dot.edge(
221
- handler_state,
222
- value,
223
- label=f"on {key}",
224
- )
225
- except (TypeError, OSError) as e:
226
- logger.warning(
227
- f"Could not parse handler '{handler_func.__name__}' for state '{handler_state}'. "
228
- f"Graph may be incomplete. Error: {e}"
229
- )
287
+
288
+ transitions = self._get_all_transitions()
289
+ defined_states = (
290
+ set(self.handlers.keys())
291
+ | set(self.aggregator_handlers.keys())
292
+ | {ch.state for ch in self.conditional_handlers}
293
+ )
294
+ states = defined_states.copy()
295
+
296
+ for source, targets in transitions.items():
297
+ for target in targets:
298
+ states.add(target)
299
+ dot.edge(source, target)
300
+
230
301
  for state in states:
231
302
  dot.node(state, state)
232
303
 
avtomatika/context.py CHANGED
@@ -11,7 +11,7 @@ class ActionFactory:
11
11
  self._sub_blueprint_to_run_val: dict[str, Any] | None = None
12
12
  self._parallel_tasks_to_dispatch_val: dict[str, Any] | None = None
13
13
 
14
- def _check_for_existing_action(self):
14
+ def _check_for_existing_action(self) -> None:
15
15
  """
16
16
  Helper to ensure only one action is set.
17
17
  Raises RuntimeError if any action value is already set.
@@ -45,7 +45,7 @@ class ActionFactory:
45
45
  def parallel_tasks_to_dispatch(self) -> dict[str, Any] | None:
46
46
  return self._parallel_tasks_to_dispatch_val
47
47
 
48
- def dispatch_parallel(self, tasks: dict[str, Any] | None, aggregate_into: str) -> None:
48
+ def dispatch_parallel(self, tasks: list[dict[str, Any]], aggregate_into: str) -> None:
49
49
  """
50
50
  Dispatches multiple tasks for parallel execution.
51
51
  """
avtomatika/data_types.py CHANGED
@@ -21,9 +21,10 @@ class JobContext(NamedTuple):
21
21
  state_history: dict[str, Any]
22
22
  client: ClientConfig
23
23
  actions: "ActionFactory"
24
- data_stores: Any = None
25
- tracing_context: dict[str, Any] = {}
24
+ data_stores: dict[str, Any] | None = None
25
+ tracing_context: dict[str, Any] | None = None
26
26
  aggregation_results: dict[str, Any] | None = None
27
+ webhook_url: str | None = None
27
28
 
28
29
 
29
30
  class GPUInfo(NamedTuple):
avtomatika/dispatcher.py CHANGED
@@ -128,7 +128,7 @@ class Dispatcher:
128
128
  """Selects the worker with the best price-quality (reputation) ratio."""
129
129
  return min(workers, key=self._get_best_value_score)
130
130
 
131
- async def dispatch(self, job_state: dict[str, Any], task_info: dict[str, Any]):
131
+ async def dispatch(self, job_state: dict[str, Any], task_info: dict[str, Any]) -> None:
132
132
  job_id = job_state["id"]
133
133
  task_type = task_info.get("type")
134
134
  if not task_type: