planar 0.9.3__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. planar/ai/agent.py +2 -1
  2. planar/ai/agent_base.py +24 -5
  3. planar/ai/state.py +17 -0
  4. planar/app.py +18 -1
  5. planar/data/connection.py +108 -0
  6. planar/data/dataset.py +11 -104
  7. planar/data/utils.py +89 -0
  8. planar/db/alembic/env.py +25 -1
  9. planar/files/storage/azure_blob.py +1 -1
  10. planar/registry_items.py +2 -0
  11. planar/routers/dataset_router.py +213 -0
  12. planar/routers/info.py +79 -36
  13. planar/routers/models.py +1 -0
  14. planar/routers/workflow.py +2 -0
  15. planar/scaffold_templates/pyproject.toml.j2 +1 -1
  16. planar/security/authorization.py +31 -3
  17. planar/security/default_policies.cedar +25 -0
  18. planar/testing/fixtures.py +34 -1
  19. planar/testing/planar_test_client.py +1 -1
  20. planar/workflows/decorators.py +2 -1
  21. planar/workflows/wrappers.py +1 -0
  22. {planar-0.9.3.dist-info → planar-0.11.0.dist-info}/METADATA +9 -1
  23. {planar-0.9.3.dist-info → planar-0.11.0.dist-info}/RECORD +25 -72
  24. {planar-0.9.3.dist-info → planar-0.11.0.dist-info}/WHEEL +1 -1
  25. planar/ai/test_agent_serialization.py +0 -229
  26. planar/ai/test_agent_tool_step_display.py +0 -78
  27. planar/data/test_dataset.py +0 -354
  28. planar/files/storage/test_azure_blob.py +0 -435
  29. planar/files/storage/test_local_directory.py +0 -162
  30. planar/files/storage/test_s3.py +0 -299
  31. planar/files/test_files.py +0 -282
  32. planar/human/test_human.py +0 -385
  33. planar/logging/test_formatter.py +0 -327
  34. planar/modeling/mixins/test_auditable.py +0 -97
  35. planar/modeling/mixins/test_timestamp.py +0 -134
  36. planar/modeling/mixins/test_uuid_primary_key.py +0 -52
  37. planar/routers/test_agents_router.py +0 -174
  38. planar/routers/test_files_router.py +0 -49
  39. planar/routers/test_object_config_router.py +0 -367
  40. planar/routers/test_routes_security.py +0 -168
  41. planar/routers/test_rule_router.py +0 -470
  42. planar/routers/test_workflow_router.py +0 -539
  43. planar/rules/test_data/account_dormancy_management.json +0 -223
  44. planar/rules/test_data/airline_loyalty_points_calculator.json +0 -262
  45. planar/rules/test_data/applicant_risk_assessment.json +0 -435
  46. planar/rules/test_data/booking_fraud_detection.json +0 -407
  47. planar/rules/test_data/cellular_data_rollover_system.json +0 -258
  48. planar/rules/test_data/clinical_trial_eligibility_screener.json +0 -437
  49. planar/rules/test_data/customer_lifetime_value.json +0 -143
  50. planar/rules/test_data/import_duties_calculator.json +0 -289
  51. planar/rules/test_data/insurance_prior_authorization.json +0 -443
  52. planar/rules/test_data/online_check_in_eligibility_system.json +0 -254
  53. planar/rules/test_data/order_consolidation_system.json +0 -375
  54. planar/rules/test_data/portfolio_risk_monitor.json +0 -471
  55. planar/rules/test_data/supply_chain_risk.json +0 -253
  56. planar/rules/test_data/warehouse_cross_docking.json +0 -237
  57. planar/rules/test_rules.py +0 -1494
  58. planar/security/tests/test_auth_middleware.py +0 -162
  59. planar/security/tests/test_authorization_context.py +0 -78
  60. planar/security/tests/test_cedar_basics.py +0 -41
  61. planar/security/tests/test_cedar_policies.py +0 -158
  62. planar/security/tests/test_jwt_principal_context.py +0 -179
  63. planar/test_app.py +0 -142
  64. planar/test_cli.py +0 -394
  65. planar/test_config.py +0 -515
  66. planar/test_object_config.py +0 -527
  67. planar/test_object_registry.py +0 -14
  68. planar/test_sqlalchemy.py +0 -193
  69. planar/test_utils.py +0 -105
  70. planar/testing/test_memory_storage.py +0 -143
  71. planar/workflows/test_concurrency_detection.py +0 -120
  72. planar/workflows/test_lock_timeout.py +0 -140
  73. planar/workflows/test_serialization.py +0 -1203
  74. planar/workflows/test_suspend_deserialization.py +0 -231
  75. planar/workflows/test_workflow.py +0 -2005
  76. {planar-0.9.3.dist-info → planar-0.11.0.dist-info}/entry_points.txt +0 -0
