planar 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. planar/_version.py +1 -1
  2. planar/ai/agent.py +155 -283
  3. planar/ai/agent_base.py +170 -0
  4. planar/ai/agent_utils.py +7 -0
  5. planar/ai/pydantic_ai.py +638 -0
  6. planar/ai/test_agent_serialization.py +1 -1
  7. planar/app.py +64 -20
  8. planar/cli.py +39 -27
  9. planar/config.py +45 -36
  10. planar/db/db.py +2 -1
  11. planar/files/storage/azure_blob.py +343 -0
  12. planar/files/storage/base.py +7 -0
  13. planar/files/storage/config.py +70 -7
  14. planar/files/storage/s3.py +6 -6
  15. planar/files/storage/test_azure_blob.py +435 -0
  16. planar/logging/formatter.py +17 -4
  17. planar/logging/test_formatter.py +327 -0
  18. planar/registry_items.py +2 -1
  19. planar/routers/agents_router.py +3 -1
  20. planar/routers/files.py +11 -2
  21. planar/routers/models.py +14 -1
  22. planar/routers/test_agents_router.py +1 -1
  23. planar/routers/test_files_router.py +49 -0
  24. planar/routers/test_routes_security.py +5 -7
  25. planar/routers/test_workflow_router.py +270 -3
  26. planar/routers/workflow.py +95 -36
  27. planar/rules/models.py +36 -39
  28. planar/rules/test_data/account_dormancy_management.json +223 -0
  29. planar/rules/test_data/airline_loyalty_points_calculator.json +262 -0
  30. planar/rules/test_data/applicant_risk_assessment.json +435 -0
  31. planar/rules/test_data/booking_fraud_detection.json +407 -0
  32. planar/rules/test_data/cellular_data_rollover_system.json +258 -0
  33. planar/rules/test_data/clinical_trial_eligibility_screener.json +437 -0
  34. planar/rules/test_data/customer_lifetime_value.json +143 -0
  35. planar/rules/test_data/import_duties_calculator.json +289 -0
  36. planar/rules/test_data/insurance_prior_authorization.json +443 -0
  37. planar/rules/test_data/online_check_in_eligibility_system.json +254 -0
  38. planar/rules/test_data/order_consolidation_system.json +375 -0
  39. planar/rules/test_data/portfolio_risk_monitor.json +471 -0
  40. planar/rules/test_data/supply_chain_risk.json +253 -0
  41. planar/rules/test_data/warehouse_cross_docking.json +237 -0
  42. planar/rules/test_rules.py +750 -6
  43. planar/scaffold_templates/planar.dev.yaml.j2 +6 -6
  44. planar/scaffold_templates/planar.prod.yaml.j2 +9 -5
  45. planar/scaffold_templates/pyproject.toml.j2 +1 -1
  46. planar/security/auth_context.py +21 -0
  47. planar/security/{jwt_middleware.py → auth_middleware.py} +70 -17
  48. planar/security/authorization.py +9 -15
  49. planar/security/tests/test_auth_middleware.py +162 -0
  50. planar/sse/proxy.py +4 -9
  51. planar/test_app.py +92 -1
  52. planar/test_cli.py +81 -59
  53. planar/test_config.py +17 -14
  54. planar/testing/fixtures.py +325 -0
  55. planar/testing/planar_test_client.py +5 -2
  56. planar/utils.py +41 -1
  57. planar/workflows/execution.py +1 -1
  58. planar/workflows/orchestrator.py +5 -0
  59. planar/workflows/serialization.py +12 -6
  60. planar/workflows/step_core.py +3 -1
  61. planar/workflows/test_serialization.py +9 -1
  62. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/METADATA +30 -5
  63. planar-0.8.0.dist-info/RECORD +166 -0
  64. planar/.__init__.py.un~ +0 -0
  65. planar/._version.py.un~ +0 -0
  66. planar/.app.py.un~ +0 -0
  67. planar/.cli.py.un~ +0 -0
  68. planar/.config.py.un~ +0 -0
  69. planar/.context.py.un~ +0 -0
  70. planar/.db.py.un~ +0 -0
  71. planar/.di.py.un~ +0 -0
  72. planar/.engine.py.un~ +0 -0
  73. planar/.files.py.un~ +0 -0
  74. planar/.log_context.py.un~ +0 -0
  75. planar/.log_metadata.py.un~ +0 -0
  76. planar/.logging.py.un~ +0 -0
  77. planar/.object_registry.py.un~ +0 -0
  78. planar/.otel.py.un~ +0 -0
  79. planar/.server.py.un~ +0 -0
  80. planar/.session.py.un~ +0 -0
  81. planar/.sqlalchemy.py.un~ +0 -0
  82. planar/.task_local.py.un~ +0 -0
  83. planar/.test_app.py.un~ +0 -0
  84. planar/.test_config.py.un~ +0 -0
  85. planar/.test_object_config.py.un~ +0 -0
  86. planar/.test_sqlalchemy.py.un~ +0 -0
  87. planar/.test_utils.py.un~ +0 -0
  88. planar/.util.py.un~ +0 -0
  89. planar/.utils.py.un~ +0 -0
  90. planar/ai/.__init__.py.un~ +0 -0
  91. planar/ai/._models.py.un~ +0 -0
  92. planar/ai/.agent.py.un~ +0 -0
  93. planar/ai/.agent_utils.py.un~ +0 -0
  94. planar/ai/.events.py.un~ +0 -0
  95. planar/ai/.files.py.un~ +0 -0
  96. planar/ai/.models.py.un~ +0 -0
  97. planar/ai/.providers.py.un~ +0 -0
  98. planar/ai/.pydantic_ai.py.un~ +0 -0
  99. planar/ai/.pydantic_ai_agent.py.un~ +0 -0
  100. planar/ai/.pydantic_ai_provider.py.un~ +0 -0
  101. planar/ai/.step.py.un~ +0 -0
  102. planar/ai/.test_agent.py.un~ +0 -0
  103. planar/ai/.test_agent_serialization.py.un~ +0 -0
  104. planar/ai/.test_providers.py.un~ +0 -0
  105. planar/ai/.utils.py.un~ +0 -0
  106. planar/ai/providers.py +0 -1088
  107. planar/ai/test_agent.py +0 -1298
  108. planar/ai/test_providers.py +0 -463
  109. planar/db/.db.py.un~ +0 -0
  110. planar/files/.config.py.un~ +0 -0
  111. planar/files/.local.py.un~ +0 -0
  112. planar/files/.local_filesystem.py.un~ +0 -0
  113. planar/files/.model.py.un~ +0 -0
  114. planar/files/.models.py.un~ +0 -0
  115. planar/files/.s3.py.un~ +0 -0
  116. planar/files/.storage.py.un~ +0 -0
  117. planar/files/.test_files.py.un~ +0 -0
  118. planar/files/storage/.__init__.py.un~ +0 -0
  119. planar/files/storage/.base.py.un~ +0 -0
  120. planar/files/storage/.config.py.un~ +0 -0
  121. planar/files/storage/.context.py.un~ +0 -0
  122. planar/files/storage/.local_directory.py.un~ +0 -0
  123. planar/files/storage/.test_local_directory.py.un~ +0 -0
  124. planar/files/storage/.test_s3.py.un~ +0 -0
  125. planar/human/.human.py.un~ +0 -0
  126. planar/human/.test_human.py.un~ +0 -0
  127. planar/logging/.__init__.py.un~ +0 -0
  128. planar/logging/.attributes.py.un~ +0 -0
  129. planar/logging/.formatter.py.un~ +0 -0
  130. planar/logging/.logger.py.un~ +0 -0
  131. planar/logging/.otel.py.un~ +0 -0
  132. planar/logging/.tracer.py.un~ +0 -0
  133. planar/modeling/.mixin.py.un~ +0 -0
  134. planar/modeling/.storage.py.un~ +0 -0
  135. planar/modeling/orm/.planar_base_model.py.un~ +0 -0
  136. planar/object_config/.object_config.py.un~ +0 -0
  137. planar/routers/.__init__.py.un~ +0 -0
  138. planar/routers/.agents_router.py.un~ +0 -0
  139. planar/routers/.crud.py.un~ +0 -0
  140. planar/routers/.decision.py.un~ +0 -0
  141. planar/routers/.event.py.un~ +0 -0
  142. planar/routers/.file_attachment.py.un~ +0 -0
  143. planar/routers/.files.py.un~ +0 -0
  144. planar/routers/.files_router.py.un~ +0 -0
  145. planar/routers/.human.py.un~ +0 -0
  146. planar/routers/.info.py.un~ +0 -0
  147. planar/routers/.models.py.un~ +0 -0
  148. planar/routers/.object_config_router.py.un~ +0 -0
  149. planar/routers/.rule.py.un~ +0 -0
  150. planar/routers/.test_object_config_router.py.un~ +0 -0
  151. planar/routers/.test_workflow_router.py.un~ +0 -0
  152. planar/routers/.workflow.py.un~ +0 -0
  153. planar/rules/.decorator.py.un~ +0 -0
  154. planar/rules/.runner.py.un~ +0 -0
  155. planar/rules/.test_rules.py.un~ +0 -0
  156. planar/security/.jwt_middleware.py.un~ +0 -0
  157. planar/sse/.constants.py.un~ +0 -0
  158. planar/sse/.example.html.un~ +0 -0
  159. planar/sse/.hub.py.un~ +0 -0
  160. planar/sse/.model.py.un~ +0 -0
  161. planar/sse/.proxy.py.un~ +0 -0
  162. planar/testing/.client.py.un~ +0 -0
  163. planar/testing/.memory_storage.py.un~ +0 -0
  164. planar/testing/.planar_test_client.py.un~ +0 -0
  165. planar/testing/.predictable_tracer.py.un~ +0 -0
  166. planar/testing/.synchronizable_tracer.py.un~ +0 -0
  167. planar/testing/.test_memory_storage.py.un~ +0 -0
  168. planar/testing/.workflow_observer.py.un~ +0 -0
  169. planar/workflows/.__init__.py.un~ +0 -0
  170. planar/workflows/.builtin_steps.py.un~ +0 -0
  171. planar/workflows/.concurrency_tracing.py.un~ +0 -0
  172. planar/workflows/.context.py.un~ +0 -0
  173. planar/workflows/.contrib.py.un~ +0 -0
  174. planar/workflows/.decorators.py.un~ +0 -0
  175. planar/workflows/.durable_test.py.un~ +0 -0
  176. planar/workflows/.errors.py.un~ +0 -0
  177. planar/workflows/.events.py.un~ +0 -0
  178. planar/workflows/.exceptions.py.un~ +0 -0
  179. planar/workflows/.execution.py.un~ +0 -0
  180. planar/workflows/.human.py.un~ +0 -0
  181. planar/workflows/.lock.py.un~ +0 -0
  182. planar/workflows/.misc.py.un~ +0 -0
  183. planar/workflows/.model.py.un~ +0 -0
  184. planar/workflows/.models.py.un~ +0 -0
  185. planar/workflows/.notifications.py.un~ +0 -0
  186. planar/workflows/.orchestrator.py.un~ +0 -0
  187. planar/workflows/.runtime.py.un~ +0 -0
  188. planar/workflows/.serialization.py.un~ +0 -0
  189. planar/workflows/.step.py.un~ +0 -0
  190. planar/workflows/.step_core.py.un~ +0 -0
  191. planar/workflows/.sub_workflow_runner.py.un~ +0 -0
  192. planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
  193. planar/workflows/.test_concurrency.py.un~ +0 -0
  194. planar/workflows/.test_concurrency_detection.py.un~ +0 -0
  195. planar/workflows/.test_human.py.un~ +0 -0
  196. planar/workflows/.test_lock_timeout.py.un~ +0 -0
  197. planar/workflows/.test_orchestrator.py.un~ +0 -0
  198. planar/workflows/.test_race_conditions.py.un~ +0 -0
  199. planar/workflows/.test_serialization.py.un~ +0 -0
  200. planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
  201. planar/workflows/.test_workflow.py.un~ +0 -0
  202. planar/workflows/.tracing.py.un~ +0 -0
  203. planar/workflows/.types.py.un~ +0 -0
  204. planar/workflows/.util.py.un~ +0 -0
  205. planar/workflows/.utils.py.un~ +0 -0
  206. planar/workflows/.workflow.py.un~ +0 -0
  207. planar/workflows/.workflow_wrapper.py.un~ +0 -0
  208. planar/workflows/.wrappers.py.un~ +0 -0
  209. planar-0.5.0.dist-info/RECORD +0 -289
  210. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/WHEEL +0 -0
  211. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/entry_points.txt +0 -0
