planar 0.11.0__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
planar/data/connection.py CHANGED
@@ -1,4 +1,6 @@
1
1
  import asyncio
2
+ from dataclasses import dataclass
3
+ from urllib.parse import urlparse
2
4
 
3
5
  import ibis
4
6
  from ibis.backends.duckdb import Backend as DuckDBBackend
@@ -17,6 +19,85 @@ from planar.session import get_config
17
19
  logger = get_logger(__name__)
18
20
 
19
21
 
22
+ @dataclass
23
+ class _ConnectionPool:
24
+ connections: list[DuckDBBackend]
25
+ cursor: int = 0
26
+
27
+
28
+ # In production a Planar app typically runs with a single data config, so we only
29
+ # ever have one signature, but we still rotate through a handful of cached
30
+ # backends to reduce the risk of concurrent calls sharing the same DuckDB
31
+ # connection. During testing we create many ephemeral configs (temp dirs, sqlite
32
+ # files, etc.), so the cache also avoids paying the attachment cost on every
33
+ # request. We keep up to `_MAX_CONNECTIONS_PER_SIGNATURE` backends per signature
34
+ # and hand them out in round-robin order; concurrency safety ultimately depends on
35
+ # DuckDB tolerating overlapping use of an individual backend.
36
+ _connection_cache: dict[int, _ConnectionPool] = {}
37
+ _cache_lock: asyncio.Lock | None = None
38
+
39
+ # Maximum number of cached connections per configuration signature.
40
+ _MAX_CONNECTIONS_PER_SIGNATURE = 10
41
+
42
+
43
+ def _config_signature(config: PlanarConfig) -> int:
44
+ """Create a stable signature for caching connections."""
45
+
46
+ assert config.data is not None, "data configuration must be set"
47
+ return hash(config.data)
48
+
49
+
50
+ async def _close_backend(connection: DuckDBBackend) -> None:
51
+ close_fn = getattr(connection, "close", None)
52
+ try:
53
+ if callable(close_fn):
54
+ await asyncio.to_thread(close_fn)
55
+ except Exception as exc:
56
+ logger.warning("failed to close DuckDB connection", error=str(exc))
57
+
58
+
59
+ def _make_aws_s3_secret_query(config: S3Config) -> str:
60
+ """
61
+ https://duckdb.org/docs/stable/core_extensions/httpfs/s3api
62
+ """
63
+ columns = ["TYPE s3"]
64
+
65
+ if config.region:
66
+ columns.append(
67
+ f"REGION '{config.region}'",
68
+ )
69
+
70
+ if config.endpoint_url:
71
+ parsed_url = urlparse(config.endpoint_url)
72
+ endpoint_host = parsed_url.hostname
73
+ endpoint_port = f":{parsed_url.port}" if parsed_url.port else ""
74
+ endpoint = f"{endpoint_host}{endpoint_port}"
75
+
76
+ columns.append(f"ENDPOINT '{endpoint}'")
77
+
78
+ if config.access_key and config.secret_key:
79
+ columns.extend(
80
+ [
81
+ "PROVIDER config",
82
+ f"KEY_ID '{config.access_key}'",
83
+ f"SECRET '{config.secret_key}'",
84
+ ]
85
+ )
86
+ else:
87
+ columns.extend(
88
+ [
89
+ "PROVIDER credential_chain",
90
+ "CHAIN 'env;sts;instance'",
91
+ ]
92
+ )
93
+
94
+ return f"""
95
+ CREATE OR REPLACE SECRET secret (
96
+ {", ".join(columns)}
97
+ );
98
+ """
99
+
100
+
20
101
  async def _create_connection(config: PlanarConfig) -> DuckDBBackend:
21
102
  """Create Ibis DuckDB connection with Ducklake."""
22
103
  data_config = config.data
@@ -67,6 +148,14 @@ async def _create_connection(config: PlanarConfig) -> DuckDBBackend:
67
148
  if isinstance(storage, LocalDirectoryConfig):
68
149
  data_path = storage.directory
69
150
  elif isinstance(storage, S3Config):
