planar 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (289) hide show
  1. planar/.__init__.py.un~ +0 -0
  2. planar/._version.py.un~ +0 -0
  3. planar/.app.py.un~ +0 -0
  4. planar/.cli.py.un~ +0 -0
  5. planar/.config.py.un~ +0 -0
  6. planar/.context.py.un~ +0 -0
  7. planar/.db.py.un~ +0 -0
  8. planar/.di.py.un~ +0 -0
  9. planar/.engine.py.un~ +0 -0
  10. planar/.files.py.un~ +0 -0
  11. planar/.log_context.py.un~ +0 -0
  12. planar/.log_metadata.py.un~ +0 -0
  13. planar/.logging.py.un~ +0 -0
  14. planar/.object_registry.py.un~ +0 -0
  15. planar/.otel.py.un~ +0 -0
  16. planar/.server.py.un~ +0 -0
  17. planar/.session.py.un~ +0 -0
  18. planar/.sqlalchemy.py.un~ +0 -0
  19. planar/.task_local.py.un~ +0 -0
  20. planar/.test_app.py.un~ +0 -0
  21. planar/.test_config.py.un~ +0 -0
  22. planar/.test_object_config.py.un~ +0 -0
  23. planar/.test_sqlalchemy.py.un~ +0 -0
  24. planar/.test_utils.py.un~ +0 -0
  25. planar/.util.py.un~ +0 -0
  26. planar/.utils.py.un~ +0 -0
  27. planar/__init__.py +26 -0
  28. planar/_version.py +1 -0
  29. planar/ai/.__init__.py.un~ +0 -0
  30. planar/ai/._models.py.un~ +0 -0
  31. planar/ai/.agent.py.un~ +0 -0
  32. planar/ai/.agent_utils.py.un~ +0 -0
  33. planar/ai/.events.py.un~ +0 -0
  34. planar/ai/.files.py.un~ +0 -0
  35. planar/ai/.models.py.un~ +0 -0
  36. planar/ai/.providers.py.un~ +0 -0
  37. planar/ai/.pydantic_ai.py.un~ +0 -0
  38. planar/ai/.pydantic_ai_agent.py.un~ +0 -0
  39. planar/ai/.pydantic_ai_provider.py.un~ +0 -0
  40. planar/ai/.step.py.un~ +0 -0
  41. planar/ai/.test_agent.py.un~ +0 -0
  42. planar/ai/.test_agent_serialization.py.un~ +0 -0
  43. planar/ai/.test_providers.py.un~ +0 -0
  44. planar/ai/.utils.py.un~ +0 -0
  45. planar/ai/__init__.py +15 -0
  46. planar/ai/agent.py +457 -0
  47. planar/ai/agent_utils.py +205 -0
  48. planar/ai/models.py +140 -0
  49. planar/ai/providers.py +1088 -0
  50. planar/ai/test_agent.py +1298 -0
  51. planar/ai/test_agent_serialization.py +229 -0
  52. planar/ai/test_providers.py +463 -0
  53. planar/ai/utils.py +102 -0
  54. planar/app.py +494 -0
  55. planar/cli.py +282 -0
  56. planar/config.py +544 -0
  57. planar/db/.db.py.un~ +0 -0
  58. planar/db/__init__.py +17 -0
  59. planar/db/alembic/env.py +136 -0
  60. planar/db/alembic/script.py.mako +28 -0
  61. planar/db/alembic/versions/3476068c153c_initial_system_tables_migration.py +339 -0
  62. planar/db/alembic.ini +128 -0
  63. planar/db/db.py +318 -0
  64. planar/files/.config.py.un~ +0 -0
  65. planar/files/.local.py.un~ +0 -0
  66. planar/files/.local_filesystem.py.un~ +0 -0
  67. planar/files/.model.py.un~ +0 -0
  68. planar/files/.models.py.un~ +0 -0
  69. planar/files/.s3.py.un~ +0 -0
  70. planar/files/.storage.py.un~ +0 -0
  71. planar/files/.test_files.py.un~ +0 -0
  72. planar/files/__init__.py +2 -0
  73. planar/files/models.py +162 -0
  74. planar/files/storage/.__init__.py.un~ +0 -0
  75. planar/files/storage/.base.py.un~ +0 -0
  76. planar/files/storage/.config.py.un~ +0 -0
  77. planar/files/storage/.context.py.un~ +0 -0
  78. planar/files/storage/.local_directory.py.un~ +0 -0
  79. planar/files/storage/.test_local_directory.py.un~ +0 -0
  80. planar/files/storage/.test_s3.py.un~ +0 -0
  81. planar/files/storage/base.py +61 -0
  82. planar/files/storage/config.py +44 -0
  83. planar/files/storage/context.py +15 -0
  84. planar/files/storage/local_directory.py +188 -0
  85. planar/files/storage/s3.py +220 -0
  86. planar/files/storage/test_local_directory.py +162 -0
  87. planar/files/storage/test_s3.py +299 -0
  88. planar/files/test_files.py +283 -0
  89. planar/human/.human.py.un~ +0 -0
  90. planar/human/.test_human.py.un~ +0 -0
  91. planar/human/__init__.py +2 -0
  92. planar/human/human.py +458 -0
  93. planar/human/models.py +80 -0
  94. planar/human/test_human.py +385 -0
  95. planar/logging/.__init__.py.un~ +0 -0
  96. planar/logging/.attributes.py.un~ +0 -0
  97. planar/logging/.formatter.py.un~ +0 -0
  98. planar/logging/.logger.py.un~ +0 -0
  99. planar/logging/.otel.py.un~ +0 -0
  100. planar/logging/.tracer.py.un~ +0 -0
  101. planar/logging/__init__.py +10 -0
  102. planar/logging/attributes.py +54 -0
  103. planar/logging/context.py +14 -0
  104. planar/logging/formatter.py +113 -0
  105. planar/logging/logger.py +114 -0
  106. planar/logging/otel.py +51 -0
  107. planar/modeling/.mixin.py.un~ +0 -0
  108. planar/modeling/.storage.py.un~ +0 -0
  109. planar/modeling/__init__.py +0 -0
  110. planar/modeling/field_helpers.py +59 -0
  111. planar/modeling/json_schema_generator.py +94 -0
  112. planar/modeling/mixins/__init__.py +10 -0
  113. planar/modeling/mixins/auditable.py +52 -0
  114. planar/modeling/mixins/test_auditable.py +97 -0
  115. planar/modeling/mixins/test_timestamp.py +134 -0
  116. planar/modeling/mixins/test_uuid_primary_key.py +52 -0
  117. planar/modeling/mixins/timestamp.py +53 -0
  118. planar/modeling/mixins/uuid_primary_key.py +19 -0
  119. planar/modeling/orm/.planar_base_model.py.un~ +0 -0
  120. planar/modeling/orm/__init__.py +18 -0
  121. planar/modeling/orm/planar_base_entity.py +29 -0
  122. planar/modeling/orm/query_filter_builder.py +122 -0
  123. planar/modeling/orm/reexports.py +15 -0
  124. planar/object_config/.object_config.py.un~ +0 -0
  125. planar/object_config/__init__.py +11 -0
  126. planar/object_config/models.py +114 -0
  127. planar/object_config/object_config.py +378 -0
  128. planar/object_registry.py +100 -0
  129. planar/registry_items.py +65 -0
  130. planar/routers/.__init__.py.un~ +0 -0
  131. planar/routers/.agents_router.py.un~ +0 -0
  132. planar/routers/.crud.py.un~ +0 -0
  133. planar/routers/.decision.py.un~ +0 -0
  134. planar/routers/.event.py.un~ +0 -0
  135. planar/routers/.file_attachment.py.un~ +0 -0
  136. planar/routers/.files.py.un~ +0 -0
  137. planar/routers/.files_router.py.un~ +0 -0
  138. planar/routers/.human.py.un~ +0 -0
  139. planar/routers/.info.py.un~ +0 -0
  140. planar/routers/.models.py.un~ +0 -0
  141. planar/routers/.object_config_router.py.un~ +0 -0
  142. planar/routers/.rule.py.un~ +0 -0
  143. planar/routers/.test_object_config_router.py.un~ +0 -0
  144. planar/routers/.test_workflow_router.py.un~ +0 -0
  145. planar/routers/.workflow.py.un~ +0 -0
  146. planar/routers/__init__.py +13 -0
  147. planar/routers/agents_router.py +197 -0
  148. planar/routers/entity_router.py +143 -0
  149. planar/routers/event.py +91 -0
  150. planar/routers/files.py +142 -0
  151. planar/routers/human.py +151 -0
  152. planar/routers/info.py +131 -0
  153. planar/routers/models.py +170 -0
  154. planar/routers/object_config_router.py +133 -0
  155. planar/routers/rule.py +108 -0
  156. planar/routers/test_agents_router.py +174 -0
  157. planar/routers/test_object_config_router.py +367 -0
  158. planar/routers/test_routes_security.py +169 -0
  159. planar/routers/test_rule_router.py +470 -0
  160. planar/routers/test_workflow_router.py +274 -0
  161. planar/routers/workflow.py +468 -0
  162. planar/rules/.decorator.py.un~ +0 -0
  163. planar/rules/.runner.py.un~ +0 -0
  164. planar/rules/.test_rules.py.un~ +0 -0
  165. planar/rules/__init__.py +23 -0
  166. planar/rules/decorator.py +184 -0
  167. planar/rules/models.py +355 -0
  168. planar/rules/rule_configuration.py +191 -0
  169. planar/rules/runner.py +64 -0
  170. planar/rules/test_rules.py +750 -0
  171. planar/scaffold_templates/app/__init__.py.j2 +0 -0
  172. planar/scaffold_templates/app/db/entities.py.j2 +11 -0
  173. planar/scaffold_templates/app/flows/process_invoice.py.j2 +67 -0
  174. planar/scaffold_templates/main.py.j2 +13 -0
  175. planar/scaffold_templates/planar.dev.yaml.j2 +34 -0
  176. planar/scaffold_templates/planar.prod.yaml.j2 +28 -0
  177. planar/scaffold_templates/pyproject.toml.j2 +10 -0
  178. planar/security/.jwt_middleware.py.un~ +0 -0
  179. planar/security/auth_context.py +148 -0
  180. planar/security/authorization.py +388 -0
  181. planar/security/default_policies.cedar +77 -0
  182. planar/security/jwt_middleware.py +116 -0
  183. planar/security/security_context.py +18 -0
  184. planar/security/tests/test_authorization_context.py +78 -0
  185. planar/security/tests/test_cedar_basics.py +41 -0
  186. planar/security/tests/test_cedar_policies.py +158 -0
  187. planar/security/tests/test_jwt_principal_context.py +179 -0
  188. planar/session.py +40 -0
  189. planar/sse/.constants.py.un~ +0 -0
  190. planar/sse/.example.html.un~ +0 -0
  191. planar/sse/.hub.py.un~ +0 -0
  192. planar/sse/.model.py.un~ +0 -0
  193. planar/sse/.proxy.py.un~ +0 -0
  194. planar/sse/constants.py +1 -0
  195. planar/sse/example.html +126 -0
  196. planar/sse/hub.py +216 -0
  197. planar/sse/model.py +8 -0
  198. planar/sse/proxy.py +257 -0
  199. planar/task_local.py +37 -0
  200. planar/test_app.py +51 -0
  201. planar/test_cli.py +372 -0
  202. planar/test_config.py +512 -0
  203. planar/test_object_config.py +527 -0
  204. planar/test_object_registry.py +14 -0
  205. planar/test_sqlalchemy.py +158 -0
  206. planar/test_utils.py +105 -0
  207. planar/testing/.client.py.un~ +0 -0
  208. planar/testing/.memory_storage.py.un~ +0 -0
  209. planar/testing/.planar_test_client.py.un~ +0 -0
  210. planar/testing/.predictable_tracer.py.un~ +0 -0
  211. planar/testing/.synchronizable_tracer.py.un~ +0 -0
  212. planar/testing/.test_memory_storage.py.un~ +0 -0
  213. planar/testing/.workflow_observer.py.un~ +0 -0
  214. planar/testing/__init__.py +0 -0
  215. planar/testing/memory_storage.py +78 -0
  216. planar/testing/planar_test_client.py +54 -0
  217. planar/testing/synchronizable_tracer.py +153 -0
  218. planar/testing/test_memory_storage.py +143 -0
  219. planar/testing/workflow_observer.py +73 -0
  220. planar/utils.py +70 -0
  221. planar/workflows/.__init__.py.un~ +0 -0
  222. planar/workflows/.builtin_steps.py.un~ +0 -0
  223. planar/workflows/.concurrency_tracing.py.un~ +0 -0
  224. planar/workflows/.context.py.un~ +0 -0
  225. planar/workflows/.contrib.py.un~ +0 -0
  226. planar/workflows/.decorators.py.un~ +0 -0
  227. planar/workflows/.durable_test.py.un~ +0 -0
  228. planar/workflows/.errors.py.un~ +0 -0
  229. planar/workflows/.events.py.un~ +0 -0
  230. planar/workflows/.exceptions.py.un~ +0 -0
  231. planar/workflows/.execution.py.un~ +0 -0
  232. planar/workflows/.human.py.un~ +0 -0
  233. planar/workflows/.lock.py.un~ +0 -0
  234. planar/workflows/.misc.py.un~ +0 -0
  235. planar/workflows/.model.py.un~ +0 -0
  236. planar/workflows/.models.py.un~ +0 -0
  237. planar/workflows/.notifications.py.un~ +0 -0
  238. planar/workflows/.orchestrator.py.un~ +0 -0
  239. planar/workflows/.runtime.py.un~ +0 -0
  240. planar/workflows/.serialization.py.un~ +0 -0
  241. planar/workflows/.step.py.un~ +0 -0
  242. planar/workflows/.step_core.py.un~ +0 -0
  243. planar/workflows/.sub_workflow_runner.py.un~ +0 -0
  244. planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
  245. planar/workflows/.test_concurrency.py.un~ +0 -0
  246. planar/workflows/.test_concurrency_detection.py.un~ +0 -0
  247. planar/workflows/.test_human.py.un~ +0 -0
  248. planar/workflows/.test_lock_timeout.py.un~ +0 -0
  249. planar/workflows/.test_orchestrator.py.un~ +0 -0
  250. planar/workflows/.test_race_conditions.py.un~ +0 -0
  251. planar/workflows/.test_serialization.py.un~ +0 -0
  252. planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
  253. planar/workflows/.test_workflow.py.un~ +0 -0
  254. planar/workflows/.tracing.py.un~ +0 -0
  255. planar/workflows/.types.py.un~ +0 -0
  256. planar/workflows/.util.py.un~ +0 -0
  257. planar/workflows/.utils.py.un~ +0 -0
  258. planar/workflows/.workflow.py.un~ +0 -0
  259. planar/workflows/.workflow_wrapper.py.un~ +0 -0
  260. planar/workflows/.wrappers.py.un~ +0 -0
  261. planar/workflows/__init__.py +42 -0
  262. planar/workflows/context.py +44 -0
  263. planar/workflows/contrib.py +190 -0
  264. planar/workflows/decorators.py +217 -0
  265. planar/workflows/events.py +185 -0
  266. planar/workflows/exceptions.py +34 -0
  267. planar/workflows/execution.py +198 -0
  268. planar/workflows/lock.py +229 -0
  269. planar/workflows/misc.py +5 -0
  270. planar/workflows/models.py +154 -0
  271. planar/workflows/notifications.py +96 -0
  272. planar/workflows/orchestrator.py +383 -0
  273. planar/workflows/query.py +256 -0
  274. planar/workflows/serialization.py +409 -0
  275. planar/workflows/step_core.py +373 -0
  276. planar/workflows/step_metadata.py +357 -0
  277. planar/workflows/step_testing_utils.py +86 -0
  278. planar/workflows/sub_workflow_runner.py +191 -0
  279. planar/workflows/test_concurrency_detection.py +120 -0
  280. planar/workflows/test_lock_timeout.py +140 -0
  281. planar/workflows/test_serialization.py +1195 -0
  282. planar/workflows/test_suspend_deserialization.py +231 -0
  283. planar/workflows/test_workflow.py +1967 -0
  284. planar/workflows/tracing.py +106 -0
  285. planar/workflows/wrappers.py +41 -0
  286. planar-0.5.0.dist-info/METADATA +285 -0
  287. planar-0.5.0.dist-info/RECORD +289 -0
  288. planar-0.5.0.dist-info/WHEEL +4 -0
  289. planar-0.5.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,373 @@
