planar 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. planar/_version.py +1 -1
  2. planar/ai/agent.py +19 -3
  3. planar/ai/agent_base.py +1 -5
  4. planar/ai/agent_utils.py +0 -72
  5. planar/ai/models.py +30 -0
  6. planar/ai/pydantic_ai.py +12 -11
  7. planar/app.py +6 -11
  8. planar/config.py +6 -1
  9. planar/data/__init__.py +17 -0
  10. planar/data/config.py +49 -0
  11. planar/data/dataset.py +263 -0
  12. planar/data/exceptions.py +19 -0
  13. planar/data/test_dataset.py +354 -0
  14. planar/db/db.py +39 -21
  15. planar/dependencies.py +30 -0
  16. planar/files/test_files.py +6 -7
  17. planar/modeling/mixins/test_auditable.py +2 -2
  18. planar/modeling/orm/planar_base_entity.py +4 -1
  19. planar/routers/agents_router.py +52 -4
  20. planar/routers/test_agents_router.py +2 -2
  21. planar/routers/test_files_router.py +2 -2
  22. planar/routers/test_object_config_router.py +2 -2
  23. planar/routers/test_routes_security.py +3 -2
  24. planar/routers/test_rule_router.py +2 -2
  25. planar/routers/test_workflow_router.py +6 -8
  26. planar/rules/__init__.py +12 -18
  27. planar/scaffold_templates/app/flows/process_invoice.py.j2 +1 -2
  28. planar/scaffold_templates/planar.dev.yaml.j2 +9 -0
  29. planar/scaffold_templates/planar.prod.yaml.j2 +14 -0
  30. planar/scaffold_templates/pyproject.toml.j2 +2 -2
  31. planar/test_sqlalchemy.py +36 -1
  32. planar/testing/fixtures.py +3 -17
  33. planar/testing/workflow_observer.py +2 -2
  34. planar/workflows/notifications.py +39 -3
  35. planar/workflows/test_lock_timeout.py +4 -4
  36. {planar-0.8.0.dist-info → planar-0.9.1.dist-info}/METADATA +27 -13
  37. {planar-0.8.0.dist-info → planar-0.9.1.dist-info}/RECORD +39 -33
  38. {planar-0.8.0.dist-info → planar-0.9.1.dist-info}/WHEEL +0 -0
  39. {planar-0.8.0.dist-info → planar-0.9.1.dist-info}/entry_points.txt +0 -0
planar/db/db.py CHANGED
@@ -20,7 +20,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
20
20
 
21
21
  import planar
22
22
  from planar.logging import get_logger
23
- from planar.modeling.orm.planar_base_entity import PLANAR_APPLICATION_METADATA
23
+ from planar.modeling.orm.planar_base_entity import (
24
+ PLANAR_APPLICATION_METADATA,
25
+ PLANAR_ENTITY_SCHEMA,
26
+ )
24
27
  from planar.utils import P, R, T, U, exponential_backoff_with_jitter
25
28
 
26
29
 
@@ -170,9 +173,12 @@ class DatabaseManager:
170
173
  def __init__(
171
174
  self,
172
175
  db_url: str | URL,
176
+ *,
177
+ entity_schema: str = PLANAR_ENTITY_SCHEMA,
173
178
  ):
174
179
  self.db_url = make_url(db_url) if isinstance(db_url, str) else db_url
175
180
  self.engine: AsyncEngine | None = None
181
+ self.entity_schema = entity_schema
176
182
 
177
183
  def _create_sqlite_engine(self, url: URL) -> AsyncEngine:
178
184
  # in practice this high timeout is only use
@@ -189,9 +195,14 @@ class DatabaseManager:
189
195
  # even though it is the default value.
190
196
  autocommit=LEGACY_TRANSACTION_CONTROL,
191
197
  ),
192
- # SQLite doesn't support schemas, so we need to translate the planar schema
193
- # name to None.
194
- execution_options={"schema_translate_map": {"planar": None}},
198
+ # SQLite doesn't support schemas, so we need to translate the planar and user
199
+ # schema names to None.
200
+ execution_options={
201
+ "schema_translate_map": {
202
+ "planar": None,
203
+ PLANAR_ENTITY_SCHEMA: None,
204
+ }
205
+ },
195
206
  )