151
+ await asyncio.to_thread(con.raw_sql, "INSTALL httpfs;")
152
+ await asyncio.to_thread(con.raw_sql, "LOAD httpfs;")
153
+
154
+ await asyncio.to_thread(
155
+ con.raw_sql,
156
+ _make_aws_s3_secret_query(storage),
157
+ )
158
+
70
159
  data_path = f"s3://{storage.bucket_name}/"
71
160
  else:
72
161
  # Generic fallback
@@ -95,8 +184,21 @@ async def _create_connection(config: PlanarConfig) -> DuckDBBackend:
95
184
  return con
96
185
 
97
186
 
98
- async def _get_connection() -> DuckDBBackend:
99
- """Get Ibis connection to Ducklake."""
187
+ def _get_cache_lock() -> asyncio.Lock:
188
+ # Create a lock on the first call to this function, or re-create it if the
189
+ # loop has changed (happens on tests).
190
+ global _cache_lock
191
+ loop = asyncio.get_running_loop()
192
+ lock = _cache_lock
193
+ if lock is None or getattr(lock, "_loop", None) is not loop:
194
+ lock = asyncio.Lock()
195
+ _cache_lock = lock
196
+ return lock
197
+
198
+
199
+ async def get_connection() -> DuckDBBackend:
200
+ """Return a cached DuckDB connection using round-robin selection."""
201
+
100
202
  config = get_config()
101
203
 
102
204
  if not config.data:
@@ -104,5 +206,39 @@ async def _get_connection() -> DuckDBBackend:
104
206
  "Data configuration not found. Please configure 'data' in your planar.yaml"
105
207
  )
106
208
 
107
- # TODO: Add cached connection pooling or memoize the connection
108
- return await _create_connection(config)
209
+ signature = _config_signature(config)
210
+ lock = _get_cache_lock()
211
+
212
+ async with lock:
213
+ pool = _connection_cache.get(signature)
214
+
215
+ if pool is None:
216
+ connection = await _create_connection(config)
217
+ _connection_cache[signature] = _ConnectionPool(connections=[connection])
218
+ return connection
219
+
220
+ if len(pool.connections) < _MAX_CONNECTIONS_PER_SIGNATURE:
221
+ connection = await _create_connection(config)
222
+ pool.connections.append(connection)
223
+ return connection
224
+
225
+ connection = pool.connections[pool.cursor]
226
+ pool.cursor = (pool.cursor + 1) % len(pool.connections)
227
+ return connection
228
+
229
+
230
+ async def reset_connection_cache() -> None:
231
+ """Reset the cached DuckDB connection, closing it if necessary."""
232
+
233
+ lock = _get_cache_lock()
234
+
235
+ async with lock:
236
+ pools = list(_connection_cache.values())
237
+ _connection_cache.clear()
238
+
239
+ for pool in pools:
240
+ for connection in pool.connections:
241
+ await _close_backend(connection)
242
+
243
+ global _cache_lock
244
+ _cache_lock = None
planar/data/dataset.py CHANGED
@@ -6,10 +6,11 @@ from typing import Literal, Self
6
6
  import ibis
7
7
  import polars as pl
8
8
  import pyarrow as pa
9
+ from ibis.backends.duckdb import Backend as DuckDBBackend
9
10
  from ibis.common.exceptions import TableNotFound
10
11
  from pydantic import BaseModel
11
12
 
12
- from planar.data.connection import _get_connection
13
+ from planar.data.connection import get_connection
13
14
  from planar.logging import get_logger
14
15
 
15
16
  from .exceptions import DataError, DatasetAlreadyExistsError, DatasetNotFoundError
@@ -67,7 +68,11 @@ class PlanarDataset(BaseModel):
67
68
 
68
69
  async def exists(self) -> bool:
69
70
  """Check if the dataset exists in Ducklake."""
70
- con = await _get_connection()
71
+ con = await get_connection()
72
+ return await self._table_exists(con)
73
+
74
+ async def _table_exists(self, con: DuckDBBackend) -> bool:
75
+ """Check for table existence using the provided connection."""
71
76
 
72
77
  try:
73
78
  # TODO: Query for the table name directly
