planar 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (289) hide show
  1. planar/.__init__.py.un~ +0 -0
  2. planar/._version.py.un~ +0 -0
  3. planar/.app.py.un~ +0 -0
  4. planar/.cli.py.un~ +0 -0
  5. planar/.config.py.un~ +0 -0
  6. planar/.context.py.un~ +0 -0
  7. planar/.db.py.un~ +0 -0
  8. planar/.di.py.un~ +0 -0
  9. planar/.engine.py.un~ +0 -0
  10. planar/.files.py.un~ +0 -0
  11. planar/.log_context.py.un~ +0 -0
  12. planar/.log_metadata.py.un~ +0 -0
  13. planar/.logging.py.un~ +0 -0
  14. planar/.object_registry.py.un~ +0 -0
  15. planar/.otel.py.un~ +0 -0
  16. planar/.server.py.un~ +0 -0
  17. planar/.session.py.un~ +0 -0
  18. planar/.sqlalchemy.py.un~ +0 -0
  19. planar/.task_local.py.un~ +0 -0
  20. planar/.test_app.py.un~ +0 -0
  21. planar/.test_config.py.un~ +0 -0
  22. planar/.test_object_config.py.un~ +0 -0
  23. planar/.test_sqlalchemy.py.un~ +0 -0
  24. planar/.test_utils.py.un~ +0 -0
  25. planar/.util.py.un~ +0 -0
  26. planar/.utils.py.un~ +0 -0
  27. planar/__init__.py +26 -0
  28. planar/_version.py +1 -0
  29. planar/ai/.__init__.py.un~ +0 -0
  30. planar/ai/._models.py.un~ +0 -0
  31. planar/ai/.agent.py.un~ +0 -0
  32. planar/ai/.agent_utils.py.un~ +0 -0
  33. planar/ai/.events.py.un~ +0 -0
  34. planar/ai/.files.py.un~ +0 -0
  35. planar/ai/.models.py.un~ +0 -0
  36. planar/ai/.providers.py.un~ +0 -0
  37. planar/ai/.pydantic_ai.py.un~ +0 -0
  38. planar/ai/.pydantic_ai_agent.py.un~ +0 -0
  39. planar/ai/.pydantic_ai_provider.py.un~ +0 -0
  40. planar/ai/.step.py.un~ +0 -0
  41. planar/ai/.test_agent.py.un~ +0 -0
  42. planar/ai/.test_agent_serialization.py.un~ +0 -0
  43. planar/ai/.test_providers.py.un~ +0 -0
  44. planar/ai/.utils.py.un~ +0 -0
  45. planar/ai/__init__.py +15 -0
  46. planar/ai/agent.py +457 -0
  47. planar/ai/agent_utils.py +205 -0
  48. planar/ai/models.py +140 -0
  49. planar/ai/providers.py +1088 -0
  50. planar/ai/test_agent.py +1298 -0
  51. planar/ai/test_agent_serialization.py +229 -0
  52. planar/ai/test_providers.py +463 -0
  53. planar/ai/utils.py +102 -0
  54. planar/app.py +494 -0
  55. planar/cli.py +282 -0
  56. planar/config.py +544 -0
  57. planar/db/.db.py.un~ +0 -0
  58. planar/db/__init__.py +17 -0
  59. planar/db/alembic/env.py +136 -0
  60. planar/db/alembic/script.py.mako +28 -0
  61. planar/db/alembic/versions/3476068c153c_initial_system_tables_migration.py +339 -0
  62. planar/db/alembic.ini +128 -0
  63. planar/db/db.py +318 -0
  64. planar/files/.config.py.un~ +0 -0
  65. planar/files/.local.py.un~ +0 -0
  66. planar/files/.local_filesystem.py.un~ +0 -0
  67. planar/files/.model.py.un~ +0 -0
  68. planar/files/.models.py.un~ +0 -0
  69. planar/files/.s3.py.un~ +0 -0
  70. planar/files/.storage.py.un~ +0 -0
  71. planar/files/.test_files.py.un~ +0 -0
  72. planar/files/__init__.py +2 -0
  73. planar/files/models.py +162 -0
  74. planar/files/storage/.__init__.py.un~ +0 -0
  75. planar/files/storage/.base.py.un~ +0 -0
  76. planar/files/storage/.config.py.un~ +0 -0
  77. planar/files/storage/.context.py.un~ +0 -0
  78. planar/files/storage/.local_directory.py.un~ +0 -0
  79. planar/files/storage/.test_local_directory.py.un~ +0 -0
  80. planar/files/storage/.test_s3.py.un~ +0 -0
  81. planar/files/storage/base.py +61 -0
  82. planar/files/storage/config.py +44 -0
  83. planar/files/storage/context.py +15 -0
  84. planar/files/storage/local_directory.py +188 -0
  85. planar/files/storage/s3.py +220 -0
  86. planar/files/storage/test_local_directory.py +162 -0
  87. planar/files/storage/test_s3.py +299 -0
  88. planar/files/test_files.py +283 -0
  89. planar/human/.human.py.un~ +0 -0
  90. planar/human/.test_human.py.un~ +0 -0
  91. planar/human/__init__.py +2 -0
  92. planar/human/human.py +458 -0
  93. planar/human/models.py +80 -0
  94. planar/human/test_human.py +385 -0
  95. planar/logging/.__init__.py.un~ +0 -0
  96. planar/logging/.attributes.py.un~ +0 -0
  97. planar/logging/.formatter.py.un~ +0 -0
  98. planar/logging/.logger.py.un~ +0 -0
  99. planar/logging/.otel.py.un~ +0 -0
  100. planar/logging/.tracer.py.un~ +0 -0
  101. planar/logging/__init__.py +10 -0
  102. planar/logging/attributes.py +54 -0
  103. planar/logging/context.py +14 -0
  104. planar/logging/formatter.py +113 -0
  105. planar/logging/logger.py +114 -0
  106. planar/logging/otel.py +51 -0
  107. planar/modeling/.mixin.py.un~ +0 -0
  108. planar/modeling/.storage.py.un~ +0 -0
  109. planar/modeling/__init__.py +0 -0
  110. planar/modeling/field_helpers.py +59 -0
  111. planar/modeling/json_schema_generator.py +94 -0
  112. planar/modeling/mixins/__init__.py +10 -0
  113. planar/modeling/mixins/auditable.py +52 -0
  114. planar/modeling/mixins/test_auditable.py +97 -0
  115. planar/modeling/mixins/test_timestamp.py +134 -0
  116. planar/modeling/mixins/test_uuid_primary_key.py +52 -0
  117. planar/modeling/mixins/timestamp.py +53 -0
  118. planar/modeling/mixins/uuid_primary_key.py +19 -0
  119. planar/modeling/orm/.planar_base_model.py.un~ +0 -0
  120. planar/modeling/orm/__init__.py +18 -0
  121. planar/modeling/orm/planar_base_entity.py +29 -0
  122. planar/modeling/orm/query_filter_builder.py +122 -0
  123. planar/modeling/orm/reexports.py +15 -0
  124. planar/object_config/.object_config.py.un~ +0 -0
  125. planar/object_config/__init__.py +11 -0
  126. planar/object_config/models.py +114 -0
  127. planar/object_config/object_config.py +378 -0
  128. planar/object_registry.py +100 -0
  129. planar/registry_items.py +65 -0
  130. planar/routers/.__init__.py.un~ +0 -0
  131. planar/routers/.agents_router.py.un~ +0 -0
  132. planar/routers/.crud.py.un~ +0 -0
  133. planar/routers/.decision.py.un~ +0 -0
  134. planar/routers/.event.py.un~ +0 -0
  135. planar/routers/.file_attachment.py.un~ +0 -0
  136. planar/routers/.files.py.un~ +0 -0
  137. planar/routers/.files_router.py.un~ +0 -0
  138. planar/routers/.human.py.un~ +0 -0
  139. planar/routers/.info.py.un~ +0 -0
  140. planar/routers/.models.py.un~ +0 -0
  141. planar/routers/.object_config_router.py.un~ +0 -0
  142. planar/routers/.rule.py.un~ +0 -0
  143. planar/routers/.test_object_config_router.py.un~ +0 -0
  144. planar/routers/.test_workflow_router.py.un~ +0 -0
  145. planar/routers/.workflow.py.un~ +0 -0
  146. planar/routers/__init__.py +13 -0
  147. planar/routers/agents_router.py +197 -0
  148. planar/routers/entity_router.py +143 -0
  149. planar/routers/event.py +91 -0
  150. planar/routers/files.py +142 -0
  151. planar/routers/human.py +151 -0
  152. planar/routers/info.py +131 -0
  153. planar/routers/models.py +170 -0
  154. planar/routers/object_config_router.py +133 -0
  155. planar/routers/rule.py +108 -0
  156. planar/routers/test_agents_router.py +174 -0
  157. planar/routers/test_object_config_router.py +367 -0
  158. planar/routers/test_routes_security.py +169 -0
  159. planar/routers/test_rule_router.py +470 -0
  160. planar/routers/test_workflow_router.py +274 -0
  161. planar/routers/workflow.py +468 -0
  162. planar/rules/.decorator.py.un~ +0 -0
  163. planar/rules/.runner.py.un~ +0 -0
  164. planar/rules/.test_rules.py.un~ +0 -0
  165. planar/rules/__init__.py +23 -0
  166. planar/rules/decorator.py +184 -0
  167. planar/rules/models.py +355 -0
  168. planar/rules/rule_configuration.py +191 -0
  169. planar/rules/runner.py +64 -0
  170. planar/rules/test_rules.py +750 -0
  171. planar/scaffold_templates/app/__init__.py.j2 +0 -0
  172. planar/scaffold_templates/app/db/entities.py.j2 +11 -0
  173. planar/scaffold_templates/app/flows/process_invoice.py.j2 +67 -0
  174. planar/scaffold_templates/main.py.j2 +13 -0
  175. planar/scaffold_templates/planar.dev.yaml.j2 +34 -0
  176. planar/scaffold_templates/planar.prod.yaml.j2 +28 -0
  177. planar/scaffold_templates/pyproject.toml.j2 +10 -0
  178. planar/security/.jwt_middleware.py.un~ +0 -0
  179. planar/security/auth_context.py +148 -0
  180. planar/security/authorization.py +388 -0
  181. planar/security/default_policies.cedar +77 -0
  182. planar/security/jwt_middleware.py +116 -0
  183. planar/security/security_context.py +18 -0
  184. planar/security/tests/test_authorization_context.py +78 -0
  185. planar/security/tests/test_cedar_basics.py +41 -0
  186. planar/security/tests/test_cedar_policies.py +158 -0
  187. planar/security/tests/test_jwt_principal_context.py +179 -0
  188. planar/session.py +40 -0
  189. planar/sse/.constants.py.un~ +0 -0
  190. planar/sse/.example.html.un~ +0 -0
  191. planar/sse/.hub.py.un~ +0 -0
  192. planar/sse/.model.py.un~ +0 -0
  193. planar/sse/.proxy.py.un~ +0 -0
  194. planar/sse/constants.py +1 -0
  195. planar/sse/example.html +126 -0
  196. planar/sse/hub.py +216 -0
  197. planar/sse/model.py +8 -0
  198. planar/sse/proxy.py +257 -0
  199. planar/task_local.py +37 -0
  200. planar/test_app.py +51 -0
  201. planar/test_cli.py +372 -0
  202. planar/test_config.py +512 -0
  203. planar/test_object_config.py +527 -0
  204. planar/test_object_registry.py +14 -0
  205. planar/test_sqlalchemy.py +158 -0
  206. planar/test_utils.py +105 -0
  207. planar/testing/.client.py.un~ +0 -0
  208. planar/testing/.memory_storage.py.un~ +0 -0
  209. planar/testing/.planar_test_client.py.un~ +0 -0
  210. planar/testing/.predictable_tracer.py.un~ +0 -0
  211. planar/testing/.synchronizable_tracer.py.un~ +0 -0
  212. planar/testing/.test_memory_storage.py.un~ +0 -0
  213. planar/testing/.workflow_observer.py.un~ +0 -0
  214. planar/testing/__init__.py +0 -0
  215. planar/testing/memory_storage.py +78 -0
  216. planar/testing/planar_test_client.py +54 -0
  217. planar/testing/synchronizable_tracer.py +153 -0
  218. planar/testing/test_memory_storage.py +143 -0
  219. planar/testing/workflow_observer.py +73 -0
  220. planar/utils.py +70 -0
  221. planar/workflows/.__init__.py.un~ +0 -0
  222. planar/workflows/.builtin_steps.py.un~ +0 -0
  223. planar/workflows/.concurrency_tracing.py.un~ +0 -0
  224. planar/workflows/.context.py.un~ +0 -0
  225. planar/workflows/.contrib.py.un~ +0 -0
  226. planar/workflows/.decorators.py.un~ +0 -0
  227. planar/workflows/.durable_test.py.un~ +0 -0
  228. planar/workflows/.errors.py.un~ +0 -0
  229. planar/workflows/.events.py.un~ +0 -0
  230. planar/workflows/.exceptions.py.un~ +0 -0
  231. planar/workflows/.execution.py.un~ +0 -0
  232. planar/workflows/.human.py.un~ +0 -0
  233. planar/workflows/.lock.py.un~ +0 -0
  234. planar/workflows/.misc.py.un~ +0 -0
  235. planar/workflows/.model.py.un~ +0 -0
  236. planar/workflows/.models.py.un~ +0 -0
  237. planar/workflows/.notifications.py.un~ +0 -0
  238. planar/workflows/.orchestrator.py.un~ +0 -0
  239. planar/workflows/.runtime.py.un~ +0 -0
  240. planar/workflows/.serialization.py.un~ +0 -0
  241. planar/workflows/.step.py.un~ +0 -0
  242. planar/workflows/.step_core.py.un~ +0 -0
  243. planar/workflows/.sub_workflow_runner.py.un~ +0 -0
  244. planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
  245. planar/workflows/.test_concurrency.py.un~ +0 -0
  246. planar/workflows/.test_concurrency_detection.py.un~ +0 -0
  247. planar/workflows/.test_human.py.un~ +0 -0
  248. planar/workflows/.test_lock_timeout.py.un~ +0 -0
  249. planar/workflows/.test_orchestrator.py.un~ +0 -0
  250. planar/workflows/.test_race_conditions.py.un~ +0 -0
  251. planar/workflows/.test_serialization.py.un~ +0 -0
  252. planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
  253. planar/workflows/.test_workflow.py.un~ +0 -0
  254. planar/workflows/.tracing.py.un~ +0 -0
  255. planar/workflows/.types.py.un~ +0 -0
  256. planar/workflows/.util.py.un~ +0 -0
  257. planar/workflows/.utils.py.un~ +0 -0
  258. planar/workflows/.workflow.py.un~ +0 -0
  259. planar/workflows/.workflow_wrapper.py.un~ +0 -0
  260. planar/workflows/.wrappers.py.un~ +0 -0
  261. planar/workflows/__init__.py +42 -0
  262. planar/workflows/context.py +44 -0
  263. planar/workflows/contrib.py +190 -0
  264. planar/workflows/decorators.py +217 -0
  265. planar/workflows/events.py +185 -0
  266. planar/workflows/exceptions.py +34 -0
  267. planar/workflows/execution.py +198 -0
  268. planar/workflows/lock.py +229 -0
  269. planar/workflows/misc.py +5 -0
  270. planar/workflows/models.py +154 -0
  271. planar/workflows/notifications.py +96 -0
  272. planar/workflows/orchestrator.py +383 -0
  273. planar/workflows/query.py +256 -0
  274. planar/workflows/serialization.py +409 -0
  275. planar/workflows/step_core.py +373 -0
  276. planar/workflows/step_metadata.py +357 -0
  277. planar/workflows/step_testing_utils.py +86 -0
  278. planar/workflows/sub_workflow_runner.py +191 -0
  279. planar/workflows/test_concurrency_detection.py +120 -0
  280. planar/workflows/test_lock_timeout.py +140 -0
  281. planar/workflows/test_serialization.py +1195 -0
  282. planar/workflows/test_suspend_deserialization.py +231 -0
  283. planar/workflows/test_workflow.py +1967 -0
  284. planar/workflows/tracing.py +106 -0
  285. planar/workflows/wrappers.py +41 -0
  286. planar-0.5.0.dist-info/METADATA +285 -0
  287. planar-0.5.0.dist-info/RECORD +289 -0
  288. planar-0.5.0.dist-info/WHEEL +4 -0
  289. planar-0.5.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1967 @@
