planar 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- planar/ai/agent.py +2 -1
- planar/ai/agent_base.py +24 -5
- planar/ai/state.py +17 -0
- planar/ai/test_agent_tool_step_display.py +1 -1
- planar/app.py +5 -0
- planar/data/connection.py +108 -0
- planar/data/dataset.py +11 -104
- planar/data/test_dataset.py +45 -41
- planar/data/utils.py +89 -0
- planar/db/alembic/env.py +25 -1
- planar/files/storage/azure_blob.py +1 -1
- planar/registry_items.py +2 -0
- planar/routers/dataset_router.py +213 -0
- planar/routers/models.py +1 -0
- planar/routers/test_dataset_router.py +429 -0
- planar/routers/test_workflow_router.py +26 -1
- planar/routers/workflow.py +2 -0
- planar/security/authorization.py +31 -3
- planar/security/default_policies.cedar +25 -0
- planar/testing/fixtures.py +30 -0
- planar/testing/planar_test_client.py +1 -1
- planar/workflows/decorators.py +2 -1
- planar/workflows/wrappers.py +1 -0
- {planar-0.9.3.dist-info → planar-0.10.0.dist-info}/METADATA +1 -1
- {planar-0.9.3.dist-info → planar-0.10.0.dist-info}/RECORD +27 -22
- {planar-0.9.3.dist-info → planar-0.10.0.dist-info}/WHEEL +1 -1
- {planar-0.9.3.dist-info → planar-0.10.0.dist-info}/entry_points.txt +0 -0
planar/data/utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
import asyncio
|
2
|
+
from typing import TypedDict
|
3
|
+
|
4
|
+
import ibis.expr.datatypes as dt
|
5
|
+
from ibis.common.exceptions import TableNotFound
|
6
|
+
|
7
|
+
from planar.data.connection import _get_connection
|
8
|
+
from planar.data.dataset import PlanarDataset
|
9
|
+
from planar.data.exceptions import DatasetNotFoundError
|
10
|
+
from planar.logging import get_logger
|
11
|
+
|
12
|
+
logger = get_logger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
# TODO: consider connection pooling or memoize the connection
|
16
|
+
|
17
|
+
|
18
|
+
async def list_datasets(limit: int = 100, offset: int = 0) -> list[PlanarDataset]:
|
19
|
+
conn = await _get_connection()
|
20
|
+
tables = await asyncio.to_thread(conn.list_tables)
|
21
|
+
return [PlanarDataset(name=table) for table in tables]
|
22
|
+
|
23
|
+
|
24
|
+
async def list_schemas() -> list[str]:
|
25
|
+
METADATA_SCHEMAS = [
|
26
|
+
"information_schema",
|
27
|
+
# FIXME: why is list_databases returning pg_catalog
|
28
|
+
# if the ducklake catalog is sqlite?
|
29
|
+
"pg_catalog",
|
30
|
+
]
|
31
|
+
|
32
|
+
conn = await _get_connection()
|
33
|
+
|
34
|
+
# in ibis, "databases" are schemas in the traditional sense
|
35
|
+
# e.g. psql: schema == ibis: database
|
36
|
+
# https://ibis-project.org/concepts/backend-table-hierarchy
|
37
|
+
schemas = await asyncio.to_thread(conn.list_databases)
|
38
|
+
|
39
|
+
return [schema for schema in schemas if schema not in METADATA_SCHEMAS]
|
40
|
+
|
41
|
+
|
42
|
+
async def get_dataset(dataset_name: str, schema_name: str = "main") -> PlanarDataset:
|
43
|
+
# TODO: add schema_name as a parameter
|
44
|
+
|
45
|
+
dataset = PlanarDataset(name=dataset_name)
|
46
|
+
|
47
|
+
if not await dataset.exists():
|
48
|
+
raise DatasetNotFoundError(f"Dataset {dataset_name} not found")
|
49
|
+
|
50
|
+
return dataset
|
51
|
+
|
52
|
+
|
53
|
+
async def get_dataset_row_count(dataset_name: str) -> int:
|
54
|
+
conn = await _get_connection()
|
55
|
+
|
56
|
+
try:
|
57
|
+
value = await asyncio.to_thread(
|
58
|
+
lambda conn, dataset_name: conn.table(dataset_name).count().to_polars(),
|
59
|
+
conn,
|
60
|
+
dataset_name,
|
61
|
+
)
|
62
|
+
|
63
|
+
assert isinstance(value, int), "Scalar must be an integer"
|
64
|
+
|
65
|
+
return value
|
66
|
+
except TableNotFound:
|
67
|
+
raise # re-raise the exception and allow the caller to handle it
|
68
|
+
|
69
|
+
|
70
|
+
class DatasetMetadata(TypedDict):
|
71
|
+
schema: dict[str, dt.DataType]
|
72
|
+
row_count: int
|
73
|
+
|
74
|
+
|
75
|
+
async def get_dataset_metadata(
|
76
|
+
dataset_name: str, schema_name: str
|
77
|
+
) -> DatasetMetadata | None:
|
78
|
+
conn = await _get_connection()
|
79
|
+
|
80
|
+
try:
|
81
|
+
schema, row_count = await asyncio.gather(
|
82
|
+
asyncio.to_thread(conn.get_schema, dataset_name, database=schema_name),
|
83
|
+
get_dataset_row_count(dataset_name),
|
84
|
+
)
|
85
|
+
|
86
|
+
return DatasetMetadata(schema=schema.fields, row_count=row_count)
|
87
|
+
|
88
|
+
except TableNotFound:
|
89
|
+
return None
|
planar/db/alembic/env.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
|
+
from functools import wraps
|
1
2
|
from logging.config import fileConfig
|
2
3
|
|
4
|
+
import alembic.ddl.base as alembic_base
|
3
5
|
from alembic import context
|
4
6
|
from sqlalchemy import Connection, engine_from_config, pool
|
5
7
|
|
@@ -48,6 +50,28 @@ def include_name(name, type_, _):
|
|
48
50
|
return True
|
49
51
|
|
50
52
|
|
53
|
+
sqlite_schema_translate_map = {PLANAR_SCHEMA: None}
|
54
|
+
|
55
|
+
|
56
|
+
def schema_translate_wrapper(f):
|
57
|
+
@wraps(f)
|
58
|
+
def format_table_name_with_schema(compiler, name, schema):
|
59
|
+
# when on sqlite, we need to translate the schema to None
|
60
|
+
is_sqlite = compiler.dialect.name == "sqlite"
|
61
|
+
if is_sqlite:
|
62
|
+
translated_schema = sqlite_schema_translate_map.get(schema, schema)
|
63
|
+
else:
|
64
|
+
translated_schema = schema
|
65
|
+
return f(compiler, name, translated_schema)
|
66
|
+
|
67
|
+
return format_table_name_with_schema
|
68
|
+
|
69
|
+
|
70
|
+
alembic_base.format_table_name = schema_translate_wrapper(
|
71
|
+
alembic_base.format_table_name
|
72
|
+
)
|
73
|
+
|
74
|
+
|
51
75
|
def run_migrations_online() -> None:
|
52
76
|
"""Run migrations in 'online' mode.
|
53
77
|
|
@@ -102,7 +126,7 @@ def run_migrations_online() -> None:
|
|
102
126
|
config_dict = config.get_section(config.config_ini_section, {})
|
103
127
|
url = config_dict["sqlalchemy.url"]
|
104
128
|
is_sqlite = url.startswith("sqlite://")
|
105
|
-
translate_map =
|
129
|
+
translate_map = sqlite_schema_translate_map if is_sqlite else {}
|
106
130
|
connectable = engine_from_config(
|
107
131
|
config_dict,
|
108
132
|
prefix="sqlalchemy.",
|
@@ -278,7 +278,7 @@ class AzureBlobStorage(Storage):
|
|
278
278
|
|
279
279
|
elif self.auth_method.name == "AZURE_AD":
|
280
280
|
# Generate a User Delegation SAS signed with a user delegation key
|
281
|
-
start_time = datetime.
|
281
|
+
start_time = datetime.now(UTC)
|
282
282
|
user_delegation_key = await self.client.get_user_delegation_key(
|
283
283
|
key_start_time=start_time, key_expiry_time=expiry_time
|
284
284
|
)
|
planar/registry_items.py
CHANGED
@@ -47,6 +47,7 @@ class RegisteredWorkflow:
|
|
47
47
|
input_schema: dict[str, Any]
|
48
48
|
output_schema: dict[str, Any]
|
49
49
|
pydantic_model: Type[BaseModel]
|
50
|
+
is_interactive: bool
|
50
51
|
|
51
52
|
@staticmethod
|
52
53
|
def from_workflow(workflow: "WorkflowWrapper") -> "RegisteredWorkflow":
|
@@ -63,4 +64,5 @@ class RegisteredWorkflow:
|
|
63
64
|
workflow.original_fn
|
64
65
|
),
|
65
66
|
pydantic_model=create_pydantic_model_for_workflow(workflow),
|
67
|
+
is_interactive=workflow.is_interactive,
|
66
68
|
)
|
@@ -0,0 +1,213 @@
|
|
1
|
+
import io
|
2
|
+
from typing import AsyncGenerator
|
3
|
+
|
4
|
+
import pyarrow as pa
|
5
|
+
import pyarrow.parquet as pq
|
6
|
+
from fastapi import APIRouter, HTTPException, Query
|
7
|
+
from fastapi.responses import StreamingResponse
|
8
|
+
from ibis.common.exceptions import TableNotFound
|
9
|
+
from pydantic import BaseModel
|
10
|
+
|
11
|
+
from planar.data.exceptions import DatasetNotFoundError
|
12
|
+
from planar.data.utils import (
|
13
|
+
get_dataset,
|
14
|
+
get_dataset_metadata,
|
15
|
+
list_datasets,
|
16
|
+
list_schemas,
|
17
|
+
)
|
18
|
+
from planar.logging import get_logger
|
19
|
+
from planar.security.authorization import (
|
20
|
+
DatasetAction,
|
21
|
+
DatasetResource,
|
22
|
+
validate_authorization_for,
|
23
|
+
)
|
24
|
+
|
25
|
+
logger = get_logger(__name__)
|
26
|
+
|
27
|
+
|
28
|
+
class DatasetMetadata(BaseModel):
|
29
|
+
name: str
|
30
|
+
table_schema: dict
|
31
|
+
row_count: int
|
32
|
+
|
33
|
+
|
34
|
+
def create_dataset_router() -> APIRouter:
|
35
|
+
router = APIRouter(tags=["Planar Datasets"])
|
36
|
+
|
37
|
+
@router.get("/schemas", response_model=list[str])
|
38
|
+
async def get_schemas():
|
39
|
+
validate_authorization_for(
|
40
|
+
DatasetResource(), DatasetAction.DATASET_LIST_SCHEMAS
|
41
|
+
)
|
42
|
+
schemas = await list_schemas()
|
43
|
+
return schemas
|
44
|
+
|
45
|
+
@router.get("/metadata", response_model=list[DatasetMetadata])
|
46
|
+
async def list_planar_datasets(
|
47
|
+
limit: int = Query(100, ge=1, le=1000),
|
48
|
+
offset: int = Query(0, ge=0),
|
49
|
+
schema_name: str = Query("main"),
|
50
|
+
):
|
51
|
+
validate_authorization_for(DatasetResource(), DatasetAction.DATASET_LIST)
|
52
|
+
datasets = await list_datasets(limit, offset)
|
53
|
+
|
54
|
+
response = []
|
55
|
+
for dataset in datasets:
|
56
|
+
metadata = await get_dataset_metadata(dataset.name, schema_name)
|
57
|
+
|
58
|
+
if not metadata:
|
59
|
+
continue
|
60
|
+
|
61
|
+
schema = metadata["schema"]
|
62
|
+
row_count = metadata["row_count"]
|
63
|
+
|
64
|
+
response.append(
|
65
|
+
DatasetMetadata(
|
66
|
+
name=dataset.name,
|
67
|
+
row_count=row_count,
|
68
|
+
table_schema={
|
69
|
+
field_name: str(field_type)
|
70
|
+
for field_name, field_type in schema.items()
|
71
|
+
},
|
72
|
+
)
|
73
|
+
)
|
74
|
+
|
75
|
+
return response
|
76
|
+
|
77
|
+
@router.get("/metadata/{dataset_name}", response_model=DatasetMetadata)
|
78
|
+
async def get_planar_dataset(dataset_name: str, schema_name: str = "main"):
|
79
|
+
validate_authorization_for(
|
80
|
+
DatasetResource(dataset_name=dataset_name),
|
81
|
+
DatasetAction.DATASET_VIEW_DETAILS,
|
82
|
+
)
|
83
|
+
try:
|
84
|
+
metadata = await get_dataset_metadata(dataset_name, schema_name)
|
85
|
+
|
86
|
+
if not metadata:
|
87
|
+
raise HTTPException(
|
88
|
+
status_code=404, detail=f"Dataset {dataset_name} not found"
|
89
|
+
)
|
90
|
+
|
91
|
+
schema = metadata["schema"]
|
92
|
+
row_count = metadata["row_count"]
|
93
|
+
|
94
|
+
return DatasetMetadata(
|
95
|
+
name=dataset_name,
|
96
|
+
row_count=row_count,
|
97
|
+
table_schema={
|
98
|
+
field_name: str(field_type)
|
99
|
+
for field_name, field_type in schema.items()
|
100
|
+
},
|
101
|
+
)
|
102
|
+
except (DatasetNotFoundError, TableNotFound):
|
103
|
+
raise HTTPException(
|
104
|
+
status_code=404, detail=f"Dataset {dataset_name} not found"
|
105
|
+
)
|
106
|
+
|
107
|
+
@router.get(
|
108
|
+
"/content/{dataset_name}/arrow-stream", response_class=StreamingResponse
|
109
|
+
)
|
110
|
+
async def stream_dataset_content(
|
111
|
+
dataset_name: str,
|
112
|
+
batch_size: int = Query(100, ge=1, le=1000),
|
113
|
+
limit: int | None = Query(None, ge=1),
|
114
|
+
):
|
115
|
+
validate_authorization_for(
|
116
|
+
DatasetResource(dataset_name=dataset_name),
|
117
|
+
DatasetAction.DATASET_STREAM_CONTENT,
|
118
|
+
)
|
119
|
+
try:
|
120
|
+
dataset = await get_dataset(dataset_name)
|
121
|
+
|
122
|
+
# Apply limit parameter if specified
|
123
|
+
table = await dataset.read(limit=limit)
|
124
|
+
|
125
|
+
schema = table.schema().to_pyarrow()
|
126
|
+
|
127
|
+
async def stream_content() -> AsyncGenerator[bytes, None]:
|
128
|
+
sink = io.BytesIO()
|
129
|
+
|
130
|
+
try:
|
131
|
+
with pa.ipc.new_stream(sink, schema) as writer:
|
132
|
+
yield sink.getvalue() # yield the schema
|
133
|
+
|
134
|
+
batch_count = 0
|
135
|
+
for batch in table.to_pyarrow_batches(chunk_size=batch_size):
|
136
|
+
# reset the sink to only stream
|
137
|
+
# the current batch
|
138
|
+
# we don't want to stream the schema or previous
|
139
|
+
# batches again
|
140
|
+
sink.seek(0)
|
141
|
+
sink.truncate(0)
|
142
|
+
|
143
|
+
writer.write_batch(batch)
|
144
|
+
yield sink.getvalue()
|
145
|
+
batch_count += 1
|
146
|
+
|
147
|
+
# For empty datasets, ensure we have a complete stream
|
148
|
+
if batch_count == 0:
|
149
|
+
# Write an empty batch to ensure valid Arrow stream format
|
150
|
+
empty_batch = pa.RecordBatch.from_arrays(
|
151
|
+
[pa.array([], type=field.type) for field in schema],
|
152
|
+
schema=schema,
|
153
|
+
)
|
154
|
+
sink.seek(0)
|
155
|
+
sink.truncate(0)
|
156
|
+
writer.write_batch(empty_batch)
|
157
|
+
yield sink.getvalue()
|
158
|
+
finally:
|
159
|
+
# Explicit BytesIO cleanup for memory safety
|
160
|
+
sink.close()
|
161
|
+
|
162
|
+
return StreamingResponse(
|
163
|
+
stream_content(),
|
164
|
+
media_type="application/vnd.apache.arrow.stream",
|
165
|
+
headers={
|
166
|
+
"Content-Disposition": f"attachment; filename={dataset_name}.arrow",
|
167
|
+
"X-Batch-Size": str(batch_size),
|
168
|
+
"X-Row-Limit": str(limit) if limit else "unlimited",
|
169
|
+
},
|
170
|
+
)
|
171
|
+
except (DatasetNotFoundError, TableNotFound):
|
172
|
+
raise HTTPException(
|
173
|
+
status_code=404, detail=f"Dataset {dataset_name} not found"
|
174
|
+
)
|
175
|
+
|
176
|
+
@router.get("/content/{dataset_name}/download")
|
177
|
+
async def download_dataset(dataset_name: str, schema_name: str = "main"):
|
178
|
+
validate_authorization_for(
|
179
|
+
DatasetResource(dataset_name=dataset_name),
|
180
|
+
DatasetAction.DATASET_DOWNLOAD,
|
181
|
+
)
|
182
|
+
try:
|
183
|
+
arrow_buffer = pa.BufferOutputStream()
|
184
|
+
dataset = await get_dataset(dataset_name, schema_name)
|
185
|
+
|
186
|
+
pyarrow_table = await dataset.to_pyarrow()
|
187
|
+
|
188
|
+
pq.write_table(pyarrow_table, arrow_buffer)
|
189
|
+
|
190
|
+
if arrow_buffer.tell() == 0:
|
191
|
+
logger.warning(
|
192
|
+
"Dataset is empty",
|
193
|
+
dataset_name=dataset_name,
|
194
|
+
schema_name=schema_name,
|
195
|
+
)
|
196
|
+
|
197
|
+
buffer = arrow_buffer.getvalue()
|
198
|
+
parquet_bytes = buffer.to_pybytes()
|
199
|
+
bytes_io = io.BytesIO(parquet_bytes)
|
200
|
+
|
201
|
+
return StreamingResponse(
|
202
|
+
bytes_io,
|
203
|
+
media_type="application/x-parquet",
|
204
|
+
headers={
|
205
|
+
"Content-Disposition": f"attachment; filename={dataset_name}.parquet"
|
206
|
+
},
|
207
|
+
)
|
208
|
+
except (DatasetNotFoundError, TableNotFound):
|
209
|
+
raise HTTPException(
|
210
|
+
status_code=404, detail=f"Dataset {dataset_name} not found"
|
211
|
+
)
|
212
|
+
|
213
|
+
return router
|