196
207
 
197
208
  def do_begin(conn: Connection):
@@ -202,7 +213,12 @@ class DatabaseManager:
202
213
  return engine
203
214
 
204
215
  def _create_postgresql_engine(self, url: URL) -> AsyncEngine:
205
- engine = create_async_engine(url)
216
+ # Map default (PLANAR_ENTITY_SCHEMA) schema to the configured entity schema for user tables.
217
+ # Leave the system table schema ('planar') unmapped so system tables are not overridden.
218
+ schema_map = {PLANAR_ENTITY_SCHEMA: self.entity_schema}
219
+ engine = create_async_engine(
220
+ url, execution_options={"schema_translate_map": schema_map}
221
+ )
206
222
 
207
223
  return engine
208
224
 
@@ -214,6 +230,12 @@ class DatabaseManager:
214
230
 
215
231
  db_backend = self.db_url.get_backend_name()
216
232
 
233
+ if self.entity_schema == PLANAR_SCHEMA:
234
+ logger.warning(
235
+ "entity_schema is set to 'planar'; mixing user and system tables in the same schema is discouraged",
236
+ entity_schema=self.entity_schema,
237
+ )
238
+
217
239
  match db_backend:
218
240
  case "sqlite":
219
241
  logger.info(
@@ -293,27 +315,23 @@ class DatabaseManager:
293
315
  else:
294
316
  # Ensure planar schema exists
295
317
  await conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {PLANAR_SCHEMA}"))
318
+ # Ensure the configured entity schema exists
319
+ if self.entity_schema != PLANAR_SCHEMA:
320
+ await conn.execute(
321
+ text(f"CREATE SCHEMA IF NOT EXISTS {self.entity_schema}")
322
+ )
296
323
 
297
- async def migrate(self, use_alembic: bool):
324
+ async def migrate(self):
298
325
  """
299
326
  Runs database migrations.
300
327
  By default, uses SQLModel.metadata.create_all.
301
- Set use_alembic=True to use Alembic (requires Alembic setup).
302
328
  """
303
329
  if not self.engine:
304
330
  raise RuntimeError("Database engine not initialized. Call connect() first.")
305
331
 
306
- logger.info("starting database migration")
307
- if use_alembic:
308
- logger.info("using alembic for migrations")
309
- await self._setup_database()
310
- await self._run_system_migrations()
311
- # For now user migrations are not supported, so we fall back to SQLModel.metadata.create_all
312
- async with self.engine.begin() as conn:
313
- await conn.run_sync(PLANAR_APPLICATION_METADATA.create_all)
314
-
315
- else:
316
- async with self.engine.begin() as conn:
317
- await self._setup_database()
318
- await conn.run_sync(PLANAR_FRAMEWORK_METADATA.create_all)
319
- await conn.run_sync(PLANAR_APPLICATION_METADATA.create_all)
332
+ logger.info("starting database migration with alembic")
333
+ await self._setup_database()
334
+ await self._run_system_migrations()
335
+ # For now user migrations are not supported, so we fall back to SQLModel.metadata.create_all
336
+ async with self.engine.begin() as conn:
337
+ await conn.run_sync(PLANAR_APPLICATION_METADATA.create_all)
planar/dependencies.py ADDED
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from importlib import import_module
5
+ from typing import Mapping, Tuple
6
+
7
+
8
+ # mapping: public_name -> (relative_submodule, attribute_in_submodule)
9
+ # This is a PEP 562 compliant way to lazily import modules
10
+ # which is a way to avoid circular dependencies in __init__.py.
11
+ def lazy_exports(module_name: str, mapping: Mapping[str, Tuple[str, str]]) -> None:
12
+ mod = sys.modules[module_name]
13
+ mod.__all__ = list(mapping.keys()) # type: ignore
14
+
15
+ def __getattr__(name: str):
16
+ try:
17
+ submod, attr = mapping[name]
18
+ except KeyError:
19
+ raise AttributeError(
20
+ f"module {module_name!r} has no attribute {name!r}"
21
+ ) from None
22
+ obj = getattr(import_module(submod, module_name), attr)
23
+ setattr(mod, name, obj) # cache
24
+ return obj
25
+
26
+ def __dir__():
27
+ return sorted(set(mod.__dict__.keys()) | set(mod.__all__))
28
+
29
+ mod.__getattr__ = __getattr__ # PEP 562
30
+ mod.__dir__ = __dir__
@@ -19,15 +19,14 @@ from planar.workflows.decorators import workflow
19
19
  from planar.workflows.execution import execute
20
20
  from planar.workflows.models import Workflow
21
21
 
22
- app = PlanarApp(
23
- config=sqlite_config(":memory:"),
24
- title="Planar app for testing file workflows",
25
- description="Testing",
26
- )
27
-
28
22
 
29
23
  @pytest.fixture(name="app")
30
- def app_fixture():
24
+ def app_fixture(tmp_db_path: str):
25
+ app = PlanarApp(
26
+ config=sqlite_config(tmp_db_path),
27
+ title="Planar app for testing file workflows",
28
+ description="Testing",
29
+ )
31
30
  yield app
32
31
 
33
32
 
@@ -37,10 +37,10 @@ class TestAuditableModel(AuditableMixin, SQLModel, table=True):
37
37
 
38
38
 
39
39
  @pytest.fixture
40
- async def session(mem_db_engine):
40
+ async def session(tmp_db_engine):
41
41
  """Create a database session."""
42
42
 
43
- async with new_session(mem_db_engine) as session:
43
+ async with new_session(tmp_db_engine) as session:
44
44
  await (await session.connection()).run_sync(SQLModel.metadata.create_all)
45
45
  yield session
46
46
 
@@ -12,7 +12,10 @@ from .reexports import SQLModel
12
12
  logger = get_logger("orm.PlanarBaseEntity")
13
13
 
14
14
 
15
- PLANAR_APPLICATION_METADATA = MetaData()
15
+ # Default schema for all entity / user tables, but can be overridden by the user
16
+ # in planar configuration, which db.py uses.
17
+ PLANAR_ENTITY_SCHEMA = "planar_entity"
18
+ PLANAR_APPLICATION_METADATA = MetaData(schema=PLANAR_ENTITY_SCHEMA)
16
19
 
17
20
 
18
21
  class PlanarBaseEntity(UUIDPrimaryKeyMixin, AuditableMixin, SQLModel, table=False):
@@ -1,12 +1,13 @@
1
1
  import asyncio
2
- from typing import Any
2
+ import json
3
+ from typing import Any, AsyncGenerator
3
4
 
4
5
  from fastapi import APIRouter, BackgroundTasks, HTTPException
5
6
  from fastapi.responses import StreamingResponse
6
7
  from pydantic import BaseModel
7
8
 
8
- from planar.ai.agent_utils import AgentEventEmitter, AgentEventType, agent_configuration
9
- from planar.ai.models import AgentConfig
9
+ from planar.ai.agent_utils import agent_configuration
10
+ from planar.ai.models import AgentConfig, AgentEventEmitter, AgentEventType
10
11
  from planar.ai.utils import AgentSerializeable, serialize_agent
11
12
  from planar.logging import get_logger
12
13
  from planar.object_config.object_config import ConfigValidationError
@@ -17,6 +18,7 @@ from planar.security.authorization import (
17
18
  validate_authorization_for,
18
19
  )
19
20
  from planar.session import get_engine, session_context
21
+ from planar.utils import utc_now
20
22
 
21
23
  logger = get_logger(__name__)
22
24
 
@@ -29,6 +31,52 @@ class AgentSimulationData[T](BaseModel):
29
31
  input_value: str | T
30
32
 
31
33
 
34
+ class SimulationAgentEvent:
35
+ def __init__(
36
+ self,
37
+ event_type: AgentEventType,
38
+ data: BaseModel | str | None,
39
+ ):
40
+ self.event_type = event_type
41
+ self.data = data
42
+ self.timestamp = utc_now().isoformat()
43
+
44
+
45
+ class SimulationAgentEventEmitter(AgentEventEmitter):
46
+ def __init__(self):
47
+ self.queue: asyncio.Queue[SimulationAgentEvent] = asyncio.Queue()
48
+
49
+ def emit(self, event_type: AgentEventType, data: BaseModel | str | None):
50
+ event = SimulationAgentEvent(event_type, data)
51
+ self.queue.put_nowait(event)
52
+
53
+ async def get_events(self) -> AsyncGenerator[str, None]:
54
+ while True:
55
+ event = await self.queue.get()
56
+
57
+ if isinstance(event.data, BaseModel):
58
+ data = {
59
+ "data": event.data.model_dump(),
60
+ "event_type": event.event_type,
61
+ }
62
+ else:
63
+ data = {
64
+ "data": event.data,
65
+ "event_type": event.event_type,
66
+ }
67
+
68
+ yield f"data: {json.dumps(data)}\n\n"
69
+
70
+ self.queue.task_done()
71
+
72
+ if event.event_type in (AgentEventType.COMPLETED, AgentEventType.ERROR):
73
+ break
74
+
75
+ def is_empty(self) -> bool:
76
+ """Check if the queue is empty."""
77
+ return self.queue.empty()
78
+
79
+
32
80
  class AgentEvent(BaseModel):
33
81
  """Model representing a single event emitted by the agent."""
34
82
 
@@ -147,7 +195,7 @@ def create_agent_router(object_registry: ObjectRegistry) -> APIRouter:
147
195
  logger.warning("agent not found for simulation", agent_name=agent_name)
148
196
  raise HTTPException(status_code=404, detail="Agent not found")
149
197
 
150
- emitter = AgentEventEmitter()
198
+ emitter = SimulationAgentEventEmitter()
151
199
 
152
200
  # Create a copy of the request data to avoid sharing data between tasks
153
201
  request_copy = request.model_copy()
@@ -15,10 +15,10 @@ from planar.testing.planar_test_client import PlanarTestClient
15
15
 
16
16
 
17
17
  @pytest.fixture(name="app")
18
- def app_fixture():
18
+ def app_fixture(tmp_db_path: str):
19
19
  """Create a test app with agents."""
20
20
  app = PlanarApp(
21
- config=sqlite_config(":memory:"),
21
+ config=sqlite_config(tmp_db_path),
22
22
  title="Test app for agent router",
23
23
  description="Testing agent endpoints",
24
24
  )
@@ -10,9 +10,9 @@ from planar.testing.planar_test_client import PlanarTestClient
10
10
 
11
11
 
12
12
  @pytest.fixture(name="app")
13
- def app_fixture():
13
+ def app_fixture(tmp_db_path: str):
14
14
  return PlanarApp(
15
- config=sqlite_config(":memory:"),
15
+ config=sqlite_config(tmp_db_path),
16
16
  title="Test app for files router",
17
17
  description="Testing files endpoints",
18
18
  )
@@ -41,10 +41,10 @@ class OutputFromTestRule(BaseModel):
41
41
 
42
42
 
43
43
  @pytest.fixture(name="app")
44
- def app_fixture():
44
+ def app_fixture(tmp_db_path: str):
45
45
  """Create a test app with agents and rules."""
46
46
  app = PlanarApp(
47
- config=sqlite_config(":memory:"),
47
+ config=sqlite_config(tmp_db_path),
48
48
  title="Test app for object config router",
49
49
  description="Testing object configuration endpoints",
50
50
  )
@@ -61,8 +61,9 @@ def restrictive_policy_file(tmp_path):
61
61
 
62
62
 
63
63
  @pytest.fixture(name="app_with_restricted_authz")
64
- def create_app_with_restricted_authz(restrictive_policy_file):
65
- config = sqlite_config("test_authz_router.db")
64
+ def create_app_with_restricted_authz(tmp_path, restrictive_policy_file):
65
+ db_path = tmp_path / "test_authz_router.db"
66
+ config = sqlite_config(str(db_path))
66
67
  config.security = SecurityConfig(
67
68
  authz=AuthzConfig(enabled=True, policy_file=restrictive_policy_file)
68
69
  )
@@ -83,9 +83,9 @@ def pricing_rule_with_wrong_type(
83
83
 
84
84
 
85
85
  @pytest.fixture(name="app")
86
- def app_fixture():
86
+ def app_fixture(tmp_db_path: str):
87
87
  app = PlanarApp(
88
- config=sqlite_config(":memory:"),
88
+ config=sqlite_config(tmp_db_path),
89
89
  title="Test app for agent router",
90
90
  description="Testing agent endpoints",
91
91
  )
@@ -120,18 +120,16 @@ async def file_processing_workflow(file: PlanarFile):
120
120
  )
121
121
 
122
122
 
123
- app = PlanarApp(
124
- config=sqlite_config("test_workflow_router.db"),
125
- title="Test Workflow Router API",
126
- description="API for testing workflow routers",
127
- )
128
-
129
-
130
123
  # ------ TESTS ------
131
124
 
132
125
 
133
126
  @pytest.fixture(name="app")
134
- def app_fixture():
127
+ def app_fixture(tmp_db_path: str):
128
+ app = PlanarApp(
129
+ config=sqlite_config(tmp_db_path),
130
+ title="Test Workflow Router API",
131
+ description="API for testing workflow routers",
132
+ )
135
133
  # Re-register workflows since ObjectRegistry gets reset before each test
136
134
  app.register_workflow(expense_approval_workflow)
137
135
  app.register_workflow(file_processing_workflow)
planar/rules/__init__.py CHANGED
@@ -1,23 +1,17 @@
1
- import importlib
2
- from typing import Any
1
+ from typing import TYPE_CHECKING, Mapping, Tuple
3
2
 
4
- _DEFERRED_IMPORTS = {
5
- "rule": ".decorator",
6
- "Rule": ".models",
7
- "RuleSerializeable": ".models",
3
+ from planar.dependencies import lazy_exports
4
+
5
+ _DEFERRED_IMPORTS: Mapping[str, Tuple[str, str]] = {
6
+ "rule": (".decorator", "rule"),
7
+ "Rule": (".models", "Rule"),
8
+ "RuleSerializeable": (".models", "RuleSerializeable"),
8
9
  }
9
10
 
11
+ if TYPE_CHECKING:
12
+ from .decorator import rule
13
+ from .models import Rule, RuleSerializeable
10
14
 
11
- def __getattr__(name: str) -> Any:
12
- """
13
- Lazily import modules to avoid circular dependencies.
14
- This is called by the Python interpreter when a module attribute is accessed
15
- that cannot be found in the module's __dict__.
16
- PEP 562
17
- """
18
- if name in _DEFERRED_IMPORTS:
19
- module_path = _DEFERRED_IMPORTS[name]
20
- module = importlib.import_module(module_path, __name__)
21
- return getattr(module, name)
15
+ __all__ = ["Rule", "RuleSerializeable", "rule"]
22
16
 
23
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
17
+ lazy_exports(__name__, _DEFERRED_IMPORTS)
@@ -1,5 +1,4 @@
1
1
  from planar.ai import Agent
2
- from planar.ai.providers import OpenAI
3
2
  from planar.files import PlanarFile
4
3
  from planar.human import Human
5
4
  from planar.rules.decorator import rule
@@ -23,7 +22,7 @@ class RuleOutput(BaseModel):
23
22
 
24
23
  invoice_agent = Agent(
25
24
  name="Invoice Agent",
26
- model=OpenAI.gpt_4_1,
25
+ model="openai:gpt-4.1",
27
26
  tools=[],
28
27
  max_turns=1,
29
28
  system_prompt="Extract vendor and amount from invoice text.",
@@ -32,3 +32,12 @@ logging:
32
32
  # the following lines to enable INFO level for the whole application (except sqlalchemy.engine, which must be enabled above)
33
33
  # "":
34
34
  # level: INFO
35
+
36
+ # Uncomment to enable data features with Ducklake
37
+ # data:
38
+ # catalog:
39
+ # type: duckdb
40
+ # path: .data/catalog.ducklake
41
+ # storage:
42
+ # backend: localdir
43
+ # directory: .data/ducklake_files
@@ -30,3 +30,17 @@ security:
30
30
  ai_providers:
31
31
  openai:
32
32
  api_key: ${OPENAI_API_KEY}
33
+
34
+ # Uncomment to enable data features with Ducklake
35
+ # data:
36
+ # catalog:
37
+ # type: postgres
38
+ # host: ${DB_HOST}
39
+ # port: ${DB_PORT}
40
+ # user: ${DB_USER}
41
+ # password: ${DB_PASSWORD}
42
+ # db: ducklake_catalog
43
+ # storage:
44
+ # backend: s3
45
+ # region: us-west-2
46
+ # bucket_name: ${S3_DATA_BUCKET}
@@ -3,8 +3,8 @@ name = "{{ name }}"
3
3
  version = "0.1.0"
4
4
  requires-python = ">=3.12"
5
5
  dependencies = [
6
- "planar>=0.6.0",
6
+ "planar>=0.9.0",
7
7
  ]
8
8
 
9
9
  [[tool.uv.index]]
10
- url = "https://coplane.github.io/planar/simple/"
10
+ url = "https://coplane.github.io/planar/simple/"
planar/test_sqlalchemy.py CHANGED
@@ -3,7 +3,7 @@ from uuid import uuid4
3
3
  import pytest
4
4
  from sqlalchemy.exc import DBAPIError
5
5
  from sqlalchemy.ext.asyncio import AsyncEngine
6
- from sqlmodel import col, insert, select
6
+ from sqlmodel import col, insert, select, text
7
7
 
8
8
  from planar.db import PlanarSession, new_session
9
9
  from planar.modeling.orm.planar_base_entity import PlanarBaseEntity
@@ -156,3 +156,38 @@ async def test_serializable_transaction_failure_1(tmp_db_engine: AsyncEngine):
156
156
  # Session 2: Commit should fail with serialization error
157
157
  with pytest.raises(DBAPIError, match="could not serialize access"):
158
158
  await session2.commit()
159
+
160
+
161
+ async def test_entity_schema_and_planar_schema_presence(tmp_db_engine: AsyncEngine):
162
+ table_name = SomeModel.__tablename__
163
+
164
+ async with new_session(tmp_db_engine) as session:
165
+ dialect = session.dialect.name
166
+
167
+ if dialect == "postgresql":
168
+ # Verify schemas include 'planar' and the default entity schema 'planar_entity'
169
+ res = await session.exec(
170
+ text("select schema_name from information_schema.schemata") # type: ignore[arg-type]
171
+ )
172
+ schemas = {row[0] for row in res}
173
+ assert "planar" in schemas
174
+ assert "planar_entity" in schemas
175
+
176
+ # Verify SomeModel table is created in the entity schema
177
+ res = await session.exec(
178
+ text(
179
+ "select table_schema from information_schema.tables where table_name = :tn"
180
+ ).bindparams(tn=table_name) # type: ignore[arg-type]
181
+ )
182
+ table_schemas = {row[0] for row in res}
183
+ assert "planar_entity" in table_schemas
184
+ assert "public" not in table_schemas
185
+
186
+ else:
187
+ # SQLite: no schemas; ensure table exists
188
+ res = await session.exec(
189
+ text("select name from sqlite_master where type='table'") # type: ignore[arg-type]
190
+ )
191
+ tables = {row[0] for row in res}
192
+ assert table_name in tables
193
+ assert not any(name.startswith("planar.") for name in tables)
@@ -254,23 +254,9 @@ async def tmp_db_engine(tmp_db_url: str):
254
254
  yield engine
255
255
 
256
256
 
257
- @pytest.fixture()
258
- async def mem_db_engine(tmp_db_engine):
259
- # Memory databases don't work well with aiosqlite due to the "database
260
- # table is locked" error (which doesn't respect timeouts), so just use a
261
- # temporary db file for now until we figure out a fix.
262
- yield tmp_db_engine
263
-
264
- # name = uuid4()
265
- # async with engine_context(
266
- # f"sqlite+aiosqlite:///file:{name}?mode=memory&cache=shared&uri=true"
267
- # ) as engine:
268
- # yield engine
269
-
270
-
271
257
  @pytest.fixture(name="session")
272
- async def session_fixture(mem_db_engine):
273
- async with new_session(mem_db_engine) as session:
258
+ async def session_fixture(tmp_db_engine):
259
+ async with new_session(tmp_db_engine) as session:
274
260
  tok = session_var.set(session)
275
261
  yield session
276
262
  session_var.reset(tok)
@@ -318,7 +304,7 @@ async def tracer_fixture():
318
304
  async def engine_context(url: str):
319
305
  db_manager = DatabaseManager(url)
320
306
  db_manager.connect()
321
- await db_manager.migrate(use_alembic=True)
307
+ await db_manager.migrate()
322
308
  engine = db_manager.get_engine()
323
309
  tok = engine_var.set(engine)
324
310
  yield engine
@@ -21,9 +21,9 @@ class WorkflowObserver:
21
21
  ) -> UUID:
22
22
  """Extract workflow_id from notification data"""
23
23
  if isinstance(notification.data, Workflow):
24
- return notification.data.id
24
+ return notification.workflow_id
25
25
  else:
26
- return notification.data.workflow_id
26
+ return notification.workflow_id
27
27
 
28
28
  def on_workflow_notification(self, notification: WorkflowNotification):
29
29
  workflow_id = UUID(str(self._get_workflow_id_from_notification(notification)))
@@ -2,10 +2,12 @@ from contextlib import asynccontextmanager
2
2
  from contextvars import ContextVar
3
3
  from enum import Enum
4
4
  from typing import Callable, Union
5
+ from uuid import UUID
5
6
 
6
7
  from pydantic import BaseModel
7
8
 
8
9
  from planar.logging import get_logger
10
+ from planar.workflows.context import get_context
9
11
  from planar.workflows.models import Workflow, WorkflowStep
10
12
 
11
13
  logger = get_logger(__name__)
@@ -20,11 +22,19 @@ class Notification(str, Enum):
20
22
  STEP_RUNNING = "step-running"
21
23
  STEP_SUCCEEDED = "step-succeeded"
22
24
  STEP_FAILED = "step-failed"
25
+ AGENT_TEXT = "agent-text"
26
+ AGENT_THINK = "agent-think"
27
+
28
+
29
+ class AgentEventData(BaseModel):
30
+ data: str
31
+ step_id: int
23
32
 
24
33
 
25
34
  class WorkflowNotification(BaseModel):
26
35
  kind: Notification
27
- data: Union[Workflow, WorkflowStep]
36
+ workflow_id: UUID
37
+ data: Union[Workflow, WorkflowStep, AgentEventData]
28
38
 
29
39
 
30
40
  WorkflowNotificationCallback = Callable[[WorkflowNotification], None]
@@ -38,7 +48,9 @@ def workflow_notify(workflow: Workflow, kind: Notification):
38
48
  callback = workflow_notification_callback_var.get(None)
39
49
  if callback is not None:
40
50
  logger.debug("notifying workflow event", kind=kind, workflow_id=workflow.id)
41
- callback(WorkflowNotification(kind=kind, data=workflow))
51
+ callback(
52
+ WorkflowNotification(kind=kind, workflow_id=workflow.id, data=workflow)
53
+ )
42
54
 
43
55
 
44
56
  def workflow_started(workflow: Workflow):
@@ -70,7 +82,9 @@ def step_notify(step: WorkflowStep, kind: Notification):
70
82
  workflow_id=step.workflow_id,
71
83
  step_id=step.step_id,
72
84
  )
73
- callback(WorkflowNotification(kind=kind, data=step))
85
+ callback(
86
+ WorkflowNotification(kind=kind, workflow_id=step.workflow_id, data=step)
87
+ )
74
88
 
75
89
 
76
90
  def step_running(step: WorkflowStep):
@@ -85,6 +99,28 @@ def step_failed(step: WorkflowStep):
85
99
  return step_notify(step, Notification.STEP_FAILED)
86
100
 
87
101
 
102
+ def agent_notify(kind: Notification, data: str):
103
+ callback = workflow_notification_callback_var.get(None)
104
+ if callback is not None:
105
+ context = get_context()
106
+ logger.debug("notifying agent event", kind=kind)
107
+ callback(
108
+ WorkflowNotification(
109
+ kind=kind,
110
+ workflow_id=context.workflow_id,
111
+ data=AgentEventData(data=data, step_id=context.step_stack[-1].step_id),
112
+ )
113
+ )
114
+
115
+
116
+ def agent_think(data: str):
117
+ agent_notify(Notification.AGENT_THINK, data)
118
+
119
+
120
+ def agent_text(data: str):
121
+ agent_notify(Notification.AGENT_TEXT, data)
122
+
123
+
88
124
  @asynccontextmanager
89
125
  async def workflow_notification_context(callback: WorkflowNotificationCallback):
90
126
  """Context manager for setting up and tearing down Workflow notification context"""