@@ -11,12 +11,12 @@ storage:
11
11
  directory: .files
12
12
 
13
13
  sse_hub: true
14
-
15
- cors:
16
- allow_origins: "^(?:https://(?:[a-zA-Z0-9-]+\\.)+coplane\\.(dev|com)|http://127.0.0.1:3000)$"
17
- allow_credentials: true
18
- allow_methods: ["*"]
19
- allow_headers: ["*"]
14
+ security:
15
+ cors:
16
+ allow_origins: ["https://app.coplane.com"]
17
+ allow_credentials: true
18
+ allow_methods: ["*"]
19
+ allow_headers: ["*"]
20
20
 
21
21
  ai_providers:
22
22
  openai:
@@ -17,11 +17,15 @@ storage:
17
17
 
18
18
  sse_hub: true
19
19
 
20
- cors:
21
- allow_origins: "^(?:https://(?:[a-zA-Z0-9-]+\\.)+coplane\\.(dev|com)|http://127.0.0.1:3000)$"
22
- allow_credentials: true
23
- allow_methods: ["*"]
24
- allow_headers: ["*"]
20
+ security:
21
+ cors:
22
+ allow_origins: ["https://app.coplane.com"]
23
+ allow_credentials: true
24
+ allow_methods: ["*"]
25
+ allow_headers: ["*"]
26
+ jwt:
27
+ client_id: ${WORKOS_CLIENT_ID}
28
+ org_id: ${WORKOS_ORG_ID}
25
29
 