@@ -88,11 +93,13 @@ class PlanarDataset(BaseModel):
88
93
  data: Data to write (Polars DataFrame/LazyFrame, PyArrow Table, or Ibis expression)
89
94
  mode: Write mode - "append" or "overwrite"
90
95
  """
91
- con = await _get_connection()
92
96
  overwrite = mode == "overwrite"
93
97
 
94
98
  try:
95
- if not await self.exists():
99
+ con = await get_connection()
100
+ table_exists = await self._table_exists(con)
101
+
102
+ if not table_exists:
96
103
  await asyncio.to_thread(
97
104
  con.create_table, self.name, data, overwrite=overwrite
98
105
  )
@@ -133,9 +140,8 @@ class PlanarDataset(BaseModel):
133
140
  Returns:
134
141
  Ibis table expression that can be further filtered using Ibis methods
135
142
  """
136
- con = await _get_connection()
137
-
138
143
  try:
144
+ con = await get_connection()
139
145
  table = await asyncio.to_thread(con.table, self.name)
140
146
 
141
147
  if columns:
@@ -162,8 +168,8 @@ class PlanarDataset(BaseModel):
162
168
 
163
169
  async def delete(self) -> None:
164
170
  """Delete the dataset."""
165
- con = await _get_connection()
166
171
  try:
172
+ con = await get_connection()
167
173
  await asyncio.to_thread(con.drop_table, self.name, force=True)
168
174
  logger.info("deleted dataset", dataset_name=self.name)
169
175
  except Exception as e:
planar/data/utils.py CHANGED
@@ -1,42 +1,45 @@
1
1
  import asyncio
2
- from typing import TypedDict
2
+ from collections import defaultdict
3
+ from typing import Sequence, TypedDict
3
4
 
5
+ import ibis
4
6
  import ibis.expr.datatypes as dt
7
+ import pyarrow as pa
8
+ from ibis.backends.duckdb import Backend as DuckDBBackend
5
9
  from ibis.common.exceptions import TableNotFound
10
+ from sqlglot import exp
6
11
 
7
- from planar.data.connection import _get_connection
12
+ from planar.data.connection import get_connection
8
13
  from planar.data.dataset import PlanarDataset
9
14
  from planar.data.exceptions import DatasetNotFoundError
10
15
  from planar.logging import get_logger
16
+ from planar.session import get_config
11
17
 
12
18
  logger = get_logger(__name__)
13
19
 
14
20
 
15
- # TODO: consider connection pooling or memoize the connection
16
-
17
-
18
21
  async def list_datasets(limit: int = 100, offset: int = 0) -> list[PlanarDataset]:
19
- conn = await _get_connection()
20
- tables = await asyncio.to_thread(conn.list_tables)
22
+ conn = await get_connection()
23
+ tables = sorted(await asyncio.to_thread(conn.list_tables))[offset : offset + limit]
21
24
  return [PlanarDataset(name=table) for table in tables]
22
25
 
23
26
 
24
27
  async def list_schemas() -> list[str]:
25
- METADATA_SCHEMAS = [
26
- "information_schema",
27
- # FIXME: why is list_databases returning pg_catalog
28
- # if the ducklake catalog is sqlite?
29
- "pg_catalog",
30
- ]
28
+ config = get_config()
29
+
30
+ if config.data is None:
31
+ return []
31
32
 
32
- conn = await _get_connection()
33
+ METADATA_SCHEMAS = [config.data.catalog_name, "main"]
34
+
35
+ conn = await get_connection()
33
36
 
34
37
  # in ibis, "databases" are schemas in the traditional sense
35
38
  # e.g. psql: schema == ibis: database
36
39
  # https://ibis-project.org/concepts/backend-table-hierarchy
37
40
  schemas = await asyncio.to_thread(conn.list_databases)
38
41
 
39
- return [schema for schema in schemas if schema not in METADATA_SCHEMAS]
42
+ return [schema for schema in schemas if schema in METADATA_SCHEMAS]
40
43
 
41
44
 
42
45
  async def get_dataset(dataset_name: str, schema_name: str = "main") -> PlanarDataset:
