planar 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. planar/_version.py +1 -1
  2. planar/ai/agent.py +67 -30
  3. planar/ai/pydantic_ai.py +570 -0
  4. planar/ai/pydantic_ai_agent.py +329 -0
  5. planar/ai/test_agent.py +2 -2
  6. planar/app.py +64 -20
  7. planar/cli.py +39 -27
  8. planar/config.py +45 -36
  9. planar/db/db.py +2 -1
  10. planar/files/storage/azure_blob.py +343 -0
  11. planar/files/storage/base.py +7 -0
  12. planar/files/storage/config.py +70 -7
  13. planar/files/storage/s3.py +6 -6
  14. planar/files/storage/test_azure_blob.py +435 -0
  15. planar/logging/formatter.py +17 -4
  16. planar/logging/test_formatter.py +327 -0
  17. planar/registry_items.py +2 -1
  18. planar/routers/agents_router.py +3 -1
  19. planar/routers/files.py +11 -2
  20. planar/routers/models.py +14 -1
  21. planar/routers/test_files_router.py +49 -0
  22. planar/routers/test_routes_security.py +5 -7
  23. planar/routers/test_workflow_router.py +270 -3
  24. planar/routers/workflow.py +95 -36
  25. planar/rules/models.py +36 -39
  26. planar/rules/test_data/account_dormancy_management.json +223 -0
  27. planar/rules/test_data/airline_loyalty_points_calculator.json +262 -0
  28. planar/rules/test_data/applicant_risk_assessment.json +435 -0
  29. planar/rules/test_data/booking_fraud_detection.json +407 -0
  30. planar/rules/test_data/cellular_data_rollover_system.json +258 -0
  31. planar/rules/test_data/clinical_trial_eligibility_screener.json +437 -0
  32. planar/rules/test_data/customer_lifetime_value.json +143 -0
  33. planar/rules/test_data/import_duties_calculator.json +289 -0
  34. planar/rules/test_data/insurance_prior_authorization.json +443 -0
  35. planar/rules/test_data/online_check_in_eligibility_system.json +254 -0
  36. planar/rules/test_data/order_consolidation_system.json +375 -0
  37. planar/rules/test_data/portfolio_risk_monitor.json +471 -0
  38. planar/rules/test_data/supply_chain_risk.json +253 -0
  39. planar/rules/test_data/warehouse_cross_docking.json +237 -0
  40. planar/rules/test_rules.py +750 -6
  41. planar/scaffold_templates/planar.dev.yaml.j2 +6 -6
  42. planar/scaffold_templates/planar.prod.yaml.j2 +9 -5
  43. planar/scaffold_templates/pyproject.toml.j2 +1 -1
  44. planar/security/auth_context.py +21 -0
  45. planar/security/{jwt_middleware.py → auth_middleware.py} +70 -17
  46. planar/security/authorization.py +9 -15
  47. planar/security/tests/test_auth_middleware.py +162 -0
  48. planar/sse/proxy.py +4 -9
  49. planar/test_app.py +92 -1
  50. planar/test_cli.py +81 -59
  51. planar/test_config.py +17 -14
  52. planar/testing/fixtures.py +325 -0
  53. planar/testing/planar_test_client.py +5 -2
  54. planar/utils.py +41 -1
  55. planar/workflows/execution.py +1 -1
  56. planar/workflows/orchestrator.py +5 -0
  57. planar/workflows/serialization.py +12 -6
  58. planar/workflows/step_core.py +3 -1
  59. planar/workflows/test_serialization.py +9 -1
  60. {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/METADATA +30 -5
  61. planar-0.7.0.dist-info/RECORD +169 -0
  62. planar/.__init__.py.un~ +0 -0
  63. planar/._version.py.un~ +0 -0
  64. planar/.app.py.un~ +0 -0
  65. planar/.cli.py.un~ +0 -0
  66. planar/.config.py.un~ +0 -0
  67. planar/.context.py.un~ +0 -0
  68. planar/.db.py.un~ +0 -0
  69. planar/.di.py.un~ +0 -0
  70. planar/.engine.py.un~ +0 -0
  71. planar/.files.py.un~ +0 -0
  72. planar/.log_context.py.un~ +0 -0
  73. planar/.log_metadata.py.un~ +0 -0
  74. planar/.logging.py.un~ +0 -0
  75. planar/.object_registry.py.un~ +0 -0
  76. planar/.otel.py.un~ +0 -0
  77. planar/.server.py.un~ +0 -0
  78. planar/.session.py.un~ +0 -0
  79. planar/.sqlalchemy.py.un~ +0 -0
  80. planar/.task_local.py.un~ +0 -0
  81. planar/.test_app.py.un~ +0 -0
  82. planar/.test_config.py.un~ +0 -0
  83. planar/.test_object_config.py.un~ +0 -0
  84. planar/.test_sqlalchemy.py.un~ +0 -0
  85. planar/.test_utils.py.un~ +0 -0
  86. planar/.util.py.un~ +0 -0
  87. planar/.utils.py.un~ +0 -0
  88. planar/ai/.__init__.py.un~ +0 -0
  89. planar/ai/._models.py.un~ +0 -0
  90. planar/ai/.agent.py.un~ +0 -0
  91. planar/ai/.agent_utils.py.un~ +0 -0
  92. planar/ai/.events.py.un~ +0 -0
  93. planar/ai/.files.py.un~ +0 -0
  94. planar/ai/.models.py.un~ +0 -0
  95. planar/ai/.providers.py.un~ +0 -0
  96. planar/ai/.pydantic_ai.py.un~ +0 -0
  97. planar/ai/.pydantic_ai_agent.py.un~ +0 -0
  98. planar/ai/.pydantic_ai_provider.py.un~ +0 -0
  99. planar/ai/.step.py.un~ +0 -0
  100. planar/ai/.test_agent.py.un~ +0 -0
  101. planar/ai/.test_agent_serialization.py.un~ +0 -0
  102. planar/ai/.test_providers.py.un~ +0 -0
  103. planar/ai/.utils.py.un~ +0 -0
  104. planar/db/.db.py.un~ +0 -0
  105. planar/files/.config.py.un~ +0 -0
  106. planar/files/.local.py.un~ +0 -0
  107. planar/files/.local_filesystem.py.un~ +0 -0
  108. planar/files/.model.py.un~ +0 -0
  109. planar/files/.models.py.un~ +0 -0
  110. planar/files/.s3.py.un~ +0 -0
  111. planar/files/.storage.py.un~ +0 -0
  112. planar/files/.test_files.py.un~ +0 -0
  113. planar/files/storage/.__init__.py.un~ +0 -0
  114. planar/files/storage/.base.py.un~ +0 -0
  115. planar/files/storage/.config.py.un~ +0 -0
  116. planar/files/storage/.context.py.un~ +0 -0
  117. planar/files/storage/.local_directory.py.un~ +0 -0
  118. planar/files/storage/.test_local_directory.py.un~ +0 -0
  119. planar/files/storage/.test_s3.py.un~ +0 -0
  120. planar/human/.human.py.un~ +0 -0
  121. planar/human/.test_human.py.un~ +0 -0
  122. planar/logging/.__init__.py.un~ +0 -0
  123. planar/logging/.attributes.py.un~ +0 -0
  124. planar/logging/.formatter.py.un~ +0 -0
  125. planar/logging/.logger.py.un~ +0 -0
  126. planar/logging/.otel.py.un~ +0 -0
  127. planar/logging/.tracer.py.un~ +0 -0
  128. planar/modeling/.mixin.py.un~ +0 -0
  129. planar/modeling/.storage.py.un~ +0 -0
  130. planar/modeling/orm/.planar_base_model.py.un~ +0 -0
  131. planar/object_config/.object_config.py.un~ +0 -0
  132. planar/routers/.__init__.py.un~ +0 -0
  133. planar/routers/.agents_router.py.un~ +0 -0
  134. planar/routers/.crud.py.un~ +0 -0
  135. planar/routers/.decision.py.un~ +0 -0
  136. planar/routers/.event.py.un~ +0 -0
  137. planar/routers/.file_attachment.py.un~ +0 -0
  138. planar/routers/.files.py.un~ +0 -0
  139. planar/routers/.files_router.py.un~ +0 -0
  140. planar/routers/.human.py.un~ +0 -0
  141. planar/routers/.info.py.un~ +0 -0
  142. planar/routers/.models.py.un~ +0 -0
  143. planar/routers/.object_config_router.py.un~ +0 -0
  144. planar/routers/.rule.py.un~ +0 -0
  145. planar/routers/.test_object_config_router.py.un~ +0 -0
  146. planar/routers/.test_workflow_router.py.un~ +0 -0
  147. planar/routers/.workflow.py.un~ +0 -0
  148. planar/rules/.decorator.py.un~ +0 -0
  149. planar/rules/.runner.py.un~ +0 -0
  150. planar/rules/.test_rules.py.un~ +0 -0
  151. planar/security/.jwt_middleware.py.un~ +0 -0
  152. planar/sse/.constants.py.un~ +0 -0
  153. planar/sse/.example.html.un~ +0 -0
  154. planar/sse/.hub.py.un~ +0 -0
  155. planar/sse/.model.py.un~ +0 -0
  156. planar/sse/.proxy.py.un~ +0 -0
  157. planar/testing/.client.py.un~ +0 -0
  158. planar/testing/.memory_storage.py.un~ +0 -0
  159. planar/testing/.planar_test_client.py.un~ +0 -0
  160. planar/testing/.predictable_tracer.py.un~ +0 -0
  161. planar/testing/.synchronizable_tracer.py.un~ +0 -0
  162. planar/testing/.test_memory_storage.py.un~ +0 -0
  163. planar/testing/.workflow_observer.py.un~ +0 -0
  164. planar/workflows/.__init__.py.un~ +0 -0
  165. planar/workflows/.builtin_steps.py.un~ +0 -0
  166. planar/workflows/.concurrency_tracing.py.un~ +0 -0
  167. planar/workflows/.context.py.un~ +0 -0
  168. planar/workflows/.contrib.py.un~ +0 -0
  169. planar/workflows/.decorators.py.un~ +0 -0
  170. planar/workflows/.durable_test.py.un~ +0 -0
  171. planar/workflows/.errors.py.un~ +0 -0
  172. planar/workflows/.events.py.un~ +0 -0
  173. planar/workflows/.exceptions.py.un~ +0 -0
  174. planar/workflows/.execution.py.un~ +0 -0
  175. planar/workflows/.human.py.un~ +0 -0
  176. planar/workflows/.lock.py.un~ +0 -0
  177. planar/workflows/.misc.py.un~ +0 -0
  178. planar/workflows/.model.py.un~ +0 -0
  179. planar/workflows/.models.py.un~ +0 -0
  180. planar/workflows/.notifications.py.un~ +0 -0
  181. planar/workflows/.orchestrator.py.un~ +0 -0
  182. planar/workflows/.runtime.py.un~ +0 -0
  183. planar/workflows/.serialization.py.un~ +0 -0
  184. planar/workflows/.step.py.un~ +0 -0
  185. planar/workflows/.step_core.py.un~ +0 -0
  186. planar/workflows/.sub_workflow_runner.py.un~ +0 -0
  187. planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
  188. planar/workflows/.test_concurrency.py.un~ +0 -0
  189. planar/workflows/.test_concurrency_detection.py.un~ +0 -0
  190. planar/workflows/.test_human.py.un~ +0 -0
  191. planar/workflows/.test_lock_timeout.py.un~ +0 -0
  192. planar/workflows/.test_orchestrator.py.un~ +0 -0
  193. planar/workflows/.test_race_conditions.py.un~ +0 -0
  194. planar/workflows/.test_serialization.py.un~ +0 -0
  195. planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
  196. planar/workflows/.test_workflow.py.un~ +0 -0
  197. planar/workflows/.tracing.py.un~ +0 -0
  198. planar/workflows/.types.py.un~ +0 -0
  199. planar/workflows/.util.py.un~ +0 -0
  200. planar/workflows/.utils.py.un~ +0 -0
  201. planar/workflows/.workflow.py.un~ +0 -0
  202. planar/workflows/.workflow_wrapper.py.un~ +0 -0
  203. planar/workflows/.wrappers.py.un~ +0 -0
  204. planar-0.5.0.dist-info/RECORD +0 -289
  205. {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/WHEEL +0 -0
  206. {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/entry_points.txt +0 -0
planar/config.py CHANGED
@@ -5,13 +5,14 @@ import os
5
5
  import sys
6
6
  from enum import Enum
7
7
  from pathlib import Path
8
- from typing import Annotated, Any, Dict, Literal, Optional
8
+ from typing import Annotated, Any, Dict, Literal
9
9
 
10
10
  import boto3
11
11
  import yaml
12
12
  from dotenv import load_dotenv
13
13
  from pydantic import (
14
14
  BaseModel,
15
+ ConfigDict,
15
16
  Field,
16
17
  HttpUrl,
17
18
  SecretStr,
@@ -45,8 +46,8 @@ class LogLevel(str, Enum):
45
46
 
46
47
  class LoggerConfig(BaseModel):
47
48
  level: LogLevel = LogLevel.INFO
48
- propagate: Optional[bool] = False
49
- file: Optional[str] = None
49
+ propagate: bool | None = False
50
+ file: str | None = None
50
51
 
51
52
 
52
53
  class SQLiteConfig(BaseModel):
@@ -66,11 +67,11 @@ class PostgreSQLConfig(BaseModel):
66
67
  driver: Literal["postgresql", "postgresql+asyncpg"] = (
67
68
  "postgresql+asyncpg" # Allow async PostgreSQL
68
69
  )
69
- host: Optional[str] = None
70
- port: Optional[int] = None
71
- user: Optional[str] = None
72
- password: Optional[str] = None
73
- db: Optional[str]
70
+ host: str | None = None
71
+ port: int | None = None
72
+ user: str | None = None
73
+ password: str | None = None
74
+ db: str | None
74
75
 
75
76
  def connection_url(self) -> URL:
76
77
  driver = self.driver
@@ -92,15 +93,15 @@ class OpenAIConfig(BaseModel):
92
93
  """Configuration for OpenAI provider."""
93
94
 
94
95
  api_key: SecretStr
95
- base_url: Optional[str] = None
96
- organization: Optional[str] = None
96
+ base_url: str | None = None
97
+ organization: str | None = None
97
98
 
98
99
 
99
100
  class AnthropicConfig(BaseModel):
100
101
  """Configuration for Anthropic provider."""
101
102
 
102
103
  api_key: SecretStr
103
- base_url: Optional[str] = None
104
+ base_url: str | None = None
104
105
 
105
106
 
106
107
  class GeminiConfig(BaseModel):
@@ -112,9 +113,9 @@ class GeminiConfig(BaseModel):
112
113
  class AIProvidersConfig(BaseModel):
113
114
  """Configuration for AI providers."""
114
115
 
115
- openai: Optional[OpenAIConfig] = None
116
- anthropic: Optional[AnthropicConfig] = None
117
- gemini: Optional[GeminiConfig] = None
116
+ openai: OpenAIConfig | None = None
117
+ anthropic: AnthropicConfig | None = None
118
+ gemini: GeminiConfig | None = None
118
119
 
119
120
 
120
121
  DatabaseConfig = Annotated[
@@ -124,7 +125,7 @@ DatabaseConfig = Annotated[
124
125
 
125
126
  class AppConfig(BaseModel):
126
127
  db_connection: str
127
- max_db_conflict_retries: Optional[int] = None
128
+ max_db_conflict_retries: int | None = None
128
129
 
129
130
 
130
131
  def default_storage_config() -> StorageConfig:
@@ -162,31 +163,27 @@ PROD_CORS_CONFIG = CorsConfig(
162
163
 
163
164
 
164
165
  class JWTConfig(BaseModel):
165
- enabled: bool = False
166
166
  client_id: str | None = None
167
167
  org_id: str | None = None
168
168
  additional_exclusion_paths: list[str] | None = Field(default_factory=list)
169
169
 
170
170
  @model_validator(mode="after")
171
171
  def validate_client_id(cls, instance):
172
- if instance.enabled and not instance.client_id:
173
- raise ValueError("client_id is required when JWT is enabled")
174
- if instance.client_id and not instance.enabled:
175
- raise ValueError(
176
- "You cannot specify a client_id without enabling JWT - did you mean to set enabled=True?"
177
- )
172
+ if not instance.client_id or not instance.org_id:
173
+ raise ValueError("Both client_id and org_id required to enable JWT")
178
174
  return instance
179
175
 
180
176
 
181
- JWT_DISABLED_CONFIG = JWTConfig(enabled=False)
177
+ # Coplane ORG JWT config
182
178
  JWT_COPLANE_CONFIG = JWTConfig(
183
- enabled=True, client_id="client_01JSJHJP9Q8GZDK5Y856FEHTB0", org_id=None
179
+ client_id="client_01JSJHJPKG09TMSK6NHJP0S180",
180
+ org_id="org_01JY4QP57Y7H4EQ7HT3BGN7TNK",
184
181
  )
185
182
 
186
183
 
187
184
  class OtelConfig(BaseModel):
188
185
  collector_endpoint: HttpUrl
189
- resource_attributes: Optional[dict[str, str]] = None
186
+ resource_attributes: dict[str, str] | None = None
190
187
 
191
188
 
192
189
  def install_otel_provider(otel_config: OtelConfig):
@@ -206,19 +203,31 @@ class AuthzConfig(BaseModel):
206
203
  policy_file: str | None = None
207
204
 
208
205
 
206
+ class ServiceTokenConfig(BaseModel):
207
+ token: str | None = Field(None, min_length=1)
208
+
209
+
210
+ class SecurityConfig(BaseModel):
211
+ cors: CorsConfig = PROD_CORS_CONFIG
212
+ jwt: JWTConfig | None = None
213
+ service_token: ServiceTokenConfig | None = None
214
+ authz: AuthzConfig | None = None
215
+
216
+
209
217
  class PlanarConfig(BaseModel):
210
218
  db_connections: Dict[str, DatabaseConfig | str]
211
219
  app: AppConfig
212
- ai_providers: Optional[AIProvidersConfig] = None
213
- storage: Optional[StorageConfig] = default_storage_config()
220
+ ai_providers: AIProvidersConfig | None = None
221
+ storage: StorageConfig | None = default_storage_config()
214
222
  sse_hub: str | bool = False
215
- cors: CorsConfig = PROD_CORS_CONFIG
216
223
  environment: Environment = Environment.DEV
217
- jwt: JWTConfig | None = None
218
- logging: Optional[dict[str, LoggerConfig]] = None
224
+ security: SecurityConfig = SecurityConfig()
225
+ logging: dict[str, LoggerConfig] | None = None
219
226
  use_alembic: bool | None = True
220
- otel: Optional[OtelConfig] = None
221
- authz: AuthzConfig | None = None
227
+ otel: OtelConfig | None = None
228
+
229
+ # forbid extra keys in the config to prevent accidental misconfiguration
230
+ model_config = ConfigDict(extra="forbid")
222
231
 
223
232
  @model_validator(mode="after")
224
233
  def validate_db_connection_reference(cls, instance):
@@ -467,14 +476,14 @@ def load_environment_aware_config[ConfigClass]() -> PlanarConfig:
467
476
 
468
477
  if env == "dev":
469
478
  base_config = sqlite_config(db_path="planar_dev.db")
470
- base_config.cors = LOCAL_CORS_CONFIG
479
+ base_config.security = SecurityConfig(cors=LOCAL_CORS_CONFIG)
471
480
  base_config.environment = Environment.DEV
472
- base_config.jwt = JWT_DISABLED_CONFIG
473
481
  else:
474
482
  base_config = sqlite_config(db_path="planar.db")
475
- base_config.cors = PROD_CORS_CONFIG
476
483
  base_config.environment = Environment.PROD
477
- base_config.jwt = JWT_COPLANE_CONFIG
484
+ base_config.security = SecurityConfig(
485
+ cors=PROD_CORS_CONFIG, jwt=JWT_COPLANE_CONFIG
486
+ )
478
487
 
479
488
  # Convert base config to dict for merging
480
489
  # Use by_alias=False to work with Python field names before validation
planar/db/db.py CHANGED
@@ -176,7 +176,8 @@ class DatabaseManager:
176
176
 
177
177
  def _create_sqlite_engine(self, url: URL) -> AsyncEngine:
178
178
  # in practice this high timeout is only use
179
- timeout = int(str(url.query.get("timeout", 10)))
179
+ timeout = int(str(url.query.get("timeout", 60)))
180
+ logger.info("Setting up SQLite engine with timeout", timeout=timeout)
180
181
 
181
182
  engine = create_async_engine(
182
183
  url,
@@ -0,0 +1,343 @@
1
+ import uuid
2
+ from datetime import UTC, datetime, timedelta
3
+ from enum import Enum
4
+ from typing import AsyncGenerator, override
5
+ from urllib.parse import urlparse
6
+
7
+ from azure.core.exceptions import ResourceNotFoundError
8
+ from azure.storage.blob import BlobSasPermissions, generate_blob_sas
9
+ from azure.storage.blob.aio import BlobServiceClient
10
+
11
+ from planar.logging import get_logger
12
+
13
+ from .base import Storage
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class AzureAuthMethod(Enum):
19
+ CONNECTION_STRING = "connection_string"
20
+ ACCOUNT_KEY = "account_key"
21
+ AZURE_AD = "azure_ad"
22
+
23
+
24
+ class AzureBlobStorage(Storage):
25
+ """Stores files and mime types in Azure Blob Storage using the async SDK."""
26
+
27
+ def __init__(
28
+ self,
29
+ container_name: str,
30
+ connection_string: str | None = None,
31
+ account_url: str | None = None,
32
+ use_azure_ad: bool = False,
33
+ account_key: str | None = None,
34
+ sas_ttl: int = 3600,
35
+ ):
36
+ """
37
+ Initializes AzureBlobStorage.
38
+
39
+ Args:
40
+ container_name: The name of the Azure Storage container.
41
+ connection_string: Full connection string (includes all credentials).
42
+ account_url: Storage account URL (e.g., 'https://<account>.blob.core.windows.net').
43
+ use_azure_ad: Whether to use DefaultAzureCredential for Azure AD auth.
44
+ account_key: Storage account key (used with account_url).
45
+ sas_ttl: Time in seconds for which SAS URLs are valid.
46
+ """
47
+ # Import Azure dependencies when the class is instantiated
48
+ try:
49
+ from azure.storage.blob.aio import BlobServiceClient
50
+ except ImportError as e:
51
+ raise ImportError(
52
+ "Azure storage dependencies are not installed. "
53
+ "Install with: pip install planar[azure]"
54
+ ) from e
55
+
56
+ self.container_name = container_name
57
+ self.sas_ttl = sas_ttl
58
+ self.client: "BlobServiceClient"
59
+
60
+ self.auth_method: AzureAuthMethod
61
+ self._account_name: str | None = None
62
+ self._account_key: str | None = None
63
+
64
+ from azure.storage.blob._shared.policies_async import ExponentialRetry
65
+
66
+ client_kwargs = {
67
+ "connection_timeout": 10,
68
+ "read_timeout": 40,
69
+ "retry_policy": ExponentialRetry(
70
+ retry_total=2,
71
+ ),
72
+ }
73
+
74
+ # Initialize BlobServiceClient based on auth method
75
+ if connection_string:
76
+ self.client = BlobServiceClient.from_connection_string(
77
+ connection_string,
78
+ **client_kwargs,
79
+ )
80
+ self.auth_method = AzureAuthMethod.CONNECTION_STRING
81
+ # Extract account name and key from the connection string for SAS signing
82
+ self._account_name = self._extract_account_name_from_connection_string(
83
+ connection_string
84
+ )
85
+ self._account_key = self._extract_account_key_from_connection_string(
86
+ connection_string
87
+ )
88
+
89
+ elif use_azure_ad:
90
+ if not account_url:
91
+ raise ValueError(
92
+ "account_url is required when using Azure AD authentication"
93
+ )
94
+ from azure.identity.aio import DefaultAzureCredential
95
+
96
+ credential = DefaultAzureCredential()
97
+ self.client = BlobServiceClient(
98
+ account_url=account_url, credential=credential, **client_kwargs
99
+ )
100
+ self.auth_method = AzureAuthMethod.AZURE_AD
101
+ self._account_key = None
102
+ self._account_name = self._extract_account_name_from_account_url(
103
+ account_url
104
+ )
105
+
106
+ elif account_key:
107
+ if not account_url:
108
+ raise ValueError(
109
+ "account_url is required when using account key authentication"
110
+ )
111
+ self.client = BlobServiceClient(
112
+ account_url=account_url, credential=account_key, **client_kwargs
113
+ )
114
+ self.auth_method = AzureAuthMethod.ACCOUNT_KEY
115
+ self._account_key = account_key
116
+ # Extract account name from URL for SAS generation
117
+ self._account_name = self._extract_account_name_from_account_url(
118
+ account_url
119
+ )
120
+
121
+ else:
122
+ raise ValueError(
123
+ "Must provide either connection_string, use_azure_ad=True, or account_key"
124
+ )
125
+
126
+ async def __aenter__(self):
127
+ """Enter async context manager for proper cleanup in tests."""
128
+ await self.client.__aenter__()
129
+ return self
130
+
131
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
132
+ """Exit async context manager and cleanup resources."""
133
+ await self.client.__aexit__(exc_type, exc_val, exc_tb)
134
+
135
+ @override
136
+ async def close(self):
137
+ """Explicitly close the client. Only needed if not using as context manager."""
138
+ await self.client.close()
139
+
140
+ async def put(
141
+ self, stream: AsyncGenerator[bytes, None], mime_type: str | None = None
142
+ ) -> str:
143
+ """
144
+ Stores a stream and optional mime type to Azure Blob Storage.
145
+
146
+ The storage reference returned is a unique UUID.
147
+ The mime_type is stored as the blob's ContentType.
148
+ """
149
+ ref = str(uuid.uuid4())
150
+
151
+ content_settings = None
152
+ if mime_type:
153
+ from azure.storage.blob import ContentSettings
154
+
155
+ content_settings = ContentSettings(content_type=mime_type)
156
+
157
+ try:
158
+ container_client = self.client.get_container_client(self.container_name)
159
+ blob_client = container_client.get_blob_client(ref)
160
+
161
+ await blob_client.upload_blob(
162
+ stream,
163
+ content_settings=content_settings,
164
+ overwrite=True,
165
+ )
166
+ return ref
167
+
168
+ except Exception as e:
169
+ logger.exception(
170
+ "failed azure blob upload",
171
+ ref=ref,
172
+ container_name=self.container_name,
173
+ )
174
+ raise IOError(f"Failed to upload to Azure blob {ref}. Error: {e}") from e
175
+
176
+ async def get(self, ref: str) -> tuple[AsyncGenerator[bytes, None], str | None]:
177
+ """
178
+ Retrieves a stream and its mime type from Azure Blob Storage.
179
+ """
180
+ try:
181
+ container_client = self.client.get_container_client(self.container_name)
182
+ blob_client = container_client.get_blob_client(ref)
183
+
184
+ # Get blob properties for content type
185
+ properties = await blob_client.get_blob_properties()
186
+ mime_type = (
187
+ properties.content_settings.content_type
188
+ if properties.content_settings
189
+ else None
190
+ )
191
+
192
+ async def _stream_wrapper():
193
+ download_stream = await blob_client.download_blob()
194
+ async for chunk in download_stream.chunks():
195
+ yield chunk
196
+
197
+ return _stream_wrapper(), mime_type
198
+ except ResourceNotFoundError as e:
199
+ logger.warning(
200
+ "azure blob not found",
201
+ ref=ref,
202
+ container_name=self.container_name,
203
+ error=e,
204
+ )
205
+ raise FileNotFoundError(f"Azure blob not found: {ref}") from e
206
+ except Exception as e:
207
+ logger.exception(
208
+ "failed azure blob download",
209
+ ref=ref,
210
+ container_name=self.container_name,
211
+ )
212
+ raise IOError(
213
+ f"Failed to download from Azure blob {ref}. Error: {e}"
214
+ ) from e
215
+
216
+ async def delete(self, ref: str) -> None:
217
+ """
218
+ Deletes a blob from Azure Storage.
219
+ Does not raise an error if the blob does not exist.
220
+ """
221
+ try:
222
+ container_client = self.client.get_container_client(self.container_name)
223
+ blob_client = container_client.get_blob_client(ref)
224
+
225
+ await blob_client.delete_blob(delete_snapshots="include")
226
+
227
+ except ResourceNotFoundError:
228
+ logger.debug(
229
+ "azure blob not found, not raising error",
230
+ ref=ref,
231
+ container_name=self.container_name,
232
+ )
233
+ except Exception as e:
234
+ logger.exception(
235
+ "failed azure blob delete",
236
+ ref=ref,
237
+ container_name=self.container_name,
238
+ )
239
+ raise IOError(f"Failed to delete Azure blob {ref}. Error: {e}") from e
240
+
241
+ async def external_url(self, ref: str) -> str | None:
242
+ """
243
+ Returns a SAS URL to access the Azure blob if we have the capability.
244
+
245
+ Supports SAS generation for:
246
+ - Account Key (Account SAS signed with account key)
247
+ - Connection String (Account SAS signed with extracted account key)
248
+ - Azure AD (User Delegation SAS signed with a User Delegation Key)
249
+ """
250
+
251
+ if not self._account_name:
252
+ logger.debug(
253
+ "cannot generate sas url without account name",
254
+ ref=ref,
255
+ has_account_name=bool(self._account_name),
256
+ )
257
+ return None
258
+
259
+ expiry_time = datetime.now(UTC) + timedelta(seconds=self.sas_ttl)
260
+
261
+ if self.auth_method.name in ("ACCOUNT_KEY", "CONNECTION_STRING"):
262
+ if not self._account_key:
263
+ logger.debug(
264
+ "cannot generate account-key SAS without account key",
265
+ ref=ref,
266
+ has_account_key=bool(self._account_key),
267
+ )
268
+ return None
269
+
270
+ sas_token = generate_blob_sas(
271
+ account_name=self._account_name,
272
+ container_name=self.container_name,
273
+ blob_name=ref,
274
+ account_key=self._account_key,
275
+ permission=BlobSasPermissions(read=True),
276
+ expiry=expiry_time,
277
+ )
278
+
279
+ elif self.auth_method.name == "AZURE_AD":
280
+ # Generate a User Delegation SAS signed with a user delegation key
281
+ start_time = datetime.utcnow()
282
+ user_delegation_key = await self.client.get_user_delegation_key(
283
+ key_start_time=start_time, key_expiry_time=expiry_time
284
+ )
285
+ sas_token = generate_blob_sas(
286
+ account_name=self._account_name,
287
+ container_name=self.container_name,
288
+ blob_name=ref,
289
+ user_delegation_key=user_delegation_key,
290
+ permission=BlobSasPermissions(read=True),
291
+ expiry=expiry_time,
292
+ )
293
+ else:
294
+ return None
295
+
296
+ blob_url = f"{self.client.url}{self.container_name}/{ref}"
297
+ return f"{blob_url}?{sas_token}"
298
+
299
+ @staticmethod
300
+ def _extract_account_name_from_connection_string(
301
+ connection_string: str,
302
+ ) -> str | None:
303
+ try:
304
+ # Split on ';' and build a dict of key/value pairs
305
+ parts = dict(
306
+ part.split("=", 1)
307
+ for part in connection_string.split(";")
308
+ if "=" in part
309
+ )
310
+ account_name = parts.get("AccountName")
311
+ return account_name
312
+ except Exception:
313
+ return None
314
+
315
+ @staticmethod
316
+ def _extract_account_key_from_connection_string(
317
+ connection_string: str,
318
+ ) -> str | None:
319
+ try:
320
+ parts = dict(
321
+ part.split("=", 1)
322
+ for part in connection_string.split(";")
323
+ if "=" in part
324
+ )
325
+ return parts.get("AccountKey")
326
+ except Exception:
327
+ return None
328
+
329
+ @staticmethod
330
+ def _extract_account_name_from_account_url(account_url: str) -> str | None:
331
+ try:
332
+ parsed = urlparse(account_url)
333
+ host = parsed.hostname or ""
334
+ # Standard Azure: https://{account}.blob.core.windows.net
335
+ if "." in host and not host.startswith("127.") and host != "localhost":
336
+ return host.split(".")[0]
337
+ # Azurite style: http://127.0.0.1:10000/{account}
338
+ path = parsed.path.strip("/")
339
+ if path:
340
+ return path.split("/")[0]
341
+ return None
342
+ except Exception:
343
+ return None
@@ -59,3 +59,10 @@ class Storage(ABC):
59
59
  data_bytes, mime_type = await self.get_bytes(ref)
60
60
  # TODO: Potentially use encoding from mime_type if available?
61
61
  return data_bytes.decode(encoding), mime_type
62
+
63
+ async def close(self) -> None:
64
+ """
65
+ Optional cleanup method for storage implementations.
66
+ Override this if your storage backend needs explicit cleanup.
67
+ """
68
+ pass
@@ -1,10 +1,15 @@
1
- from typing import Annotated, Literal, Optional
1
+ from __future__ import annotations
2
2
 
3
- from pydantic import BaseModel, Field
3
+ from typing import TYPE_CHECKING, Annotated, Literal
4
+
5
+ from pydantic import BaseModel, Field, model_validator
4
6
 
5
7
  from .local_directory import LocalDirectoryStorage
6
8
  from .s3 import S3Storage
7
9
 
10
+ if TYPE_CHECKING:
11
+ from .azure_blob import AzureBlobStorage
12
+
8
13
 
9
14
  class LocalDirectoryConfig(BaseModel):
10
15
  backend: Literal["localdir"]
@@ -15,19 +20,66 @@ class S3Config(BaseModel):
15
20
  backend: Literal["s3"]
16
21
  bucket_name: str
17
22
  region: str
18
- access_key: Optional[str] = None
19
- secret_key: Optional[str] = None
20
- endpoint_url: Optional[str] = None
23
+ access_key: str | None = None
24
+ secret_key: str | None = None
25
+ endpoint_url: str | None = None
21
26
  presigned_url_ttl: int = 3600
22
27
 
23
28
 
29
+ class AzureBlobConfig(BaseModel):
30
+ backend: Literal["azure_blob"]
31
+ container_name: str
32
+
33
+ # Authentication options (mutually exclusive)
34
+ connection_string: str | None = None # Full connection string
35
+ account_url: str | None = None # Storage account URL
36
+ use_azure_ad: bool | None = None # Use DefaultAzureCredential
37
+ account_key: str | None = None # Storage account key
38
+
39
+ # Common settings
40
+ sas_ttl: int = 3600 # SAS URL expiry time in seconds
41
+
42
+ @model_validator(mode="after")
43
+ def validate_auth_config(self):
44
+ """Ensure exactly one valid authentication configuration."""
45
+
46
+ # Check if connection_string is provided
47
+ if self.connection_string:
48
+ # Connection string is self-contained
49
+ if self.account_url or self.use_azure_ad or self.account_key:
50
+ raise ValueError(
51
+ "When using connection_string, don't provide account_url, use_azure_ad, or account_key"
52
+ )
53
+ return self
54
+
55
+ # If no connection string, must have account_url
56
+ if not self.account_url:
57
+ raise ValueError("Either connection_string or account_url must be provided")
58
+
59
+ # With account_url, must have exactly one credential type
60
+ credential_methods = [
61
+ self.use_azure_ad is True,
62
+ self.account_key is not None,
63
+ ]
64
+
65
+ if sum(credential_methods) != 1:
66
+ raise ValueError(
67
+ "When using account_url, exactly one credential method must be specified: "
68
+ "use_azure_ad=true or account_key"
69
+ )
70
+
71
+ return self
72
+
73
+
24
74
  StorageConfig = Annotated[
25
- LocalDirectoryConfig | S3Config,
75
+ LocalDirectoryConfig | S3Config | AzureBlobConfig,
26
76
  Field(discriminator="backend"),
27
77
  ]
28
78
 
29
79
 
30
- def create_from_config(config: StorageConfig) -> LocalDirectoryStorage | S3Storage:
80
+ def create_from_config(
81
+ config: StorageConfig,
82
+ ) -> LocalDirectoryStorage | S3Storage | AzureBlobStorage:
31
83
  """Creates a storage instance from the given configuration."""
32
84
  if config.backend == "localdir":
33
85
  return LocalDirectoryStorage(config.directory)
@@ -40,5 +92,16 @@ def create_from_config(config: StorageConfig) -> LocalDirectoryStorage | S3Stora
40
92
  endpoint_url=config.endpoint_url,
41
93
  presigned_url_ttl=config.presigned_url_ttl,
42
94
  )
95
+ elif config.backend == "azure_blob":
96
+ from .azure_blob import AzureBlobStorage
97
+
98
+ return AzureBlobStorage(
99
+ container_name=config.container_name,
100
+ connection_string=config.connection_string,
101
+ account_url=config.account_url,
102
+ use_azure_ad=config.use_azure_ad or False,
103
+ account_key=config.account_key,
104
+ sas_ttl=config.sas_ttl,
105
+ )
43
106
  else:
44
107
  raise ValueError(f"Unsupported backend: {config.backend}")
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  import io
3
3
  import uuid
4
- from typing import Any, AsyncGenerator, Dict, Optional, Tuple
4
+ from typing import Any, AsyncGenerator, Dict, Tuple
5
5
 
6
6
  import boto3
7
7
  from botocore.client import Config as BotoConfig
@@ -21,11 +21,11 @@ class S3Storage(Storage):
21
21
  self,
22
22
  bucket_name: str,
23
23
  region: str,
24
- endpoint_url: Optional[str] = None,
25
- access_key_id: Optional[str] = None,
26
- secret_access_key: Optional[str] = None,
27
- session_token: Optional[str] = None, # For temporary credentials
28
- boto_config: Optional[Dict[str, Any]] = None, # Additional boto3 client config
24
+ endpoint_url: str | None = None,
25
+ access_key_id: str | None = None,
26
+ secret_access_key: str | None = None,
27
+ session_token: str | None = None, # For temporary credentials
28
+ boto_config: Dict[str, Any] | None = None, # Additional boto3 client config
29
29
  presigned_url_ttl: int = 3600,
30
30
  ):
31
31
  """