1
+ import inspect
2
+ import json
3
+ import traceback
4
+ from collections.abc import Mapping, Sequence
5
+ from dataclasses import dataclass
6
+ from datetime import datetime, timedelta
7
+ from functools import wraps
8
+ from typing import Callable, Coroutine, Type, cast
9
+
10
+ from sqlmodel import col, select
11
+
12
+ from planar.logging import get_logger
13
+ from planar.session import get_session
14
+ from planar.utils import P, R, T, U, utc_now
15
+ from planar.workflows.context import get_context
16
+ from planar.workflows.exceptions import (
17
+ NonDeterministicStepCallError,
18
+ try_restore_exception,
19
+ )
20
+ from planar.workflows.misc import func_full_name
21
+ from planar.workflows.models import StepStatus, StepType, WorkflowStep
22
+ from planar.workflows.notifications import step_failed, step_running, step_succeeded
23
+ from planar.workflows.serialization import (
24
+ deserialize_result,
25
+ serialize_args,
26
+ serialize_result,
27
+ )
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ def deep_equals(obj1, obj2):
33
+ """Recursively compares two JSON-like objects for equality"""
34
+
35
+ if isinstance(obj1, Mapping) and isinstance(obj2, Mapping):
36
+ if len(obj1) != len(obj2):
37
+ return False
38
+ for k1, v1 in obj1.items():
39
+ if k1 not in obj2:
40
+ return False
41
+ if not deep_equals(v1, obj2[k1]):
42
+ return False
43
+ elif (
44
+ isinstance(obj1, Sequence)
45
+ and isinstance(obj2, Sequence)
46
+ and not isinstance(obj1, (str, bytes))
47
+ ):
48
+ if len(obj1) != len(obj2):
49
+ return False
50
+ for item1, item2 in zip(obj1, obj2):
51
+ if not deep_equals(item1, item2):
52
+ return False
53
+
54
+ elif obj1 != obj2:
55
+ return False
56
+
57
+ return True
58
+
59
+
60
+ @dataclass(kw_only=True, frozen=True)
61
+ class Suspend:
62
+ wakeup_at: datetime | None
63
+ event_key: str | None
64
+ exception: Exception | None
65
+
66
+ def __await__(self):
67
+ result = yield self
68
+ return result
69
+
70
+
71
+ def suspend_workflow(
72
+ wakeup_at: datetime | None = None,
73
+ interval: timedelta | None = None,
74
+ event_key: str | None = None,
75
+ exception: Exception | None = None,
76
+ ) -> Suspend:
77
+ logger.debug(
78
+ "suspending workflow",
79
+ wakeup_at=wakeup_at,
80
+ interval_seconds=interval.total_seconds() if interval else None,
81
+ event_key=event_key,
82
+ exception=str(exception) if exception else None,
83
+ )
84
+ if exception is not None:
85
+ return Suspend(wakeup_at=None, event_key=None, exception=exception)
86
+
87
+ ctx = get_context()
88
+ workflow = ctx.workflow
89
+
90
+ if interval and wakeup_at:
91
+ raise ValueError("Only one of interval or wakeup_at must be provided")
92
+
93
+ # Set the workflow waiting_for_event, when provided
94
+ workflow.waiting_for_event = event_key
95
+
96
+ if wakeup_at is None and interval is None:
97
+ if event_key is None:
98
+ raise ValueError("Either wakeup_at or interval must be provided")
99
+ else:
100
+ workflow.wakeup_at = None
101
+ logger.debug(
102
+ "workflow suspended waiting for event",
103
+ workflow_id=ctx.workflow_id,
104
+ event_key=event_key,
105
+ )
106
+ return Suspend(wakeup_at=None, event_key=event_key, exception=None)
107
+
108
+ if interval is not None:
109
+ wakeup_at = utc_now() + interval
110
+ workflow.wakeup_at = wakeup_at
111
+ logger.debug(
112
+ "workflow suspended until",
113
+ workflow_id=ctx.workflow_id,
114
+ wakeup_at=wakeup_at,
115
+ event_key=event_key,
116
+ )
117
+ return Suspend(wakeup_at=wakeup_at, event_key=event_key, exception=None)
118
+
119
+
120
+ def _step(
121
+ *,
122
+ max_retries: int = 0,
123
+ return_type: Type | None = None,
124
+ ):
125
+ def decorator(
126
+ func: Callable[P, Coroutine[T, U, R]],
127
+ step_type: StepType = StepType.COMPUTE,
128
+ display_name: str | None = None,
129
+ ) -> Callable[P, Coroutine[T, U, R]]:
130
+ if not inspect.iscoroutinefunction(func):
131
+ raise TypeError("Step functions must be coroutines")
132
+
133
+ name = func_full_name(func)
134
+
135
+ @wraps(func)
136
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
137
+ session = get_session()
138
+ ctx = get_context()
139
+ ctx.current_step_id += 1
140
+ logger.debug("executing step", step_name=name, args=args, kwargs=kwargs)
141
+ step_id = ctx.current_step_id
142
+ validate_repeated_call_args = False
143
+
144
+ step = (
145
+ await session.exec(
146
+ select(WorkflowStep)
147
+ .where(col(WorkflowStep.workflow_id) == ctx.workflow_id)
148
+ .where(col(WorkflowStep.step_id) == step_id)
149
+ )
150
+ ).first()
151
+
152
+ if not step:
153
+ logger.debug(
154
+ "first time executing step",
155
+ step_name=name,
156
+ step_id=step_id,
157
+ workflow_id=ctx.workflow_id,
158
+ )
159
+ # first time executing this step
160
+ # Get parent from the step stack if available
161
+ parent_step_id = None
162
+ if ctx.step_stack:
163
+ parent_step_id = ctx.step_stack[-1].step_id
164
+
165
+ step = WorkflowStep(
166
+ step_id=step_id,
167
+ workflow_id=ctx.workflow_id,
168
+ function_name=name,
169
+ step_type=step_type,
170
+ args=[],
171
+ kwargs={},
172
+ status=StepStatus.RUNNING,
173
+ parent_step_id=parent_step_id,
174
+ display_name=display_name,
175
+ )
176
+ session.add(step)
177
+
178
+ elif step.status == StepStatus.SUCCEEDED:
179
+ logger.info(
180
+ "step already completed, returning cached result",
181
+ step_name=name,
182
+ step_id=step_id,
183
+ )
184
+ # already completed, return cached result
185
+ # have to update the current_step_id in the context
186
+ ctx.current_step_id += step.sub_step_count
187
+ deserialized_result = deserialize_result(
188
+ func, step.result, return_type, args, kwargs
189
+ )
190
+ return cast(R, deserialized_result)
191
+ elif step.status == StepStatus.FAILED:
192
+ logger.debug(
193
+ "step previously failed, checking for retry or non-determinism",
194
+ step_name=name,
195
+ step_id=step_id,
196
+ )
197
+ # Check that the function name is the same as the previous call. Note
198
+ # that we need to check this before checking max_retries, because if
199
+ # the step is different and has a different max_retries setting, we could
200
+ # try to restore/raise the exception of the initial step
201
+ if step.function_name != name:
202
+ step.status = StepStatus.FAILED
203
+ err_msg = (
204
+ f"Non-deterministic step call detected at step ID {step_id}. "
205
+ f"Previous function name: {step.function_name}, current: {name}"
206
+ )
207
+ logger.warning(
208
+ "non-deterministic step call detected",
209
+ step_id=step_id,
210
+ previous_function_name=step.function_name,
211
+ current_function_name=name,
212
+ )
213
+ await suspend_workflow(
214
+ exception=NonDeterministicStepCallError(err_msg)
215
+ )
216
+ assert False, "Non-deterministic step call detected"
217
+
218
+ validate_repeated_call_args = True
219
+
220
+ if max_retries < 0 or step.retry_count < max_retries:
221
+ logger.info(
222
+ "retrying step",
223
+ step_name=name,
224
+ step_id=step_id,
225
+ retry_count=step.retry_count + 1,
226
+ max_retries=max_retries if max_retries >= 0 else "unlimited",
227
+ )
228
+ # failed previously but will be retried
229
+ step.retry_count += 1
230
+ step.status = StepStatus.RUNNING
231
+ else:
232
+ assert step.error
233
+ logger.warning(
234
+ "max retries reached for step, raising original error",
235
+ step_name=name,
236
+ step_id=step_id,
237
+ )
238
+ # max retries reached
239
+ raise try_restore_exception(step.error)
240
+
241
+ # Add step input parameters to the step record
242
+ serialized_args, serialized_kwargs = serialize_args(func, args, kwargs)
243
+
244
+ if validate_repeated_call_args:
245
+ logger.debug(
246
+ "validating repeated call arguments for step",
247
+ step_name=name,
248
+ step_id=step_id,
249
+ )
250
+ # Check that the arguments are the same - deep compare args tuple
251
+ if not deep_equals(step.args, serialized_args):
252
+ step.status = StepStatus.FAILED
253
+ err_msg = (
254
+ f"Non-deterministic step call detected at step ID {step_id}. "
255
+ f"Previous args: {json.dumps(step.args)}, current: {json.dumps(serialized_args)}"
256
+ )
257
+ logger.warning(
258
+ "non-deterministic step call detected on args",
259
+ step_id=step_id,
260
+ previous_args=json.dumps(step.args),
261
+ current_args=json.dumps(serialized_args),
262
+ )
263
+ await suspend_workflow(
264
+ exception=NonDeterministicStepCallError(err_msg)
265
+ )
266
+ assert False, "Non-deterministic step call detected"
267
+
268
+ # Check keyword arguments determinism - deep compare kwargs dict
269
+ if not deep_equals(step.kwargs, serialized_kwargs):
270
+ step.status = StepStatus.FAILED
271
+ err_msg = (
272
+ f"Non-deterministic step call detected at step ID {step_id}. "
273
+ f"Previous kwargs: {json.dumps(step.kwargs)}, current: {json.dumps(serialized_kwargs)}"
274
+ )
275
+ logger.warning(
276
+ "non-deterministic step call detected on kwargs",
277
+ step_id=step_id,
278
+ previous_kwargs=json.dumps(step.kwargs),
279
+ current_kwargs=json.dumps(serialized_kwargs),
280
+ )
281
+ await suspend_workflow(
282
+ exception=NonDeterministicStepCallError(err_msg)
283
+ )
284
+ assert False, "Non-deterministic step call detected"
285
+
286
+ step.args = serialized_args
287
+ step.kwargs = serialized_kwargs
288
+
289
+ await session.commit()
290
+ step_running(step)
291
+
292
+ ctx.step_stack.append(step)
293
+ logger.debug(
294
+ "step pushed to stack",
295
+ step_name=name,
296
+ step_id=step_id,
297
+ stack_size=len(ctx.step_stack),
298
+ )
299
+
300
+ try:
301
+ result = await func(*args, **kwargs)
302
+ step.status = StepStatus.SUCCEEDED
303
+ step.result = serialize_result(func, result)
304
+ step.error = None
305
+ step.sub_step_count = ctx.current_step_id - step_id
306
+ await session.commit()
307
+ step_succeeded(step)
308
+ logger.info(
309
+ "step succeeded",
310
+ step_name=name,
311
+ step_id=step_id,
312
+ result=step.result,
313
+ )
314
+ # Deserialize the result to ensure consistency
315
+ # between initial run and re-runs (due to suspension).
316
+ deserialized_result = deserialize_result(
317
+ func, step.result, return_type, args, kwargs
318
+ )
319
+ return cast(R, deserialized_result)
320
+ except BaseException as e:
321
+ if isinstance(e, GeneratorExit):
322
+ raise
323
+ logger.exception("exception in step", step_name=name, step_id=step_id)
324
+ # rollback user changes
325
+ await session.rollback()
326
+ step.status = StepStatus.FAILED
327
+ step.error = {
328
+ "type": type(e).__name__,
329
+ "message": str(e),
330
+ "traceback": str(traceback.format_exc()),
331
+ }
332
+ # rollback would have removed the added step (if it was new),
333
+ # so we use `merge` as an "insert or update"
334
+ await session.merge(step)
335
+ await session.commit()
336
+ step_failed(step)
337
+
338
+ if max_retries < 0 or step.retry_count < max_retries:
339
+ logger.info(
340
+ "step failed, will suspend for retry",
341
+ step_name=name,
342
+ step_id=step_id,
343
+ error=str(e),
344
+ )
345
+ # This step is going to be retried, so we will suspend the workflow
346
+ # TODO add configurable backoff delay
347
+ await suspend_workflow(interval=timedelta(seconds=5))
348
+
349
+ raise e
350
+ finally:
351
+ ctx.step_stack.pop()
352
+ logger.debug(
353
+ "step popped from stack",
354
+ step_name=name,
355
+ step_id=step_id,
356
+ stack_size=len(ctx.step_stack),
357
+ )
358
+
359
+ return wrapper
360
+
361
+ return decorator
362
+
363
+
364
+ @_step()
365
+ async def suspend(
366
+ *, interval: timedelta | None = None, wakeup_at: datetime | None = None
367
+ ):
368
+ ctx = get_context()
369
+ step = ctx.step_stack[-1]
370
+ session = get_session()
371
+ step.status = StepStatus.SUCCEEDED
372
+ await session.merge(step)
373
+ await suspend_workflow(wakeup_at=wakeup_at, interval=interval)