26
30
  ai_providers:
27
31
  openai:
@@ -3,7 +3,7 @@ name = "{{ name }}"
3
3
  version = "0.1.0"
4
4
  requires-python = ">=3.12"
5
5
  dependencies = [
6
- "planar>=0.1.0",
6
+ "planar>=0.6.0",
7
7
  ]
8
8
 
9
9
  [[tool.uv.index]]
@@ -5,6 +5,7 @@ This module provides context variables and utilities for managing the current
5
5
  authenticated principal (user) throughout the request lifecycle.
6
6
  """
7
7
 
8
+ import time
8
9
  from contextlib import contextmanager
9
10
  from contextvars import ContextVar
10
11
  from typing import Any, Iterator
@@ -71,6 +72,26 @@ class Principal(BaseModel):
71
72
 
72
73
  return cls(**principal_data)
73
74
 
75
+ @classmethod
76
+ def from_service_token(cls, token: str) -> "Principal":
77
+ """Create a Principal from a service token."""
78
+ # TO-DO Potentially lookup token in database to get org_id, org_name, user_first_name, user_last_name, user_email, role, permissions
79
+ return cls(
80
+ sub="service_token",
81
+ iss="service_token",
82
+ exp=int(time.time()) + 3600,
83
+ iat=int(time.time()),
84
+ sid="service_token",
85
+ jti="service_token",
86
+ org_id="service_token",
87
+ org_name="service_token",
88
+ user_first_name="service_token",
89
+ user_last_name="service_token",
90
+ user_email="service_token",
91
+ role="service_token",
92
+ permissions=["service_token"],
93
+ )
94
+
74
95
 
75
96
  # Context variable for the current principal
76
97
  principal_var: ContextVar[Principal | None] = ContextVar("principal", default=None)
@@ -4,34 +4,39 @@ from fastapi.responses import JSONResponse
4
4
  from starlette.middleware.base import BaseHTTPMiddleware
5
5
 
6
6
  from planar.logging import get_logger
7
- from planar.security.auth_context import Principal, clear_principal, set_principal
7
+ from planar.security.auth_context import (
8
+ Principal,
9
+ clear_principal,
10
+ set_principal,
11
+ )
8
12
 
9
13
  logger = get_logger(__name__)
10
14
 
11
15
  BASE_JWKS_URL = "https://auth-api.coplane.com/sso/jwks"
12
16
  EXPECTED_ISSUER = "https://auth-api.coplane.com"
17
+ SERVICE_TOKEN_HEADER_PREFIX = "Bearer plt_"
13
18
 
14
19
 
15
- class JWTMiddleware(BaseHTTPMiddleware):
20
+ class AuthMiddleware(BaseHTTPMiddleware):
16
21
  def __init__(
17
22
  self,
18
23
  app: FastAPI,
19
24
  client_id: str,
20
- org_id: str | None = None,
25
+ org_id: str,
21
26
  additional_exclusion_paths: list[str] | None = None,
27
+ service_token: str | None = None,
22
28
  ):
23
29
  super().__init__(app)
24
- self.client_id = client_id
25
30
  self.org_id = org_id
26
31
  self.additional_exclusion_paths = additional_exclusion_paths or []
32
+ self.client = jwt.PyJWKClient(f"{BASE_JWKS_URL}/{client_id}", cache_keys=True)
33
+ self.service_token = service_token
27
34
 
28
- def get_signing_key_from_jwt(self, client_id: str, token: str):
29
- jwks_url = f"{BASE_JWKS_URL}/{client_id}"
30
- jwks_client = jwt.PyJWKClient(jwks_url, cache_keys=True)
31
- return jwks_client.get_signing_key_from_jwt(token)
35
+ def get_signing_key_from_jwt(self, token: str):
36
+ return self.client.get_signing_key_from_jwt(token)
32
37
 
33
38
  def validate_jwt_token(self, token: str):
34
- signing_key = self.get_signing_key_from_jwt(self.client_id, token)
39
+ signing_key = self.get_signing_key_from_jwt(token)
35
40
 
36
41
  payload = jwt.decode(
37
42
  token,
@@ -46,7 +51,12 @@ class JWTMiddleware(BaseHTTPMiddleware):
46
51
  )
47
52
 
48
53
  org_id_from_token = payload.get("org_id")
49
- if self.org_id and org_id_from_token != self.org_id:
54
+
55
+ if (
56
+ org_id_from_token is None
57
+ or org_id_from_token == ""
58
+ or org_id_from_token != self.org_id
59
+ ):
50
60
  raise HTTPException(
51
61
  status_code=401,
52
62
  detail="Invalid organization",
@@ -56,18 +66,61 @@ class JWTMiddleware(BaseHTTPMiddleware):
56
66
  return payload
57
67
 
58
68
  async def dispatch(self, request: Request, call_next):
59
- if (
60
- request.url.path
61
- in [
62
- "/docs",
63
- "/redoc",
64
- "/openapi.json",
69
+ if request.url.path in (
70
+ [
65
71
  "/planar/v1/health",
66
72
  ]
67
- or request.url.path in self.additional_exclusion_paths
73
+ + self.additional_exclusion_paths
68
74
  ):
69
75
  return await call_next(request)
70
76
 
77
+ authorization = request.headers.get("Authorization")
78
+ if authorization and authorization.startswith(SERVICE_TOKEN_HEADER_PREFIX):
79
+ return await self.dispatch_service_token(request, call_next)
80
+ else:
81
+ return await self.dispatch_jwt(request, call_next)
82
+
83
+ async def dispatch_service_token(self, request: Request, call_next):
84
+ authorization = request.headers.get("Authorization")
85
+ if not authorization or not authorization.startswith(
86
+ SERVICE_TOKEN_HEADER_PREFIX
87
+ ):
88
+ return JSONResponse(
89
+ status_code=401,
90
+ content={"detail": "Invalid authentication scheme"},
91
+ headers={"WWW-Authenticate": "Bearer"},
92
+ )
93
+
94
+ token_from_header = authorization.replace("Bearer ", "")
95
+ if token_from_header != self.service_token:
96
+ return JSONResponse(
97
+ status_code=401,
98
+ content={
99
+ "detail": "Invalid authentication credentials for service token"
100
+ },
101
+ headers={"WWW-Authenticate": "Bearer"},
102
+ )
103
+
104
+ principal_token = None
105
+ payload = {
106
+ "sub": "service_token",
107
+ }
108
+ # Store payload in request state for backward compatibility
109
+ request.state.user = payload
110
+ # Create and set the principal in context
111
+ principal = Principal.from_service_token(token_from_header)
112
+ principal_token = set_principal(principal)
113
+
114
+ try:
115
+ response = await call_next(request)
116
+ finally:
117
+ # Clean up the principal context
118
+ if principal_token is not None:
119
+ clear_principal(principal_token)
120
+
121
+ return response
122
+
123
+ async def dispatch_jwt(self, request: Request, call_next):
71
124
  principal_token = None
72
125
  try:
73
126
  authorization = request.headers.get("Authorization")
@@ -357,22 +357,16 @@ def validate_authorization_for(
357
357
  entity: CedarEntity | None = None
358
358
 
359
359
  match action:
360
- case WorkflowAction():
361
- if isinstance(resource_descriptor, WorkflowResource):
362
- entity = CedarEntity.from_workflow(resource_descriptor.function_name)
363
- case AgentAction():
364
- if isinstance(resource_descriptor, AgentResource):
365
- entity = CedarEntity.from_agent(resource_descriptor.id)
366
- case RuleAction():
367
- if isinstance(resource_descriptor, RuleResource):
368
- entity = CedarEntity.from_rule(resource_descriptor.rule_name)
360
+ case WorkflowAction() if isinstance(resource_descriptor, WorkflowResource):
361
+ entity = CedarEntity.from_workflow(resource_descriptor.function_name)
362
+ case AgentAction() if isinstance(resource_descriptor, AgentResource):
363
+ entity = CedarEntity.from_agent(resource_descriptor.id)
364
+ case RuleAction() if isinstance(resource_descriptor, RuleResource):
365
+ entity = CedarEntity.from_rule(resource_descriptor.rule_name)
369
366
  case _:
370
- raise ValueError(f"Invalid action type: {action}")
371
-
372
- if not entity:
373
- raise ValueError(
374
- f"Invalid resource descriptor {type(resource_descriptor).__name__} for action {action}"
375
- )
367
+ raise ValueError(
368
+ f"Invalid resource descriptor {type(resource_descriptor).__name__} for action {action}"
369
+ )
376
370
 
377
371
  # Get current principal and check authorization on current resource
378
372
  principal: Principal | None = get_current_principal()
@@ -0,0 +1,162 @@
1
+ from unittest.mock import Mock, patch
2
+
3
+ import pytest
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.responses import JSONResponse
6
+
7
+ from planar.security.auth_middleware import AuthMiddleware
8
+
9
+
10
+ @pytest.fixture
11
+ def app():
12
+ return FastAPI()
13
+
14
+
15
+ @pytest.fixture
16
+ def auth_middleware(app):
17
+ return AuthMiddleware(
18
+ app=app,
19
+ client_id="test-client-id",
20
+ org_id="test-org-id",
21
+ additional_exclusion_paths=["/test/exclude"],
22
+ service_token="plt_test-service-token",
23
+ )
24
+
25
+
26
+ class TestAuthMiddleware:
27
+ def test_org_id_validation_none(self, auth_middleware):
28
+ """Test that org_id validation fails when token has None org_id"""
29
+ with (
30
+ patch.object(auth_middleware, "get_signing_key_from_jwt"),
31
+ patch("jwt.decode") as mock_decode,
32
+ ):
33
+ mock_decode.return_value = {"org_id": None}
34
+
35
+ with pytest.raises(HTTPException) as exc_info:
36
+ auth_middleware.validate_jwt_token("fake-token")
37
+
38
+ assert exc_info.value.status_code == 401
39
+ assert exc_info.value.detail == "Invalid organization"
40
+
41
+ def test_org_id_validation_empty_string(self, auth_middleware):
42
+ """Test that org_id validation fails when token has empty string org_id"""
43
+ with (
44
+ patch.object(auth_middleware, "get_signing_key_from_jwt"),
45
+ patch("jwt.decode") as mock_decode,
46
+ ):
47
+ mock_decode.return_value = {"org_id": ""}
48
+
49
+ with pytest.raises(HTTPException) as exc_info:
50
+ auth_middleware.validate_jwt_token("fake-token")
51
+
52
+ assert exc_info.value.status_code == 401
53
+ assert exc_info.value.detail == "Invalid organization"
54
+
55
+ def test_org_id_validation_mismatch(self, auth_middleware):
56
+ """Test that org_id validation fails when token org_id doesn't match"""
57
+ with (
58
+ patch.object(auth_middleware, "get_signing_key_from_jwt"),
59
+ patch("jwt.decode") as mock_decode,
60
+ ):
61
+ mock_decode.return_value = {"org_id": "different-org-id"}
62
+
63
+ with pytest.raises(HTTPException) as exc_info:
64
+ auth_middleware.validate_jwt_token("fake-token")
65
+
66
+ assert exc_info.value.status_code == 401
67
+ assert exc_info.value.detail == "Invalid organization"
68
+
69
+ def test_org_id_validation_success(self, auth_middleware):
70
+ """Test that org_id validation succeeds when token org_id matches"""
71
+ with (
72
+ patch.object(auth_middleware, "get_signing_key_from_jwt"),
73
+ patch("jwt.decode") as mock_decode,
74
+ ):
75
+ expected_payload = {"org_id": "test-org-id", "user_id": "test-user"}
76
+ mock_decode.return_value = expected_payload
77
+
78
+ result = auth_middleware.validate_jwt_token("fake-token")
79
+
80
+ assert result == expected_payload
81
+
82
+ @pytest.mark.asyncio
83
+ async def test_service_token_validation_success(self, auth_middleware):
84
+ """Test that service token validation succeeds when token matches"""
85
+ mock_request = Mock()
86
+ mock_request.url.path = "/planar/v1/something"
87
+ mock_request.headers = {"Authorization": "Bearer plt_test-service-token"}
88
+ mock_call_next = Mock()
89
+ mock_call_next.return_value = JSONResponse(
90
+ status_code=200, content={"message": "success"}
91
+ )
92
+
93
+ async def mock_call_next_func(request):
94
+ return mock_call_next(request)
95
+
96
+ result = await auth_middleware.dispatch(mock_request, mock_call_next_func)
97
+ mock_call_next.assert_called_once_with(mock_request)
98
+ assert result is not None
99
+ assert result.status_code == 200
100
+
101
+ @pytest.mark.asyncio
102
+ async def test_service_token_validation_failure(self, auth_middleware):
103
+ """Test that service token validation succeeds when token matches"""
104
+ mock_request = Mock()
105
+ mock_request.url.path = "/planar/v1/something"
106
+ mock_request.headers = {"Authorization": "Bearer plt_wrong-token"}
107
+ mock_call_next = Mock()
108
+ mock_call_next.return_value = JSONResponse(
109
+ status_code=200, content={"message": "success"}
110
+ )
111
+
112
+ async def mock_call_next_func(request):
113
+ return mock_call_next(request)
114
+
115
+ result = await auth_middleware.dispatch(mock_request, mock_call_next_func)
116
+ mock_call_next.assert_not_called()
117
+ assert result is not None
118
+ assert result.status_code == 401
119
+
120
+ @pytest.mark.asyncio
121
+ async def test_exclusion_paths_includes_health_and_additional(self, app):
122
+ """Test that exclusion paths include health endpoint and additional paths"""
123
+ middleware = AuthMiddleware(
124
+ app=app,
125
+ client_id="test-client-id",
126
+ org_id="test-org-id",
127
+ additional_exclusion_paths=["/custom/path", "/another/path"],
128
+ )
129
+
130
+ mock_request = Mock()
131
+ mock_call_next = Mock()
132
+
133
+ async def mock_call_next_func(request):
134
+ return mock_call_next(request)
135
+
136
+ expected_paths = ["/planar/v1/health", "/custom/path", "/another/path"]
137
+ for path in expected_paths:
138
+ mock_request.url.path = path
139
+ mock_call_next.reset_mock()
140
+
141
+ await middleware.dispatch(mock_request, mock_call_next_func) # type: ignore
142
+
143
+ mock_call_next.assert_called_once_with(mock_request)
144
+
145
+ # Test that non-excluded paths are not excluded
146
+ mock_request.url.path = "/not-excluded"
147
+ mock_call_next.reset_mock()
148
+ result = await middleware.dispatch(mock_request, mock_call_next_func)
149
+ assert result is not None
150
+ assert result.status_code == 401
151
+ mock_call_next.assert_not_called()
152
+
153
+ def test_required_org_id_parameter(self, app):
154
+ """Test that org_id parameter is required (not optional)"""
155
+ # This test ensures org_id cannot be None based on type hints
156
+ # The actual enforcement is at the type level, so we just verify
157
+ # the constructor works with a valid org_id
158
+ middleware = AuthMiddleware(
159
+ app=app, client_id="test-client-id", org_id="required-org-id"
160
+ )
161
+
162
+ assert middleware.org_id == "required-org-id"
planar/sse/proxy.py CHANGED
@@ -65,7 +65,6 @@ class SSEProxy:
65
65
  self.config = config
