planar 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- planar/_version.py +1 -1
- planar/ai/agent.py +155 -283
- planar/ai/agent_base.py +170 -0
- planar/ai/agent_utils.py +7 -0
- planar/ai/pydantic_ai.py +638 -0
- planar/ai/test_agent_serialization.py +1 -1
- planar/app.py +64 -20
- planar/cli.py +39 -27
- planar/config.py +45 -36
- planar/db/db.py +2 -1
- planar/files/storage/azure_blob.py +343 -0
- planar/files/storage/base.py +7 -0
- planar/files/storage/config.py +70 -7
- planar/files/storage/s3.py +6 -6
- planar/files/storage/test_azure_blob.py +435 -0
- planar/logging/formatter.py +17 -4
- planar/logging/test_formatter.py +327 -0
- planar/registry_items.py +2 -1
- planar/routers/agents_router.py +3 -1
- planar/routers/files.py +11 -2
- planar/routers/models.py +14 -1
- planar/routers/test_agents_router.py +1 -1
- planar/routers/test_files_router.py +49 -0
- planar/routers/test_routes_security.py +5 -7
- planar/routers/test_workflow_router.py +270 -3
- planar/routers/workflow.py +95 -36
- planar/rules/models.py +36 -39
- planar/rules/test_data/account_dormancy_management.json +223 -0
- planar/rules/test_data/airline_loyalty_points_calculator.json +262 -0
- planar/rules/test_data/applicant_risk_assessment.json +435 -0
- planar/rules/test_data/booking_fraud_detection.json +407 -0
- planar/rules/test_data/cellular_data_rollover_system.json +258 -0
- planar/rules/test_data/clinical_trial_eligibility_screener.json +437 -0
- planar/rules/test_data/customer_lifetime_value.json +143 -0
- planar/rules/test_data/import_duties_calculator.json +289 -0
- planar/rules/test_data/insurance_prior_authorization.json +443 -0
- planar/rules/test_data/online_check_in_eligibility_system.json +254 -0
- planar/rules/test_data/order_consolidation_system.json +375 -0
- planar/rules/test_data/portfolio_risk_monitor.json +471 -0
- planar/rules/test_data/supply_chain_risk.json +253 -0
- planar/rules/test_data/warehouse_cross_docking.json +237 -0
- planar/rules/test_rules.py +750 -6
- planar/scaffold_templates/planar.dev.yaml.j2 +6 -6
- planar/scaffold_templates/planar.prod.yaml.j2 +9 -5
- planar/scaffold_templates/pyproject.toml.j2 +1 -1
- planar/security/auth_context.py +21 -0
- planar/security/{jwt_middleware.py → auth_middleware.py} +70 -17
- planar/security/authorization.py +9 -15
- planar/security/tests/test_auth_middleware.py +162 -0
- planar/sse/proxy.py +4 -9
- planar/test_app.py +92 -1
- planar/test_cli.py +81 -59
- planar/test_config.py +17 -14
- planar/testing/fixtures.py +325 -0
- planar/testing/planar_test_client.py +5 -2
- planar/utils.py +41 -1
- planar/workflows/execution.py +1 -1
- planar/workflows/orchestrator.py +5 -0
- planar/workflows/serialization.py +12 -6
- planar/workflows/step_core.py +3 -1
- planar/workflows/test_serialization.py +9 -1
- {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/METADATA +30 -5
- planar-0.8.0.dist-info/RECORD +166 -0
- planar/.__init__.py.un~ +0 -0
- planar/._version.py.un~ +0 -0
- planar/.app.py.un~ +0 -0
- planar/.cli.py.un~ +0 -0
- planar/.config.py.un~ +0 -0
- planar/.context.py.un~ +0 -0
- planar/.db.py.un~ +0 -0
- planar/.di.py.un~ +0 -0
- planar/.engine.py.un~ +0 -0
- planar/.files.py.un~ +0 -0
- planar/.log_context.py.un~ +0 -0
- planar/.log_metadata.py.un~ +0 -0
- planar/.logging.py.un~ +0 -0
- planar/.object_registry.py.un~ +0 -0
- planar/.otel.py.un~ +0 -0
- planar/.server.py.un~ +0 -0
- planar/.session.py.un~ +0 -0
- planar/.sqlalchemy.py.un~ +0 -0
- planar/.task_local.py.un~ +0 -0
- planar/.test_app.py.un~ +0 -0
- planar/.test_config.py.un~ +0 -0
- planar/.test_object_config.py.un~ +0 -0
- planar/.test_sqlalchemy.py.un~ +0 -0
- planar/.test_utils.py.un~ +0 -0
- planar/.util.py.un~ +0 -0
- planar/.utils.py.un~ +0 -0
- planar/ai/.__init__.py.un~ +0 -0
- planar/ai/._models.py.un~ +0 -0
- planar/ai/.agent.py.un~ +0 -0
- planar/ai/.agent_utils.py.un~ +0 -0
- planar/ai/.events.py.un~ +0 -0
- planar/ai/.files.py.un~ +0 -0
- planar/ai/.models.py.un~ +0 -0
- planar/ai/.providers.py.un~ +0 -0
- planar/ai/.pydantic_ai.py.un~ +0 -0
- planar/ai/.pydantic_ai_agent.py.un~ +0 -0
- planar/ai/.pydantic_ai_provider.py.un~ +0 -0
- planar/ai/.step.py.un~ +0 -0
- planar/ai/.test_agent.py.un~ +0 -0
- planar/ai/.test_agent_serialization.py.un~ +0 -0
- planar/ai/.test_providers.py.un~ +0 -0
- planar/ai/.utils.py.un~ +0 -0
- planar/ai/providers.py +0 -1088
- planar/ai/test_agent.py +0 -1298
- planar/ai/test_providers.py +0 -463
- planar/db/.db.py.un~ +0 -0
- planar/files/.config.py.un~ +0 -0
- planar/files/.local.py.un~ +0 -0
- planar/files/.local_filesystem.py.un~ +0 -0
- planar/files/.model.py.un~ +0 -0
- planar/files/.models.py.un~ +0 -0
- planar/files/.s3.py.un~ +0 -0
- planar/files/.storage.py.un~ +0 -0
- planar/files/.test_files.py.un~ +0 -0
- planar/files/storage/.__init__.py.un~ +0 -0
- planar/files/storage/.base.py.un~ +0 -0
- planar/files/storage/.config.py.un~ +0 -0
- planar/files/storage/.context.py.un~ +0 -0
- planar/files/storage/.local_directory.py.un~ +0 -0
- planar/files/storage/.test_local_directory.py.un~ +0 -0
- planar/files/storage/.test_s3.py.un~ +0 -0
- planar/human/.human.py.un~ +0 -0
- planar/human/.test_human.py.un~ +0 -0
- planar/logging/.__init__.py.un~ +0 -0
- planar/logging/.attributes.py.un~ +0 -0
- planar/logging/.formatter.py.un~ +0 -0
- planar/logging/.logger.py.un~ +0 -0
- planar/logging/.otel.py.un~ +0 -0
- planar/logging/.tracer.py.un~ +0 -0
- planar/modeling/.mixin.py.un~ +0 -0
- planar/modeling/.storage.py.un~ +0 -0
- planar/modeling/orm/.planar_base_model.py.un~ +0 -0
- planar/object_config/.object_config.py.un~ +0 -0
- planar/routers/.__init__.py.un~ +0 -0
- planar/routers/.agents_router.py.un~ +0 -0
- planar/routers/.crud.py.un~ +0 -0
- planar/routers/.decision.py.un~ +0 -0
- planar/routers/.event.py.un~ +0 -0
- planar/routers/.file_attachment.py.un~ +0 -0
- planar/routers/.files.py.un~ +0 -0
- planar/routers/.files_router.py.un~ +0 -0
- planar/routers/.human.py.un~ +0 -0
- planar/routers/.info.py.un~ +0 -0
- planar/routers/.models.py.un~ +0 -0
- planar/routers/.object_config_router.py.un~ +0 -0
- planar/routers/.rule.py.un~ +0 -0
- planar/routers/.test_object_config_router.py.un~ +0 -0
- planar/routers/.test_workflow_router.py.un~ +0 -0
- planar/routers/.workflow.py.un~ +0 -0
- planar/rules/.decorator.py.un~ +0 -0
- planar/rules/.runner.py.un~ +0 -0
- planar/rules/.test_rules.py.un~ +0 -0
- planar/security/.jwt_middleware.py.un~ +0 -0
- planar/sse/.constants.py.un~ +0 -0
- planar/sse/.example.html.un~ +0 -0
- planar/sse/.hub.py.un~ +0 -0
- planar/sse/.model.py.un~ +0 -0
- planar/sse/.proxy.py.un~ +0 -0
- planar/testing/.client.py.un~ +0 -0
- planar/testing/.memory_storage.py.un~ +0 -0
- planar/testing/.planar_test_client.py.un~ +0 -0
- planar/testing/.predictable_tracer.py.un~ +0 -0
- planar/testing/.synchronizable_tracer.py.un~ +0 -0
- planar/testing/.test_memory_storage.py.un~ +0 -0
- planar/testing/.workflow_observer.py.un~ +0 -0
- planar/workflows/.__init__.py.un~ +0 -0
- planar/workflows/.builtin_steps.py.un~ +0 -0
- planar/workflows/.concurrency_tracing.py.un~ +0 -0
- planar/workflows/.context.py.un~ +0 -0
- planar/workflows/.contrib.py.un~ +0 -0
- planar/workflows/.decorators.py.un~ +0 -0
- planar/workflows/.durable_test.py.un~ +0 -0
- planar/workflows/.errors.py.un~ +0 -0
- planar/workflows/.events.py.un~ +0 -0
- planar/workflows/.exceptions.py.un~ +0 -0
- planar/workflows/.execution.py.un~ +0 -0
- planar/workflows/.human.py.un~ +0 -0
- planar/workflows/.lock.py.un~ +0 -0
- planar/workflows/.misc.py.un~ +0 -0
- planar/workflows/.model.py.un~ +0 -0
- planar/workflows/.models.py.un~ +0 -0
- planar/workflows/.notifications.py.un~ +0 -0
- planar/workflows/.orchestrator.py.un~ +0 -0
- planar/workflows/.runtime.py.un~ +0 -0
- planar/workflows/.serialization.py.un~ +0 -0
- planar/workflows/.step.py.un~ +0 -0
- planar/workflows/.step_core.py.un~ +0 -0
- planar/workflows/.sub_workflow_runner.py.un~ +0 -0
- planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
- planar/workflows/.test_concurrency.py.un~ +0 -0
- planar/workflows/.test_concurrency_detection.py.un~ +0 -0
- planar/workflows/.test_human.py.un~ +0 -0
- planar/workflows/.test_lock_timeout.py.un~ +0 -0
- planar/workflows/.test_orchestrator.py.un~ +0 -0
- planar/workflows/.test_race_conditions.py.un~ +0 -0
- planar/workflows/.test_serialization.py.un~ +0 -0
- planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
- planar/workflows/.test_workflow.py.un~ +0 -0
- planar/workflows/.tracing.py.un~ +0 -0
- planar/workflows/.types.py.un~ +0 -0
- planar/workflows/.util.py.un~ +0 -0
- planar/workflows/.utils.py.un~ +0 -0
- planar/workflows/.workflow.py.un~ +0 -0
- planar/workflows/.workflow_wrapper.py.un~ +0 -0
- planar/workflows/.wrappers.py.un~ +0 -0
- planar-0.5.0.dist-info/RECORD +0 -289
- {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/WHEEL +0 -0
- {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/entry_points.txt +0 -0
@@ -11,12 +11,12 @@ storage:
|
|
11
11
|
directory: .files
|
12
12
|
|
13
13
|
sse_hub: true
|
14
|
-
|
15
|
-
cors:
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
14
|
+
security:
|
15
|
+
cors:
|
16
|
+
allow_origins: ["https://app.coplane.com"]
|
17
|
+
allow_credentials: true
|
18
|
+
allow_methods: ["*"]
|
19
|
+
allow_headers: ["*"]
|
20
20
|
|
21
21
|
ai_providers:
|
22
22
|
openai:
|
@@ -17,11 +17,15 @@ storage:
|
|
17
17
|
|
18
18
|
sse_hub: true
|
19
19
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
20
|
+
security:
|
21
|
+
cors:
|
22
|
+
allow_origins: ["https://app.coplane.com"]
|
23
|
+
allow_credentials: true
|
24
|
+
allow_methods: ["*"]
|
25
|
+
allow_headers: ["*"]
|
26
|
+
jwt:
|
27
|
+
client_id: ${WORKOS_CLIENT_ID}
|
28
|
+
org_id: ${WORKOS_ORG_ID}
|
25
29
|
|
26
30
|
ai_providers:
|
27
31
|
openai:
|
planar/security/auth_context.py
CHANGED
@@ -5,6 +5,7 @@ This module provides context variables and utilities for managing the current
|
|
5
5
|
authenticated principal (user) throughout the request lifecycle.
|
6
6
|
"""
|
7
7
|
|
8
|
+
import time
|
8
9
|
from contextlib import contextmanager
|
9
10
|
from contextvars import ContextVar
|
10
11
|
from typing import Any, Iterator
|
@@ -71,6 +72,26 @@ class Principal(BaseModel):
|
|
71
72
|
|
72
73
|
return cls(**principal_data)
|
73
74
|
|
75
|
+
@classmethod
|
76
|
+
def from_service_token(cls, token: str) -> "Principal":
|
77
|
+
"""Create a Principal from a service token."""
|
78
|
+
# TO-DO Potentially lookup token in database to get org_id, org_name, user_first_name, user_last_name, user_email, role, permissions
|
79
|
+
return cls(
|
80
|
+
sub="service_token",
|
81
|
+
iss="service_token",
|
82
|
+
exp=int(time.time()) + 3600,
|
83
|
+
iat=int(time.time()),
|
84
|
+
sid="service_token",
|
85
|
+
jti="service_token",
|
86
|
+
org_id="service_token",
|
87
|
+
org_name="service_token",
|
88
|
+
user_first_name="service_token",
|
89
|
+
user_last_name="service_token",
|
90
|
+
user_email="service_token",
|
91
|
+
role="service_token",
|
92
|
+
permissions=["service_token"],
|
93
|
+
)
|
94
|
+
|
74
95
|
|
75
96
|
# Context variable for the current principal
|
76
97
|
principal_var: ContextVar[Principal | None] = ContextVar("principal", default=None)
|
@@ -4,34 +4,39 @@ from fastapi.responses import JSONResponse
|
|
4
4
|
from starlette.middleware.base import BaseHTTPMiddleware
|
5
5
|
|
6
6
|
from planar.logging import get_logger
|
7
|
-
from planar.security.auth_context import
|
7
|
+
from planar.security.auth_context import (
|
8
|
+
Principal,
|
9
|
+
clear_principal,
|
10
|
+
set_principal,
|
11
|
+
)
|
8
12
|
|
9
13
|
logger = get_logger(__name__)
|
10
14
|
|
11
15
|
BASE_JWKS_URL = "https://auth-api.coplane.com/sso/jwks"
|
12
16
|
EXPECTED_ISSUER = "https://auth-api.coplane.com"
|
17
|
+
SERVICE_TOKEN_HEADER_PREFIX = "Bearer plt_"
|
13
18
|
|
14
19
|
|
15
|
-
class
|
20
|
+
class AuthMiddleware(BaseHTTPMiddleware):
|
16
21
|
def __init__(
|
17
22
|
self,
|
18
23
|
app: FastAPI,
|
19
24
|
client_id: str,
|
20
|
-
org_id: str
|
25
|
+
org_id: str,
|
21
26
|
additional_exclusion_paths: list[str] | None = None,
|
27
|
+
service_token: str | None = None,
|
22
28
|
):
|
23
29
|
super().__init__(app)
|
24
|
-
self.client_id = client_id
|
25
30
|
self.org_id = org_id
|
26
31
|
self.additional_exclusion_paths = additional_exclusion_paths or []
|
32
|
+
self.client = jwt.PyJWKClient(f"{BASE_JWKS_URL}/{client_id}", cache_keys=True)
|
33
|
+
self.service_token = service_token
|
27
34
|
|
28
|
-
def get_signing_key_from_jwt(self,
|
29
|
-
|
30
|
-
jwks_client = jwt.PyJWKClient(jwks_url, cache_keys=True)
|
31
|
-
return jwks_client.get_signing_key_from_jwt(token)
|
35
|
+
def get_signing_key_from_jwt(self, token: str):
|
36
|
+
return self.client.get_signing_key_from_jwt(token)
|
32
37
|
|
33
38
|
def validate_jwt_token(self, token: str):
|
34
|
-
signing_key = self.get_signing_key_from_jwt(
|
39
|
+
signing_key = self.get_signing_key_from_jwt(token)
|
35
40
|
|
36
41
|
payload = jwt.decode(
|
37
42
|
token,
|
@@ -46,7 +51,12 @@ class JWTMiddleware(BaseHTTPMiddleware):
|
|
46
51
|
)
|
47
52
|
|
48
53
|
org_id_from_token = payload.get("org_id")
|
49
|
-
|
54
|
+
|
55
|
+
if (
|
56
|
+
org_id_from_token is None
|
57
|
+
or org_id_from_token == ""
|
58
|
+
or org_id_from_token != self.org_id
|
59
|
+
):
|
50
60
|
raise HTTPException(
|
51
61
|
status_code=401,
|
52
62
|
detail="Invalid organization",
|
@@ -56,18 +66,61 @@ class JWTMiddleware(BaseHTTPMiddleware):
|
|
56
66
|
return payload
|
57
67
|
|
58
68
|
async def dispatch(self, request: Request, call_next):
|
59
|
-
if (
|
60
|
-
|
61
|
-
in [
|
62
|
-
"/docs",
|
63
|
-
"/redoc",
|
64
|
-
"/openapi.json",
|
69
|
+
if request.url.path in (
|
70
|
+
[
|
65
71
|
"/planar/v1/health",
|
66
72
|
]
|
67
|
-
|
73
|
+
+ self.additional_exclusion_paths
|
68
74
|
):
|
69
75
|
return await call_next(request)
|
70
76
|
|
77
|
+
authorization = request.headers.get("Authorization")
|
78
|
+
if authorization and authorization.startswith(SERVICE_TOKEN_HEADER_PREFIX):
|
79
|
+
return await self.dispatch_service_token(request, call_next)
|
80
|
+
else:
|
81
|
+
return await self.dispatch_jwt(request, call_next)
|
82
|
+
|
83
|
+
async def dispatch_service_token(self, request: Request, call_next):
|
84
|
+
authorization = request.headers.get("Authorization")
|
85
|
+
if not authorization or not authorization.startswith(
|
86
|
+
SERVICE_TOKEN_HEADER_PREFIX
|
87
|
+
):
|
88
|
+
return JSONResponse(
|
89
|
+
status_code=401,
|
90
|
+
content={"detail": "Invalid authentication scheme"},
|
91
|
+
headers={"WWW-Authenticate": "Bearer"},
|
92
|
+
)
|
93
|
+
|
94
|
+
token_from_header = authorization.replace("Bearer ", "")
|
95
|
+
if token_from_header != self.service_token:
|
96
|
+
return JSONResponse(
|
97
|
+
status_code=401,
|
98
|
+
content={
|
99
|
+
"detail": "Invalid authentication credentials for service token"
|
100
|
+
},
|
101
|
+
headers={"WWW-Authenticate": "Bearer"},
|
102
|
+
)
|
103
|
+
|
104
|
+
principal_token = None
|
105
|
+
payload = {
|
106
|
+
"sub": "service_token",
|
107
|
+
}
|
108
|
+
# Store payload in request state for backward compatibility
|
109
|
+
request.state.user = payload
|
110
|
+
# Create and set the principal in context
|
111
|
+
principal = Principal.from_service_token(token_from_header)
|
112
|
+
principal_token = set_principal(principal)
|
113
|
+
|
114
|
+
try:
|
115
|
+
response = await call_next(request)
|
116
|
+
finally:
|
117
|
+
# Clean up the principal context
|
118
|
+
if principal_token is not None:
|
119
|
+
clear_principal(principal_token)
|
120
|
+
|
121
|
+
return response
|
122
|
+
|
123
|
+
async def dispatch_jwt(self, request: Request, call_next):
|
71
124
|
principal_token = None
|
72
125
|
try:
|
73
126
|
authorization = request.headers.get("Authorization")
|
planar/security/authorization.py
CHANGED
@@ -357,22 +357,16 @@ def validate_authorization_for(
|
|
357
357
|
entity: CedarEntity | None = None
|
358
358
|
|
359
359
|
match action:
|
360
|
-
case WorkflowAction():
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
case RuleAction():
|
367
|
-
if isinstance(resource_descriptor, RuleResource):
|
368
|
-
entity = CedarEntity.from_rule(resource_descriptor.rule_name)
|
360
|
+
case WorkflowAction() if isinstance(resource_descriptor, WorkflowResource):
|
361
|
+
entity = CedarEntity.from_workflow(resource_descriptor.function_name)
|
362
|
+
case AgentAction() if isinstance(resource_descriptor, AgentResource):
|
363
|
+
entity = CedarEntity.from_agent(resource_descriptor.id)
|
364
|
+
case RuleAction() if isinstance(resource_descriptor, RuleResource):
|
365
|
+
entity = CedarEntity.from_rule(resource_descriptor.rule_name)
|
369
366
|
case _:
|
370
|
-
raise ValueError(
|
371
|
-
|
372
|
-
|
373
|
-
raise ValueError(
|
374
|
-
f"Invalid resource descriptor {type(resource_descriptor).__name__} for action {action}"
|
375
|
-
)
|
367
|
+
raise ValueError(
|
368
|
+
f"Invalid resource descriptor {type(resource_descriptor).__name__} for action {action}"
|
369
|
+
)
|
376
370
|
|
377
371
|
# Get current principal and check authorization on current resource
|
378
372
|
principal: Principal | None = get_current_principal()
|
@@ -0,0 +1,162 @@
|
|
1
|
+
from unittest.mock import Mock, patch
|
2
|
+
|
3
|
+
import pytest
|
4
|
+
from fastapi import FastAPI, HTTPException
|
5
|
+
from fastapi.responses import JSONResponse
|
6
|
+
|
7
|
+
from planar.security.auth_middleware import AuthMiddleware
|
8
|
+
|
9
|
+
|
10
|
+
@pytest.fixture
|
11
|
+
def app():
|
12
|
+
return FastAPI()
|
13
|
+
|
14
|
+
|
15
|
+
@pytest.fixture
|
16
|
+
def auth_middleware(app):
|
17
|
+
return AuthMiddleware(
|
18
|
+
app=app,
|
19
|
+
client_id="test-client-id",
|
20
|
+
org_id="test-org-id",
|
21
|
+
additional_exclusion_paths=["/test/exclude"],
|
22
|
+
service_token="plt_test-service-token",
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
class TestAuthMiddleware:
|
27
|
+
def test_org_id_validation_none(self, auth_middleware):
|
28
|
+
"""Test that org_id validation fails when token has None org_id"""
|
29
|
+
with (
|
30
|
+
patch.object(auth_middleware, "get_signing_key_from_jwt"),
|
31
|
+
patch("jwt.decode") as mock_decode,
|
32
|
+
):
|
33
|
+
mock_decode.return_value = {"org_id": None}
|
34
|
+
|
35
|
+
with pytest.raises(HTTPException) as exc_info:
|
36
|
+
auth_middleware.validate_jwt_token("fake-token")
|
37
|
+
|
38
|
+
assert exc_info.value.status_code == 401
|
39
|
+
assert exc_info.value.detail == "Invalid organization"
|
40
|
+
|
41
|
+
def test_org_id_validation_empty_string(self, auth_middleware):
|
42
|
+
"""Test that org_id validation fails when token has empty string org_id"""
|
43
|
+
with (
|
44
|
+
patch.object(auth_middleware, "get_signing_key_from_jwt"),
|
45
|
+
patch("jwt.decode") as mock_decode,
|
46
|
+
):
|
47
|
+
mock_decode.return_value = {"org_id": ""}
|
48
|
+
|
49
|
+
with pytest.raises(HTTPException) as exc_info:
|
50
|
+
auth_middleware.validate_jwt_token("fake-token")
|
51
|
+
|
52
|
+
assert exc_info.value.status_code == 401
|
53
|
+
assert exc_info.value.detail == "Invalid organization"
|
54
|
+
|
55
|
+
def test_org_id_validation_mismatch(self, auth_middleware):
|
56
|
+
"""Test that org_id validation fails when token org_id doesn't match"""
|
57
|
+
with (
|
58
|
+
patch.object(auth_middleware, "get_signing_key_from_jwt"),
|
59
|
+
patch("jwt.decode") as mock_decode,
|
60
|
+
):
|
61
|
+
mock_decode.return_value = {"org_id": "different-org-id"}
|
62
|
+
|
63
|
+
with pytest.raises(HTTPException) as exc_info:
|
64
|
+
auth_middleware.validate_jwt_token("fake-token")
|
65
|
+
|
66
|
+
assert exc_info.value.status_code == 401
|
67
|
+
assert exc_info.value.detail == "Invalid organization"
|
68
|
+
|
69
|
+
def test_org_id_validation_success(self, auth_middleware):
|
70
|
+
"""Test that org_id validation succeeds when token org_id matches"""
|
71
|
+
with (
|
72
|
+
patch.object(auth_middleware, "get_signing_key_from_jwt"),
|
73
|
+
patch("jwt.decode") as mock_decode,
|
74
|
+
):
|
75
|
+
expected_payload = {"org_id": "test-org-id", "user_id": "test-user"}
|
76
|
+
mock_decode.return_value = expected_payload
|
77
|
+
|
78
|
+
result = auth_middleware.validate_jwt_token("fake-token")
|
79
|
+
|
80
|
+
assert result == expected_payload
|
81
|
+
|
82
|
+
@pytest.mark.asyncio
|
83
|
+
async def test_service_token_validation_success(self, auth_middleware):
|
84
|
+
"""Test that service token validation succeeds when token matches"""
|
85
|
+
mock_request = Mock()
|
86
|
+
mock_request.url.path = "/planar/v1/something"
|
87
|
+
mock_request.headers = {"Authorization": "Bearer plt_test-service-token"}
|
88
|
+
mock_call_next = Mock()
|
89
|
+
mock_call_next.return_value = JSONResponse(
|
90
|
+
status_code=200, content={"message": "success"}
|
91
|
+
)
|
92
|
+
|
93
|
+
async def mock_call_next_func(request):
|
94
|
+
return mock_call_next(request)
|
95
|
+
|
96
|
+
result = await auth_middleware.dispatch(mock_request, mock_call_next_func)
|
97
|
+
mock_call_next.assert_called_once_with(mock_request)
|
98
|
+
assert result is not None
|
99
|
+
assert result.status_code == 200
|
100
|
+
|
101
|
+
@pytest.mark.asyncio
|
102
|
+
async def test_service_token_validation_failure(self, auth_middleware):
|
103
|
+
"""Test that service token validation succeeds when token matches"""
|
104
|
+
mock_request = Mock()
|
105
|
+
mock_request.url.path = "/planar/v1/something"
|
106
|
+
mock_request.headers = {"Authorization": "Bearer plt_wrong-token"}
|
107
|
+
mock_call_next = Mock()
|
108
|
+
mock_call_next.return_value = JSONResponse(
|
109
|
+
status_code=200, content={"message": "success"}
|
110
|
+
)
|
111
|
+
|
112
|
+
async def mock_call_next_func(request):
|
113
|
+
return mock_call_next(request)
|
114
|
+
|
115
|
+
result = await auth_middleware.dispatch(mock_request, mock_call_next_func)
|
116
|
+
mock_call_next.assert_not_called()
|
117
|
+
assert result is not None
|
118
|
+
assert result.status_code == 401
|
119
|
+
|
120
|
+
@pytest.mark.asyncio
|
121
|
+
async def test_exclusion_paths_includes_health_and_additional(self, app):
|
122
|
+
"""Test that exclusion paths include health endpoint and additional paths"""
|
123
|
+
middleware = AuthMiddleware(
|
124
|
+
app=app,
|
125
|
+
client_id="test-client-id",
|
126
|
+
org_id="test-org-id",
|
127
|
+
additional_exclusion_paths=["/custom/path", "/another/path"],
|
128
|
+
)
|
129
|
+
|
130
|
+
mock_request = Mock()
|
131
|
+
mock_call_next = Mock()
|
132
|
+
|
133
|
+
async def mock_call_next_func(request):
|
134
|
+
return mock_call_next(request)
|
135
|
+
|
136
|
+
expected_paths = ["/planar/v1/health", "/custom/path", "/another/path"]
|
137
|
+
for path in expected_paths:
|
138
|
+
mock_request.url.path = path
|
139
|
+
mock_call_next.reset_mock()
|
140
|
+
|
141
|
+
await middleware.dispatch(mock_request, mock_call_next_func) # type: ignore
|
142
|
+
|
143
|
+
mock_call_next.assert_called_once_with(mock_request)
|
144
|
+
|
145
|
+
# Test that non-excluded paths are not excluded
|
146
|
+
mock_request.url.path = "/not-excluded"
|
147
|
+
mock_call_next.reset_mock()
|
148
|
+
result = await middleware.dispatch(mock_request, mock_call_next_func)
|
149
|
+
assert result is not None
|
150
|
+
assert result.status_code == 401
|
151
|
+
mock_call_next.assert_not_called()
|
152
|
+
|
153
|
+
def test_required_org_id_parameter(self, app):
|
154
|
+
"""Test that org_id parameter is required (not optional)"""
|
155
|
+
# This test ensures org_id cannot be None based on type hints
|
156
|
+
# The actual enforcement is at the type level, so we just verify
|
157
|
+
# the constructor works with a valid org_id
|
158
|
+
middleware = AuthMiddleware(
|
159
|
+
app=app, client_id="test-client-id", org_id="required-org-id"
|
160
|
+
)
|
161
|
+
|
162
|
+
assert middleware.org_id == "required-org-id"
|
planar/sse/proxy.py
CHANGED
@@ -65,7 +65,6 @@ class SSEProxy:
|
|
65
65
|
self.config = config
|
66
66
|
self.enable_builtin_hub = False
|
67
67
|
self.hub_url = ""
|
68
|
-
self.transport: httpx.AsyncHTTPTransport | None = None
|
69
68
|
self.stream_tasks: WeakSet[Task] = WeakSet()
|
70
69
|
|
71
70
|
if isinstance(sse_hub, str):
|
@@ -127,12 +126,11 @@ class SSEProxy:
|
|
127
126
|
if self.enable_builtin_hub:
|
128
127
|
self.start_builtin_hub()
|
129
128
|
|
130
|
-
self.transport, self.hub_url = parse_hub_url(self.hub_url)
|
131
|
-
forward_url = f"{self.hub_url}/push"
|
132
|
-
|
133
129
|
async def forward():
|
130
|
+
transport, hub_url = parse_hub_url(self.hub_url)
|
131
|
+
forward_url = f"{hub_url}/push"
|
134
132
|
logger.debug("sse event forwarding task started", url=forward_url)
|
135
|
-
async with httpx.AsyncClient(transport=
|
133
|
+
async with httpx.AsyncClient(transport=transport) as client:
|
136
134
|
while True:
|
137
135
|
event = await self.queue.get()
|
138
136
|
logger.debug(
|
@@ -206,11 +204,8 @@ class SSEProxy:
|
|
206
204
|
@asynccontextmanager
|
207
205
|
async def connect(self, query: str = "", headers: dict[str, str] = {}):
|
208
206
|
logger.debug("sseproxy connect called", query=query, headers=headers)
|
209
|
-
if not self.hub_url or not self.transport:
|
210
|
-
raise ValueError("hub_url is not set")
|
211
207
|
|
212
|
-
hub_url = self.hub_url
|
213
|
-
transport = self.transport
|
208
|
+
transport, hub_url = parse_hub_url(self.hub_url)
|
214
209
|
|
215
210
|
client = httpx.AsyncClient(
|
216
211
|
transport=transport, base_url=hub_url, timeout=httpx.Timeout(None)
|
planar/test_app.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1
|
+
from unittest.mock import Mock, patch
|
2
|
+
|
3
|
+
import pytest
|
1
4
|
from dotenv import load_dotenv
|
2
5
|
from fastapi import APIRouter
|
3
|
-
from pydantic import BaseModel
|
6
|
+
from pydantic import BaseModel, ValidationError
|
4
7
|
|
5
8
|
from examples.simple_service.models import (
|
6
9
|
Invoice,
|
7
10
|
)
|
8
11
|
from planar import PlanarApp, sqlite_config
|
12
|
+
from planar.app import setup_auth_middleware
|
13
|
+
from planar.config import Environment, JWTConfig, SecurityConfig
|
9
14
|
|
10
15
|
load_dotenv()
|
11
16
|
|
@@ -49,3 +54,89 @@ def test_register_model_deduplication():
|
|
49
54
|
# Verify that the model in the registry is the Invoice model
|
50
55
|
registered_models = app._object_registry.get_entities()
|
51
56
|
assert any(model.__name__ == "Invoice" for model in registered_models)
|
57
|
+
|
58
|
+
|
59
|
+
class TestJWTSetup:
|
60
|
+
def test_setup_jwt_middleware_production_requires_config(self):
|
61
|
+
"""Test that JWT setup throws ValueError in production without proper config"""
|
62
|
+
mock_app = Mock()
|
63
|
+
mock_app.config.environment = Environment.PROD
|
64
|
+
mock_app.config.security = SecurityConfig()
|
65
|
+
|
66
|
+
with pytest.raises(
|
67
|
+
ValueError,
|
68
|
+
match="Auth middleware is required in production. Please set the JWT config and optionally service token config.",
|
69
|
+
):
|
70
|
+
setup_auth_middleware(mock_app)
|
71
|
+
|
72
|
+
def test_setup_jwt_middleware_production_requires_client_id(self):
|
73
|
+
"""Test that JWT setup throws ValueError in production without client_id"""
|
74
|
+
with pytest.raises(
|
75
|
+
ValidationError,
|
76
|
+
match="Both client_id and org_id required to enable JWT",
|
77
|
+
):
|
78
|
+
JWTConfig(client_id=None, org_id="test-org")
|
79
|
+
|
80
|
+
def test_setup_jwt_middleware_production_requires_org_id(self):
|
81
|
+
"""Test that JWT setup throws ValueError in production without org_id"""
|
82
|
+
with pytest.raises(
|
83
|
+
ValidationError,
|
84
|
+
match="Both client_id and org_id required to enable JWT",
|
85
|
+
):
|
86
|
+
JWTConfig(client_id="test-client-id", org_id=None)
|
87
|
+
|
88
|
+
@patch("planar.app.logger")
|
89
|
+
def test_setup_jwt_middleware_success_with_all_fields(self, mock_logger):
|
90
|
+
"""Test that JWT setup succeeds with all required fields"""
|
91
|
+
mock_app = Mock()
|
92
|
+
mock_app.config.environment = Environment.PROD
|
93
|
+
mock_app.config.security = SecurityConfig(
|
94
|
+
jwt=JWTConfig(
|
95
|
+
client_id="test-client-id",
|
96
|
+
org_id="test-org-id",
|
97
|
+
additional_exclusion_paths=["/test/path"],
|
98
|
+
)
|
99
|
+
)
|
100
|
+
|
101
|
+
setup_auth_middleware(mock_app)
|
102
|
+
|
103
|
+
# Verify middleware was added to app.fastapi
|
104
|
+
mock_app.fastapi.add_middleware.assert_called_once()
|
105
|
+
|
106
|
+
# Check that info log was called
|
107
|
+
mock_logger.info.assert_called_once_with(
|
108
|
+
"Auth middleware enabled",
|
109
|
+
client_id="test-client-id",
|
110
|
+
org_id="test-org-id",
|
111
|
+
additional_exclusion_paths=["/test/path"],
|
112
|
+
)
|
113
|
+
|
114
|
+
@patch("planar.app.logger")
|
115
|
+
def test_setup_jwt_middleware_dev_environment_allows_missing_config(
|
116
|
+
self, mock_logger
|
117
|
+
):
|
118
|
+
"""Test that JWT setup is skipped in dev environment without config"""
|
119
|
+
mock_app = Mock()
|
120
|
+
mock_app.config.environment = Environment.DEV
|
121
|
+
mock_app.config.security = SecurityConfig()
|
122
|
+
|
123
|
+
setup_auth_middleware(mock_app)
|
124
|
+
|
125
|
+
# Verify warning was logged and no middleware added
|
126
|
+
mock_logger.warning.assert_called_once_with("Auth middleware disabled")
|
127
|
+
mock_app.fastapi.add_middleware.assert_not_called()
|
128
|
+
|
129
|
+
@patch("planar.app.logger")
|
130
|
+
def test_setup_jwt_middleware_dev_environment_allows_disabled_jwt(
|
131
|
+
self, mock_logger
|
132
|
+
):
|
133
|
+
"""Test that JWT setup is skipped in dev environment with disabled JWT"""
|
134
|
+
mock_app = Mock()
|
135
|
+
mock_app.config.environment = Environment.DEV
|
136
|
+
mock_app.config.security = SecurityConfig()
|
137
|
+
|
138
|
+
setup_auth_middleware(mock_app)
|
139
|
+
|
140
|
+
# Verify warning was logged and no middleware added
|
141
|
+
mock_logger.warning.assert_called_once_with("Auth middleware disabled")
|
142
|
+
mock_app.fastapi.add_middleware.assert_not_called()
|