@@ -51,7 +54,7 @@ async def get_dataset(dataset_name: str, schema_name: str = "main") -> PlanarDat
51
54
 
52
55
 
53
56
  async def get_dataset_row_count(dataset_name: str) -> int:
54
- conn = await _get_connection()
57
+ conn = await get_connection()
55
58
 
56
59
  try:
57
60
  value = await asyncio.to_thread(
@@ -72,18 +75,135 @@ class DatasetMetadata(TypedDict):
72
75
  row_count: int
73
76
 
74
77
 
75
- async def get_dataset_metadata(
76
- dataset_name: str, schema_name: str
77
- ) -> DatasetMetadata | None:
78
- conn = await _get_connection()
78
+ async def _fetch_column_schemas(
79
+ conn: DuckDBBackend,
80
+ dataset_names: Sequence[str],
81
+ schema_name: str,
82
+ ) -> dict[str, dict[str, dt.DataType]]:
83
+ columns = conn.table("columns", database="information_schema")
84
+ schema_literal = ibis.literal(schema_name)
85
+ dataset_literals = [ibis.literal(name) for name in dataset_names]
86
+ filtered = columns.filter(
87
+ (columns.table_schema == schema_literal)
88
+ & (columns.table_name.isin(dataset_literals))
89
+ )
90
+
91
+ selected = filtered.select(
92
+ columns.table_name.name("table_name"),
93
+ columns.column_name.name("column_name"),
94
+ columns.ordinal_position.name("ordinal_position"),
95
+ columns.data_type.name("data_type"),
96
+ columns.is_nullable.name("is_nullable"),
97
+ )
98
+
99
+ arrow_table: pa.Table = await asyncio.to_thread(selected.to_pyarrow)
100
+ rows = arrow_table.to_pylist()
101
+
102
+ schema_fields: dict[str, list[tuple[int, str, dt.DataType]]] = defaultdict(list)
103
+ type_mapper = conn.compiler.type_mapper
104
+
105
+ for row in rows:
106
+ table_name = row["table_name"]
107
+ column_name = row["column_name"]
108
+ ordinal_position = row["ordinal_position"]
109
+ dtype = type_mapper.from_string(
110
+ row["data_type"], nullable=row.get("is_nullable") == "YES"
111
+ )
79
112
 
80
- try:
81
- schema, row_count = await asyncio.gather(
82
- asyncio.to_thread(conn.get_schema, dataset_name, database=schema_name),
83
- get_dataset_row_count(dataset_name),
113
+ schema_fields[table_name].append((ordinal_position, column_name, dtype))
114
+
115
+ ordered_fields: dict[str, dict[str, dt.DataType]] = {}
116
+ for table_name, fields in schema_fields.items():
117
+ ordered_fields[table_name] = {
118
+ column_name: dtype
119
+ for _, column_name, dtype in sorted(fields, key=lambda entry: entry[0])
120
+ }
121
+
122
+ return ordered_fields
123
+
124
+
125
+ async def _fetch_row_counts(
126
+ conn: DuckDBBackend,
127
+ dataset_names: Sequence[str],
128
+ schema_name: str,
129
+ ) -> dict[str, int]:
130
+ if not dataset_names:
131
+ return {}
132
+
133
+ quoted = conn.compiler.quoted
134
+ count_queries: list[exp.Select] = []
135
+
136
+ for dataset_name in dataset_names:
137
+ table_expr = exp.Table(
138
+ this=exp.Identifier(this=dataset_name, quoted=quoted),
139
+ db=(
140
+ exp.Identifier(this=schema_name, quoted=quoted) if schema_name else None
141
+ ),
84
142
  )
143
+ select_expr = (
144
+ exp.Select()
145
+ .select(
146
+ exp.Literal.string(dataset_name).as_("dataset_name"),
147
+ exp.Count(this=exp.Star()).as_("row_count"),
148
+ )
149
+ .from_(table_expr)
150
+ )
151
+ count_queries.append(select_expr)
152
+
153
+ if not count_queries:
154
+ return {}
155
+
156
+ union_query: exp.Expression = count_queries[0]
157
+ for query in count_queries[1:]:
158
+ union_query = exp.Union(this=union_query, expression=query, distinct=False)
159
+
160
+ def _execute() -> dict[str, int]:
161
+ with conn._safe_raw_sql(union_query) as cursor: # type: ignore[attr-defined]
162
+ rows = cursor.fetchall()
85
163
 
86
- return DatasetMetadata(schema=schema.fields, row_count=row_count)
164
+ return {str(dataset_name): int(row_count) for dataset_name, row_count in rows}
87
165
 
166
+ return await asyncio.to_thread(_execute)
167
+
168
+
169
+ async def get_datasets_metadata(
170
+ dataset_names: Sequence[str], schema_name: str
171
+ ) -> dict[str, DatasetMetadata]:
172
+ if not dataset_names:
173
+ return {}
174
+
175
+ dataset_list = list(dict.fromkeys(dataset_names))
176
+ if not dataset_list:
177
+ return {}
178
+
179
+ conn = await get_connection()
180
+
181
+ schemas = await _fetch_column_schemas(conn, dataset_list, schema_name)
182
+ row_counts = await _fetch_row_counts(conn, list(schemas.keys()), schema_name)
183
+
184
+ metadata: dict[str, DatasetMetadata] = {}
185
+
186
+ for dataset_name in dataset_list:
187
+ schema = schemas.get(dataset_name)
188
+ row_count = row_counts.get(dataset_name)
189
+
190
+ if not schema or row_count is None:
191
+ continue
192
+
193
+ metadata[dataset_name] = DatasetMetadata(
194
+ schema=schema,
195
+ row_count=row_count,
196
+ )
197
+
198
+ return metadata
199
+
200
+
201
+ async def get_dataset_metadata(
202
+ dataset_name: str, schema_name: str
203
+ ) -> DatasetMetadata | None:
204
+ try:
205
+ metadata = await get_datasets_metadata([dataset_name], schema_name)
88
206
  except TableNotFound:
89
207
  return None
208
+
209
+ return metadata.get(dataset_name)
planar/db/alembic/env.py CHANGED
@@ -1,9 +1,11 @@
1
+ import asyncio
1
2
  from functools import wraps
2
3
  from logging.config import fileConfig
3
4
 
4
5
  import alembic.ddl.base as alembic_base
5
6
  from alembic import context
6
- from sqlalchemy import Connection, engine_from_config, pool
7
+ from sqlalchemy import Connection, pool
8
+ from sqlalchemy.ext.asyncio import create_async_engine
7
9
 
8
10
  from planar.db import PLANAR_FRAMEWORK_METADATA, PLANAR_SCHEMA
9
11
 
@@ -72,6 +74,69 @@ alembic_base.format_table_name = schema_translate_wrapper(
72
74
  )
73
75
 
74
76
 
77
+ async def run_migrations_online_async() -> None:
78
+ """Run migrations in 'online' mode using async engine for development."""
79
+ # Import models to ensure they're registered with PLANAR_FRAMEWORK_METADATA
80
+ try:
81
+ from planar.files.models import PlanarFileMetadata # noqa: F401, PLC0415
82
+ from planar.human.models import HumanTask # noqa: F401, PLC0415
83
+ from planar.object_config.models import ( # noqa: F401, PLC0415
84
+ ObjectConfiguration,
85
+ )
86
+ from planar.workflows.models import ( # noqa: PLC0415
87
+ LockedResource, # noqa: F401
88
+ Workflow, # noqa: F401
89
+ WorkflowEvent, # noqa: F401
90
+ WorkflowStep, # noqa: F401
91
+ )
92
+ except ImportError as e:
93
+ raise RuntimeError(
94
+ f"Failed to import system models for migration generation: {e}"
95
+ )
96
+
97
+ config_dict = config.get_section(config.config_ini_section, {})
98
+ url = config_dict["sqlalchemy.url"]
99
+ is_sqlite = url.startswith("sqlite://")
100
+
101
+ # Create async engine
102
+ connectable = create_async_engine(
103
+ url,
104
+ poolclass=pool.NullPool,
105
+ execution_options={
106
+ # SQLite doesn't support schemas, so we need to translate the planar schema
107
+ # name to None in order to ignore it.
108
+ "schema_translate_map": sqlite_schema_translate_map if is_sqlite else {},
109
+ },
110
+ )
111
+
112
+ async with connectable.connect() as connection:
113
+ is_sqlite = connection.dialect.name == "sqlite"
114
+ if is_sqlite:
115
+ connection.dialect.default_schema_name = PLANAR_SCHEMA
116
+
117
+ def do_run_migrations(sync_conn):
118
+ context.configure(
119
+ connection=sync_conn,
120
+ target_metadata=target_metadata,
121
+ # For SQLite, don't use schema since it's not supported
122
+ version_table_schema=None if is_sqlite else PLANAR_SCHEMA,
123
+ include_schemas=True,
124
+ include_name=include_name,
125
+ # SQLite doesn't support alter table, so we need to use render_as_batch
126
+ # to create the tables in a single transaction. For other databases,
127
+ # the batch op is no-op.
128
+ # https://alembic.sqlalchemy.org/en/latest/batch.html#running-batch-migrations-for-sqlite-and-other-databases
129
+ render_as_batch=True,
130
+ )
131
+
132
+ with context.begin_transaction():
133
+ context.run_migrations()
134
+
135
+ await connection.run_sync(do_run_migrations)
136
+
137
+ await connectable.dispose()
138
+
139
+
75
140
  def run_migrations_online() -> None:
76
141
  """Run migrations in 'online' mode.
77
142
 
@@ -103,62 +168,8 @@ def run_migrations_online() -> None:
103
168
  with context.begin_transaction():
104
169
  context.run_migrations()
105
170
  else:
106
- # Development mode: create engine from alembic.ini
107
- # Used for alembic to generate migrations
108
- # Import models to ensure they're registered with PLANAR_FRAMEWORK_METADATA
109
- try:
110
- from planar.files.models import PlanarFileMetadata # noqa: F401, PLC0415
111
- from planar.human.models import HumanTask # noqa: F401, PLC0415
112
- from planar.object_config.models import ( # noqa: F401, PLC0415
113
- ObjectConfiguration,
114
- )
115
- from planar.workflows.models import ( # noqa: PLC0415
116
- LockedResource, # noqa: F401
117
- Workflow, # noqa: F401
118
- WorkflowEvent, # noqa: F401
119
- WorkflowStep, # noqa: F401
120
- )
121
- except ImportError as e:
122
- raise RuntimeError(
123
- f"Failed to import system models for migration generation: {e}"
124
- )
125
-
126
- config_dict = config.get_section(config.config_ini_section, {})
127
- url = config_dict["sqlalchemy.url"]
128
- is_sqlite = url.startswith("sqlite://")
129
- translate_map = sqlite_schema_translate_map if is_sqlite else {}
130
- connectable = engine_from_config(
131
- config_dict,
132
- prefix="sqlalchemy.",
133
- poolclass=pool.NullPool,
134
- execution_options={
135
- # SQLite doesn't support schemas, so we need to translate the planar schema
136
- # name to None in order to ignore it.
137
- "schema_translate_map": translate_map,
138
- },
139
- )
140
-
141
- with connectable.connect() as connection:
142
- is_sqlite = connection.dialect.name == "sqlite"
143
- if is_sqlite:
144
- connection.dialect.default_schema_name = PLANAR_SCHEMA
145
-
146
- context.configure(
147
- connection=connection,
148
- target_metadata=target_metadata,
149
- # For SQLite, don't use schema since it's not supported
150
- version_table_schema=None if is_sqlite else PLANAR_SCHEMA,
151
- include_schemas=True,
152
- include_name=include_name,
153
- # SQLite doesn't support alter table, so we need to use render_as_batch
154
- # to create the tables in a single transaction. For other databases,
155
- # the batch op is no-op.
156
- # https://alembic.sqlalchemy.org/en/latest/batch.html#running-batch-migrations-for-sqlite-and-other-databases
157
- render_as_batch=True,
158
- )
159
-
160
- with context.begin_transaction():
161
- context.run_migrations()
171
+ # Development mode: run migrations asynchronously
172
+ asyncio.run(run_migrations_online_async())
162
173
 
163
174
 
164
175
  if context.is_offline_mode():
planar/db/alembic.ini CHANGED
@@ -70,7 +70,7 @@ version_path_separator = os
70
70
  # so it sometimes incorrectly thinks it needs to re-generate things (like indices) that already
71
71
  # exist in the database from a prior migration. Using postgres obviates that issue.
72
72
  # https://github.com/sqlalchemy/alembic/issues/555
73
- sqlalchemy.url = postgresql+psycopg2://postgres:postgres@localhost:5432/postgres
73
+ sqlalchemy.url = postgresql+asyncpg://postgres:postgres@localhost:5432/postgres
74
74
 
75
75
  # [post_write_hooks]
76
76
  # This section defines scripts or Python functions that are run
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING, Annotated, Literal
4
4
 
5
- from pydantic import BaseModel, Field, model_validator
5
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
6
6
 
7
7
  from .local_directory import LocalDirectoryStorage
8
8
  from .s3 import S3Storage
@@ -15,6 +15,8 @@ class LocalDirectoryConfig(BaseModel):
15
15
  backend: Literal["localdir"]
16
16
  directory: str
17
17
 
18
+ model_config = ConfigDict(frozen=True)
19
+
18
20
 
19
21
  class S3Config(BaseModel):
20
22
  backend: Literal["s3"]
@@ -25,6 +27,8 @@ class S3Config(BaseModel):
25
27
  endpoint_url: str | None = None
26
28
  presigned_url_ttl: int = 3600
27
29
 
30
+ model_config = ConfigDict(frozen=True)
31
+
28
32
 
29
33
  class AzureBlobConfig(BaseModel):
30
34
  backend: Literal["azure_blob"]
@@ -39,6 +43,8 @@ class AzureBlobConfig(BaseModel):
39
43
  # Common settings
40
44
  sas_ttl: int = 3600 # SAS URL expiry time in seconds
41
45
 
46
+ model_config = ConfigDict(frozen=True)
47
+
42
48
  @model_validator(mode="after")
43
49
  def validate_auth_config(self):
44
50
  """Ensure exactly one valid authentication configuration."""
planar/human/models.py CHANGED
@@ -7,7 +7,6 @@ from pydantic import BaseModel
7
7
  from sqlmodel import JSON, Column, Field
8
8
 
9
9
  from planar.db import PlanarInternalBase
10
- from planar.modeling.field_helpers import JsonSchema
11
10
  from planar.modeling.mixins import TimestampMixin
12
11
  from planar.modeling.mixins.auditable import AuditableMixin
13
12
  from planar.modeling.mixins.uuid_primary_key import UUIDPrimaryKeyMixin
@@ -48,12 +47,12 @@ class HumanTask(
48
47
  workflow_name: str
49
48
 
50
49
  # Input data for context
51
- input_schema: Optional[JsonSchema] = Field(default=None, sa_column=Column(JSON))
50
+ input_schema: Optional[dict[str, Any]] = Field(default=None, sa_column=Column(JSON))
52
51
  input_data: Optional[dict[str, Any]] = Field(default=None, sa_column=Column(JSON))
53
52
  message: Optional[str] = Field(default=None)
54
53
 
55
54
  # Schema for expected output
56
- output_schema: JsonSchema = Field(sa_column=Column(JSON))
55
+ output_schema: dict[str, Any] = Field(sa_column=Column(JSON))
57
56
  output_data: Optional[dict[str, Any]] = Field(default=None, sa_column=Column(JSON))
58
57
 
59
58
  # Suggested data for the form (optional)
@@ -18,6 +18,7 @@ def _in_event_loop_task() -> bool:
18
18
  return (
19
19
  threading.main_thread() == threading.current_thread()
20
20
  and asyncio.get_running_loop() is not None
21
+ and asyncio.current_task() is not None
21
22
  )
22
23
  except RuntimeError:
23
24
  return False