66
66
  self.enable_builtin_hub = False
67
67
  self.hub_url = ""
68
- self.transport: httpx.AsyncHTTPTransport | None = None
69
68
  self.stream_tasks: WeakSet[Task] = WeakSet()
70
69
 
71
70
  if isinstance(sse_hub, str):
@@ -127,12 +126,11 @@ class SSEProxy:
127
126
  if self.enable_builtin_hub:
128
127
  self.start_builtin_hub()
129
128
 
130
- self.transport, self.hub_url = parse_hub_url(self.hub_url)
131
- forward_url = f"{self.hub_url}/push"
132
-
133
129
  async def forward():
130
+ transport, hub_url = parse_hub_url(self.hub_url)
131
+ forward_url = f"{hub_url}/push"
134
132
  logger.debug("sse event forwarding task started", url=forward_url)
135
- async with httpx.AsyncClient(transport=self.transport) as client:
133
+ async with httpx.AsyncClient(transport=transport) as client:
136
134
  while True:
137
135
  event = await self.queue.get()
138
136
  logger.debug(
@@ -206,11 +204,8 @@ class SSEProxy:
206
204
  @asynccontextmanager
207
205
  async def connect(self, query: str = "", headers: dict[str, str] = {}):
208
206
  logger.debug("sseproxy connect called", query=query, headers=headers)
209
- if not self.hub_url or not self.transport:
210
- raise ValueError("hub_url is not set")
211
207
 
212
- hub_url = self.hub_url
213
- transport = self.transport
208
+ transport, hub_url = parse_hub_url(self.hub_url)
214
209
 
215
210
  client = httpx.AsyncClient(
216
211
  transport=transport, base_url=hub_url, timeout=httpx.Timeout(None)
planar/test_app.py CHANGED
@@ -1,11 +1,16 @@
1
+ from unittest.mock import Mock, patch
2
+
3
+ import pytest
1
4
  from dotenv import load_dotenv
2
5
  from fastapi import APIRouter
3
- from pydantic import BaseModel
6
+ from pydantic import BaseModel, ValidationError
4
7
 
5
8
  from examples.simple_service.models import (
6
9
  Invoice,
7
10
  )
8
11
  from planar import PlanarApp, sqlite_config
12
+ from planar.app import setup_auth_middleware
13
+ from planar.config import Environment, JWTConfig, SecurityConfig
9
14
 
10
15
  load_dotenv()
11
16
 
@@ -49,3 +54,89 @@ def test_register_model_deduplication():
49
54
  # Verify that the model in the registry is the Invoice model
50
55
  registered_models = app._object_registry.get_entities()
51
56
  assert any(model.__name__ == "Invoice" for model in registered_models)
57
+
58
+
59
+ class TestJWTSetup:
60
+ def test_setup_jwt_middleware_production_requires_config(self):
61
+ """Test that JWT setup throws ValueError in production without proper config"""
62
+ mock_app = Mock()
63
+ mock_app.config.environment = Environment.PROD
64
+ mock_app.config.security = SecurityConfig()
65
+
66
+ with pytest.raises(
67
+ ValueError,
68
+ match="Auth middleware is required in production. Please set the JWT config and optionally service token config.",
69
+ ):
70
+ setup_auth_middleware(mock_app)
71
+
72
+ def test_setup_jwt_middleware_production_requires_client_id(self):
73
+ """Test that JWT setup throws ValueError in production without client_id"""
74
+ with pytest.raises(
75
+ ValidationError,
76
+ match="Both client_id and org_id required to enable JWT",
77
+ ):
78
+ JWTConfig(client_id=None, org_id="test-org")
79
+
80
+ def test_setup_jwt_middleware_production_requires_org_id(self):
81
+ """Test that JWT setup throws ValueError in production without org_id"""
82
+ with pytest.raises(
83
+ ValidationError,
84
+ match="Both client_id and org_id required to enable JWT",
85
+ ):
86
+ JWTConfig(client_id="test-client-id", org_id=None)
87
+
88
+ @patch("planar.app.logger")
89
+ def test_setup_jwt_middleware_success_with_all_fields(self, mock_logger):
90
+ """Test that JWT setup succeeds with all required fields"""
91
+ mock_app = Mock()
92
+ mock_app.config.environment = Environment.PROD
93
+ mock_app.config.security = SecurityConfig(
94
+ jwt=JWTConfig(
95
+ client_id="test-client-id",
96
+ org_id="test-org-id",
97
+ additional_exclusion_paths=["/test/path"],
98
+ )
99
+ )
100
+
101
+ setup_auth_middleware(mock_app)
102
+
103
+ # Verify middleware was added to app.fastapi
104
+ mock_app.fastapi.add_middleware.assert_called_once()
105
+
106
+ # Check that info log was called
107
+ mock_logger.info.assert_called_once_with(
108
+ "Auth middleware enabled",
109
+ client_id="test-client-id",
110
+ org_id="test-org-id",
111
+ additional_exclusion_paths=["/test/path"],
112
+ )
113
+
114
+ @patch("planar.app.logger")
115
+ def test_setup_jwt_middleware_dev_environment_allows_missing_config(
116
+ self, mock_logger
117
+ ):
118
+ """Test that JWT setup is skipped in dev environment without config"""
119
+ mock_app = Mock()
120
+ mock_app.config.environment = Environment.DEV
121
+ mock_app.config.security = SecurityConfig()
122
+
123
+ setup_auth_middleware(mock_app)
124
+
125
+ # Verify warning was logged and no middleware added
126
+ mock_logger.warning.assert_called_once_with("Auth middleware disabled")
127
+ mock_app.fastapi.add_middleware.assert_not_called()
128
+
129
+ @patch("planar.app.logger")
130
+ def test_setup_jwt_middleware_dev_environment_allows_disabled_jwt(
131
+ self, mock_logger
132
+ ):
133
+ """Test that JWT setup is skipped in dev environment with disabled JWT"""
134
+ mock_app = Mock()
135
+ mock_app.config.environment = Environment.DEV
136
+ mock_app.config.security = SecurityConfig()
137
+
138
+ setup_auth_middleware(mock_app)
139
+
140
+ # Verify warning was logged and no middleware added
141
+ mock_logger.warning.assert_called_once_with("Auth middleware disabled")
142
+ mock_app.fastapi.add_middleware.assert_not_called()