planar/test_sqlalchemy.py DELETED
@@ -1,193 +0,0 @@
1
- from uuid import uuid4
2
-
3
- import pytest
4
- from sqlalchemy.exc import DBAPIError
5
- from sqlalchemy.ext.asyncio import AsyncEngine
6
- from sqlmodel import col, insert, select, text
7
-
8
- from planar.db import PlanarSession, new_session
9
- from planar.modeling.orm.planar_base_entity import PlanarBaseEntity
10
-
11
-
12
- class SomeModel(PlanarBaseEntity, table=True):
13
- name: str
14
- value: int = 0
15
-
16
-
17
- async def test_run_transaction_success(tmp_db_engine):
18
- uuid = uuid4()
19
- uuid2 = uuid4()
20
-
21
- async def transaction_func(session: PlanarSession):
22
- await session.exec(
23
- insert(SomeModel).values(id=uuid, name="test_item", value=42) # type: ignore
24
- )
25
- await session.exec(
26
- insert(SomeModel).values(id=uuid2, name="test_item2", value=42) # type: ignore
27
- )
28
-
29
- async with new_session(tmp_db_engine) as session:
30
- session.max_conflict_retries = 3
31
- await session.run_transaction(transaction_func, session)
32
-
33
- async with new_session(tmp_db_engine) as session:
34
- items = (
35
- await session.exec(select(SomeModel).order_by(col(SomeModel.name)))
36
- ).all()
37
- assert items == [
38
- SomeModel(id=uuid, name="test_item", value=42),
39
- SomeModel(id=uuid2, name="test_item2", value=42),
40
- ]
41
-
42
-
43
- async def test_run_transaction_failure(tmp_db_engine):
44
- async def transaction_func(session: PlanarSession):
45
- await session.exec(insert(SomeModel).values(name="test_item", value=42)) # type: ignore
46
- raise ValueError("Test error")
47
- await session.exec(insert(SomeModel).values(name="test_item2", value=42)) # type: ignore
48
-
49
- async with new_session(tmp_db_engine) as session:
50
- with pytest.raises(ValueError, match="Test error"):
51
- session.max_conflict_retries = 3
52
- await session.run_transaction(transaction_func, session)
53
-
54
- async with new_session(tmp_db_engine) as session:
55
- items = (await session.exec(select(SomeModel))).all()
56
- assert items == []
57
-
58
-
59
- async def test_run_transaction_concurrent_retry_success(tmp_db_engine):
60
- attempts = 0
61
- uuid = uuid4()
62
- uuid2 = uuid4()
63
-
64
- async def transaction_func(session: PlanarSession):
65
- nonlocal attempts
66
- await session.exec(
67
- insert(SomeModel).values(id=uuid, name="test_item", value=42) # type: ignore
68
- )
69
- if attempts == 0:
70
- attempts += 1
71
- raise DBAPIError(
72
- "Test error", None, Exception("could not serialize access")
73
- )
74
- await session.exec(
75
- insert(SomeModel).values(id=uuid2, name="test_item2", value=42) # type: ignore
76
- )
77
-
78
- async with new_session(tmp_db_engine) as session:
79
- session.max_conflict_retries = 1
80
- await session.run_transaction(transaction_func, session)
81
-
82
- async with new_session(tmp_db_engine) as session:
83
- items = (
84
- await session.exec(select(SomeModel).order_by(col(SomeModel.name)))
85
- ).all()
86
- assert items == [
87
- SomeModel(id=uuid, name="test_item", value=42),
88
- SomeModel(id=uuid2, name="test_item2", value=42),
89
- ]
90
-
91
-
92
- async def test_run_transaction_concurrent_retry_failure(tmp_db_engine):
93
- attempts = 0
94
-
95
- async def transaction_func(session: PlanarSession):
96
- nonlocal attempts
97
- await session.exec(insert(SomeModel).values(name="test_item", value=42)) # type: ignore
98
- if attempts < 2:
99
- attempts += 1
100
- raise DBAPIError(
101
- "Test error", None, Exception("could not serialize access")
102
- )
103
- await session.exec(insert(SomeModel).values(name="test_item2", value=42)) # type: ignore
104
-
105
- async with new_session(tmp_db_engine) as session:
106
- with pytest.raises(DBAPIError, match="Test error"):
107
- session.max_conflict_retries = 1
108
- await session.run_transaction(transaction_func, session)
109
-
110
- async with new_session(tmp_db_engine) as session:
111
- items = (await session.exec(select(SomeModel))).all()
112
- assert items == []
113
-
114
-
115
- async def test_serializable_transaction_failure_1(tmp_db_engine: AsyncEngine):
116
- if tmp_db_engine.dialect.name != "postgresql":
117
- return pytest.skip("Test requires PostgreSQL database")
118
-
119
- async with new_session(tmp_db_engine) as setup_session:
120
- # Setup: Insert initial data
121
- async with setup_session.begin():
122
- setup_session.add(SomeModel(id=uuid4(), name="initial", value=10))
123
-
124
- # Create two sessions
125
- async with (
126
- new_session(tmp_db_engine) as session1,
127
- new_session(tmp_db_engine) as session2,
128
- ):
129
- # Begin transactions in both sessions
130
- await session1.begin()
131
- await session2.begin()
132
-
133
- # Set serializable isolation level
134
- await session1.set_serializable_isolation()
135
- await session2.set_serializable_isolation()
136
-
137
- # Session 1: Read data
138
- item1 = (
139
- await session1.exec(select(SomeModel).where(SomeModel.name == "initial"))
140
- ).one()
141
- assert item1.value == 10
142
-
143
- # Session 2: Read the same data
144
- item2 = (
145
- await session2.exec(select(SomeModel).where(SomeModel.name == "initial"))
146
- ).one()
147
- assert item2.value == 10
148
-
149
- # Both sessions update the same row
150
- item1.value += 5
151
- item2.value += 3
152
-
153
- # Session 1: Commit should succeed
154
- await session1.commit()
155
-
156
- # Session 2: Commit should fail with serialization error
157
- with pytest.raises(DBAPIError, match="could not serialize access"):
158
- await session2.commit()
159
-
160
-
161
- async def test_entity_schema_and_planar_schema_presence(tmp_db_engine: AsyncEngine):
162
- table_name = SomeModel.__tablename__
163
-
164
- async with new_session(tmp_db_engine) as session:
165
- dialect = session.dialect.name
166
-
167
- if dialect == "postgresql":
168
- # Verify schemas include 'planar' and the default entity schema 'planar_entity'
169
- res = await session.exec(
170
- text("select schema_name from information_schema.schemata") # type: ignore[arg-type]
171
- )
172
- schemas = {row[0] for row in res}
173
- assert "planar" in schemas
174
- assert "planar_entity" in schemas
175
-
176
- # Verify SomeModel table is created in the entity schema
177
- res = await session.exec(
178
- text(
179
- "select table_schema from information_schema.tables where table_name = :tn"
180
- ).bindparams(tn=table_name) # type: ignore[arg-type]
181
- )
182
- table_schemas = {row[0] for row in res}
183
- assert "planar_entity" in table_schemas
184
- assert "public" not in table_schemas
185
-
186
- else:
187
- # SQLite: no schemas; ensure table exists
188
- res = await session.exec(
189
- text("select name from sqlite_master where type='table'") # type: ignore[arg-type]
190
- )
191
- tables = {row[0] for row in res}
192
- assert table_name in tables
193
- assert not any(name.startswith("planar.") for name in tables)
planar/test_utils.py DELETED
@@ -1,105 +0,0 @@
1
- import asyncio
2
- import time
3
- from datetime import UTC, datetime
4
-
5
- import pytest
6
-
7
- from planar.utils import asyncify, utc_now
8
-
9
-
10
- async def test_asyncify_converts_sync_to_async():
11
- """Test that asyncify correctly converts a synchronous function to an asynchronous one."""
12
-
13
- def sync_function(x, y):
14
- return x + y
15
-
16
- async_function = asyncify(sync_function)
17
-
18
- # Check that the function is now a coroutine function
19
- assert asyncio.iscoroutinefunction(async_function)
20
- assert not asyncio.iscoroutinefunction(sync_function)
21
-
22
- # Check that it can be awaited
23
- result = await async_function(5, 3)
24
- assert result == 8
25
-
26
-
27
- async def test_asyncify_with_args_and_kwargs():
28
- """Test that asyncify correctly passes positional and keyword arguments."""
29
-
30
- def complex_function(a, b, c=0, d=0):
31
- return a + b + c + d
32
-
33
- async_function = asyncify(complex_function)
34
-
35
- # Test with positional args only
36
- result1 = await async_function(1, 2)
37
- assert result1 == 3
38
-
39
- # Test with positional and keyword args
40
- result2 = await async_function(1, 2, c=3, d=4)
41
- assert result2 == 10
42
-
43
-
44
- async def test_asyncify_preserves_exceptions():
45
- """Test that asyncify preserves exceptions raised by the wrapped function."""
46
-
47
- def failing_function():
48
- raise ValueError("Expected error")
49
-
50
- async_function = asyncify(failing_function)
51
-
52
- with pytest.raises(ValueError, match="Expected error"):
53
- await async_function()
54
-
55
-
56
- async def test_asyncify_non_blocking():
57
- """Test that asyncify runs the function in a way that doesn't block the event loop."""
58
- # This counter will be incremented by a task running concurrently with our slow function
59
- counter = 0
60
-
61
- @asyncify
62
- def slow_function():
63
- time.sleep(0.5) # This would block the event loop if not run in executor
64
- return counter
65
-
66
- # This task will increment the counter while the slow function is running
67
- async def increment_counter():
68
- nonlocal counter
69
- await asyncio.sleep(0.1) # Short sleep to allow the slow function to start
70
- for _ in range(10):
71
- counter += 1
72
- await asyncio.sleep(0.01) # Short sleep to yield control
73
-
74
- # Create increment task
75
- task = asyncio.create_task(increment_counter())
76
-
77
- # Run the async function
78
- assert counter == 0
79
- result = await slow_function()
80
- # If the event loop was blocked, the counter would be 0
81
- assert counter == 10
82
- assert result == 10
83
-
84
- await task
85
-
86
-
87
- def test_raises_when_applied_to_async_function():
88
- """Test that asyncify raises an error when applied to an async function."""
89
-
90
- async def async_function():
91
- pass
92
-
93
- with pytest.raises(ValueError, match="Function is already async"):
94
- asyncify(async_function)
95
-
96
-
97
- def test_utc_now_returns_naive_utc():
98
- """utc_now should return a naive datetime captured within two timestamps."""
99
-
100
- before = datetime.now(UTC).replace(tzinfo=None)
101
- result = utc_now()
102
- after = datetime.now(UTC).replace(tzinfo=None)
103
-
104
- assert result.tzinfo is None
105
- assert before <= result <= after
@@ -1,143 +0,0 @@
1
- import asyncio
2
- import uuid
3
-
4
- import pytest
5
-
6
- from planar.testing.memory_storage import MemoryStorage
7
-
8
-
9
- @pytest.fixture
10
- async def storage() -> MemoryStorage:
11
- """Provides an instance of MemoryStorage."""
12
- return MemoryStorage()
13
-
14
-
15
- async def test_put_get_bytes(storage: MemoryStorage):
16
- """Test storing and retrieving raw bytes."""
17
- test_data = b"some binary data \x00\xff for memory"
18
- mime_type = "application/octet-stream"
19
-
20
- ref = await storage.put_bytes(test_data, mime_type=mime_type)
21
- assert isinstance(ref, str)
22
- try:
23
- uuid.UUID(ref) # Check if ref is a valid UUID string
24
- except ValueError:
25
- pytest.fail(f"Returned ref '{ref}' is not a valid UUID string")
26
-
27
- retrieved_data, retrieved_mime = await storage.get_bytes(ref)
28
-
29
- assert retrieved_data == test_data
30
- assert retrieved_mime == mime_type
31
-
32
- # Check internal state (optional)
33
- assert ref in storage._blobs
34
- assert ref in storage._mime_types
35
- assert storage._blobs[ref] == test_data
36
- assert storage._mime_types[ref] == mime_type
37
-
38
-
39
- async def test_put_get_string(storage: MemoryStorage):
40
- """Test storing and retrieving a string."""
41
- test_string = "Hello, memory! This is a test string with Unicode: éàçü."
42
- mime_type = "text/plain"
43
- encoding = "utf-16"
44
-
45
- # Store with explicit encoding and mime type
46
- ref = await storage.put_string(test_string, encoding=encoding, mime_type=mime_type)
47
- expected_mime_type = f"{mime_type}; charset={encoding}"
48
-
49
- retrieved_string, retrieved_mime = await storage.get_string(ref, encoding=encoding)
50
-
51
- assert retrieved_string == test_string
52
- assert retrieved_mime == expected_mime_type
53
-
54
- # Test default encoding (utf-8)
55
- ref_utf8 = await storage.put_string(test_string, mime_type="text/html")
56
- expected_mime_utf8 = "text/html; charset=utf-8"
57
- retrieved_string_utf8, retrieved_mime_utf8 = await storage.get_string(ref_utf8)
58
- assert retrieved_string_utf8 == test_string
59
- assert retrieved_mime_utf8 == expected_mime_utf8
60
-
61
-
62
- async def test_put_get_stream(storage: MemoryStorage):
63
- """Test storing data from an async generator stream."""
64
- test_chunks = [b"mem_chunk1 ", b"mem_chunk2 ", b"mem_chunk3"]
65
- full_data = b"".join(test_chunks)
66
- mime_type = "image/gif"
67
-
68
- async def _test_stream():
69
- for chunk in test_chunks:
70
- yield chunk
71
- await asyncio.sleep(0.01) # Simulate async work
72
-
73
- ref = await storage.put(_test_stream(), mime_type=mime_type)
74
-
75
- stream, retrieved_mime = await storage.get(ref)
76
- retrieved_data = b""
77
- async for chunk in stream:
78
- retrieved_data += chunk
79
-
80
- assert retrieved_data == full_data
81
- assert retrieved_mime == mime_type
82
-
83
-
84
- async def test_put_no_mime_type(storage: MemoryStorage):
85
- """Test storing data without providing a mime type."""
86
- test_data = b"memory data without mime"
87
-
88
- ref = await storage.put_bytes(test_data)
89
- retrieved_data, retrieved_mime = await storage.get_bytes(ref)
90
-
91
- assert retrieved_data == test_data
92
- assert retrieved_mime is None
93
-
94
- # Check internal state
95
- assert ref in storage._blobs
96
- assert ref not in storage._mime_types
97
-
98
-
99
- async def test_delete(storage: MemoryStorage):
100
- """Test deleting stored data."""
101
- ref = await storage.put_bytes(b"to be deleted from memory", mime_type="text/plain")
102
-
103
- # Verify data exists before delete (optional)
104
- assert ref in storage._blobs
105
- assert ref in storage._mime_types
106
-
107
- await storage.delete(ref)
108
-
109
- # Verify data is gone after delete
110
- assert ref not in storage._blobs
111
- assert ref not in storage._mime_types
112
-
113
- # Try getting deleted ref
114
- with pytest.raises(FileNotFoundError):
115
- await storage.get(ref)
116
-
117
-
118
- async def test_get_non_existent(storage: MemoryStorage):
119
- """Test getting a reference that does not exist."""
120
- non_existent_ref = str(uuid.uuid4())
121
- with pytest.raises(FileNotFoundError):
122
- await storage.get(non_existent_ref)
123
-
124
-
125
- async def test_delete_non_existent(storage: MemoryStorage):
126
- """Test deleting a reference that does not exist (should not raise error)."""
127
- non_existent_ref = str(uuid.uuid4())
128
- initial_blob_count = len(storage._blobs)
129
- initial_mime_count = len(storage._mime_types)
130
- try:
131
- await storage.delete(non_existent_ref)
132
- # Ensure no data was actually deleted
133
- assert len(storage._blobs) == initial_blob_count
134
- assert len(storage._mime_types) == initial_mime_count
135
- except Exception as e:
136
- pytest.fail(f"Deleting non-existent ref raised an exception: {e}")
137
-
138
-
139
- async def test_external_url(storage: MemoryStorage):
140
- """Test that external_url returns None for memory storage."""
141
- ref = await storage.put_bytes(b"some data for url test")
142
- url = await storage.external_url(ref)
143
- assert url is None
@@ -1,120 +0,0 @@
1
- import asyncio
2
- import multiprocessing
3
- import multiprocessing.connection
4
- from multiprocessing.connection import Connection
5
- from uuid import UUID
6
-
7
- from planar.db import DatabaseManager, new_session
8
- from planar.session import engine_var, get_engine, session_var
9
- from planar.workflows.decorators import step, workflow
10
- from planar.workflows.exceptions import LockResourceFailed
11
- from planar.workflows.execution import (
12
- _DEFAULT_LOCK_DURATION,
13
- execute,
14
- )
15
- from planar.workflows.lock import lock_workflow
16
- from planar.workflows.models import Workflow, WorkflowStatus
17
-
18
- # bidirectional communication between the test process and the worker processes.
19
- conn: Connection
20
-
21
-
22
- @step(max_retries=0)
23
- async def dummy_step():
24
- conn.send("waiting")
25
- # Wait until "proceed" is received from the queue.
26
- if conn.recv() != "proceed":
27
- raise Exception('Expected "proceed"')
28
- return "success"
29
-
30
-
31
- @workflow()
32
- async def dummy_workflow():
33
- # Run the dummy step and return its result.
34
- result = await dummy_step()
35
- return result
36
-
37
-
38
- # copy of the resume_workflow function which allows more fine grained control from
39
- # the test process. This is fine because our goal is to test concurrency detection
40
- # implemented by the execute function.
41
- async def resume_with_semaphores(workflow_id: UUID):
42
- engine = get_engine()
43
- async with new_session(engine) as session:
44
- tok = session_var.set(session)
45
- try:
46
- async with session.begin():
47
- workflow = await session.get(Workflow, workflow_id)
48
- if not workflow:
49
- raise ValueError(f"Workflow {workflow_id} not found")
50
- conn.send("ready")
51
- # Wait until "start" is received on stdin.
52
- if conn.recv() != "start":
53
- raise Exception('Expected "start"')
54
- async with lock_workflow(
55
- workflow,
56
- _DEFAULT_LOCK_DURATION,
57
- retry_count=0,
58
- ):
59
- await execute(workflow)
60
- conn.send("completed")
61
- except LockResourceFailed:
62
- conn.send("conflict")
63
- finally:
64
- session_var.reset(tok)
65
-
66
-
67
- # This worker function will be launched as a separate process.
68
- # It takes the workflow id, db_url and a multiprocess Pipe.
69
- def worker(wf_id: UUID, db_url: str, connection: Connection):
70
- global conn
71
- conn = connection
72
- # Create a new engine for this process.
73
- db_manager = DatabaseManager(db_url)
74
- db_manager.connect()
75
- engine = db_manager.get_engine()
76
- engine_var.set(engine)
77
- # Run the resume_with_semaphores coroutine.
78
- # We use asyncio.run so that the worker’s event loop is independent.
79
- asyncio.run(resume_with_semaphores(wf_id))
80
-
81
-
82
- async def test_concurrent_workflow_execution(tmp_db_url, tmp_db_engine):
83
- async with new_session(tmp_db_engine) as session:
84
- session_var.set(session)
85
- wf: Workflow = await dummy_workflow.start()
86
- wf_id = wf.id
87
-
88
- # Launch two separate processes that attempt to resume the workflow concurrently.
89
- p1_parent, p1_worker = multiprocessing.Pipe(duplex=True)
90
- p2_parent, p2_worker = multiprocessing.Pipe(duplex=True)
91
- p1 = multiprocessing.Process(target=worker, args=(wf_id, tmp_db_url, p1_worker))
92
- p2 = multiprocessing.Process(target=worker, args=(wf_id, tmp_db_url, p2_worker))
93
- p1.start()
94
- p2.start()
95
- # wait for both workers to fetch the workflow from the database.
96
- assert p1_parent.recv() == "ready"
97
- assert p2_parent.recv() == "ready"
98
- # allow worker 1 to proceed.
99
- p1_parent.send("start")
100
- # wait for worker 1 to start the workflow and pause in the dummy step.
101
- assert p1_parent.recv() == "waiting"
102
- # allow worker 2 to proceed.
103
- p2_parent.send("start")
104
- # worker 2 should fail and will send a "conflict" message.
105
- assert p2_parent.recv() == "conflict"
106
- # allow worker 1 to proceed
107
- p1_parent.send("proceed")
108
- # worker 1 should complete the workflow and send a "completed" message.
109
- assert p1_parent.recv() == "completed"
110
- # cleanup workers
111
- p1.join()
112
- p2.join()
113
-
114
- await session.refresh(wf)
115
- assert wf, f"Workflow {wf_id} not found"
116
- # Assert that the workflow completed successfully.
117
- assert wf.status == WorkflowStatus.SUCCEEDED, (
118
- f"Unexpected workflow status: {wf.status}"
119
- )
120
- assert wf.result == "success", f"Unexpected workflow result: {wf.result}"
@@ -1,140 +0,0 @@
1
- import asyncio
2
- from datetime import timedelta
3
-
4
- from planar.db import new_session
5
- from planar.session import session_var
6
- from planar.testing.synchronizable_tracer import SynchronizableTracer, TraceSpec
7
- from planar.utils import utc_now
8
- from planar.workflows.decorators import workflow
9
- from planar.workflows.execution import execute
10
- from planar.workflows.models import (
11
- LockedResource,
12
- Workflow,
13
- WorkflowStatus,
14
- workflow_exec_lock_key,
15
- )
16
- from planar.workflows.orchestrator import WorkflowOrchestrator
17
- from planar.workflows.step_core import Suspend, suspend
18
- from planar.workflows.tracing import tracer_var
19
-
20
-
21
- # Define a long-running workflow.
22
- @workflow()
23
- async def long_running_workflow():
24
- # Simulate a long-running operation by sleeping 1 second.
25
- await asyncio.sleep(1)
26
- return "finished"
27
-
28
-
29
- async def test_lock_timer_extension(tmp_db_engine):
30
- tracer = SynchronizableTracer()
31
- tracer_var.set(tracer)
32
- lock_acquired = tracer.instrument(
33
- TraceSpec(function_name="lock_resource", message="commit")
34
- )
35
- lock_heartbeat = tracer.instrument(
36
- TraceSpec(function_name="lock_heartbeat", message="commit")
37
- )
38
-
39
- async with new_session(tmp_db_engine) as session:
40
- # This test verifies that when a workflow is executing, the heartbeat task
41
- # (lock_heartbeat) extends the workflow's lock_until field. We run a
42
- # long-running workflow (which sleeps for 1 second) with a short lock
43
- # duration and heartbeat interval. While the workflow is running we query
44
- # the stored workflow record and ensure that lock_until is updated
45
- # (extended) by the heartbeat.
46
-
47
- session_var.set(session)
48
- # Start the workflow.
49
- # Run workflow execution in the background with short durations so
50
- # heartbeat kicks in quickly.
51
- async with WorkflowOrchestrator.ensure_started(
52
- lock_duration=timedelta(seconds=1)
53
- ) as orchestrator:
54
- wf: Workflow = await long_running_workflow.start()
55
- wf_id = wf.id
56
- lock_key = workflow_exec_lock_key(wf_id)
57
-
58
- await lock_acquired.wait()
59
-
60
- async with session.begin():
61
- locked_resource = await session.get(LockedResource, lock_key)
62
- assert locked_resource, "Expected a locked resource record"
63
- lock_time_1 = locked_resource.lock_until
64
- assert lock_time_1, "Lock time should be set"
65
-
66
- # Wait a bit longer to allow another heartbeat cycle.
67
- await lock_heartbeat.wait()
68
- async with session.begin():
69
- await session.refresh(locked_resource)
70
- lock_time_2 = locked_resource.lock_until
71
- assert lock_time_2, "Lock time should be set"
72
-
73
- # The lock_time_2 should be later than lock_time_1 if the heartbeat is working.
74
- assert lock_time_2 > lock_time_1, (
75
- f"Expected lock_until to be extended by heartbeat: {lock_time_1} vs {lock_time_2}"
76
- )
77
-
78
- # Let the workflow finish.
79
- await orchestrator.wait_for_completion(wf_id)
80
-
81
- # Verify the workflow completed successfully.
82
- await session.refresh(wf)
83
- assert wf.status == WorkflowStatus.SUCCEEDED
84
- assert wf.result == "finished"
85
-
86
-
87
- @workflow()
88
- async def crashed_worker_workflow():
89
- # This workflow uses suspend() to simulate work that is paused.
90
- # The first execution returns a Suspend object.
91
- # When resumed it completes and returns "completed".
92
- # First step: suspend (simulate waiting, e.g. because a worker had locked it).
93
- await suspend(interval=timedelta(seconds=5))
94
- # After the suspension it resumes here.
95
- return "completed"
96
-
97
-
98
- async def test_orchestrator_resumes_crashed_worker(tmp_db_engine):
99
- # This test simulates the scenario where a worker has “crashed” after
100
- # locking a workflow. We start a workflow that suspends. Then we add a LockedResource
101
- # record with an expired lock_until time to simulate a crashed
102
-
103
- # Invoking the workflow_orchestrator (which polls for suspended workflows
104
- # whose wakeup time is reached or that have expired locks) should cause the
105
- # the workflow to be resumed. Finally, we verify that the workflow
106
- # completes successfully. Start the workflow – its first execution will
107
- # suspend.
108
- async with new_session(tmp_db_engine) as session:
109
- session_var.set(session)
110
- wf = await crashed_worker_workflow.start()
111
-
112
- result = await execute(wf)
113
- assert isinstance(result, Suspend), (
114
- "Expected the workflow to suspend on first execution."
115
- )
116
- # Simulate a crashed worker by directly changing the workflow record.
117
- await session.refresh(wf)
118
- # Force wakeup_at and lock_until to be in the past.
119
- past_time = utc_now() - timedelta(seconds=1)
120
- wf.wakeup_at = past_time
121
- session.add(LockedResource(lock_key=f"workflow:{wf.id}", lock_until=past_time))
122
- # Ensure it is marked as running, which would not normally be picked by
123
- # the orchestrator
124
- await session.commit()
125
-
126
- # Now run the orchestrator, which polls for suspended workflows with
127
- # wakeup_at <= now.
128
- # We use a short poll interval.
129
- async with WorkflowOrchestrator.ensure_started(
130
- poll_interval=0.2
131
- ) as orchestrator:
132
- await orchestrator.wait_for_completion(wf.id)
133
-
134
- await session.refresh(wf)
135
- assert wf.status == WorkflowStatus.SUCCEEDED, (
136
- f"Expected workflow status 'success' but got {wf.status}"
137
- )
138
- assert wf.result == "completed", (
139
- f"Expected workflow result 'completed' but got {wf.result}"
140
- )