1
+ import asyncio
2
+ import gc
3
+ import json
4
+ from collections import defaultdict
5
+ from datetime import datetime, timedelta
6
+ from decimal import Decimal
7
+ from uuid import UUID
8
+
9
+ import pytest
10
+ from freezegun import freeze_time
11
+ from pydantic import BaseModel
12
+ from sqlmodel import col, select
13
+ from sqlmodel.ext.asyncio.session import AsyncSession
14
+
15
+ from planar.session import get_session
16
+ from planar.testing.workflow_observer import WorkflowObserver
17
+ from planar.utils import utc_now
18
+ from planar.workflows.decorators import (
19
+ __AS_STEP_CACHE,
20
+ __is_workflow_step,
21
+ as_step,
22
+ step,
23
+ workflow,
24
+ )
25
+ from planar.workflows.exceptions import NonDeterministicStepCallError
26
+ from planar.workflows.execution import execute, lock_and_execute
27
+ from planar.workflows.models import (
28
+ StepStatus,
29
+ StepType,
30
+ Workflow,
31
+ WorkflowStatus,
32
+ WorkflowStep,
33
+ )
34
+ from planar.workflows.notifications import Notification, workflow_notification_context
35
+ from planar.workflows.orchestrator import WorkflowOrchestrator
36
+ from planar.workflows.step_core import (
37
+ Suspend,
38
+ suspend,
39
+ )
40
+ from planar.workflows.step_testing_utils import (
41
+ get_step_ancestors,
42
+ get_step_children,
43
+ get_step_descendants,
44
+ get_step_parent,
45
+ )
46
+
47
+
48
+ # =============================================================================
49
+ # Test 1 – Basic Workflow Lifecycle
50
+ # =============================================================================
51
+ async def test_workflow_lifecycle(session: AsyncSession):
52
+ @workflow()
53
+ async def sample_workflow():
54
+ return "success"
55
+
56
+ wf = await sample_workflow.start()
57
+ await execute(wf)
58
+ updated_wf = await session.get(Workflow, wf.id)
59
+ assert updated_wf is not None
60
+ assert updated_wf.status == WorkflowStatus.SUCCEEDED
61
+ assert updated_wf.result == "success"
62
+
63
+
64
+ # =============================================================================
65
+ # Test 2 – Session Context Is Set
66
+ # =============================================================================
67
+ async def test_session_context_is_set(session: AsyncSession):
68
+ @workflow()
69
+ async def session_workflow():
70
+ s = get_session()
71
+ # Ensure that the session returned is the one we set from the fixture.
72
+ assert s is session
73
+ return "success"
74
+
75
+ wf = await session_workflow.start()
76
+ await execute(wf)
77
+ updated_wf = await session.get(Workflow, wf.id)
78
+ assert updated_wf
79
+ assert updated_wf.status == WorkflowStatus.SUCCEEDED
80
+ assert updated_wf.result == "success"
81
+
82
+
83
+ # =============================================================================
84
+ # Test 3 – Step Execution and Tracking
85
+ # =============================================================================
86
+ async def test_step_execution(session: AsyncSession):
87
+ @step()
88
+ async def step1():
89
+ return "step1_result"
90
+
91
+ @step()
92
+ async def step2():
93
+ return "step2_result"
94
+
95
+ @workflow()
96
+ async def multistep_workflow():
97
+ await step1()
98
+ await step2()
99
+ return "done"
100
+
101
+ wf = await multistep_workflow.start()
102
+ await execute(wf)
103
+
104
+ steps = (
105
+ await session.exec(
106
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
107
+ )
108
+ ).all()
109
+ assert len(steps) == 2
110
+ fnames = {s.function_name.split(".")[-1] for s in steps}
111
+ assert "step1" in fnames
112
+ assert "step2" in fnames
113
+ for s in steps:
114
+ assert s.status == StepStatus.SUCCEEDED
115
+ assert s.workflow_id == wf.id
116
+
117
+
118
+ # =============================================================================
119
+ # Test 4 – Step Error Handling
120
+ # =============================================================================
121
+ async def test_step_error_handling(session: AsyncSession):
122
+ @step()
123
+ async def failing_step():
124
+ raise ValueError("Intentional failure")
125
+
126
+ @workflow()
127
+ async def error_workflow():
128
+ await failing_step()
129
+ return "done"
130
+
131
+ wf = await error_workflow.start()
132
+ with pytest.raises(ValueError, match="Intentional failure"):
133
+ await execute(wf)
134
+
135
+ updated_wf = await session.get(Workflow, wf.id)
136
+ assert updated_wf
137
+ assert updated_wf.status == WorkflowStatus.FAILED
138
+ step_entry = (
139
+ await session.exec(
140
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
141
+ )
142
+ ).one()
143
+ assert step_entry.error is not None
144
+ assert "Intentional failure" in step_entry.error["message"]
145
+
146
+
147
+ # =============================================================================
148
+ # Test 5 – Workflow Resumption (Retry a Failing Step)
149
+ # =============================================================================
150
+ async def test_workflow_resumption(session: AsyncSession):
151
+ should_fail = True
152
+
153
+ @step(max_retries=1)
154
+ async def dynamic_step():
155
+ nonlocal should_fail
156
+ if should_fail:
157
+ raise RuntimeError("Temporary failure")
158
+ return "done"
159
+
160
+ @workflow()
161
+ async def resumable_workflow():
162
+ return await dynamic_step()
163
+
164
+ wf = await resumable_workflow.start()
165
+ # First execution should suspend (i.e. return a Suspend object) because of failure.
166
+ result1 = await execute(wf)
167
+ assert isinstance(result1, Suspend)
168
+ updated_wf = await session.get(Workflow, wf.id)
169
+ assert updated_wf
170
+ assert updated_wf.status == WorkflowStatus.PENDING
171
+
172
+ # Fix the error and resume.
173
+ should_fail = False
174
+ result2 = await execute(wf)
175
+ updated_wf = await session.get(Workflow, wf.id)
176
+ assert updated_wf
177
+ assert updated_wf.status == WorkflowStatus.SUCCEEDED
178
+ assert result2 == "done"
179
+
180
+
181
+ # =============================================================================
182
+ # Test 6 – Input Data Persistence
183
+ # =============================================================================
184
+ async def test_input_data_persistence(session: AsyncSession):
185
+ @workflow()
186
+ async def data_workflow(a: int, b: int):
187
+ return a + b
188
+
189
+ wf = await data_workflow.start(10, 20)
190
+ await execute(wf)
191
+ updated_wf = await session.get(Workflow, wf.id)
192
+ assert updated_wf
193
+ assert updated_wf.args == [10, 20]
194
+ assert updated_wf.kwargs == {}
195
+ assert updated_wf.result == 30
196
+
197
+
198
+ # =============================================================================
199
+ # Test 7 – Completed Workflow Resumption
200
+ # =============================================================================
201
+ async def test_completed_workflow_resumption(session: AsyncSession):
202
+ @workflow()
203
+ async def completed_workflow():
204
+ return "final_result"
205
+
206
+ wf = await completed_workflow.start()
207
+ result1 = await execute(wf)
208
+ result2 = await execute(wf)
209
+ assert result1 == "final_result"
210
+ assert result2 == "final_result"
211
+
212
+
213
+ # =============================================================================
214
+ # Test 8 – Step Idempotency
215
+ # =============================================================================
216
+ async def test_step_idempotency(session: AsyncSession):
217
+ execution_count = 0
218
+
219
+ @step()
220
+ async def idempotent_step():
221
+ nonlocal execution_count
222
+ execution_count += 1
223
+ return "idempotent"
224
+
225
+ @workflow()
226
+ async def idempotent_workflow():
227
+ await idempotent_step()
228
+ return "done"
229
+
230
+ wf = await idempotent_workflow.start()
231
+ await execute(wf)
232
+ # On resumption the step should not run again.
233
+ await execute(wf)
234
+ assert execution_count == 1
235
+
236
+
237
+ # =============================================================================
238
+ # Test 9 – Error Traceback Storage (Adjusted)
239
+ # =============================================================================
240
+ async def test_error_traceback_storage(session: AsyncSession):
241
+ @step()
242
+ async def error_step():
243
+ raise ValueError("Error with traceback")
244
+
245
+ @workflow()
246
+ async def traceback_workflow():
247
+ await error_step()
248
+ return "done"
249
+
250
+ wf = await traceback_workflow.start()
251
+ with pytest.raises(ValueError, match="Error with traceback"):
252
+ await execute(wf)
253
+
254
+ step_entry = (
255
+ await session.exec(
256
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
257
+ )
258
+ ).one()
259
+ assert step_entry.error is not None
260
+ # The new engine does not store full tracebacks, so we check only the error message.
261
+ assert "Error with traceback" in step_entry.error["message"]
262
+
263
+
264
+ # =============================================================================
265
+ # Test 10 – Empty Workflow (No Steps)
266
+ # =============================================================================
267
+ async def test_empty_workflow(session: AsyncSession):
268
+ @workflow()
269
+ async def empty_workflow():
270
+ return "direct_result"
271
+
272
+ wf = await empty_workflow.start()
273
+ await execute(wf)
274
+ updated_wf = await session.get(Workflow, wf.id)
275
+ assert updated_wf
276
+ assert updated_wf.status == WorkflowStatus.SUCCEEDED
277
+ assert updated_wf.result == "direct_result"
278
+ # Verify no DurableStep records were created.
279
+ step_entry = (
280
+ await session.exec(
281
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
282
+ )
283
+ ).first()
284
+ assert step_entry is None
285
+
286
+
287
+ # =============================================================================
288
+ # Test 11 – Complex Workflow with Retries and Data Persistence
289
+ # =============================================================================
290
+ async def test_complex_workflow_with_retries_and_data_persistence(
291
+ session: AsyncSession,
292
+ ):
293
+ step1_attempts = 0
294
+ step2_attempts = 0
295
+ step3_attempts = 0
296
+
297
+ @step(max_retries=1)
298
+ async def step1(input_val: int) -> int:
299
+ nonlocal step1_attempts
300
+ step1_attempts += 1
301
+ if step1_attempts == 1:
302
+ raise RuntimeError("Step 1 temporary failure")
303
+ return input_val + 10
304
+
305
+ @step(max_retries=1)
306
+ async def step2(input_val: int) -> int:
307
+ nonlocal step2_attempts
308
+ step2_attempts += 1
309
+ if step2_attempts == 1:
310
+ raise RuntimeError("Step 2 temporary failure")
311
+ return input_val * 2
312
+
313
+ @step(max_retries=1)
314
+ async def step3(input_val: int) -> int:
315
+ nonlocal step3_attempts
316
+ step3_attempts += 1
317
+ if step3_attempts == 1:
318
+ raise RuntimeError("Step 3 temporary failure")
319
+ return input_val - 5
320
+
321
+ @workflow()
322
+ async def chained_workflow(initial_input: int) -> int:
323
+ r1 = await step1(initial_input)
324
+ r2 = await step2(r1)
325
+ r3 = await step3(r2)
326
+ return r3
327
+
328
+ wf = await chained_workflow.start(5)
329
+ # First run: step1 fails → workflow suspended.
330
+ await execute(wf)
331
+ updated_wf = await session.get(Workflow, wf.id)
332
+ assert updated_wf
333
+ assert updated_wf.status == WorkflowStatus.PENDING
334
+ step1_entry = (
335
+ await session.exec(
336
+ select(WorkflowStep)
337
+ .where(WorkflowStep.workflow_id == wf.id)
338
+ .where(col(WorkflowStep.function_name).like("%step1%"))
339
+ )
340
+ ).one()
341
+ assert step1_entry.status == StepStatus.FAILED
342
+ assert step1_attempts == 1
343
+
344
+ # Second run: step1 succeeds, step2 fails.
345
+ await execute(wf)
346
+ updated_wf = await session.get(Workflow, wf.id)
347
+ step1_entry = (
348
+ await session.exec(
349
+ select(WorkflowStep)
350
+ .where(WorkflowStep.workflow_id == wf.id)
351
+ .where(col(WorkflowStep.function_name).like("%step1%"))
352
+ )
353
+ ).one()
354
+ assert step1_entry.status == StepStatus.SUCCEEDED
355
+ assert step1_entry.result == 15 # 5 + 10
356
+ assert step1_attempts == 2
357
+ step2_entry = (
358
+ await session.exec(
359
+ select(WorkflowStep)
360
+ .where(WorkflowStep.workflow_id == wf.id)
361
+ .where(col(WorkflowStep.function_name).like("%step2%"))
362
+ )
363
+ ).one()
364
+ assert step2_entry.status == StepStatus.FAILED
365
+ assert step2_attempts == 1
366
+
367
+ # Third run: step2 succeeds, step3 fails.
368
+ await execute(wf)
369
+ updated_wf = await session.get(Workflow, wf.id)
370
+ step2_entry = (
371
+ await session.exec(
372
+ select(WorkflowStep)
373
+ .where(WorkflowStep.workflow_id == wf.id)
374
+ .where(col(WorkflowStep.function_name).like("%step2%"))
375
+ )
376
+ ).one()
377
+ assert step2_entry.status == StepStatus.SUCCEEDED
378
+ assert step2_entry.result == 30 # 15 * 2
379
+ assert step2_attempts == 2
380
+ step3_entry = (
381
+ await session.exec(
382
+ select(WorkflowStep)
383
+ .where(WorkflowStep.workflow_id == wf.id)
384
+ .where(col(WorkflowStep.function_name).like("%step3%"))
385
+ )
386
+ ).one()
387
+ assert step3_entry.status == StepStatus.FAILED
388
+ assert step3_attempts == 1
389
+
390
+ # Fourth run: step3 succeeds → final result.
391
+ final_result = await execute(wf)
392
+ updated_wf = await session.get(Workflow, wf.id)
393
+ assert updated_wf
394
+ assert updated_wf.status == WorkflowStatus.SUCCEEDED
395
+ assert final_result == 25 # 30 - 5
396
+ assert updated_wf.result == 25
397
+ step3_entry = (
398
+ await session.exec(
399
+ select(WorkflowStep)
400
+ .where(WorkflowStep.workflow_id == wf.id)
401
+ .where(col(WorkflowStep.function_name).like("%step3%"))
402
+ )
403
+ ).one()
404
+ assert step3_entry.status == StepStatus.SUCCEEDED
405
+ assert step3_entry.result == 25
406
+ assert step3_attempts == 2
407
+
408
+ # Verify workflow input data persistence.
409
+ assert updated_wf.args == [5]
410
+ assert updated_wf.kwargs == {}
411
+
412
+
413
+ # =============================================================================
414
+ # Test 12 – Step Retries
415
+ # =============================================================================
416
+ async def test_step_retries(session: AsyncSession):
417
+ retry_limit = 3
418
+ attempt_count = 0
419
+
420
+ @step(max_retries=retry_limit)
421
+ async def retry_step():
422
+ nonlocal attempt_count
423
+ attempt_count += 1
424
+ raise RuntimeError("Temporary failure")
425
+
426
+ @workflow()
427
+ async def retry_workflow():
428
+ await retry_step()
429
+ return "done"
430
+
431
+ wf = await retry_workflow.start()
432
+
433
+ # Attempt 1
434
+ await execute(wf)
435
+ updated_wf = await session.get(Workflow, wf.id)
436
+ assert updated_wf
437
+ assert updated_wf.status == WorkflowStatus.PENDING
438
+ step_entry = (
439
+ await session.exec(
440
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
441
+ )
442
+ ).one()
443
+ assert step_entry.retry_count == 0
444
+ assert attempt_count == 1
445
+
446
+ # Attempt 2
447
+ await execute(wf)
448
+ step_entry = (
449
+ await session.exec(
450
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
451
+ )
452
+ ).one()
453
+ assert step_entry.retry_count == 1
454
+ assert attempt_count == 2
455
+
456
+ # Attempt 3
457
+ await execute(wf)
458
+ step_entry = (
459
+ await session.exec(
460
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
461
+ )
462
+ ).one()
463
+ assert step_entry.retry_count == 2
464
+ assert attempt_count == 3
465
+
466
+ # Attempt 4 – exceed retries so that execution raises.
467
+ with pytest.raises(RuntimeError, match="Temporary failure"):
468
+ await execute(wf)
469
+ step_entry = (
470
+ await session.exec(
471
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
472
+ )
473
+ ).one()
474
+ assert step_entry.retry_count == 3
475
+ assert attempt_count == 4
476
+
477
+ # Further execution should not increment attempts.
478
+ with pytest.raises(RuntimeError, match="Temporary failure"):
479
+ await execute(wf)
480
+ step_entry = (
481
+ await session.exec(
482
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
483
+ )
484
+ ).one()
485
+ assert step_entry.retry_count == 3
486
+ assert attempt_count == 4
487
+
488
+
489
+ # =============================================================================
490
+ # Test 13 – Looped Step Execution
491
+ # =============================================================================
492
+ async def test_looped_step_execution(session: AsyncSession):
493
+ loop_count = 3
494
+
495
+ @step()
496
+ async def say_hello_step():
497
+ return "hello"
498
+
499
+ @workflow()
500
+ async def looped_workflow(count: int):
501
+ for _ in range(count):
502
+ await say_hello_step()
503
+ return "done"
504
+
505
+ wf = await looped_workflow.start(loop_count)
506
+ await execute(wf)
507
+
508
+ steps = (
509
+ await session.exec(
510
+ select(WorkflowStep)
511
+ .where(WorkflowStep.workflow_id == wf.id)
512
+ .order_by(col(WorkflowStep.step_id))
513
+ )
514
+ ).all()
515
+ assert len(steps) == loop_count
516
+ for i, s in enumerate(steps, start=1):
517
+ assert s.function_name.split(".")[-1] == "say_hello_step"
518
+ assert s.step_id == i
519
+ assert s.status == StepStatus.SUCCEEDED
520
+
521
+
522
+ # =============================================================================
523
+ # Test 14 – Basic Sleep Functionality
524
+ # =============================================================================
525
+ async def test_basic_sleep_functionality(session: AsyncSession):
526
+ with freeze_time("2024-01-01 00:00:00") as frozen_time:
527
+
528
+ @workflow()
529
+ async def sleeping_workflow():
530
+ await suspend(interval=timedelta(seconds=10))
531
+ return "awake"
532
+
533
+ wf = await sleeping_workflow.start()
534
+ result = await execute(wf)
535
+ updated_wf = await session.get(Workflow, wf.id)
536
+ assert updated_wf
537
+ # The suspend step should have returned a Suspend object.
538
+ assert isinstance(result, Suspend)
539
+ assert updated_wf.status == WorkflowStatus.PENDING
540
+ expected_wakeup = datetime(2024, 1, 1, 0, 0, 10)
541
+ assert updated_wf.wakeup_at == expected_wakeup
542
+
543
+ # Check that the suspend step record has function_name 'suspend'
544
+ sleep_step = (
545
+ await session.exec(
546
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
547
+ )
548
+ ).one()
549
+ assert sleep_step.function_name.split(".")[-1] == "suspend"
550
+
551
+ # Move time forward and resume.
552
+ frozen_time.move_to("2024-01-01 00:00:11")
553
+ final_result = await execute(wf)
554
+ assert final_result == "awake"
555
+ assert updated_wf.status == WorkflowStatus.SUCCEEDED
556
+
557
+
558
+ # =============================================================================
559
+ # Test 15 – Worker Skips Sleeping Workflows
560
+ # =============================================================================
561
+ async def test_worker_skips_sleeping_workflows(session: AsyncSession):
562
+ @workflow()
563
+ async def sleeping_workflow():
564
+ await suspend(interval=timedelta(minutes=5))
565
+ return "done"
566
+
567
+ wf = await sleeping_workflow.start()
568
+ # Execute once to suspend.
569
+ result = await execute(wf)
570
+ assert isinstance(result, Suspend)
571
+
572
+ # Simulate the worker’s query for ready workflows.
573
+ ready_wfs = (
574
+ await session.exec(
575
+ select(Workflow)
576
+ .where(Workflow.status == WorkflowStatus.PENDING)
577
+ .where(col(Workflow.wakeup_at) <= utc_now())
578
+ )
579
+ ).all()
580
+ # At 12:00 the wakeup time (12:05) is in the future.
581
+ assert len(ready_wfs) == 0
582
+
583
+ result = await execute(wf)
584
+ assert result == "done"
585
+
586
+
587
+ # =============================================================================
588
+ # Test 16 – Multiple Sleep Steps
589
+ # =============================================================================
590
+ async def test_multiple_sleep_steps(session: AsyncSession):
591
+ @workflow()
592
+ async def multi_sleep_workflow():
593
+ await suspend(interval=timedelta(seconds=2))
594
+ await suspend(interval=timedelta(seconds=4))
595
+ return 42
596
+
597
+ start_date = utc_now()
598
+ wf = await multi_sleep_workflow.start()
599
+ assert wf
600
+ # First run: suspend for 10 seconds.
601
+ result = await execute(wf)
602
+ assert isinstance(result, Suspend)
603
+ await session.refresh(wf)
604
+ assert wf.wakeup_at
605
+ assert (wf.wakeup_at - start_date) >= timedelta(seconds=2)
606
+ assert (wf.wakeup_at - start_date) <= timedelta(seconds=3)
607
+
608
+ # Move time forward and resume.
609
+ await asyncio.sleep(2)
610
+ result = await execute(wf)
611
+ assert isinstance(result, Suspend)
612
+ await session.refresh(wf)
613
+ assert (wf.wakeup_at - start_date) >= timedelta(seconds=6)
614
+ assert (wf.wakeup_at - start_date) <= timedelta(seconds=7)
615
+
616
+ # Verify that two suspend steps were recorded.
617
+ sleep_steps = (
618
+ await session.exec(
619
+ select(WorkflowStep)
620
+ .where(WorkflowStep.workflow_id == wf.id)
621
+ .order_by(col(WorkflowStep.step_id))
622
+ )
623
+ ).all()
624
+ assert len(sleep_steps) == 2
625
+ assert [s.step_id for s in sleep_steps] == [1, 2]
626
+
627
+ # Final execution after second sleep.
628
+ await asyncio.sleep(4.5)
629
+ final_result = await execute(wf)
630
+ assert final_result == 42
631
+
632
+
633
+ # =============================================================================
634
+ # Test 17 – Looped Execution with Step Dependencies
635
+ # =============================================================================
636
+ async def test_looped_execution_with_step_dependencies(session: AsyncSession):
637
+ step_attempts = defaultdict(int)
638
+ expected_results = []
639
+
640
+ @step(max_retries=1)
641
+ async def process_step(input_val: int) -> int:
642
+ step_attempts[input_val] += 1
643
+ if step_attempts[input_val] == 1:
644
+ raise RuntimeError(f"Temporary failure for input {input_val}")
645
+ return input_val + 5
646
+
647
+ @workflow()
648
+ async def looped_dependency_workflow(initial: int) -> int:
649
+ nonlocal expected_results
650
+ expected_results = []
651
+ current = initial
652
+ for _ in range(3):
653
+ current = await process_step(current)
654
+ expected_results.append(current)
655
+ return current
656
+
657
+ wf = await looped_dependency_workflow.start(10)
658
+ # Run through several execution attempts until the workflow finishes.
659
+ for _ in range(6):
660
+ try:
661
+ await execute(wf)
662
+ except Exception:
663
+ pass
664
+
665
+ updated_wf = await session.get(Workflow, wf.id)
666
+ assert updated_wf
667
+ assert updated_wf.status == WorkflowStatus.SUCCEEDED
668
+ # 10 → 15 → 20 → 25
669
+ assert updated_wf.result == 25
670
+
671
+ steps = (
672
+ await session.exec(
673
+ select(WorkflowStep)
674
+ .where(WorkflowStep.workflow_id == wf.id)
675
+ .order_by(col(WorkflowStep.step_id))
676
+ )
677
+ ).all()
678
+ assert len(steps) == 3
679
+ assert all("process_step" in s.function_name for s in steps)
680
+ assert [s.result for s in steps] == [15, 20, 25]
681
+
682
+ # Each step should have retried exactly once.
683
+ for s in steps:
684
+ assert s.retry_count == 1
685
+
686
+ # Verify that error messages were recorded on failed attempts.
687
+ step_errors = (
688
+ await session.exec(
689
+ select(WorkflowStep.error)
690
+ .where(WorkflowStep.workflow_id == wf.id)
691
+ .where(WorkflowStep.status == StepStatus.SUCCEEDED)
692
+ )
693
+ ).all()
694
+ for err in step_errors:
695
+ if err:
696
+ assert "Temporary failure" in err["message"]
697
+
698
+ assert expected_results == [15, 20, 25]
699
+
700
+
701
+ async def test_handling_step_errors(session: AsyncSession):
702
+ @step(max_retries=0)
703
+ async def step1():
704
+ raise ValueError("Step 1 error")
705
+
706
+ @workflow()
707
+ async def step_try_catch_workflow():
708
+ try:
709
+ await step1()
710
+ except ValueError:
711
+ # Suspend the workflow in the except block
712
+ await suspend(interval=timedelta(seconds=5))
713
+ return "handled"
714
+ return "done"
715
+
716
+ # Start the workflow
717
+ wf = await step_try_catch_workflow.start()
718
+
719
+ # First execution: should raise ValueError in step1, catch it, call
720
+ # sleep(...) -> suspended
721
+ result = await execute(wf)
722
+ # Expect a Suspend object because the workflow is waiting
723
+ assert isinstance(result, Suspend)
724
+
725
+ updated_wf = await session.get(Workflow, wf.id)
726
+ assert updated_wf is not None
727
+ assert updated_wf.status == WorkflowStatus.PENDING
728
+ assert updated_wf.wakeup_at is not None
729
+
730
+ # Verify that two step records were created:
731
+ steps = (
732
+ await session.exec(
733
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
734
+ )
735
+ ).all()
736
+ assert len(steps) == 2
737
+ # The first step (step1) failed
738
+ assert steps[0].status == StepStatus.FAILED
739
+ assert steps[0].result is None
740
+
741
+ # --- Second execution: after wakeup time
742
+ final_result = await execute(wf)
743
+ updated_wf = await session.get(Workflow, wf.id)
744
+ assert updated_wf
745
+ # Now the workflow should resume and finish, returning "handled"
746
+ assert final_result == "handled"
747
+ assert updated_wf.status == WorkflowStatus.SUCCEEDED
748
+ assert updated_wf.result == "handled"
749
+
750
+ # Finally, verify the step records remain as expected.
751
+ steps = (
752
+ await session.exec(
753
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
754
+ )
755
+ ).all()
756
+ assert len(steps) == 2
757
+ assert steps[0].status == StepStatus.FAILED
758
+ assert steps[0].result is None
759
+ assert steps[1].status == StepStatus.SUCCEEDED
760
+ assert steps[1].error is None
761
+
762
+
763
+ async def test_exceute_properly_intercepts_coroutine(session: AsyncSession):
764
+ async def shell(cmd: str):
765
+ proc = await asyncio.create_subprocess_shell(
766
+ cmd, stdout=asyncio.subprocess.PIPE
767
+ )
768
+ stdout, _ = await proc.communicate()
769
+ return stdout.decode().strip()
770
+
771
+ @step()
772
+ async def step1():
773
+ echo_output = await non_step1()
774
+ assert echo_output == "echoing 20"
775
+ count = int(echo_output.split()[-1])
776
+ for _ in range(10):
777
+ await asyncio.sleep(0.01)
778
+ count += 1
779
+ return count
780
+
781
+ async def non_step1():
782
+ count = 0
783
+ for _ in range(10):
784
+ await asyncio.sleep(0.01)
785
+ count += 1
786
+ return await step2(count)
787
+
788
+ @step()
789
+ async def step2(count: int):
790
+ return await non_step2(count)
791
+
792
+ async def non_step2(count: int):
793
+ for _ in range(count):
794
+ await asyncio.sleep(0.01)
795
+ count += 1
796
+ return await shell(f"echo echoing {count}")
797
+
798
+ @step()
799
+ async def step3(count: int):
800
+ for _ in range(10):
801
+ await asyncio.sleep(0.01)
802
+ count += 1
803
+ return count
804
+
805
+ @workflow()
806
+ async def nested_step_and_non_step_calls():
807
+ count = await step1()
808
+ count = await step3(count)
809
+ return count
810
+
811
+ wf = await nested_step_and_non_step_calls.start()
812
+ await execute(wf)
813
+ assert wf.status == WorkflowStatus.SUCCEEDED
814
+ assert wf.result == 40
815
+
816
+ steps = (
817
+ await session.exec(
818
+ select(WorkflowStep)
819
+ .where(WorkflowStep.workflow_id == wf.id)
820
+ .order_by(col(WorkflowStep.step_id))
821
+ )
822
+ ).all()
823
+ for s in steps:
824
+ s.function_name = s.function_name.split(".")[-1]
825
+
826
+ assert all(s.status == StepStatus.SUCCEEDED for s in steps)
827
+ assert tuple(s.function_name.split(".")[-1] for s in steps) == (
828
+ "step1",
829
+ "step2",
830
+ "step3",
831
+ )
832
+ assert tuple(s.result for s in steps) == (30, "echoing 20", 40)
833
+
834
+
835
+ async def test_sub_workflows(session: AsyncSession):
836
+ @step()
837
+ async def step1(n: int) -> Decimal:
838
+ await suspend(interval=timedelta(seconds=0.1))
839
+ return Decimal(1 + n)
840
+
841
+ @step()
842
+ async def step2(n: int) -> Decimal:
843
+ await suspend(interval=timedelta(seconds=0.1))
844
+ return Decimal(2 + n)
845
+
846
+ @step()
847
+ async def step3(n: int) -> Decimal:
848
+ await suspend(interval=timedelta(seconds=0.1))
849
+ return Decimal(3 + n)
850
+
851
+ @workflow()
852
+ async def workflow1(n: int) -> Decimal:
853
+ return await step1(n)
854
+
855
+ @workflow()
856
+ async def workflow2(n: int) -> Decimal:
857
+ return await step2(n)
858
+
859
+ @workflow()
860
+ async def workflow3(n: int) -> Decimal:
861
+ return await step3(n)
862
+
863
+ @workflow()
864
+ async def call_sub_workflows() -> Decimal:
865
+ w1 = await workflow1(1)
866
+ w2 = await workflow2(2)
867
+ w3 = await workflow3(3)
868
+ assert w1 == Decimal(2)
869
+ assert w2 == Decimal(4)
870
+ assert w3 == Decimal(6)
871
+ return w1 + w2 + w3
872
+
873
+ async with WorkflowOrchestrator.ensure_started() as orchestrator:
874
+ wf = await call_sub_workflows.start()
875
+ result = await orchestrator.wait_for_completion(wf.id)
876
+
877
+ await session.refresh(wf)
878
+ assert wf.status == WorkflowStatus.SUCCEEDED
879
+ assert result == Decimal(12)
880
+
881
+ all_workflows = []
882
+ workflows = (
883
+ await session.exec(select(Workflow).order_by(col(Workflow.created_at)))
884
+ ).all()
885
+ for w in workflows:
886
+ steps = (
887
+ await session.exec(
888
+ select(WorkflowStep)
889
+ .where(col(WorkflowStep.workflow_id) == w.id)
890
+ .order_by(col(WorkflowStep.step_id))
891
+ )
892
+ ).all()
893
+ all_workflows.append(
894
+ {
895
+ "status": w.status,
896
+ "function_name": w.function_name.split(".")[-1],
897
+ "steps": [
898
+ {
899
+ "step_id": s.step_id,
900
+ "step_status": s.status,
901
+ "function_name": s.function_name.split(".")[-1],
902
+ }
903
+ for s in steps
904
+ ],
905
+ }
906
+ )
907
+
908
+ assert all_workflows == [
909
+ {
910
+ "status": WorkflowStatus.SUCCEEDED,
911
+ "function_name": "call_sub_workflows",
912
+ "steps": [
913
+ {
914
+ "step_id": 1,
915
+ "step_status": StepStatus.SUCCEEDED,
916
+ "function_name": "start_workflow_step",
917
+ },
918
+ {
919
+ "step_id": 2,
920
+ "step_status": StepStatus.SUCCEEDED,
921
+ "function_name": "start_workflow_step",
922
+ },
923
+ {
924
+ "step_id": 3,
925
+ "step_status": StepStatus.SUCCEEDED,
926
+ "function_name": "start_workflow_step",
927
+ },
928
+ ],
929
+ },
930
+ {
931
+ "status": WorkflowStatus.SUCCEEDED,
932
+ "function_name": "workflow1",
933
+ "steps": [
934
+ {
935
+ "step_id": 1,
936
+ "step_status": StepStatus.SUCCEEDED,
937
+ "function_name": "step1",
938
+ },
939
+ {
940
+ "step_id": 2,
941
+ "step_status": StepStatus.SUCCEEDED,
942
+ "function_name": "suspend",
943
+ },
944
+ ],
945
+ },
946
+ {
947
+ "status": WorkflowStatus.SUCCEEDED,
948
+ "function_name": "workflow2",
949
+ "steps": [
950
+ {
951
+ "step_id": 1,
952
+ "step_status": StepStatus.SUCCEEDED,
953
+ "function_name": "step2",
954
+ },
955
+ {
956
+ "step_id": 2,
957
+ "step_status": StepStatus.SUCCEEDED,
958
+ "function_name": "suspend",
959
+ },
960
+ ],
961
+ },
962
+ {
963
+ "status": WorkflowStatus.SUCCEEDED,
964
+ "function_name": "workflow3",
965
+ "steps": [
966
+ {
967
+ "step_id": 1,
968
+ "step_status": StepStatus.SUCCEEDED,
969
+ "function_name": "step3",
970
+ },
971
+ {
972
+ "step_id": 2,
973
+ "step_status": StepStatus.SUCCEEDED,
974
+ "function_name": "suspend",
975
+ },
976
+ ],
977
+ },
978
+ ]
979
+
980
+
981
+ @pytest.mark.xfail(reason="Not supported for now")
982
+ async def test_sub_workflows_concurrent_execution(session: AsyncSession):
983
+ @step()
984
+ async def step1(n: int):
985
+ await suspend(interval=timedelta(seconds=0.1))
986
+ return 1 + n
987
+
988
+ @step()
989
+ async def step2(n: int):
990
+ await suspend(interval=timedelta(seconds=0.1))
991
+ return 2 + n
992
+
993
+ @step()
994
+ async def step3(n: int):
995
+ await suspend(interval=timedelta(seconds=0.1))
996
+ return 3 + n
997
+
998
+ @workflow()
999
+ async def workflow1(n: int):
1000
+ return await step1(n)
1001
+
1002
+ @workflow()
1003
+ async def workflow2(n: int):
1004
+ return await step2(n)
1005
+
1006
+ @workflow()
1007
+ async def workflow3(n: int):
1008
+ return await step3(n)
1009
+
1010
+ @workflow()
1011
+ async def concurrent_call_sub_workflows():
1012
+ w1, w2, w3 = await asyncio.gather(workflow1(1), workflow2(2), workflow3(3))
1013
+ return w1 + w2 + w3
1014
+
1015
+ async with WorkflowOrchestrator.ensure_started() as orchestrator:
1016
+ wf = await concurrent_call_sub_workflows.start()
1017
+ await orchestrator.wait_for_completion(wf.id)
1018
+
1019
+ await session.refresh(wf)
1020
+ assert wf.status == WorkflowStatus.SUCCEEDED
1021
+ assert wf.result == 12
1022
+
1023
+ all_workflows = []
1024
+ workflows = (
1025
+ await session.exec(select(Workflow).order_by(col(Workflow.created_at)))
1026
+ ).all()
1027
+ for w in workflows:
1028
+ steps = (
1029
+ await session.exec(
1030
+ select(WorkflowStep)
1031
+ .where(col(WorkflowStep.workflow_id) == w.id)
1032
+ .order_by(col(WorkflowStep.step_id))
1033
+ )
1034
+ ).all()
1035
+ all_workflows.append(
1036
+ {
1037
+ "status": w.status,
1038
+ "function_name": w.function_name.split(".")[-1],
1039
+ "steps": [
1040
+ {
1041
+ "step_id": s.step_id,
1042
+ "step_status": s.status,
1043
+ "function_name": s.function_name.split(".")[-1],
1044
+ }
1045
+ for s in steps
1046
+ ],
1047
+ "result": w.result,
1048
+ }
1049
+ )
1050
+
1051
+ assert all_workflows == [
1052
+ {
1053
+ "status": WorkflowStatus.SUCCEEDED,
1054
+ "function_name": "concurrent_call_sub_workflows",
1055
+ "steps": [
1056
+ {
1057
+ "step_id": 1,
1058
+ "step_status": StepStatus.SUCCEEDED,
1059
+ "function_name": "start_workflow_step",
1060
+ },
1061
+ {
1062
+ "step_id": 2,
1063
+ "step_status": StepStatus.SUCCEEDED,
1064
+ "function_name": "start_workflow_step",
1065
+ },
1066
+ {
1067
+ "step_id": 3,
1068
+ "step_status": StepStatus.SUCCEEDED,
1069
+ "function_name": "start_workflow_step",
1070
+ },
1071
+ ],
1072
+ "result": 12,
1073
+ },
1074
+ {
1075
+ "status": WorkflowStatus.SUCCEEDED,
1076
+ "function_name": "workflow1",
1077
+ "steps": [
1078
+ {
1079
+ "step_id": 1,
1080
+ "step_status": StepStatus.SUCCEEDED,
1081
+ "function_name": "step1",
1082
+ },
1083
+ {
1084
+ "step_id": 2,
1085
+ "step_status": StepStatus.SUCCEEDED,
1086
+ "function_name": "suspend",
1087
+ },
1088
+ ],
1089
+ "result": 2,
1090
+ },
1091
+ {
1092
+ "status": WorkflowStatus.SUCCEEDED,
1093
+ "function_name": "workflow2",
1094
+ "steps": [
1095
+ {
1096
+ "step_id": 1,
1097
+ "step_status": StepStatus.SUCCEEDED,
1098
+ "function_name": "step2",
1099
+ },
1100
+ {
1101
+ "step_id": 2,
1102
+ "step_status": StepStatus.SUCCEEDED,
1103
+ "function_name": "suspend",
1104
+ },
1105
+ ],
1106
+ "result": 4,
1107
+ },
1108
+ {
1109
+ "status": WorkflowStatus.SUCCEEDED,
1110
+ "function_name": "workflow3",
1111
+ "steps": [
1112
+ {
1113
+ "step_id": 1,
1114
+ "step_status": StepStatus.SUCCEEDED,
1115
+ "function_name": "step3",
1116
+ },
1117
+ {
1118
+ "step_id": 2,
1119
+ "step_status": StepStatus.SUCCEEDED,
1120
+ "function_name": "suspend",
1121
+ },
1122
+ ],
1123
+ "result": 6,
1124
+ },
1125
+ ]
1126
+
1127
+
1128
+ @pytest.mark.xfail(reason="Not supported for now")
1129
+ async def test_step_can_be_scheduled_as_tasks(session: AsyncSession):
1130
+ @step()
1131
+ async def step1():
1132
+ s2, s3, s4 = await asyncio.gather(step2(), step3(), step4())
1133
+ return s2 + s3 + s4
1134
+
1135
+ @step()
1136
+ async def step2():
1137
+ await suspend(interval=timedelta(seconds=0.1))
1138
+ return 2
1139
+
1140
+ @step()
1141
+ async def step3():
1142
+ await suspend(interval=timedelta(seconds=0.1))
1143
+ return 3
1144
+
1145
+ @step()
1146
+ async def step4():
1147
+ await suspend(interval=timedelta(seconds=0.1))
1148
+ return 4
1149
+
1150
+ @workflow()
1151
+ async def execute_steps_in_parallel():
1152
+ return await step1()
1153
+
1154
+ async with WorkflowOrchestrator.ensure_started() as orchestrator:
1155
+ wf = await execute_steps_in_parallel.start()
1156
+ await orchestrator.wait_for_completion(wf.id)
1157
+
1158
+ await session.refresh(wf)
1159
+ assert wf.status == WorkflowStatus.SUCCEEDED
1160
+ assert wf.result == 9
1161
+
1162
+ all_workflows = []
1163
+ workflows = (
1164
+ await session.exec(select(Workflow).order_by(col(Workflow.created_at)))
1165
+ ).all()
1166
+ for w in workflows:
1167
+ steps = (
1168
+ await session.exec(
1169
+ select(WorkflowStep)
1170
+ .where(col(WorkflowStep.workflow_id) == w.id)
1171
+ .order_by(col(WorkflowStep.step_id))
1172
+ )
1173
+ ).all()
1174
+ all_workflows.append(
1175
+ {
1176
+ "status": w.status,
1177
+ "function_name": w.function_name.split(".")[-1],
1178
+ "steps": [
1179
+ {
1180
+ "step_id": s.step_id,
1181
+ "step_status": s.status,
1182
+ "function_name": s.function_name.split(".")[-1],
1183
+ }
1184
+ for s in steps
1185
+ ],
1186
+ "result": w.result,
1187
+ }
1188
+ )
1189
+
1190
+ assert all_workflows == [
1191
+ {
1192
+ "status": WorkflowStatus.SUCCEEDED,
1193
+ "function_name": "execute_steps_in_parallel",
1194
+ "steps": [
1195
+ {
1196
+ "step_id": 1,
1197
+ "step_status": StepStatus.SUCCEEDED,
1198
+ "function_name": "step1",
1199
+ },
1200
+ {
1201
+ "step_id": 2,
1202
+ "step_status": StepStatus.SUCCEEDED,
1203
+ "function_name": "start_workflow_step",
1204
+ },
1205
+ {
1206
+ "step_id": 3,
1207
+ "step_status": StepStatus.SUCCEEDED,
1208
+ "function_name": "start_workflow_step",
1209
+ },
1210
+ {
1211
+ "step_id": 4,
1212
+ "step_status": StepStatus.SUCCEEDED,
1213
+ "function_name": "start_workflow_step",
1214
+ },
1215
+ ],
1216
+ "result": 9,
1217
+ },
1218
+ {
1219
+ "status": WorkflowStatus.SUCCEEDED,
1220
+ "function_name": "auto_workflow",
1221
+ "steps": [
1222
+ {
1223
+ "step_id": 1,
1224
+ "step_status": StepStatus.SUCCEEDED,
1225
+ "function_name": "step2",
1226
+ },
1227
+ {
1228
+ "step_id": 2,
1229
+ "step_status": StepStatus.SUCCEEDED,
1230
+ "function_name": "suspend",
1231
+ },
1232
+ ],
1233
+ "result": 2,
1234
+ },
1235
+ {
1236
+ "status": WorkflowStatus.SUCCEEDED,
1237
+ "function_name": "auto_workflow",
1238
+ "steps": [
1239
+ {
1240
+ "step_id": 1,
1241
+ "step_status": StepStatus.SUCCEEDED,
1242
+ "function_name": "step3",
1243
+ },
1244
+ {
1245
+ "step_id": 2,
1246
+ "step_status": StepStatus.SUCCEEDED,
1247
+ "function_name": "suspend",
1248
+ },
1249
+ ],
1250
+ "result": 3,
1251
+ },
1252
+ {
1253
+ "status": WorkflowStatus.SUCCEEDED,
1254
+ "function_name": "auto_workflow",
1255
+ "steps": [
1256
+ {
1257
+ "step_id": 1,
1258
+ "step_status": StepStatus.SUCCEEDED,
1259
+ "function_name": "step4",
1260
+ },
1261
+ {
1262
+ "step_id": 2,
1263
+ "step_status": StepStatus.SUCCEEDED,
1264
+ "function_name": "suspend",
1265
+ },
1266
+ ],
1267
+ "result": 4,
1268
+ },
1269
+ ]
1270
+
1271
+
1272
+ async def test_nested_workflow_started_from_nested_step_failed(session: AsyncSession):
1273
+ @step()
1274
+ async def update_inbound_document_with_classification(
1275
+ item_id: str, classification: str
1276
+ ) -> bool:
1277
+ await asyncio.sleep(0.1)
1278
+ raise Exception(f"some issue with {item_id}/{classification}")
1279
+
1280
+ @workflow()
1281
+ async def classify_inbound_document(item_id: str, attachment_id: str):
1282
+ await update_inbound_document_with_classification(item_id, "classified")
1283
+
1284
+ @step()
1285
+ async def upload_documents_from_email(limit: int) -> list[str]:
1286
+ await asyncio.sleep(0.1)
1287
+ return [
1288
+ json.dumps({"item_id": "doc 1", "attachment_id": "attachment 1"}),
1289
+ ]
1290
+
1291
+ @step()
1292
+ async def start_classify_inbound_document_workflow(
1293
+ inbound_document_with_attachment: str,
1294
+ ):
1295
+ obj = json.loads(inbound_document_with_attachment)
1296
+ await classify_inbound_document(obj["item_id"], obj["attachment_id"])
1297
+
1298
+ @workflow()
1299
+ async def email_documents_uploader(limit: int = 10) -> list[str]:
1300
+ inbound_documents_with_attachments = await upload_documents_from_email(limit)
1301
+ for doc in inbound_documents_with_attachments:
1302
+ await start_classify_inbound_document_workflow(doc)
1303
+ return inbound_documents_with_attachments
1304
+
1305
+ wf = await email_documents_uploader.start()
1306
+ async with WorkflowOrchestrator.ensure_started(poll_interval=1) as orchestrator:
1307
+ with pytest.raises(Exception, match="some issue with doc 1/classified"):
1308
+ await orchestrator.wait_for_completion(wf.id)
1309
+
1310
+ all_workflows = []
1311
+ workflows = (
1312
+ await session.exec(select(Workflow).order_by(col(Workflow.created_at)))
1313
+ ).all()
1314
+ for w in workflows:
1315
+ steps = (
1316
+ await session.exec(
1317
+ select(WorkflowStep)
1318
+ .where(col(WorkflowStep.workflow_id) == w.id)
1319
+ .order_by(col(WorkflowStep.step_id))
1320
+ )
1321
+ ).all()
1322
+ all_workflows.append(
1323
+ {
1324
+ "status": w.status,
1325
+ "function_name": w.function_name.split(".")[-1],
1326
+ "steps": [
1327
+ {
1328
+ "step_id": s.step_id,
1329
+ "step_status": s.status,
1330
+ "function_name": s.function_name.split(".")[-1],
1331
+ }
1332
+ for s in steps
1333
+ ],
1334
+ }
1335
+ )
1336
+
1337
+ assert all_workflows == [
1338
+ {
1339
+ "status": WorkflowStatus.FAILED,
1340
+ "function_name": "email_documents_uploader",
1341
+ "steps": [
1342
+ {
1343
+ "step_id": 1,
1344
+ "step_status": StepStatus.SUCCEEDED,
1345
+ "function_name": "upload_documents_from_email",
1346
+ },
1347
+ {
1348
+ "step_id": 2,
1349
+ "step_status": StepStatus.FAILED,
1350
+ "function_name": "start_classify_inbound_document_workflow",
1351
+ },
1352
+ {
1353
+ "step_id": 3,
1354
+ "step_status": StepStatus.SUCCEEDED,
1355
+ "function_name": "start_workflow_step",
1356
+ },
1357
+ ],
1358
+ },
1359
+ {
1360
+ "status": WorkflowStatus.FAILED,
1361
+ "function_name": "classify_inbound_document",
1362
+ "steps": [
1363
+ {
1364
+ "function_name": "update_inbound_document_with_classification",
1365
+ "step_id": 1,
1366
+ "step_status": StepStatus.FAILED,
1367
+ }
1368
+ ],
1369
+ },
1370
+ ]
1371
+
1372
+
1373
+ # =============================================================================
1374
+ # Tests for Non-Deterministic Step Call Detection
1375
+ # =============================================================================
1376
+ async def test_non_deterministic_step_detection_args(session: AsyncSession):
1377
+ # Track whether we're in first or second execution attempt
1378
+ is_first_execution = [True]
1379
+
1380
+ class ConfigModel(BaseModel):
1381
+ name: str
1382
+ value: int
1383
+ nested: dict[str, str]
1384
+
1385
+ @step(max_retries=1)
1386
+ async def failing_step_with_model(config: ConfigModel) -> str:
1387
+ # First execution will always fail
1388
+ if is_first_execution[0]:
1389
+ is_first_execution[0] = False
1390
+ raise RuntimeError("First attempt fails deliberately")
1391
+
1392
+ # Return something (won't matter for the test)
1393
+ return f"Processed {config.name} with value {config.value}"
1394
+
1395
+ @workflow()
1396
+ async def model_workflow() -> str:
1397
+ # First execution will use this config
1398
+ config = ConfigModel(name="test-config", value=42, nested={"key": "original"})
1399
+
1400
+ # On retry, we'll modify the config in a non-deterministic way
1401
+ if not is_first_execution[0]:
1402
+ # This change should be detected as non-deterministic
1403
+ config = ConfigModel(
1404
+ name="test-config",
1405
+ value=42,
1406
+ nested={"key": "modified"}, # Change in nested field
1407
+ )
1408
+
1409
+ return await failing_step_with_model(config)
1410
+
1411
+ # Start and execute the workflow
1412
+ wf = await model_workflow.start()
1413
+
1414
+ # First execution will fail but set up for retry
1415
+ await execute(wf)
1416
+
1417
+ # Verify the workflow is in pending state with a failed step
1418
+ updated_wf = await session.get(Workflow, wf.id)
1419
+ assert updated_wf
1420
+ assert updated_wf.status == WorkflowStatus.PENDING
1421
+
1422
+ # Find the step record
1423
+ s = (
1424
+ await session.exec(
1425
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
1426
+ )
1427
+ ).one()
1428
+ assert s.status == StepStatus.FAILED
1429
+ assert s.retry_count == 0
1430
+
1431
+ # Second execution should fail with NonDeterministicStepCallError
1432
+ # because the nested field was changed
1433
+ with pytest.raises(
1434
+ NonDeterministicStepCallError,
1435
+ match="Non-deterministic step call detected at step ID 1. Previous args",
1436
+ ) as excinfo:
1437
+ await execute(wf)
1438
+
1439
+ # Verify error message contains information about the non-deterministic input
1440
+ err_msg = str(excinfo.value)
1441
+ assert "Non-deterministic step call detected" in err_msg
1442
+ assert "nested" in err_msg or "original" in err_msg or "modified" in err_msg
1443
+
1444
+
1445
+ async def test_non_deterministic_step_detection_kwargs(session: AsyncSession):
1446
+ # Track whether we're in first or second execution attempt
1447
+ is_first_execution = [True]
1448
+
1449
+ class ConfigModel(BaseModel):
1450
+ name: str
1451
+ value: int
1452
+ options: dict[str, bool]
1453
+
1454
+ @step(max_retries=1)
1455
+ async def failing_step_with_kwargs(
1456
+ basic_value: int, *, config: ConfigModel, flag: bool = False
1457
+ ) -> str:
1458
+ # First execution will always fail
1459
+ if is_first_execution[0]:
1460
+ is_first_execution[0] = False
1461
+ raise RuntimeError("First attempt fails deliberately")
1462
+
1463
+ # Return something (won't matter for the test)
1464
+ return f"Processed with {basic_value} and {config.name}"
1465
+
1466
+ @workflow()
1467
+ async def kwargs_workflow() -> str:
1468
+ # First execution will use these values
1469
+ basic_value = 100
1470
+ config = ConfigModel(
1471
+ name="config-1", value=42, options={"debug": True, "verbose": False}
1472
+ )
1473
+ flag = False
1474
+
1475
+ # On retry, we'll modify the kwargs in a non-deterministic way
1476
+ if not is_first_execution[0]:
1477
+ # This kwargs change should be detected as non-deterministic
1478
+ flag = True # Changed from False to True
1479
+
1480
+ return await failing_step_with_kwargs(basic_value, config=config, flag=flag)
1481
+
1482
+ # Start and execute the workflow
1483
+ wf = await kwargs_workflow.start()
1484
+
1485
+ # First execution will fail but set up for retry
1486
+ await execute(wf)
1487
+
1488
+ # Verify the workflow is in pending state with a failed step
1489
+ updated_wf = await session.get(Workflow, wf.id)
1490
+ assert updated_wf
1491
+ assert updated_wf.status == WorkflowStatus.PENDING
1492
+
1493
+ # Find the step record
1494
+ s = (
1495
+ await session.exec(
1496
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
1497
+ )
1498
+ ).one()
1499
+ assert s.status == StepStatus.FAILED
1500
+ assert s.retry_count == 0
1501
+
1502
+ # Second execution should fail with NonDeterministicStepCallError
1503
+ # because the flag kwarg was changed
1504
+ with pytest.raises(
1505
+ NonDeterministicStepCallError,
1506
+ match="Non-deterministic step call detected at step ID 1. Previous kwargs",
1507
+ ) as excinfo:
1508
+ await execute(wf)
1509
+
1510
+ # Verify error message contains information about the non-deterministic input
1511
+ err_msg = str(excinfo.value)
1512
+ assert "Non-deterministic step call detected" in err_msg
1513
+ assert "flag" in err_msg
1514
+
1515
+
1516
+ async def test_non_deterministic_step_detection_function(session: AsyncSession):
1517
+ # Track whether we're in first or second execution attempt
1518
+ is_first_execution = [True]
1519
+
1520
+ @step(max_retries=1)
1521
+ async def first_step(value: int) -> int:
1522
+ is_first_execution[0] = False
1523
+ raise RuntimeError("First step fails deliberately")
1524
+
1525
+ @step()
1526
+ async def second_step(value: int) -> int:
1527
+ return value * 2
1528
+
1529
+ @workflow()
1530
+ async def different_step_workflow() -> int:
1531
+ initial_value = 5
1532
+
1533
+ # On first execution, call first_step
1534
+ if is_first_execution[0]:
1535
+ return await first_step(initial_value)
1536
+ else:
1537
+ # On retry, call a completely different step
1538
+ # This should be detected as non-deterministic
1539
+ return await second_step(initial_value)
1540
+
1541
+ # Start and execute the workflow
1542
+ wf = await different_step_workflow.start()
1543
+
1544
+ # First execution will fail but set up for retry
1545
+ await execute(wf)
1546
+
1547
+ # Verify the workflow is in pending state with a failed step
1548
+ updated_wf = await session.get(Workflow, wf.id)
1549
+ assert updated_wf
1550
+ assert updated_wf.status == WorkflowStatus.PENDING
1551
+
1552
+ # Find the step record
1553
+ s = (
1554
+ await session.exec(
1555
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
1556
+ )
1557
+ ).one()
1558
+ assert s.status == StepStatus.FAILED
1559
+ assert s.retry_count == 0
1560
+
1561
+ # Second execution should fail with NonDeterministicStepCallError
1562
+ # because we're calling a completely different step
1563
+ with pytest.raises(
1564
+ NonDeterministicStepCallError,
1565
+ match="Non-deterministic step call detected at step ID 1. Previous function name",
1566
+ ) as excinfo:
1567
+ await execute(wf)
1568
+
1569
+ # Verify error message contains information about the non-deterministic function call
1570
+ err_msg = str(excinfo.value)
1571
+ assert "Non-deterministic step call detected" in err_msg
1572
+ assert "first_step" in err_msg and "second_step" in err_msg
1573
+
1574
+
1575
+ async def test_task_cancellation(session: AsyncSession):
1576
+ @step()
1577
+ async def handled_cancellation_step():
1578
+ try:
1579
+ asyncio.create_task(canceller(asyncio.current_task()))
1580
+ await asyncio.sleep(10)
1581
+ return "completed"
1582
+ except asyncio.CancelledError:
1583
+ return "cancelled"
1584
+
1585
+ @step()
1586
+ async def unhandled_cancellation_step():
1587
+ asyncio.create_task(canceller(asyncio.current_task()))
1588
+ await asyncio.sleep(10)
1589
+ return "completed2"
1590
+
1591
+ @workflow()
1592
+ async def cancellation_workflow():
1593
+ result = await handled_cancellation_step()
1594
+ try:
1595
+ return await unhandled_cancellation_step()
1596
+ except asyncio.CancelledError:
1597
+ return f'first step result: "{result}". second step cancelled'
1598
+
1599
+ async def canceller(task: asyncio.Task | None):
1600
+ assert task
1601
+ task.cancel()
1602
+
1603
+ async with WorkflowOrchestrator.ensure_started() as orchestrator:
1604
+ wf = await cancellation_workflow.start()
1605
+ await orchestrator.wait_for_completion(wf.id)
1606
+
1607
+ await session.refresh(wf)
1608
+ steps = (
1609
+ await session.exec(
1610
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
1611
+ )
1612
+ ).all()
1613
+ assert len(steps) == 2
1614
+ assert steps[0].status == StepStatus.SUCCEEDED
1615
+ assert steps[0].result == "cancelled"
1616
+ assert steps[1].status == StepStatus.FAILED
1617
+ assert steps[1].error
1618
+ assert steps[1].error["type"] == "CancelledError"
1619
+ assert wf.status == WorkflowStatus.SUCCEEDED
1620
+ assert wf.result == 'first step result: "cancelled". second step cancelled'
1621
+
1622
+
1623
+ async def test_as_step_helper(session: AsyncSession):
1624
+ # Force garbage collection to ensure any non-referenced cached functions are removed
1625
+ gc.collect()
1626
+
1627
+ # Store the initial cache count since our assertions will be based on it
1628
+ initial_cache_count = len(__AS_STEP_CACHE)
1629
+
1630
+ # Create a regular coroutine function (not a step)
1631
+ async def regular_function(value: int) -> int:
1632
+ return value * 2
1633
+
1634
+ # Verify it's not already a step
1635
+ assert not __is_workflow_step(regular_function)
1636
+
1637
+ # Convert it to a step
1638
+ step_function = as_step(regular_function, step_type=StepType.COMPUTE)
1639
+
1640
+ # Verify it's now recognized as a step
1641
+ assert __is_workflow_step(step_function)
1642
+
1643
+ # Calling as_step again should return the same cached step function
1644
+ step_function_again = as_step(regular_function, step_type=StepType.COMPUTE)
1645
+ assert step_function is step_function_again
1646
+
1647
+ # Create a workflow that uses the step
1648
+ @workflow()
1649
+ async def as_step_workflow(input_value: int) -> int:
1650
+ result = await step_function(input_value)
1651
+ return result
1652
+
1653
+ # Execute the workflow
1654
+ wf = await as_step_workflow.start(5)
1655
+ result = await execute(wf)
1656
+
1657
+ # Verify the workflow completed successfully
1658
+ updated_wf = await session.get(Workflow, wf.id)
1659
+ assert updated_wf
1660
+ assert updated_wf.status == WorkflowStatus.SUCCEEDED
1661
+ assert result == 10 # 5 * 2
1662
+
1663
+ # Verify a step was created and executed
1664
+ steps = (
1665
+ await session.exec(
1666
+ select(WorkflowStep).where(WorkflowStep.workflow_id == wf.id)
1667
+ )
1668
+ ).all()
1669
+ assert len(steps) == 1
1670
+ assert steps[0].status == StepStatus.SUCCEEDED
1671
+ assert steps[0].result == 10
1672
+
1673
+ # Test with an already-decorated step function
1674
+ @step()
1675
+ async def already_step_function(value: int) -> int:
1676
+ return value + 10
1677
+
1678
+ # as_step should return the original function if it's already a step
1679
+ same_step = as_step(already_step_function, step_type=StepType.COMPUTE)
1680
+ assert same_step is already_step_function
1681
+
1682
+ # Create and execute a workflow using the already-step function
1683
+ @workflow()
1684
+ async def existing_step_workflow(input_value: int) -> int:
1685
+ result = await already_step_function(input_value)
1686
+ return result
1687
+
1688
+ wf2 = await existing_step_workflow.start(7)
1689
+ result2 = await execute(wf2)
1690
+
1691
+ # Verify workflow execution
1692
+ updated_wf2 = await session.get(Workflow, wf2.id)
1693
+ assert updated_wf2
1694
+ assert updated_wf2.status == WorkflowStatus.SUCCEEDED
1695
+ assert result2 == 17 # 7 + 10
1696
+
1697
+ # We should have 1 entry in the cache at this point.
1698
+ assert len(__AS_STEP_CACHE) == initial_cache_count + 1
1699
+
1700
+ # Test that WeakKeyDictionary prevents memory leaks
1701
+ def create_temp_function():
1702
+ # Create a function that will go out of scope
1703
+ async def temp_function(x: int) -> int:
1704
+ return x * 3
1705
+
1706
+ # Apply as_step to the function
1707
+ as_step(temp_function, step_type=StepType.COMPUTE)
1708
+
1709
+ return temp_function
1710
+
1711
+ # Create a temporary function which will have `as_step` applied on it
1712
+ temp_function = create_temp_function()
1713
+
1714
+ # Verify that the new function is in the cache
1715
+ assert len(__AS_STEP_CACHE) == initial_cache_count + 2
1716
+
1717
+ # Clear the reference to the function
1718
+ temp_function = None
1719
+ assert temp_function is None # use the variable to make linter happy
1720
+
1721
+ # Force garbage collection
1722
+ gc.collect()
1723
+
1724
+ # Verify the weak reference is now None (object was garbage collected)
1725
+ assert len(__AS_STEP_CACHE) == initial_cache_count + 1
1726
+
1727
+
1728
+ async def test_workflow_notifications(session: AsyncSession):
1729
+ """Test that all workflow notifications are delivered correctly."""
1730
+ # Create a WorkflowObserver to capture notifications
1731
+ observer = WorkflowObserver()
1732
+ exec_count = 0
1733
+
1734
+ @step(max_retries=1)
1735
+ async def some_step():
1736
+ nonlocal exec_count
1737
+ if exec_count == 0:
1738
+ exec_count += 1
1739
+ raise Exception("First execution")
1740
+ return "success"
1741
+
1742
+ @workflow()
1743
+ async def notification_workflow():
1744
+ await some_step()
1745
+ return "done"
1746
+
1747
+ async def wait_notifications(workflow_id: UUID):
1748
+ # First execution fails
1749
+ await observer.wait(Notification.WORKFLOW_STARTED, workflow_id)
1750
+ await observer.wait(Notification.WORKFLOW_RESUMED, workflow_id)
1751
+ await observer.wait(Notification.STEP_RUNNING, workflow_id)
1752
+ await observer.wait(Notification.STEP_FAILED, workflow_id)
1753
+ await observer.wait(Notification.WORKFLOW_SUSPENDED, workflow_id)
1754
+
1755
+ # # Second execution succeeds
1756
+ await observer.wait(Notification.WORKFLOW_RESUMED, workflow_id)
1757
+ await observer.wait(Notification.STEP_RUNNING, workflow_id)
1758
+ await observer.wait(Notification.STEP_SUCCEEDED, workflow_id)
1759
+ await observer.wait(Notification.WORKFLOW_SUCCEEDED, workflow_id)
1760
+
1761
+ async with workflow_notification_context(observer.on_workflow_notification):
1762
+ wf = await notification_workflow.start()
1763
+ wait_task = asyncio.create_task(wait_notifications(wf.id))
1764
+ # execution 1
1765
+ await lock_and_execute(wf)
1766
+ # execution 2
1767
+ await lock_and_execute(wf)
1768
+
1769
+ # Verify we received all notifications by simply waiting the task
1770
+ await wait_task
1771
+
1772
+
1773
+ # =============================================================================
1774
+ # Test for Step Hierarchy Implementation
1775
+ # =============================================================================
1776
+ async def test_step_hierarchy_implementation(session: AsyncSession):
1777
+ """Test that the step hierarchy is correctly implemented with parent-child relationships."""
1778
+
1779
+ @step()
1780
+ async def parent_step():
1781
+ return await child_step()
1782
+
1783
+ @step()
1784
+ async def child_step():
1785
+ return await grandchild_step()
1786
+
1787
+ @step()
1788
+ async def grandchild_step():
1789
+ return "done"
1790
+
1791
+ @workflow()
1792
+ async def hierarchy_workflow():
1793
+ return await parent_step()
1794
+
1795
+ # Run the workflow
1796
+ wf = await hierarchy_workflow.start()
1797
+ await lock_and_execute(wf)
1798
+
1799
+ # Get all steps for this workflow
1800
+ steps = (
1801
+ await session.exec(
1802
+ select(WorkflowStep)
1803
+ .where(WorkflowStep.workflow_id == wf.id)
1804
+ .order_by(col(WorkflowStep.step_id))
1805
+ )
1806
+ ).all()
1807
+
1808
+ # We should have 3 steps
1809
+ assert len(steps) == 3
1810
+
1811
+ # Verify step types
1812
+ assert steps[0].function_name.split(".")[-1] == "parent_step"
1813
+ assert steps[1].function_name.split(".")[-1] == "child_step"
1814
+ assert steps[2].function_name.split(".")[-1] == "grandchild_step"
1815
+
1816
+ parent_step_id = steps[0].step_id
1817
+ descendant_step_ids = [steps[1].step_id, steps[2].step_id]
1818
+
1819
+ for descendant_step_id in descendant_step_ids:
1820
+ assert parent_step_id < descendant_step_id
1821
+
1822
+ # Verify parent-child relationships
1823
+ assert steps[0].parent_step_id is None # Parent has no parent
1824
+ assert steps[1].parent_step_id == steps[0].step_id # Child's parent is parent
1825
+ assert steps[2].parent_step_id == steps[1].step_id # Grandchild's parent is child
1826
+
1827
+ # Verify hierarchy utility functions
1828
+
1829
+ # 1. Get parent
1830
+ parent = await get_step_parent(steps[2]) # Get parent of grandchild
1831
+ assert parent is not None
1832
+ assert parent.step_id == steps[1].step_id
1833
+ assert parent.function_name == steps[1].function_name
1834
+
1835
+ # 2. Get children
1836
+ children = await get_step_children(steps[0]) # Get children of parent
1837
+ assert len(children) == 1
1838
+ assert children[0].step_id == steps[1].step_id
1839
+
1840
+ # 3. Get descendants
1841
+ descendants = await get_step_descendants(steps[0]) # Get all descendants of parent
1842
+ assert len(descendants) == 2
1843
+ descendant_ids = sorted([d.step_id for d in descendants])
1844
+ assert descendant_ids == [steps[1].step_id, steps[2].step_id]
1845
+
1846
+ # Get ancestors of grandchild
1847
+ ancestors = await get_step_ancestors(steps[2])
1848
+ assert len(ancestors) == 2
1849
+ assert (
1850
+ ancestors[0].step_id == steps[1].step_id
1851
+ ) # First ancestor is the immediate parent
1852
+ assert (
1853
+ ancestors[1].step_id == steps[0].step_id
1854
+ ) # Second ancestor is the grandparent
1855
+
1856
+
1857
+ async def test_basic_step_parent_child(session: AsyncSession):
1858
+ """Basic test of parent-child relationship."""
1859
+
1860
+ @step()
1861
+ async def parent():
1862
+ return await child()
1863
+
1864
+ @step()
1865
+ async def child():
1866
+ return "done"
1867
+
1868
+ @workflow()
1869
+ async def parent_child_workflow():
1870
+ return await parent()
1871
+
1872
+ wf = await parent_child_workflow.start()
1873
+ await lock_and_execute(wf)
1874
+
1875
+ steps = (
1876
+ await session.exec(
1877
+ select(WorkflowStep)
1878
+ .where(WorkflowStep.workflow_id == wf.id)
1879
+ .order_by(col(WorkflowStep.step_id))
1880
+ )
1881
+ ).all()
1882
+
1883
+ assert len(steps) == 2
1884
+ parent_step = steps[0]
1885
+ child_step = steps[1]
1886
+
1887
+ assert parent_step.step_id < child_step.step_id
1888
+
1889
+ # Test parent-child relationship
1890
+ assert child_step.parent_step_id == parent_step.step_id
1891
+ assert parent_step.parent_step_id is None
1892
+
1893
+
1894
+ async def test_child_workflow_called_as_function_has_parent_id(session: AsyncSession):
1895
+ @workflow()
1896
+ async def child_workflow():
1897
+ return "child_result"
1898
+
1899
+ @workflow()
1900
+ async def parent_workflow():
1901
+ # Call child workflow as async function - this should set parent_id
1902
+ result = await child_workflow()
1903
+ return f"parent got: {result}"
1904
+
1905
+ # Start parent workflow
1906
+ parent_wf = await parent_workflow.start()
1907
+
1908
+ async with WorkflowOrchestrator.ensure_started() as orchestrator:
1909
+ await orchestrator.wait_for_completion(parent_wf.id)
1910
+
1911
+ # Verify parent workflow completed successfully
1912
+ await session.refresh(parent_wf)
1913
+ assert parent_wf.status == WorkflowStatus.SUCCEEDED
1914
+ assert parent_wf.result == "parent got: child_result"
1915
+
1916
+ # Get all workflows and find the child workflow
1917
+ all_workflows = (
1918
+ await session.exec(select(Workflow).order_by(col(Workflow.created_at)))
1919
+ ).all()
1920
+
1921
+ assert len(all_workflows) == 2
1922
+ child_wf = next(wf for wf in all_workflows if wf.id != parent_wf.id)
1923
+
1924
+ # Verify child workflow has parent_id set to parent workflow
1925
+ assert child_wf.parent_id == parent_wf.id
1926
+ assert child_wf.status == WorkflowStatus.SUCCEEDED
1927
+ assert child_wf.result == "child_result"
1928
+
1929
+
1930
+ async def test_child_workflow_called_as_start_step(session: AsyncSession):
1931
+ child_workflow_id = None
1932
+
1933
+ @workflow()
1934
+ async def child_workflow():
1935
+ return "child_result"
1936
+
1937
+ @workflow()
1938
+ async def parent_workflow():
1939
+ nonlocal child_workflow_id
1940
+ # Call child workflow using start_step - this should NOT set parent_id
1941
+ child_workflow_id = await child_workflow.start_step()
1942
+ return f"started child: {child_workflow_id}"
1943
+
1944
+ # Start parent workflow
1945
+ parent_wf = await parent_workflow.start()
1946
+
1947
+ async with WorkflowOrchestrator.ensure_started() as orchestrator:
1948
+ await orchestrator.wait_for_completion(parent_wf.id)
1949
+ assert child_workflow_id
1950
+ await orchestrator.wait_for_completion(child_workflow_id)
1951
+
1952
+ # Verify parent workflow completed successfully
1953
+ await session.refresh(parent_wf)
1954
+ assert parent_wf.status == WorkflowStatus.SUCCEEDED
1955
+
1956
+ # Get all workflows and find the child workflow
1957
+ all_workflows = (
1958
+ await session.exec(select(Workflow).order_by(col(Workflow.created_at)))
1959
+ ).all()
1960
+
1961
+ assert len(all_workflows) == 2
1962
+ child_wf = next(wf for wf in all_workflows if wf.id != parent_wf.id)
1963
+
1964
+ # Verify child workflow has NO parent_id set
1965
+ assert child_wf.parent_id is None
1966
+ assert child_wf.status == WorkflowStatus.SUCCEEDED
1967
+ assert child_wf.result == "child_result"