horsies 0.1.0a4__py3-none-any.whl → 0.1.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -236,7 +236,9 @@ class WorkflowTaskModel(Base):
236
236
 
237
237
  # Unique constraint: one task per index per workflow
238
238
  __table_args__ = (
239
- UniqueConstraint('workflow_id', 'task_index', name='uq_horsies_workflow_task_index'),
239
+ UniqueConstraint(
240
+ 'workflow_id', 'task_index', name='uq_horsies_workflow_task_index'
241
+ ),
240
242
  )
241
243
 
242
244
 
@@ -21,6 +21,10 @@ from horsies.core.worker.worker import import_by_path
21
21
 
22
22
  logger = get_logger('scheduler')
23
23
 
24
+ SCHEDULE_ADVISORY_LOCK_SQL = text(
25
+ """SELECT pg_advisory_xact_lock(CAST(:key AS BIGINT))"""
26
+ )
27
+
24
28
 
25
29
  class Scheduler:
26
30
  """
@@ -257,7 +261,7 @@ class Scheduler:
257
261
  async with self.broker.session_factory() as session:
258
262
  # Acquire transaction-scoped advisory lock for this specific schedule
259
263
  await session.execute(
260
- text('SELECT pg_advisory_xact_lock(CAST(:key AS BIGINT))'),
264
+ SCHEDULE_ADVISORY_LOCK_SQL,
261
265
  {'key': lock_key},
262
266
  )
263
267
 
@@ -9,6 +9,39 @@ from horsies.core.logging import get_logger
9
9
 
10
10
  logger = get_logger('scheduler.state')
11
11
 
12
+ UPDATE_SCHEDULE_AFTER_RUN_SQL = text("""
13
+ UPDATE horsies_schedule_state
14
+ SET last_run_at = :executed_at,
15
+ next_run_at = :next_run_at,
16
+ last_task_id = :task_id,
17
+ run_count = run_count + 1,
18
+ updated_at = :now
19
+ WHERE schedule_name = :schedule_name
20
+ """)
21
+
22
+ UPDATE_SCHEDULE_NEXT_RUN_WITH_HASH_SQL = text("""
23
+ UPDATE horsies_schedule_state
24
+ SET next_run_at = :next_run_at,
25
+ config_hash = :config_hash,
26
+ updated_at = :now
27
+ WHERE schedule_name = :schedule_name
28
+ """)
29
+
30
+ UPDATE_SCHEDULE_NEXT_RUN_SQL = text("""
31
+ UPDATE horsies_schedule_state
32
+ SET next_run_at = :next_run_at,
33
+ updated_at = :now
34
+ WHERE schedule_name = :schedule_name
35
+ """)
36
+
37
+ DELETE_SCHEDULE_STATE_SQL = text("""
38
+ DELETE FROM horsies_schedule_state WHERE schedule_name = :schedule_name
39
+ """)
40
+
41
+ GET_ALL_SCHEDULE_STATES_SQL = text("""
42
+ SELECT * FROM horsies_schedule_state ORDER BY schedule_name
43
+ """)
44
+
12
45
 
13
46
  class ScheduleStateManager:
14
47
  """
@@ -127,15 +160,7 @@ class ScheduleStateManager:
127
160
  async with self.session_factory() as session:
128
161
  # Use raw SQL for atomic update with increment
