planar 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- planar/_version.py +1 -1
- planar/ai/agent.py +67 -30
- planar/ai/pydantic_ai.py +570 -0
- planar/ai/pydantic_ai_agent.py +329 -0
- planar/ai/test_agent.py +2 -2
- planar/app.py +64 -20
- planar/cli.py +39 -27
- planar/config.py +45 -36
- planar/db/db.py +2 -1
- planar/files/storage/azure_blob.py +343 -0
- planar/files/storage/base.py +7 -0
- planar/files/storage/config.py +70 -7
- planar/files/storage/s3.py +6 -6
- planar/files/storage/test_azure_blob.py +435 -0
- planar/logging/formatter.py +17 -4
- planar/logging/test_formatter.py +327 -0
- planar/registry_items.py +2 -1
- planar/routers/agents_router.py +3 -1
- planar/routers/files.py +11 -2
- planar/routers/models.py +14 -1
- planar/routers/test_files_router.py +49 -0
- planar/routers/test_routes_security.py +5 -7
- planar/routers/test_workflow_router.py +270 -3
- planar/routers/workflow.py +95 -36
- planar/rules/models.py +36 -39
- planar/rules/test_data/account_dormancy_management.json +223 -0
- planar/rules/test_data/airline_loyalty_points_calculator.json +262 -0
- planar/rules/test_data/applicant_risk_assessment.json +435 -0
- planar/rules/test_data/booking_fraud_detection.json +407 -0
- planar/rules/test_data/cellular_data_rollover_system.json +258 -0
- planar/rules/test_data/clinical_trial_eligibility_screener.json +437 -0
- planar/rules/test_data/customer_lifetime_value.json +143 -0
- planar/rules/test_data/import_duties_calculator.json +289 -0
- planar/rules/test_data/insurance_prior_authorization.json +443 -0
- planar/rules/test_data/online_check_in_eligibility_system.json +254 -0
- planar/rules/test_data/order_consolidation_system.json +375 -0
- planar/rules/test_data/portfolio_risk_monitor.json +471 -0
- planar/rules/test_data/supply_chain_risk.json +253 -0
- planar/rules/test_data/warehouse_cross_docking.json +237 -0
- planar/rules/test_rules.py +750 -6
- planar/scaffold_templates/planar.dev.yaml.j2 +6 -6
- planar/scaffold_templates/planar.prod.yaml.j2 +9 -5
- planar/scaffold_templates/pyproject.toml.j2 +1 -1
- planar/security/auth_context.py +21 -0
- planar/security/{jwt_middleware.py → auth_middleware.py} +70 -17
- planar/security/authorization.py +9 -15
- planar/security/tests/test_auth_middleware.py +162 -0
- planar/sse/proxy.py +4 -9
- planar/test_app.py +92 -1
- planar/test_cli.py +81 -59
- planar/test_config.py +17 -14
- planar/testing/fixtures.py +325 -0
- planar/testing/planar_test_client.py +5 -2
- planar/utils.py +41 -1
- planar/workflows/execution.py +1 -1
- planar/workflows/orchestrator.py +5 -0
- planar/workflows/serialization.py +12 -6
- planar/workflows/step_core.py +3 -1
- planar/workflows/test_serialization.py +9 -1
- {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/METADATA +30 -5
- planar-0.7.0.dist-info/RECORD +169 -0
- planar/.__init__.py.un~ +0 -0
- planar/._version.py.un~ +0 -0
- planar/.app.py.un~ +0 -0
- planar/.cli.py.un~ +0 -0
- planar/.config.py.un~ +0 -0
- planar/.context.py.un~ +0 -0
- planar/.db.py.un~ +0 -0
- planar/.di.py.un~ +0 -0
- planar/.engine.py.un~ +0 -0
- planar/.files.py.un~ +0 -0
- planar/.log_context.py.un~ +0 -0
- planar/.log_metadata.py.un~ +0 -0
- planar/.logging.py.un~ +0 -0
- planar/.object_registry.py.un~ +0 -0
- planar/.otel.py.un~ +0 -0
- planar/.server.py.un~ +0 -0
- planar/.session.py.un~ +0 -0
- planar/.sqlalchemy.py.un~ +0 -0
- planar/.task_local.py.un~ +0 -0
- planar/.test_app.py.un~ +0 -0
- planar/.test_config.py.un~ +0 -0
- planar/.test_object_config.py.un~ +0 -0
- planar/.test_sqlalchemy.py.un~ +0 -0
- planar/.test_utils.py.un~ +0 -0
- planar/.util.py.un~ +0 -0
- planar/.utils.py.un~ +0 -0
- planar/ai/.__init__.py.un~ +0 -0
- planar/ai/._models.py.un~ +0 -0
- planar/ai/.agent.py.un~ +0 -0
- planar/ai/.agent_utils.py.un~ +0 -0
- planar/ai/.events.py.un~ +0 -0
- planar/ai/.files.py.un~ +0 -0
- planar/ai/.models.py.un~ +0 -0
- planar/ai/.providers.py.un~ +0 -0
- planar/ai/.pydantic_ai.py.un~ +0 -0
- planar/ai/.pydantic_ai_agent.py.un~ +0 -0
- planar/ai/.pydantic_ai_provider.py.un~ +0 -0
- planar/ai/.step.py.un~ +0 -0
- planar/ai/.test_agent.py.un~ +0 -0
- planar/ai/.test_agent_serialization.py.un~ +0 -0
- planar/ai/.test_providers.py.un~ +0 -0
- planar/ai/.utils.py.un~ +0 -0
- planar/db/.db.py.un~ +0 -0
- planar/files/.config.py.un~ +0 -0
- planar/files/.local.py.un~ +0 -0
- planar/files/.local_filesystem.py.un~ +0 -0
- planar/files/.model.py.un~ +0 -0
- planar/files/.models.py.un~ +0 -0
- planar/files/.s3.py.un~ +0 -0
- planar/files/.storage.py.un~ +0 -0
- planar/files/.test_files.py.un~ +0 -0
- planar/files/storage/.__init__.py.un~ +0 -0
- planar/files/storage/.base.py.un~ +0 -0
- planar/files/storage/.config.py.un~ +0 -0
- planar/files/storage/.context.py.un~ +0 -0
- planar/files/storage/.local_directory.py.un~ +0 -0
- planar/files/storage/.test_local_directory.py.un~ +0 -0
- planar/files/storage/.test_s3.py.un~ +0 -0
- planar/human/.human.py.un~ +0 -0
- planar/human/.test_human.py.un~ +0 -0
- planar/logging/.__init__.py.un~ +0 -0
- planar/logging/.attributes.py.un~ +0 -0
- planar/logging/.formatter.py.un~ +0 -0
- planar/logging/.logger.py.un~ +0 -0
- planar/logging/.otel.py.un~ +0 -0
- planar/logging/.tracer.py.un~ +0 -0
- planar/modeling/.mixin.py.un~ +0 -0
- planar/modeling/.storage.py.un~ +0 -0
- planar/modeling/orm/.planar_base_model.py.un~ +0 -0
- planar/object_config/.object_config.py.un~ +0 -0
- planar/routers/.__init__.py.un~ +0 -0
- planar/routers/.agents_router.py.un~ +0 -0
- planar/routers/.crud.py.un~ +0 -0
- planar/routers/.decision.py.un~ +0 -0
- planar/routers/.event.py.un~ +0 -0
- planar/routers/.file_attachment.py.un~ +0 -0
- planar/routers/.files.py.un~ +0 -0
- planar/routers/.files_router.py.un~ +0 -0
- planar/routers/.human.py.un~ +0 -0
- planar/routers/.info.py.un~ +0 -0
- planar/routers/.models.py.un~ +0 -0
- planar/routers/.object_config_router.py.un~ +0 -0
- planar/routers/.rule.py.un~ +0 -0
- planar/routers/.test_object_config_router.py.un~ +0 -0
- planar/routers/.test_workflow_router.py.un~ +0 -0
- planar/routers/.workflow.py.un~ +0 -0
- planar/rules/.decorator.py.un~ +0 -0
- planar/rules/.runner.py.un~ +0 -0
- planar/rules/.test_rules.py.un~ +0 -0
- planar/security/.jwt_middleware.py.un~ +0 -0
- planar/sse/.constants.py.un~ +0 -0
- planar/sse/.example.html.un~ +0 -0
- planar/sse/.hub.py.un~ +0 -0
- planar/sse/.model.py.un~ +0 -0
- planar/sse/.proxy.py.un~ +0 -0
- planar/testing/.client.py.un~ +0 -0
- planar/testing/.memory_storage.py.un~ +0 -0
- planar/testing/.planar_test_client.py.un~ +0 -0
- planar/testing/.predictable_tracer.py.un~ +0 -0
- planar/testing/.synchronizable_tracer.py.un~ +0 -0
- planar/testing/.test_memory_storage.py.un~ +0 -0
- planar/testing/.workflow_observer.py.un~ +0 -0
- planar/workflows/.__init__.py.un~ +0 -0
- planar/workflows/.builtin_steps.py.un~ +0 -0
- planar/workflows/.concurrency_tracing.py.un~ +0 -0
- planar/workflows/.context.py.un~ +0 -0
- planar/workflows/.contrib.py.un~ +0 -0
- planar/workflows/.decorators.py.un~ +0 -0
- planar/workflows/.durable_test.py.un~ +0 -0
- planar/workflows/.errors.py.un~ +0 -0
- planar/workflows/.events.py.un~ +0 -0
- planar/workflows/.exceptions.py.un~ +0 -0
- planar/workflows/.execution.py.un~ +0 -0
- planar/workflows/.human.py.un~ +0 -0
- planar/workflows/.lock.py.un~ +0 -0
- planar/workflows/.misc.py.un~ +0 -0
- planar/workflows/.model.py.un~ +0 -0
- planar/workflows/.models.py.un~ +0 -0
- planar/workflows/.notifications.py.un~ +0 -0
- planar/workflows/.orchestrator.py.un~ +0 -0
- planar/workflows/.runtime.py.un~ +0 -0
- planar/workflows/.serialization.py.un~ +0 -0
- planar/workflows/.step.py.un~ +0 -0
- planar/workflows/.step_core.py.un~ +0 -0
- planar/workflows/.sub_workflow_runner.py.un~ +0 -0
- planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
- planar/workflows/.test_concurrency.py.un~ +0 -0
- planar/workflows/.test_concurrency_detection.py.un~ +0 -0
- planar/workflows/.test_human.py.un~ +0 -0
- planar/workflows/.test_lock_timeout.py.un~ +0 -0
- planar/workflows/.test_orchestrator.py.un~ +0 -0
- planar/workflows/.test_race_conditions.py.un~ +0 -0
- planar/workflows/.test_serialization.py.un~ +0 -0
- planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
- planar/workflows/.test_workflow.py.un~ +0 -0
- planar/workflows/.tracing.py.un~ +0 -0
- planar/workflows/.types.py.un~ +0 -0
- planar/workflows/.util.py.un~ +0 -0
- planar/workflows/.utils.py.un~ +0 -0
- planar/workflows/.workflow.py.un~ +0 -0
- planar/workflows/.workflow_wrapper.py.un~ +0 -0
- planar/workflows/.wrappers.py.un~ +0 -0
- planar-0.5.0.dist-info/RECORD +0 -289
- {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/WHEEL +0 -0
- {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/entry_points.txt +0 -0
planar/config.py
CHANGED
@@ -5,13 +5,14 @@ import os
|
|
5
5
|
import sys
|
6
6
|
from enum import Enum
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import Annotated, Any, Dict, Literal
|
8
|
+
from typing import Annotated, Any, Dict, Literal
|
9
9
|
|
10
10
|
import boto3
|
11
11
|
import yaml
|
12
12
|
from dotenv import load_dotenv
|
13
13
|
from pydantic import (
|
14
14
|
BaseModel,
|
15
|
+
ConfigDict,
|
15
16
|
Field,
|
16
17
|
HttpUrl,
|
17
18
|
SecretStr,
|
@@ -45,8 +46,8 @@ class LogLevel(str, Enum):
|
|
45
46
|
|
46
47
|
class LoggerConfig(BaseModel):
|
47
48
|
level: LogLevel = LogLevel.INFO
|
48
|
-
propagate:
|
49
|
-
file:
|
49
|
+
propagate: bool | None = False
|
50
|
+
file: str | None = None
|
50
51
|
|
51
52
|
|
52
53
|
class SQLiteConfig(BaseModel):
|
@@ -66,11 +67,11 @@ class PostgreSQLConfig(BaseModel):
|
|
66
67
|
driver: Literal["postgresql", "postgresql+asyncpg"] = (
|
67
68
|
"postgresql+asyncpg" # Allow async PostgreSQL
|
68
69
|
)
|
69
|
-
host:
|
70
|
-
port:
|
71
|
-
user:
|
72
|
-
password:
|
73
|
-
db:
|
70
|
+
host: str | None = None
|
71
|
+
port: int | None = None
|
72
|
+
user: str | None = None
|
73
|
+
password: str | None = None
|
74
|
+
db: str | None
|
74
75
|
|
75
76
|
def connection_url(self) -> URL:
|
76
77
|
driver = self.driver
|
@@ -92,15 +93,15 @@ class OpenAIConfig(BaseModel):
|
|
92
93
|
"""Configuration for OpenAI provider."""
|
93
94
|
|
94
95
|
api_key: SecretStr
|
95
|
-
base_url:
|
96
|
-
organization:
|
96
|
+
base_url: str | None = None
|
97
|
+
organization: str | None = None
|
97
98
|
|
98
99
|
|
99
100
|
class AnthropicConfig(BaseModel):
|
100
101
|
"""Configuration for Anthropic provider."""
|
101
102
|
|
102
103
|
api_key: SecretStr
|
103
|
-
base_url:
|
104
|
+
base_url: str | None = None
|
104
105
|
|
105
106
|
|
106
107
|
class GeminiConfig(BaseModel):
|
@@ -112,9 +113,9 @@ class GeminiConfig(BaseModel):
|
|
112
113
|
class AIProvidersConfig(BaseModel):
|
113
114
|
"""Configuration for AI providers."""
|
114
115
|
|
115
|
-
openai:
|
116
|
-
anthropic:
|
117
|
-
gemini:
|
116
|
+
openai: OpenAIConfig | None = None
|
117
|
+
anthropic: AnthropicConfig | None = None
|
118
|
+
gemini: GeminiConfig | None = None
|
118
119
|
|
119
120
|
|
120
121
|
DatabaseConfig = Annotated[
|
@@ -124,7 +125,7 @@ DatabaseConfig = Annotated[
|
|
124
125
|
|
125
126
|
class AppConfig(BaseModel):
|
126
127
|
db_connection: str
|
127
|
-
max_db_conflict_retries:
|
128
|
+
max_db_conflict_retries: int | None = None
|
128
129
|
|
129
130
|
|
130
131
|
def default_storage_config() -> StorageConfig:
|
@@ -162,31 +163,27 @@ PROD_CORS_CONFIG = CorsConfig(
|
|
162
163
|
|
163
164
|
|
164
165
|
class JWTConfig(BaseModel):
|
165
|
-
enabled: bool = False
|
166
166
|
client_id: str | None = None
|
167
167
|
org_id: str | None = None
|
168
168
|
additional_exclusion_paths: list[str] | None = Field(default_factory=list)
|
169
169
|
|
170
170
|
@model_validator(mode="after")
|
171
171
|
def validate_client_id(cls, instance):
|
172
|
-
if instance.
|
173
|
-
raise ValueError("client_id
|
174
|
-
if instance.client_id and not instance.enabled:
|
175
|
-
raise ValueError(
|
176
|
-
"You cannot specify a client_id without enabling JWT - did you mean to set enabled=True?"
|
177
|
-
)
|
172
|
+
if not instance.client_id or not instance.org_id:
|
173
|
+
raise ValueError("Both client_id and org_id required to enable JWT")
|
178
174
|
return instance
|
179
175
|
|
180
176
|
|
181
|
-
|
177
|
+
# Coplane ORG JWT config
|
182
178
|
JWT_COPLANE_CONFIG = JWTConfig(
|
183
|
-
|
179
|
+
client_id="client_01JSJHJPKG09TMSK6NHJP0S180",
|
180
|
+
org_id="org_01JY4QP57Y7H4EQ7HT3BGN7TNK",
|
184
181
|
)
|
185
182
|
|
186
183
|
|
187
184
|
class OtelConfig(BaseModel):
|
188
185
|
collector_endpoint: HttpUrl
|
189
|
-
resource_attributes:
|
186
|
+
resource_attributes: dict[str, str] | None = None
|
190
187
|
|
191
188
|
|
192
189
|
def install_otel_provider(otel_config: OtelConfig):
|
@@ -206,19 +203,31 @@ class AuthzConfig(BaseModel):
|
|
206
203
|
policy_file: str | None = None
|
207
204
|
|
208
205
|
|
206
|
+
class ServiceTokenConfig(BaseModel):
|
207
|
+
token: str | None = Field(None, min_length=1)
|
208
|
+
|
209
|
+
|
210
|
+
class SecurityConfig(BaseModel):
|
211
|
+
cors: CorsConfig = PROD_CORS_CONFIG
|
212
|
+
jwt: JWTConfig | None = None
|
213
|
+
service_token: ServiceTokenConfig | None = None
|
214
|
+
authz: AuthzConfig | None = None
|
215
|
+
|
216
|
+
|
209
217
|
class PlanarConfig(BaseModel):
|
210
218
|
db_connections: Dict[str, DatabaseConfig | str]
|
211
219
|
app: AppConfig
|
212
|
-
ai_providers:
|
213
|
-
storage:
|
220
|
+
ai_providers: AIProvidersConfig | None = None
|
221
|
+
storage: StorageConfig | None = default_storage_config()
|
214
222
|
sse_hub: str | bool = False
|
215
|
-
cors: CorsConfig = PROD_CORS_CONFIG
|
216
223
|
environment: Environment = Environment.DEV
|
217
|
-
|
218
|
-
logging:
|
224
|
+
security: SecurityConfig = SecurityConfig()
|
225
|
+
logging: dict[str, LoggerConfig] | None = None
|
219
226
|
use_alembic: bool | None = True
|
220
|
-
otel:
|
221
|
-
|
227
|
+
otel: OtelConfig | None = None
|
228
|
+
|
229
|
+
# forbid extra keys in the config to prevent accidental misconfiguration
|
230
|
+
model_config = ConfigDict(extra="forbid")
|
222
231
|
|
223
232
|
@model_validator(mode="after")
|
224
233
|
def validate_db_connection_reference(cls, instance):
|
@@ -467,14 +476,14 @@ def load_environment_aware_config[ConfigClass]() -> PlanarConfig:
|
|
467
476
|
|
468
477
|
if env == "dev":
|
469
478
|
base_config = sqlite_config(db_path="planar_dev.db")
|
470
|
-
base_config.
|
479
|
+
base_config.security = SecurityConfig(cors=LOCAL_CORS_CONFIG)
|
471
480
|
base_config.environment = Environment.DEV
|
472
|
-
base_config.jwt = JWT_DISABLED_CONFIG
|
473
481
|
else:
|
474
482
|
base_config = sqlite_config(db_path="planar.db")
|
475
|
-
base_config.cors = PROD_CORS_CONFIG
|
476
483
|
base_config.environment = Environment.PROD
|
477
|
-
base_config.
|
484
|
+
base_config.security = SecurityConfig(
|
485
|
+
cors=PROD_CORS_CONFIG, jwt=JWT_COPLANE_CONFIG
|
486
|
+
)
|
478
487
|
|
479
488
|
# Convert base config to dict for merging
|
480
489
|
# Use by_alias=False to work with Python field names before validation
|
planar/db/db.py
CHANGED
@@ -176,7 +176,8 @@ class DatabaseManager:
|
|
176
176
|
|
177
177
|
def _create_sqlite_engine(self, url: URL) -> AsyncEngine:
|
178
178
|
# in practice this high timeout is only use
|
179
|
-
timeout = int(str(url.query.get("timeout",
|
179
|
+
timeout = int(str(url.query.get("timeout", 60)))
|
180
|
+
logger.info("Setting up SQLite engine with timeout", timeout=timeout)
|
180
181
|
|
181
182
|
engine = create_async_engine(
|
182
183
|
url,
|
@@ -0,0 +1,343 @@
|
|
1
|
+
import uuid
|
2
|
+
from datetime import UTC, datetime, timedelta
|
3
|
+
from enum import Enum
|
4
|
+
from typing import AsyncGenerator, override
|
5
|
+
from urllib.parse import urlparse
|
6
|
+
|
7
|
+
from azure.core.exceptions import ResourceNotFoundError
|
8
|
+
from azure.storage.blob import BlobSasPermissions, generate_blob_sas
|
9
|
+
from azure.storage.blob.aio import BlobServiceClient
|
10
|
+
|
11
|
+
from planar.logging import get_logger
|
12
|
+
|
13
|
+
from .base import Storage
|
14
|
+
|
15
|
+
logger = get_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class AzureAuthMethod(Enum):
|
19
|
+
CONNECTION_STRING = "connection_string"
|
20
|
+
ACCOUNT_KEY = "account_key"
|
21
|
+
AZURE_AD = "azure_ad"
|
22
|
+
|
23
|
+
|
24
|
+
class AzureBlobStorage(Storage):
|
25
|
+
"""Stores files and mime types in Azure Blob Storage using the async SDK."""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
container_name: str,
|
30
|
+
connection_string: str | None = None,
|
31
|
+
account_url: str | None = None,
|
32
|
+
use_azure_ad: bool = False,
|
33
|
+
account_key: str | None = None,
|
34
|
+
sas_ttl: int = 3600,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Initializes AzureBlobStorage.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
container_name: The name of the Azure Storage container.
|
41
|
+
connection_string: Full connection string (includes all credentials).
|
42
|
+
account_url: Storage account URL (e.g., 'https://<account>.blob.core.windows.net').
|
43
|
+
use_azure_ad: Whether to use DefaultAzureCredential for Azure AD auth.
|
44
|
+
account_key: Storage account key (used with account_url).
|
45
|
+
sas_ttl: Time in seconds for which SAS URLs are valid.
|
46
|
+
"""
|
47
|
+
# Import Azure dependencies when the class is instantiated
|
48
|
+
try:
|
49
|
+
from azure.storage.blob.aio import BlobServiceClient
|
50
|
+
except ImportError as e:
|
51
|
+
raise ImportError(
|
52
|
+
"Azure storage dependencies are not installed. "
|
53
|
+
"Install with: pip install planar[azure]"
|
54
|
+
) from e
|
55
|
+
|
56
|
+
self.container_name = container_name
|
57
|
+
self.sas_ttl = sas_ttl
|
58
|
+
self.client: "BlobServiceClient"
|
59
|
+
|
60
|
+
self.auth_method: AzureAuthMethod
|
61
|
+
self._account_name: str | None = None
|
62
|
+
self._account_key: str | None = None
|
63
|
+
|
64
|
+
from azure.storage.blob._shared.policies_async import ExponentialRetry
|
65
|
+
|
66
|
+
client_kwargs = {
|
67
|
+
"connection_timeout": 10,
|
68
|
+
"read_timeout": 40,
|
69
|
+
"retry_policy": ExponentialRetry(
|
70
|
+
retry_total=2,
|
71
|
+
),
|
72
|
+
}
|
73
|
+
|
74
|
+
# Initialize BlobServiceClient based on auth method
|
75
|
+
if connection_string:
|
76
|
+
self.client = BlobServiceClient.from_connection_string(
|
77
|
+
connection_string,
|
78
|
+
**client_kwargs,
|
79
|
+
)
|
80
|
+
self.auth_method = AzureAuthMethod.CONNECTION_STRING
|
81
|
+
# Extract account name and key from the connection string for SAS signing
|
82
|
+
self._account_name = self._extract_account_name_from_connection_string(
|
83
|
+
connection_string
|
84
|
+
)
|
85
|
+
self._account_key = self._extract_account_key_from_connection_string(
|
86
|
+
connection_string
|
87
|
+
)
|
88
|
+
|
89
|
+
elif use_azure_ad:
|
90
|
+
if not account_url:
|
91
|
+
raise ValueError(
|
92
|
+
"account_url is required when using Azure AD authentication"
|
93
|
+
)
|
94
|
+
from azure.identity.aio import DefaultAzureCredential
|
95
|
+
|
96
|
+
credential = DefaultAzureCredential()
|
97
|
+
self.client = BlobServiceClient(
|
98
|
+
account_url=account_url, credential=credential, **client_kwargs
|
99
|
+
)
|
100
|
+
self.auth_method = AzureAuthMethod.AZURE_AD
|
101
|
+
self._account_key = None
|
102
|
+
self._account_name = self._extract_account_name_from_account_url(
|
103
|
+
account_url
|
104
|
+
)
|
105
|
+
|
106
|
+
elif account_key:
|
107
|
+
if not account_url:
|
108
|
+
raise ValueError(
|
109
|
+
"account_url is required when using account key authentication"
|
110
|
+
)
|
111
|
+
self.client = BlobServiceClient(
|
112
|
+
account_url=account_url, credential=account_key, **client_kwargs
|
113
|
+
)
|
114
|
+
self.auth_method = AzureAuthMethod.ACCOUNT_KEY
|
115
|
+
self._account_key = account_key
|
116
|
+
# Extract account name from URL for SAS generation
|
117
|
+
self._account_name = self._extract_account_name_from_account_url(
|
118
|
+
account_url
|
119
|
+
)
|
120
|
+
|
121
|
+
else:
|
122
|
+
raise ValueError(
|
123
|
+
"Must provide either connection_string, use_azure_ad=True, or account_key"
|
124
|
+
)
|
125
|
+
|
126
|
+
async def __aenter__(self):
|
127
|
+
"""Enter async context manager for proper cleanup in tests."""
|
128
|
+
await self.client.__aenter__()
|
129
|
+
return self
|
130
|
+
|
131
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
132
|
+
"""Exit async context manager and cleanup resources."""
|
133
|
+
await self.client.__aexit__(exc_type, exc_val, exc_tb)
|
134
|
+
|
135
|
+
@override
|
136
|
+
async def close(self):
|
137
|
+
"""Explicitly close the client. Only needed if not using as context manager."""
|
138
|
+
await self.client.close()
|
139
|
+
|
140
|
+
async def put(
|
141
|
+
self, stream: AsyncGenerator[bytes, None], mime_type: str | None = None
|
142
|
+
) -> str:
|
143
|
+
"""
|
144
|
+
Stores a stream and optional mime type to Azure Blob Storage.
|
145
|
+
|
146
|
+
The storage reference returned is a unique UUID.
|
147
|
+
The mime_type is stored as the blob's ContentType.
|
148
|
+
"""
|
149
|
+
ref = str(uuid.uuid4())
|
150
|
+
|
151
|
+
content_settings = None
|
152
|
+
if mime_type:
|
153
|
+
from azure.storage.blob import ContentSettings
|
154
|
+
|
155
|
+
content_settings = ContentSettings(content_type=mime_type)
|
156
|
+
|
157
|
+
try:
|
158
|
+
container_client = self.client.get_container_client(self.container_name)
|
159
|
+
blob_client = container_client.get_blob_client(ref)
|
160
|
+
|
161
|
+
await blob_client.upload_blob(
|
162
|
+
stream,
|
163
|
+
content_settings=content_settings,
|
164
|
+
overwrite=True,
|
165
|
+
)
|
166
|
+
return ref
|
167
|
+
|
168
|
+
except Exception as e:
|
169
|
+
logger.exception(
|
170
|
+
"failed azure blob upload",
|
171
|
+
ref=ref,
|
172
|
+
container_name=self.container_name,
|
173
|
+
)
|
174
|
+
raise IOError(f"Failed to upload to Azure blob {ref}. Error: {e}") from e
|
175
|
+
|
176
|
+
async def get(self, ref: str) -> tuple[AsyncGenerator[bytes, None], str | None]:
|
177
|
+
"""
|
178
|
+
Retrieves a stream and its mime type from Azure Blob Storage.
|
179
|
+
"""
|
180
|
+
try:
|
181
|
+
container_client = self.client.get_container_client(self.container_name)
|
182
|
+
blob_client = container_client.get_blob_client(ref)
|
183
|
+
|
184
|
+
# Get blob properties for content type
|
185
|
+
properties = await blob_client.get_blob_properties()
|
186
|
+
mime_type = (
|
187
|
+
properties.content_settings.content_type
|
188
|
+
if properties.content_settings
|
189
|
+
else None
|
190
|
+
)
|
191
|
+
|
192
|
+
async def _stream_wrapper():
|
193
|
+
download_stream = await blob_client.download_blob()
|
194
|
+
async for chunk in download_stream.chunks():
|
195
|
+
yield chunk
|
196
|
+
|
197
|
+
return _stream_wrapper(), mime_type
|
198
|
+
except ResourceNotFoundError as e:
|
199
|
+
logger.warning(
|
200
|
+
"azure blob not found",
|
201
|
+
ref=ref,
|
202
|
+
container_name=self.container_name,
|
203
|
+
error=e,
|
204
|
+
)
|
205
|
+
raise FileNotFoundError(f"Azure blob not found: {ref}") from e
|
206
|
+
except Exception as e:
|
207
|
+
logger.exception(
|
208
|
+
"failed azure blob download",
|
209
|
+
ref=ref,
|
210
|
+
container_name=self.container_name,
|
211
|
+
)
|
212
|
+
raise IOError(
|
213
|
+
f"Failed to download from Azure blob {ref}. Error: {e}"
|
214
|
+
) from e
|
215
|
+
|
216
|
+
async def delete(self, ref: str) -> None:
|
217
|
+
"""
|
218
|
+
Deletes a blob from Azure Storage.
|
219
|
+
Does not raise an error if the blob does not exist.
|
220
|
+
"""
|
221
|
+
try:
|
222
|
+
container_client = self.client.get_container_client(self.container_name)
|
223
|
+
blob_client = container_client.get_blob_client(ref)
|
224
|
+
|
225
|
+
await blob_client.delete_blob(delete_snapshots="include")
|
226
|
+
|
227
|
+
except ResourceNotFoundError:
|
228
|
+
logger.debug(
|
229
|
+
"azure blob not found, not raising error",
|
230
|
+
ref=ref,
|
231
|
+
container_name=self.container_name,
|
232
|
+
)
|
233
|
+
except Exception as e:
|
234
|
+
logger.exception(
|
235
|
+
"failed azure blob delete",
|
236
|
+
ref=ref,
|
237
|
+
container_name=self.container_name,
|
238
|
+
)
|
239
|
+
raise IOError(f"Failed to delete Azure blob {ref}. Error: {e}") from e
|
240
|
+
|
241
|
+
async def external_url(self, ref: str) -> str | None:
|
242
|
+
"""
|
243
|
+
Returns a SAS URL to access the Azure blob if we have the capability.
|
244
|
+
|
245
|
+
Supports SAS generation for:
|
246
|
+
- Account Key (Account SAS signed with account key)
|
247
|
+
- Connection String (Account SAS signed with extracted account key)
|
248
|
+
- Azure AD (User Delegation SAS signed with a User Delegation Key)
|
249
|
+
"""
|
250
|
+
|
251
|
+
if not self._account_name:
|
252
|
+
logger.debug(
|
253
|
+
"cannot generate sas url without account name",
|
254
|
+
ref=ref,
|
255
|
+
has_account_name=bool(self._account_name),
|
256
|
+
)
|
257
|
+
return None
|
258
|
+
|
259
|
+
expiry_time = datetime.now(UTC) + timedelta(seconds=self.sas_ttl)
|
260
|
+
|
261
|
+
if self.auth_method.name in ("ACCOUNT_KEY", "CONNECTION_STRING"):
|
262
|
+
if not self._account_key:
|
263
|
+
logger.debug(
|
264
|
+
"cannot generate account-key SAS without account key",
|
265
|
+
ref=ref,
|
266
|
+
has_account_key=bool(self._account_key),
|
267
|
+
)
|
268
|
+
return None
|
269
|
+
|
270
|
+
sas_token = generate_blob_sas(
|
271
|
+
account_name=self._account_name,
|
272
|
+
container_name=self.container_name,
|
273
|
+
blob_name=ref,
|
274
|
+
account_key=self._account_key,
|
275
|
+
permission=BlobSasPermissions(read=True),
|
276
|
+
expiry=expiry_time,
|
277
|
+
)
|
278
|
+
|
279
|
+
elif self.auth_method.name == "AZURE_AD":
|
280
|
+
# Generate a User Delegation SAS signed with a user delegation key
|
281
|
+
start_time = datetime.utcnow()
|
282
|
+
user_delegation_key = await self.client.get_user_delegation_key(
|
283
|
+
key_start_time=start_time, key_expiry_time=expiry_time
|
284
|
+
)
|
285
|
+
sas_token = generate_blob_sas(
|
286
|
+
account_name=self._account_name,
|
287
|
+
container_name=self.container_name,
|
288
|
+
blob_name=ref,
|
289
|
+
user_delegation_key=user_delegation_key,
|
290
|
+
permission=BlobSasPermissions(read=True),
|
291
|
+
expiry=expiry_time,
|
292
|
+
)
|
293
|
+
else:
|
294
|
+
return None
|
295
|
+
|
296
|
+
blob_url = f"{self.client.url}{self.container_name}/{ref}"
|
297
|
+
return f"{blob_url}?{sas_token}"
|
298
|
+
|
299
|
+
@staticmethod
|
300
|
+
def _extract_account_name_from_connection_string(
|
301
|
+
connection_string: str,
|
302
|
+
) -> str | None:
|
303
|
+
try:
|
304
|
+
# Split on ';' and build a dict of key/value pairs
|
305
|
+
parts = dict(
|
306
|
+
part.split("=", 1)
|
307
|
+
for part in connection_string.split(";")
|
308
|
+
if "=" in part
|
309
|
+
)
|
310
|
+
account_name = parts.get("AccountName")
|
311
|
+
return account_name
|
312
|
+
except Exception:
|
313
|
+
return None
|
314
|
+
|
315
|
+
@staticmethod
|
316
|
+
def _extract_account_key_from_connection_string(
|
317
|
+
connection_string: str,
|
318
|
+
) -> str | None:
|
319
|
+
try:
|
320
|
+
parts = dict(
|
321
|
+
part.split("=", 1)
|
322
|
+
for part in connection_string.split(";")
|
323
|
+
if "=" in part
|
324
|
+
)
|
325
|
+
return parts.get("AccountKey")
|
326
|
+
except Exception:
|
327
|
+
return None
|
328
|
+
|
329
|
+
@staticmethod
|
330
|
+
def _extract_account_name_from_account_url(account_url: str) -> str | None:
|
331
|
+
try:
|
332
|
+
parsed = urlparse(account_url)
|
333
|
+
host = parsed.hostname or ""
|
334
|
+
# Standard Azure: https://{account}.blob.core.windows.net
|
335
|
+
if "." in host and not host.startswith("127.") and host != "localhost":
|
336
|
+
return host.split(".")[0]
|
337
|
+
# Azurite style: http://127.0.0.1:10000/{account}
|
338
|
+
path = parsed.path.strip("/")
|
339
|
+
if path:
|
340
|
+
return path.split("/")[0]
|
341
|
+
return None
|
342
|
+
except Exception:
|
343
|
+
return None
|
planar/files/storage/base.py
CHANGED
@@ -59,3 +59,10 @@ class Storage(ABC):
|
|
59
59
|
data_bytes, mime_type = await self.get_bytes(ref)
|
60
60
|
# TODO: Potentially use encoding from mime_type if available?
|
61
61
|
return data_bytes.decode(encoding), mime_type
|
62
|
+
|
63
|
+
async def close(self) -> None:
|
64
|
+
"""
|
65
|
+
Optional cleanup method for storage implementations.
|
66
|
+
Override this if your storage backend needs explicit cleanup.
|
67
|
+
"""
|
68
|
+
pass
|
planar/files/storage/config.py
CHANGED
@@ -1,10 +1,15 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
3
|
+
from typing import TYPE_CHECKING, Annotated, Literal
|
4
|
+
|
5
|
+
from pydantic import BaseModel, Field, model_validator
|
4
6
|
|
5
7
|
from .local_directory import LocalDirectoryStorage
|
6
8
|
from .s3 import S3Storage
|
7
9
|
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from .azure_blob import AzureBlobStorage
|
12
|
+
|
8
13
|
|
9
14
|
class LocalDirectoryConfig(BaseModel):
|
10
15
|
backend: Literal["localdir"]
|
@@ -15,19 +20,66 @@ class S3Config(BaseModel):
|
|
15
20
|
backend: Literal["s3"]
|
16
21
|
bucket_name: str
|
17
22
|
region: str
|
18
|
-
access_key:
|
19
|
-
secret_key:
|
20
|
-
endpoint_url:
|
23
|
+
access_key: str | None = None
|
24
|
+
secret_key: str | None = None
|
25
|
+
endpoint_url: str | None = None
|
21
26
|
presigned_url_ttl: int = 3600
|
22
27
|
|
23
28
|
|
29
|
+
class AzureBlobConfig(BaseModel):
|
30
|
+
backend: Literal["azure_blob"]
|
31
|
+
container_name: str
|
32
|
+
|
33
|
+
# Authentication options (mutually exclusive)
|
34
|
+
connection_string: str | None = None # Full connection string
|
35
|
+
account_url: str | None = None # Storage account URL
|
36
|
+
use_azure_ad: bool | None = None # Use DefaultAzureCredential
|
37
|
+
account_key: str | None = None # Storage account key
|
38
|
+
|
39
|
+
# Common settings
|
40
|
+
sas_ttl: int = 3600 # SAS URL expiry time in seconds
|
41
|
+
|
42
|
+
@model_validator(mode="after")
|
43
|
+
def validate_auth_config(self):
|
44
|
+
"""Ensure exactly one valid authentication configuration."""
|
45
|
+
|
46
|
+
# Check if connection_string is provided
|
47
|
+
if self.connection_string:
|
48
|
+
# Connection string is self-contained
|
49
|
+
if self.account_url or self.use_azure_ad or self.account_key:
|
50
|
+
raise ValueError(
|
51
|
+
"When using connection_string, don't provide account_url, use_azure_ad, or account_key"
|
52
|
+
)
|
53
|
+
return self
|
54
|
+
|
55
|
+
# If no connection string, must have account_url
|
56
|
+
if not self.account_url:
|
57
|
+
raise ValueError("Either connection_string or account_url must be provided")
|
58
|
+
|
59
|
+
# With account_url, must have exactly one credential type
|
60
|
+
credential_methods = [
|
61
|
+
self.use_azure_ad is True,
|
62
|
+
self.account_key is not None,
|
63
|
+
]
|
64
|
+
|
65
|
+
if sum(credential_methods) != 1:
|
66
|
+
raise ValueError(
|
67
|
+
"When using account_url, exactly one credential method must be specified: "
|
68
|
+
"use_azure_ad=true or account_key"
|
69
|
+
)
|
70
|
+
|
71
|
+
return self
|
72
|
+
|
73
|
+
|
24
74
|
StorageConfig = Annotated[
|
25
|
-
LocalDirectoryConfig | S3Config,
|
75
|
+
LocalDirectoryConfig | S3Config | AzureBlobConfig,
|
26
76
|
Field(discriminator="backend"),
|
27
77
|
]
|
28
78
|
|
29
79
|
|
30
|
-
def create_from_config(
|
80
|
+
def create_from_config(
|
81
|
+
config: StorageConfig,
|
82
|
+
) -> LocalDirectoryStorage | S3Storage | AzureBlobStorage:
|
31
83
|
"""Creates a storage instance from the given configuration."""
|
32
84
|
if config.backend == "localdir":
|
33
85
|
return LocalDirectoryStorage(config.directory)
|
@@ -40,5 +92,16 @@ def create_from_config(config: StorageConfig) -> LocalDirectoryStorage | S3Stora
|
|
40
92
|
endpoint_url=config.endpoint_url,
|
41
93
|
presigned_url_ttl=config.presigned_url_ttl,
|
42
94
|
)
|
95
|
+
elif config.backend == "azure_blob":
|
96
|
+
from .azure_blob import AzureBlobStorage
|
97
|
+
|
98
|
+
return AzureBlobStorage(
|
99
|
+
container_name=config.container_name,
|
100
|
+
connection_string=config.connection_string,
|
101
|
+
account_url=config.account_url,
|
102
|
+
use_azure_ad=config.use_azure_ad or False,
|
103
|
+
account_key=config.account_key,
|
104
|
+
sas_ttl=config.sas_ttl,
|
105
|
+
)
|
43
106
|
else:
|
44
107
|
raise ValueError(f"Unsupported backend: {config.backend}")
|
planar/files/storage/s3.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import io
|
3
3
|
import uuid
|
4
|
-
from typing import Any, AsyncGenerator, Dict,
|
4
|
+
from typing import Any, AsyncGenerator, Dict, Tuple
|
5
5
|
|
6
6
|
import boto3
|
7
7
|
from botocore.client import Config as BotoConfig
|
@@ -21,11 +21,11 @@ class S3Storage(Storage):
|
|
21
21
|
self,
|
22
22
|
bucket_name: str,
|
23
23
|
region: str,
|
24
|
-
endpoint_url:
|
25
|
-
access_key_id:
|
26
|
-
secret_access_key:
|
27
|
-
session_token:
|
28
|
-
boto_config:
|
24
|
+
endpoint_url: str | None = None,
|
25
|
+
access_key_id: str | None = None,
|
26
|
+
secret_access_key: str | None = None,
|
27
|
+
session_token: str | None = None, # For temporary credentials
|
28
|
+
boto_config: Dict[str, Any] | None = None, # Additional boto3 client config
|
29
29
|
presigned_url_ttl: int = 3600,
|
30
30
|
):
|
31
31
|
"""
|