129
162
  result = await session.execute(
130
- text("""
131
- UPDATE horsies_schedule_state
132
- SET last_run_at = :executed_at,
133
- next_run_at = :next_run_at,
134
- last_task_id = :task_id,
135
- run_count = run_count + 1,
136
- updated_at = :now
137
- WHERE schedule_name = :schedule_name
138
- """),
163
+ UPDATE_SCHEDULE_AFTER_RUN_SQL,
139
164
  {
140
165
  'schedule_name': schedule_name,
141
166
  'executed_at': executed_at,
@@ -174,13 +199,7 @@ class ScheduleStateManager:
174
199
  async with self.session_factory() as session:
175
200
  # Build UPDATE query dynamically based on whether config_hash is provided
176
201
  if config_hash is not None:
177
- query = """
178
- UPDATE horsies_schedule_state
179
- SET next_run_at = :next_run_at,
180
- config_hash = :config_hash,
181
- updated_at = :now
182
- WHERE schedule_name = :schedule_name
183
- """
202
+ query = UPDATE_SCHEDULE_NEXT_RUN_WITH_HASH_SQL
184
203
  params = {
185
204
  'schedule_name': schedule_name,
186
205
  'next_run_at': next_run_at,
@@ -188,19 +207,14 @@ class ScheduleStateManager:
188
207
  'now': datetime.now(timezone.utc),
189
208
  }
190
209
  else:
191
- query = """
192
- UPDATE horsies_schedule_state
193
- SET next_run_at = :next_run_at,
194
- updated_at = :now
195
- WHERE schedule_name = :schedule_name
196
- """
210
+ query = UPDATE_SCHEDULE_NEXT_RUN_SQL
197
211
  params = {
198
212
  'schedule_name': schedule_name,
199
213
  'next_run_at': next_run_at,
200
214
  'now': datetime.now(timezone.utc),
201
215
  }
202
216
 
203
- result = await session.execute(text(query), params)
217
+ result = await session.execute(query, params)
204
218
  await session.commit()
205
219
 
206
220
  rows_updated = getattr(result, 'rowcount', 0)
@@ -223,7 +237,7 @@ class ScheduleStateManager:
223
237
  """
224
238
  async with self.session_factory() as session:
225
239
  result = await session.execute(
226
- text('DELETE FROM horsies_schedule_state WHERE schedule_name = :schedule_name'),
240
+ DELETE_SCHEDULE_STATE_SQL,
227
241
  {'schedule_name': schedule_name},
228
242
  )
229
243
  await session.commit()
@@ -244,9 +258,7 @@ class ScheduleStateManager:
244
258
  List of all ScheduleStateModel records
245
259
  """
246
260
  async with self.session_factory() as session:
247
- result = await session.execute(
248
- text('SELECT * FROM horsies_schedule_state ORDER BY schedule_name')
249
- )
261
+ result = await session.execute(GET_ALL_SCHEDULE_STATES_SQL)
250
262
  rows = result.fetchall()
251
263
  columns = result.keys()
252
264
 
@@ -6,7 +6,9 @@ from typing import (
6
6
  get_origin,
7
7
  get_type_hints,
8
8
  get_args,
9
+ Literal,
9
10
  ParamSpec,
11
+ Sequence,
10
12
  TypeVar,
11
13
  Generic,
12
14
  Protocol,
@@ -24,6 +26,11 @@ if TYPE_CHECKING:
24
26
  from horsies.core.models.tasks import TaskOptions
25
27
  from horsies.core.models.tasks import TaskError, TaskResult
26
28
  from horsies.core.models.tasks import TaskInfo
29
+ from horsies.core.models.workflow import (
30
+ TaskNode,
31
+ SubWorkflowNode,
32
+ WorkflowContext,
33
+ )
27
34
 
28
35
  from horsies.core.models.tasks import TaskResult, TaskError, LibraryErrorCode
29
36
  from horsies.core.models.workflow import WorkflowContextMissingIdError
@@ -241,6 +248,87 @@ class TaskHandle(Generic[T]):
241
248
  self._result_fetched = True
242
249
 
243
250
 
251
+ class NodeFactory(Generic[P, T]):
252
+ """
253
+ Factory for creating TaskNode instances with typed arguments.
254
+
255
+ Returned by TaskFunction.node(). Call with the task's arguments
256
+ to create a TaskNode with full static type checking.
257
+
258
+ Example:
259
+ node = my_task.node(waits_for=[dep])(value='test')
260
+ # Type checker validates 'value' against my_task's signature
261
+ """
262
+
263
+ _fn: 'TaskFunction[P, T]'
264
+ _waits_for: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None
265
+ _workflow_ctx_from: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None
266
+ _args_from: dict[str, 'TaskNode[Any] | SubWorkflowNode[Any]'] | None
267
+ _queue: str | None
268
+ _priority: int | None
269
+ _allow_failed_deps: bool
270
+ _run_when: Callable[['WorkflowContext'], bool] | None
271
+ _skip_when: Callable[['WorkflowContext'], bool] | None
272
+ _join: Literal['all', 'any', 'quorum']
273
+ _min_success: int | None
274
+ _good_until: datetime | None
275
+ _node_id: str | None
276
+
277
+ def __init__(
278
+ self,
279
+ fn: 'TaskFunction[P, T]',
280
+ *,
281
+ waits_for: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None,
282
+ workflow_ctx_from: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None,
283
+ args_from: dict[str, 'TaskNode[Any] | SubWorkflowNode[Any]'] | None,
284
+ queue: str | None,
285
+ priority: int | None,
286
+ allow_failed_deps: bool,
287
+ run_when: Callable[['WorkflowContext'], bool] | None,
288
+ skip_when: Callable[['WorkflowContext'], bool] | None,
289
+ join: Literal['all', 'any', 'quorum'],
290
+ min_success: int | None,
291
+ good_until: datetime | None,
292
+ node_id: str | None,
293
+ ) -> None:
294
+ self._fn = fn
295
+ self._waits_for = waits_for
296
+ self._workflow_ctx_from = workflow_ctx_from
297
+ self._args_from = args_from
298
+ self._queue = queue
299
+ self._priority = priority
300
+ self._allow_failed_deps = allow_failed_deps
301
+ self._run_when = run_when
302
+ self._skip_when = skip_when
303
+ self._join = join
304
+ self._min_success = min_success
305
+ self._good_until = good_until
306
+ self._node_id = node_id
307
+
308
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> 'TaskNode[T]':
309
+ from horsies.core.models.workflow import TaskNode
310
+
311
+ return TaskNode(
312
+ fn=self._fn,
313
+ args=args,
314
+ kwargs=dict(kwargs),
315
+ waits_for=list(self._waits_for) if self._waits_for else [],
316
+ workflow_ctx_from=list(self._workflow_ctx_from)
317
+ if self._workflow_ctx_from
318
+ else None,
319
+ args_from=dict(self._args_from) if self._args_from else {},
320
+ queue=self._queue,
321
+ priority=self._priority,
322
+ allow_failed_deps=self._allow_failed_deps,
323
+ run_when=self._run_when,
324
+ skip_when=self._skip_when,
325
+ join=self._join,
326
+ min_success=self._min_success,
327
+ good_until=self._good_until,
328
+ node_id=self._node_id,
329
+ )
330
+
331
+
244
332
  class TaskFunction(Protocol[P, T]):
245
333
  """
246
334
  A TaskFunction is a function that gets a @task decorator applied to it.
@@ -281,6 +369,24 @@ class TaskFunction(Protocol[P, T]):
281
369
  **kwargs: P.kwargs,
282
370
  ) -> 'TaskHandle[T]': ...
283
371
 
372
+ @abstractmethod
373
+ def node(
374
+ self,
375
+ *,
376
+ waits_for: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None = None,
377
+ workflow_ctx_from: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None = None,
378
+ args_from: dict[str, 'TaskNode[Any] | SubWorkflowNode[Any]'] | None = None,
379
+ queue: str | None = None,
380
+ priority: int | None = None,
381
+ allow_failed_deps: bool = False,
382
+ run_when: Callable[['WorkflowContext'], bool] | None = None,
383
+ skip_when: Callable[['WorkflowContext'], bool] | None = None,
384
+ join: Literal['all', 'any', 'quorum'] = 'all',
385
+ min_success: int | None = None,
386
+ good_until: datetime | None = None,
387
+ node_id: str | None = None,
388
+ ) -> 'NodeFactory[P, T]': ...
389
+
284
390
 
285
391
  def create_task_wrapper(
286
392
  fn: Callable[P, TaskResult[T, TaskError]],
@@ -647,6 +753,38 @@ def create_task_wrapper(
647
753
  ) -> TaskHandle[T]:
648
754
  return schedule(delay, *args, **kwargs)
649
755
 
756
+ def node(
757
+ self,
758
+ *,
759
+ waits_for: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None = None,
760
+ workflow_ctx_from: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None = None,
761
+ args_from: dict[str, 'TaskNode[Any] | SubWorkflowNode[Any]'] | None = None,
762
+ queue: str | None = None,
763
+ priority: int | None = None,
764
+ allow_failed_deps: bool = False,
765
+ run_when: Callable[['WorkflowContext'], bool] | None = None,
766
+ skip_when: Callable[['WorkflowContext'], bool] | None = None,
767
+ join: Literal['all', 'any', 'quorum'] = 'all',
768
+ min_success: int | None = None,
769
+ good_until: datetime | None = None,
770
+ node_id: str | None = None,
771
+ ) -> NodeFactory[P, T]:
772
+ return NodeFactory(
773
+ fn=self, # type: ignore[arg-type]
774
+ waits_for=waits_for,
775
+ workflow_ctx_from=workflow_ctx_from,
776
+ args_from=args_from,
777
+ queue=queue,
778
+ priority=priority,
779
+ allow_failed_deps=allow_failed_deps,
780
+ run_when=run_when,
781
+ skip_when=skip_when,
782
+ join=join,
783
+ min_success=min_success,
784
+ good_until=good_until,
785
+ node_id=node_id,
786
+ )
787
+
650
788
  # Copy metadata
651
789
  def __getattr__(self, name: str) -> Any:
652
790
  return getattr(wrapped_function, name)
@@ -31,8 +31,10 @@ class TaskStatus(Enum):
31
31
  return self in TASK_TERMINAL_STATES
32
32
 
33
33
 
34
- TASK_TERMINAL_STATES: frozenset[TaskStatus] = frozenset({
35
- TaskStatus.COMPLETED,
36
- TaskStatus.FAILED,
37
- TaskStatus.CANCELLED,
38
- })
34
+ TASK_TERMINAL_STATES: frozenset[TaskStatus] = frozenset(
35
+ {
36
+ TaskStatus.COMPLETED,
37
+ TaskStatus.FAILED,
38
+ TaskStatus.CANCELLED,
39
+ }
40
+ )
@@ -20,7 +20,7 @@ from typing import Any
20
20
 
21
21
  from horsies.core.logging import get_logger
22
22
 
23
- logger = get_logger("imports")
23
+ logger = get_logger('imports')
24
24
 
25
25
 
26
26
  def find_project_root(start_dir: str) -> str | None:
@@ -34,7 +34,7 @@ def find_project_root(start_dir: str) -> str | None:
34
34
  Returns start_dir if it contains a marker file, None otherwise.
35
35
  """
36
36
  start_dir = os.path.abspath(start_dir)
37
- for marker in ("pyproject.toml", "setup.cfg", "setup.py"):
37
+ for marker in ('pyproject.toml', 'setup.cfg', 'setup.py'):
38
38
  if os.path.exists(os.path.join(start_dir, marker)):
39
39
  return start_dir
40
40
  return None
@@ -54,7 +54,7 @@ def setup_sys_path_from_cwd() -> str | None:
54
54
  cwd = os.getcwd()
55
55
  if find_project_root(cwd) and cwd not in sys.path:
56
56
  sys.path.insert(0, cwd)
57
- logger.debug(f"Added cwd to sys.path: {cwd}")
57
+ logger.debug(f'Added cwd to sys.path: {cwd}')
58
58
  return cwd
59
59
  return None
60
60
 
@@ -90,7 +90,7 @@ def _compute_synthetic_module_name(path: str) -> str:
90
90
  """
91
91
  realpath = os.path.realpath(path)
92
92
  hash_prefix = hashlib.sha256(realpath.encode()).hexdigest()[:12]
93
- return f"horsies._dynamic.{hash_prefix}"
93
+ return f'horsies._dynamic.{hash_prefix}'
94
94
 
95
95
 
96
96
  def import_file_path(
@@ -121,11 +121,11 @@ def import_file_path(
121
121
  file_path = os.path.realpath(file_path)
122
122
 
123
123
  if not os.path.exists(file_path):
124
- raise FileNotFoundError(f"Module file not found: {file_path}")
124
+ raise FileNotFoundError(f'Module file not found: {file_path}')
125
125
 
126
126
  # Check if already loaded
127
127
  for name, mod in list(sys.modules.items()):
128
- mod_file = getattr(mod, "__file__", None)
128
+ mod_file = getattr(mod, '__file__', None)
129
129
  if mod_file and os.path.realpath(mod_file) == file_path:
130
130
  return mod
131
131
 
@@ -142,7 +142,7 @@ def import_file_path(
142
142
  # Load the module
143
143
  spec = importlib.util.spec_from_file_location(module_name, file_path)
144
144
  if spec is None or spec.loader is None:
145
- raise ImportError(f"Could not load module from path: {file_path}")
145
+ raise ImportError(f'Could not load module from path: {file_path}')
146
146
 
147
147
  mod = importlib.util.module_from_spec(spec)
148
148
  sys.modules[module_name] = mod
@@ -164,7 +164,7 @@ def import_by_path(path: str, module_name: str | None = None) -> Any:
164
164
  For file paths: delegates to import_file_path()
165
165
  For module paths: delegates to import_module_path()
166
166
  """
167
- if path.endswith(".py") or os.path.sep in path:
167
+ if path.endswith('.py') or os.path.sep in path:
168
168
  return import_file_path(path, module_name)
169
169
  else:
170
170
  return import_module_path(path)
@@ -187,7 +187,7 @@ def compute_package_path_from_fs(file_path: str) -> tuple[str | None, str | None
187
187
 
188
188
  components = [module_name]
189
189
  while True:
190
- init_path = os.path.join(current_dir, "__init__.py")
190
+ init_path = os.path.join(current_dir, '__init__.py')
191
191
  if not os.path.exists(init_path):
192
192
  break
193
193
  package_name = os.path.basename(current_dir)
@@ -198,6 +198,6 @@ def compute_package_path_from_fs(file_path: str) -> tuple[str | None, str | None
198
198
  return (None, None)
199
199
 
200
200
  components.reverse()
201
- dotted_name = ".".join(components)
201
+ dotted_name = '.'.join(components)
202
202
  package_root = current_dir
203
203
  return (dotted_name, package_root)