planar 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. planar/_version.py +1 -1
  2. planar/ai/agent.py +67 -30
  3. planar/ai/pydantic_ai.py +570 -0
  4. planar/ai/pydantic_ai_agent.py +329 -0
  5. planar/ai/test_agent.py +2 -2
  6. planar/app.py +64 -20
  7. planar/cli.py +39 -27
  8. planar/config.py +45 -36
  9. planar/db/db.py +2 -1
  10. planar/files/storage/azure_blob.py +343 -0
  11. planar/files/storage/base.py +7 -0
  12. planar/files/storage/config.py +70 -7
  13. planar/files/storage/s3.py +6 -6
  14. planar/files/storage/test_azure_blob.py +435 -0
  15. planar/logging/formatter.py +17 -4
  16. planar/logging/test_formatter.py +327 -0
  17. planar/registry_items.py +2 -1
  18. planar/routers/agents_router.py +3 -1
  19. planar/routers/files.py +11 -2
  20. planar/routers/models.py +14 -1
  21. planar/routers/test_files_router.py +49 -0
  22. planar/routers/test_routes_security.py +5 -7
  23. planar/routers/test_workflow_router.py +270 -3
  24. planar/routers/workflow.py +95 -36
  25. planar/rules/models.py +36 -39
  26. planar/rules/test_data/account_dormancy_management.json +223 -0
  27. planar/rules/test_data/airline_loyalty_points_calculator.json +262 -0
  28. planar/rules/test_data/applicant_risk_assessment.json +435 -0
  29. planar/rules/test_data/booking_fraud_detection.json +407 -0
  30. planar/rules/test_data/cellular_data_rollover_system.json +258 -0
  31. planar/rules/test_data/clinical_trial_eligibility_screener.json +437 -0
  32. planar/rules/test_data/customer_lifetime_value.json +143 -0
  33. planar/rules/test_data/import_duties_calculator.json +289 -0
  34. planar/rules/test_data/insurance_prior_authorization.json +443 -0
  35. planar/rules/test_data/online_check_in_eligibility_system.json +254 -0
  36. planar/rules/test_data/order_consolidation_system.json +375 -0
  37. planar/rules/test_data/portfolio_risk_monitor.json +471 -0
  38. planar/rules/test_data/supply_chain_risk.json +253 -0
  39. planar/rules/test_data/warehouse_cross_docking.json +237 -0
  40. planar/rules/test_rules.py +750 -6
  41. planar/scaffold_templates/planar.dev.yaml.j2 +6 -6
  42. planar/scaffold_templates/planar.prod.yaml.j2 +9 -5
  43. planar/scaffold_templates/pyproject.toml.j2 +1 -1
  44. planar/security/auth_context.py +21 -0
  45. planar/security/{jwt_middleware.py → auth_middleware.py} +70 -17
  46. planar/security/authorization.py +9 -15
  47. planar/security/tests/test_auth_middleware.py +162 -0
  48. planar/sse/proxy.py +4 -9
  49. planar/test_app.py +92 -1
  50. planar/test_cli.py +81 -59
  51. planar/test_config.py +17 -14
  52. planar/testing/fixtures.py +325 -0
  53. planar/testing/planar_test_client.py +5 -2
  54. planar/utils.py +41 -1
  55. planar/workflows/execution.py +1 -1
  56. planar/workflows/orchestrator.py +5 -0
  57. planar/workflows/serialization.py +12 -6
  58. planar/workflows/step_core.py +3 -1
  59. planar/workflows/test_serialization.py +9 -1
  60. {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/METADATA +30 -5
  61. planar-0.7.0.dist-info/RECORD +169 -0
  62. planar/.__init__.py.un~ +0 -0
  63. planar/._version.py.un~ +0 -0
  64. planar/.app.py.un~ +0 -0
  65. planar/.cli.py.un~ +0 -0
  66. planar/.config.py.un~ +0 -0
  67. planar/.context.py.un~ +0 -0
  68. planar/.db.py.un~ +0 -0
  69. planar/.di.py.un~ +0 -0
  70. planar/.engine.py.un~ +0 -0
  71. planar/.files.py.un~ +0 -0
  72. planar/.log_context.py.un~ +0 -0
  73. planar/.log_metadata.py.un~ +0 -0
  74. planar/.logging.py.un~ +0 -0
  75. planar/.object_registry.py.un~ +0 -0
  76. planar/.otel.py.un~ +0 -0
  77. planar/.server.py.un~ +0 -0
  78. planar/.session.py.un~ +0 -0
  79. planar/.sqlalchemy.py.un~ +0 -0
  80. planar/.task_local.py.un~ +0 -0
  81. planar/.test_app.py.un~ +0 -0
  82. planar/.test_config.py.un~ +0 -0
  83. planar/.test_object_config.py.un~ +0 -0
  84. planar/.test_sqlalchemy.py.un~ +0 -0
  85. planar/.test_utils.py.un~ +0 -0
  86. planar/.util.py.un~ +0 -0
  87. planar/.utils.py.un~ +0 -0
  88. planar/ai/.__init__.py.un~ +0 -0
  89. planar/ai/._models.py.un~ +0 -0
  90. planar/ai/.agent.py.un~ +0 -0
  91. planar/ai/.agent_utils.py.un~ +0 -0
  92. planar/ai/.events.py.un~ +0 -0
  93. planar/ai/.files.py.un~ +0 -0
  94. planar/ai/.models.py.un~ +0 -0
  95. planar/ai/.providers.py.un~ +0 -0
  96. planar/ai/.pydantic_ai.py.un~ +0 -0
  97. planar/ai/.pydantic_ai_agent.py.un~ +0 -0
  98. planar/ai/.pydantic_ai_provider.py.un~ +0 -0
  99. planar/ai/.step.py.un~ +0 -0
  100. planar/ai/.test_agent.py.un~ +0 -0
  101. planar/ai/.test_agent_serialization.py.un~ +0 -0
  102. planar/ai/.test_providers.py.un~ +0 -0
  103. planar/ai/.utils.py.un~ +0 -0
  104. planar/db/.db.py.un~ +0 -0
  105. planar/files/.config.py.un~ +0 -0
  106. planar/files/.local.py.un~ +0 -0
  107. planar/files/.local_filesystem.py.un~ +0 -0
  108. planar/files/.model.py.un~ +0 -0
  109. planar/files/.models.py.un~ +0 -0
  110. planar/files/.s3.py.un~ +0 -0
  111. planar/files/.storage.py.un~ +0 -0
  112. planar/files/.test_files.py.un~ +0 -0
  113. planar/files/storage/.__init__.py.un~ +0 -0
  114. planar/files/storage/.base.py.un~ +0 -0
  115. planar/files/storage/.config.py.un~ +0 -0
  116. planar/files/storage/.context.py.un~ +0 -0
  117. planar/files/storage/.local_directory.py.un~ +0 -0
  118. planar/files/storage/.test_local_directory.py.un~ +0 -0
  119. planar/files/storage/.test_s3.py.un~ +0 -0
  120. planar/human/.human.py.un~ +0 -0
  121. planar/human/.test_human.py.un~ +0 -0
  122. planar/logging/.__init__.py.un~ +0 -0
  123. planar/logging/.attributes.py.un~ +0 -0
  124. planar/logging/.formatter.py.un~ +0 -0
  125. planar/logging/.logger.py.un~ +0 -0
  126. planar/logging/.otel.py.un~ +0 -0
  127. planar/logging/.tracer.py.un~ +0 -0
  128. planar/modeling/.mixin.py.un~ +0 -0
  129. planar/modeling/.storage.py.un~ +0 -0
  130. planar/modeling/orm/.planar_base_model.py.un~ +0 -0
  131. planar/object_config/.object_config.py.un~ +0 -0
  132. planar/routers/.__init__.py.un~ +0 -0
  133. planar/routers/.agents_router.py.un~ +0 -0
  134. planar/routers/.crud.py.un~ +0 -0
  135. planar/routers/.decision.py.un~ +0 -0
  136. planar/routers/.event.py.un~ +0 -0
  137. planar/routers/.file_attachment.py.un~ +0 -0
  138. planar/routers/.files.py.un~ +0 -0
  139. planar/routers/.files_router.py.un~ +0 -0
  140. planar/routers/.human.py.un~ +0 -0
  141. planar/routers/.info.py.un~ +0 -0
  142. planar/routers/.models.py.un~ +0 -0
  143. planar/routers/.object_config_router.py.un~ +0 -0
  144. planar/routers/.rule.py.un~ +0 -0
  145. planar/routers/.test_object_config_router.py.un~ +0 -0
  146. planar/routers/.test_workflow_router.py.un~ +0 -0
  147. planar/routers/.workflow.py.un~ +0 -0
  148. planar/rules/.decorator.py.un~ +0 -0
  149. planar/rules/.runner.py.un~ +0 -0
  150. planar/rules/.test_rules.py.un~ +0 -0
  151. planar/security/.jwt_middleware.py.un~ +0 -0
  152. planar/sse/.constants.py.un~ +0 -0
  153. planar/sse/.example.html.un~ +0 -0
  154. planar/sse/.hub.py.un~ +0 -0
  155. planar/sse/.model.py.un~ +0 -0
  156. planar/sse/.proxy.py.un~ +0 -0
  157. planar/testing/.client.py.un~ +0 -0
  158. planar/testing/.memory_storage.py.un~ +0 -0
  159. planar/testing/.planar_test_client.py.un~ +0 -0
  160. planar/testing/.predictable_tracer.py.un~ +0 -0
  161. planar/testing/.synchronizable_tracer.py.un~ +0 -0
  162. planar/testing/.test_memory_storage.py.un~ +0 -0
  163. planar/testing/.workflow_observer.py.un~ +0 -0
  164. planar/workflows/.__init__.py.un~ +0 -0
  165. planar/workflows/.builtin_steps.py.un~ +0 -0
  166. planar/workflows/.concurrency_tracing.py.un~ +0 -0
  167. planar/workflows/.context.py.un~ +0 -0
  168. planar/workflows/.contrib.py.un~ +0 -0
  169. planar/workflows/.decorators.py.un~ +0 -0
  170. planar/workflows/.durable_test.py.un~ +0 -0
  171. planar/workflows/.errors.py.un~ +0 -0
  172. planar/workflows/.events.py.un~ +0 -0
  173. planar/workflows/.exceptions.py.un~ +0 -0
  174. planar/workflows/.execution.py.un~ +0 -0
  175. planar/workflows/.human.py.un~ +0 -0
  176. planar/workflows/.lock.py.un~ +0 -0
  177. planar/workflows/.misc.py.un~ +0 -0
  178. planar/workflows/.model.py.un~ +0 -0
  179. planar/workflows/.models.py.un~ +0 -0
  180. planar/workflows/.notifications.py.un~ +0 -0
  181. planar/workflows/.orchestrator.py.un~ +0 -0
  182. planar/workflows/.runtime.py.un~ +0 -0
  183. planar/workflows/.serialization.py.un~ +0 -0
  184. planar/workflows/.step.py.un~ +0 -0
  185. planar/workflows/.step_core.py.un~ +0 -0
  186. planar/workflows/.sub_workflow_runner.py.un~ +0 -0
  187. planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
  188. planar/workflows/.test_concurrency.py.un~ +0 -0
  189. planar/workflows/.test_concurrency_detection.py.un~ +0 -0
  190. planar/workflows/.test_human.py.un~ +0 -0
  191. planar/workflows/.test_lock_timeout.py.un~ +0 -0
  192. planar/workflows/.test_orchestrator.py.un~ +0 -0
  193. planar/workflows/.test_race_conditions.py.un~ +0 -0
  194. planar/workflows/.test_serialization.py.un~ +0 -0
  195. planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
  196. planar/workflows/.test_workflow.py.un~ +0 -0
  197. planar/workflows/.tracing.py.un~ +0 -0
  198. planar/workflows/.types.py.un~ +0 -0
  199. planar/workflows/.util.py.un~ +0 -0
  200. planar/workflows/.utils.py.un~ +0 -0
  201. planar/workflows/.workflow.py.un~ +0 -0
  202. planar/workflows/.workflow_wrapper.py.un~ +0 -0
  203. planar/workflows/.wrappers.py.un~ +0 -0
  204. planar-0.5.0.dist-info/RECORD +0 -289
  205. {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/WHEEL +0 -0
  206. {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,329 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Any, Type, cast
4
+
5
+ from pydantic import BaseModel
6
+ from pydantic_ai import models
7
+
8
+ from planar.ai.agent import AgentBase
9
+ from planar.ai.agent_utils import (
10
+ AgentEventType,
11
+ ToolCallResult,
12
+ create_tool_definition,
13
+ extract_files_from_model,
14
+ get_agent_config,
15
+ render_template,
16
+ )
17
+ from planar.ai.models import (
18
+ AgentRunResult,
19
+ AssistantMessage,
20
+ ModelMessage,
21
+ SystemMessage,
22
+ ToolDefinition,
23
+ ToolMessage,
24
+ ToolResponse,
25
+ UserMessage,
26
+ )
27
+ from planar.ai.providers import ModelSpec
28
+ from planar.ai.pydantic_ai import ModelRunResponse, model_run
29
+ from planar.logging import get_logger
30
+ from planar.utils import utc_now
31
+ from planar.workflows.models import StepType
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class Agent[
38
+ TInput: BaseModel | str,
39
+ TOutput: BaseModel | str,
40
+ ](AgentBase[TInput, TOutput]):
41
+ model: models.KnownModelName | models.Model = "openai:gpt-4o"
42
+
43
+ async def run_step(
44
+ self,
45
+ input_value: TInput,
46
+ ) -> AgentRunResult[TOutput]:
47
+ """Execute the agent with the provided inputs.
48
+
49
+ Args:
50
+ input_value: The primary input value to the agent, can be a string or Pydantic model
51
+ **kwargs: Alternative way to pass inputs as keyword arguments
52
+
53
+ Returns:
54
+ AgentRunResult containing the agent's response
55
+ """
56
+ event_emitter = self.event_emitter
57
+ logger.debug(
58
+ "agent run_step called", agent_name=self.name, input_type=type(input_value)
59
+ )
60
+ result = None
61
+
62
+ config = await get_agent_config(self.name, self.to_config())
63
+ logger.debug("agent using config", agent_name=self.name, config=config)
64
+
65
+ input_map: dict[str, str | dict[str, Any]] = {}
66
+
67
+ files = extract_files_from_model(input_value)
68
+ logger.debug(
69
+ "extracted files from input for agent",
70
+ num_files=len(files),
71
+ agent_name=self.name,
72
+ )
73
+ match input_value:
74
+ case BaseModel():
75
+ if self.input_type and not isinstance(input_value, self.input_type):
76
+ logger.warning(
77
+ "input value type mismatch for agent",
78
+ agent_name=self.name,
79
+ expected_type=self.input_type,
80
+ got_type=type(input_value),
81
+ )
82
+ raise ValueError(
83
+ f"Input value must be of type {self.input_type}, but got {type(input_value)}"
84
+ )
85
+ input_map["input"] = cast(BaseModel, input_value).model_dump()
86
+ case str():
87
+ input_map["input"] = input_value
88
+ case _:
89
+ logger.warning(
90
+ "unexpected input value type for agent",
91
+ agent_name=self.name,
92
+ type=type(input_value),
93
+ )
94
+ raise ValueError(f"Unexpected input value type: {type(input_value)}")
95
+
96
+ # Add built-in variables
97
+ # TODO: Make deterministic or step
98
+ built_in_vars = {
99
+ "datetime_now": utc_now().isoformat(),
100
+ "date_today": utc_now().date().isoformat(),
101
+ }
102
+ input_map.update(built_in_vars)
103
+
104
+ # Format the prompts with the provided arguments using Jinja templates
105
+ try:
106
+ formatted_system_prompt = (
107
+ render_template(config.system_prompt, input_map)
108
+ if config.system_prompt
109
+ else ""
110
+ )
111
+ formatted_user_prompt = (
112
+ render_template(config.user_prompt, input_map)
113
+ if config.user_prompt
114
+ else ""
115
+ )
116
+ except ValueError as e:
117
+ logger.exception("error formatting prompts for agent", agent_name=self.name)
118
+ raise ValueError(f"Missing required parameter for prompt formatting: {e}")
119
+
120
+ # Get the LLM provider and model
121
+ if isinstance(self.model, str):
122
+ model = models.infer_model(self.model)
123
+ else:
124
+ model = self.model
125
+
126
+ # Apply model parameters if specified
127
+ model_settings = None
128
+ if config.model_parameters:
129
+ model_settings = config.model_parameters
130
+
131
+ # Prepare structured messages
132
+ messages: list[ModelMessage] = []
133
+ if formatted_system_prompt:
134
+ messages.append(SystemMessage(content=formatted_system_prompt))
135
+
136
+ if formatted_user_prompt:
137
+ messages.append(UserMessage(content=formatted_user_prompt, files=files))
138
+
139
+ # Prepare tools if provided
140
+ tool_definitions = None
141
+ if self.tools:
142
+ tool_definitions = [create_tool_definition(tool) for tool in self.tools]
143
+
144
+ # Determine output type for the agent call
145
+ # Pass the Pydantic model type if output_type is a subclass of BaseModel,
146
+ # otherwise pass None (indicating string output is expected).
147
+ output_type: Type[BaseModel] | None = None
148
+ # Use issubclass safely by checking if output_type is a type first
149
+ if inspect.isclass(self.output_type) and issubclass(
150
+ self.output_type, BaseModel
151
+ ):
152
+ output_type = cast(Type[BaseModel], self.output_type)
153
+
154
+ # Execute the LLM call
155
+ max_turns = config.max_turns
156
+
157
+ # We use this inner function to pass "model" and "event_emitter",
158
+ # which are not serializable as step parameters.
159
+ async def agent_run_step(
160
+ model_spec: ModelSpec,
161
+ messages: list[ModelMessage],
162
+ turns_left: int,
163
+ tools: list[ToolDefinition] | None = None,
164
+ output_type: Type[BaseModel] | None = None,
165
+ ):
166
+ logger.debug(
167
+ "agent running",
168
+ agent_name=self.name,
169
+ model=model_spec,
170
+ model_settings=model_settings,
171
+ output_type=output_type,
172
+ )
173
+ if output_type is None:
174
+ return await model_run(
175
+ model=model,
176
+ max_extra_turns=turns_left,
177
+ model_settings=model_settings,
178
+ messages=messages,
179
+ tools=tools or [],
180
+ event_handler=cast(Any, event_emitter),
181
+ )
182
+ else:
183
+ return await model_run(
184
+ model=model,
185
+ max_extra_turns=turns_left,
186
+ model_settings=model_settings,
187
+ messages=messages,
188
+ output_type=output_type,
189
+ tools=tools or [],
190
+ event_handler=cast(Any, event_emitter),
191
+ )
192
+
193
+ model_spec = ModelSpec(
194
+ model_id=str(model),
195
+ parameters=config.model_parameters,
196
+ )
197
+ result = None
198
+ logger.debug(
199
+ "agent performing multi-turn completion with tools",
200
+ agent_name=self.name,
201
+ max_turns=max_turns,
202
+ )
203
+ turns_left = max_turns
204
+ while turns_left > 0:
205
+ turns_left -= 1
206
+ logger.debug("agent turn", agent_name=self.name, turns_left=turns_left)
207
+
208
+ # Get model response
209
+ run_response = await self.as_step_if_durable(
210
+ agent_run_step,
211
+ step_type=StepType.AGENT,
212
+ return_type=ModelRunResponse[output_type or str],
213
+ )(
214
+ model_spec=model_spec,
215
+ messages=messages,
216
+ turns_left=turns_left,
217
+ output_type=output_type,
218
+ tools=tool_definitions or [],
219
+ )
220
+ response = run_response.response
221
+ turns_left -= run_response.extra_turns_used
222
+
223
+ # Emit response event if event_emitter is provided
224
+ if event_emitter:
225
+ event_emitter.emit(AgentEventType.RESPONSE, response.content)
226
+
227
+ # If no tool calls or last turn, return content
228
+ if not response.tool_calls or turns_left == 0:
229
+ logger.debug(
230
+ "agent completion: no tool calls or last turn",
231
+ agent_name=self.name,
232
+ has_content=response.content is not None,
233
+ )
234
+ result = response.content
235
+ break
236
+
237
+ # Process tool calls
238
+ logger.debug(
239
+ "agent received tool calls",
240
+ agent_name=self.name,
241
+ num_tool_calls=len(response.tool_calls),
242
+ )
243
+ assistant_message = AssistantMessage(
244
+ content=None,
245
+ tool_calls=response.tool_calls,
246
+ )
247
+ messages.append(assistant_message)
248
+
249
+ # Execute each tool and add tool responses to messages
250
+ for tool_call_idx, tool_call in enumerate(response.tool_calls):
251
+ logger.debug(
252
+ "agent processing tool call",
253
+ agent_name=self.name,
254
+ tool_call_index=tool_call_idx + 1,
255
+ tool_call_id=tool_call.id,
256
+ tool_call_name=tool_call.name,
257
+ )
258
+ # Find the matching tool function
259
+ tool_fn = next(
260
+ (t for t in self.tools if t.__name__ == tool_call.name),
261
+ None,
262
+ )
263
+
264
+ if not tool_fn:
265
+ tool_result = f"Error: Tool '{tool_call.name}' not found."
266
+ logger.warning(
267
+ "tool not found for agent",
268
+ tool_name=tool_call.name,
269
+ agent_name=self.name,
270
+ )
271
+ else:
272
+ # Execute the tool with the provided arguments
273
+ tool_result = await self.as_step_if_durable(
274
+ tool_fn,
275
+ step_type=StepType.TOOL_CALL,
276
+ )(**tool_call.arguments)
277
+ logger.info(
278
+ "tool executed by agent",
279
+ tool_name=tool_call.name,
280
+ agent_name=self.name,
281
+ result_type=type(tool_result),
282
+ )
283
+
284
+ # Create a tool response
285
+ tool_response = ToolResponse(
286
+ tool_call_id=tool_call.id or "call_1", content=str(tool_result)
287
+ )
288
+
289
+ # Emit tool response event if event_emitter is provided
290
+ if event_emitter:
291
+ event_emitter.emit(
292
+ AgentEventType.TOOL_RESPONSE,
293
+ ToolCallResult(
294
+ tool_call_id=tool_call.id or "call_1",
295
+ tool_call_name=tool_call.name,
296
+ content=tool_result,
297
+ ),
298
+ )
299
+
300
+ tool_message = ToolMessage(
301
+ content=tool_response.content,
302
+ tool_call_id=tool_response.tool_call_id or "call_1",
303
+ )
304
+ messages.append(tool_message)
305
+
306
+ # Continue to next turn
307
+
308
+ if result is None:
309
+ logger.warning(
310
+ "agent completed tool interactions but result is none",
311
+ agent_name=self.name,
312
+ expected_type=self.output_type,
313
+ )
314
+ raise ValueError(
315
+ f"Expected result of type {self.output_type} but got none after tool interactions."
316
+ )
317
+
318
+ if event_emitter:
319
+ event_emitter.emit(AgentEventType.COMPLETED, result)
320
+
321
+ logger.info(
322
+ "agent completed",
323
+ agent_name=self.name,
324
+ final_result_type=type(result),
325
+ )
326
+ return AgentRunResult[TOutput](output=cast(TOutput, result))
327
+
328
+ def get_model_str(self) -> str:
329
+ return str(self.model)
planar/ai/test_agent.py CHANGED
@@ -435,7 +435,7 @@ async def test_agent_with_input_validation(
435
435
  async def test_agent_with_tools(
436
436
  mock_providers,
437
437
  client: PlanarTestClient,
438
- app,
438
+ app: PlanarApp,
439
439
  ):
440
440
  """Test agent with tools for multi-turn conversations."""
441
441
  openai_mock, anthropic_mock = mock_providers
@@ -475,7 +475,7 @@ async def test_agent_with_tools(
475
475
  ):
476
476
  # Start and execute the workflow
477
477
  wf = await tools_workflow.start("complex problem")
478
- result = await execute(wf)
478
+ result = await app.orchestrator.wait_for_completion(wf.id)
479
479
 
480
480
  # Verify the result
481
481
  assert isinstance(result, dict)
planar/app.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import asyncio
2
+ import signal
2
3
  from asyncio import CancelledError
3
4
  from contextlib import asynccontextmanager
5
+ from types import FrameType
4
6
  from typing import Any, Callable, Coroutine, Type
5
7
 
6
8
  from fastapi import APIRouter, FastAPI, HTTPException, Request
@@ -11,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine
11
13
  from typing_extensions import TypeVar
12
14
 
13
15
  from planar.ai import Agent
14
- from planar.config import PlanarConfig, load_environment_aware_config
16
+ from planar.config import Environment, PlanarConfig, load_environment_aware_config
15
17
  from planar.db import DatabaseManager
16
18
  from planar.files.storage.base import Storage
17
19
  from planar.files.storage.config import create_from_config
@@ -30,8 +32,8 @@ from planar.routers.entity_router import create_entities_router
30
32
  from planar.routers.object_config_router import create_object_config_router
31
33
  from planar.routers.rule import create_rule_router
32
34
  from planar.rules.decorator import RULE_REGISTRY
35
+ from planar.security.auth_middleware import AuthMiddleware
33
36
  from planar.security.authorization import PolicyService, policy_service_context
34
- from planar.security.jwt_middleware import JWTMiddleware
35
37
  from planar.session import config_var, session_context
36
38
  from planar.sse.proxy import SSEProxy
37
39
  from planar.workflows import (
@@ -92,7 +94,7 @@ class PlanarApp:
92
94
  setup_orchestrator_middleware(self)
93
95
  setup_workflow_notification_middleware(self)
94
96
  setup_tracer_middleware(self)
95
- setup_jwt_middleware(self)
97
+ setup_auth_middleware(self)
96
98
  setup_http_exception_handler(self)
97
99
  setup_authorization_policy_service(self)
98
100
 
@@ -202,6 +204,27 @@ class PlanarApp:
202
204
 
203
205
  @asynccontextmanager
204
206
  async def _lifespan(self, app: FastAPI):
207
+ # We manually capture SIGINT/SIGTERM to trigger our own graceful shutdown.
208
+ # This is necessary because long-lived connections, such as from the SSE
209
+ # proxy, can cause uvicorn's default graceful shutdown to hang, preventing
210
+ # the lifespan shutdown logic (after the yield) from ever being reached.
211
+ # Our handler starts the shutdown of these components and then chains to the
212
+ # original uvicorn handler to allow it to proceed with its own shutdown.
213
+ original_handlers = {
214
+ signal.SIGINT: signal.getsignal(signal.SIGINT),
215
+ signal.SIGTERM: signal.getsignal(signal.SIGTERM),
216
+ }
217
+
218
+ def terminate_now(signum: int, frame: FrameType | None = None):
219
+ asyncio.create_task(self.graceful_shutdown())
220
+ handler = original_handlers.get(signal.Signals(signum))
221
+ if callable(handler):
222
+ handler(signum, frame)
223
+
224
+ signal.signal(signal.SIGINT, terminate_now)
225
+ signal.signal(signal.SIGTERM, terminate_now)
226
+
227
+ # Begin the normal lifespan logic
205
228
  self.db_manager.connect()
206
229
  await self.db_manager.migrate(
207
230
  self.config.use_alembic if self.config.use_alembic is not None else True
@@ -240,6 +263,10 @@ class PlanarApp:
240
263
  config_var.reset(config_tok)
241
264
 
242
265
  await self.db_manager.disconnect()
266
+
267
+ if self.storage:
268
+ await self.storage.close()
269
+
243
270
  logger.info("stopping sse")
244
271
  await self.stop_sse()
245
272
  logger.info("lifespan completed")
@@ -435,15 +462,15 @@ def setup_http_exception_handler(app: PlanarApp):
435
462
 
436
463
  def setup_cors_middleware(app: PlanarApp):
437
464
  opts = {
438
- "allow_headers": app.config.cors.allow_headers,
439
- "allow_methods": app.config.cors.allow_methods,
440
- "allow_credentials": app.config.cors.allow_credentials,
465
+ "allow_headers": app.config.security.cors.allow_headers,
466
+ "allow_methods": app.config.security.cors.allow_methods,
467
+ "allow_credentials": app.config.security.cors.allow_credentials,
441
468
  }
442
469
 
443
- if isinstance(app.config.cors.allow_origins, str):
444
- opts["allow_origin_regex"] = app.config.cors.allow_origins
470
+ if isinstance(app.config.security.cors.allow_origins, str):
471
+ opts["allow_origin_regex"] = app.config.security.cors.allow_origins
445
472
  else:
446
- opts["allow_origins"] = app.config.cors.allow_origins
473
+ opts["allow_origins"] = app.config.security.cors.allow_origins
447
474
 
448
475
  app.fastapi.add_middleware(
449
476
  CORSMiddleware,
@@ -451,32 +478,49 @@ def setup_cors_middleware(app: PlanarApp):
451
478
  )
452
479
 
453
480
 
454
- def setup_jwt_middleware(app: PlanarApp):
455
- if app.config.jwt and app.config.jwt.enabled and app.config.jwt.client_id:
456
- client_id = app.config.jwt.client_id
457
- org_id = app.config.jwt.org_id
458
- additional_exclusion_paths = app.config.jwt.additional_exclusion_paths
481
+ def setup_auth_middleware(app: PlanarApp):
482
+ if (
483
+ app.config.security
484
+ and app.config.security.jwt
485
+ and app.config.security.jwt.client_id
486
+ and app.config.security.jwt.org_id
487
+ ):
488
+ client_id = app.config.security.jwt.client_id
489
+ org_id = app.config.security.jwt.org_id
490
+ additional_exclusion_paths = app.config.security.jwt.additional_exclusion_paths
459
491
  app.fastapi.add_middleware(
460
- JWTMiddleware, # type: ignore
492
+ AuthMiddleware, # type: ignore
461
493
  client_id,
462
494
  org_id,
463
495
  additional_exclusion_paths,
496
+ service_token=app.config.security.service_token.token
497
+ if app.config.security.service_token
498
+ and app.config.security.service_token.token
499
+ else None,
464
500
  )
465
501
  logger.info(
466
- "jwt middleware enabled",
502
+ "Auth middleware enabled",
467
503
  client_id=client_id,
468
504
  org_id=org_id,
469
505
  additional_exclusion_paths=additional_exclusion_paths,
470
506
  )
507
+ elif app.config.environment == Environment.PROD:
508
+ raise ValueError(
509
+ "Auth middleware is required in production. Please set the JWT config and optionally service token config."
510
+ )
471
511
  else:
472
- logger.warning("JWT middleware disabled")
512
+ logger.warning("Auth middleware disabled")
473
513
 
474
514
 
475
515
  def setup_authorization_policy_service(app: PlanarApp):
476
- if app.config.authz and app.config.authz.enabled:
516
+ if (
517
+ app.config.security
518
+ and app.config.security.authz
519
+ and app.config.security.authz.enabled
520
+ ):
477
521
  app.policy_service = PolicyService(
478
- policy_file_path=app.config.authz.policy_file
479
- if app.config.authz.policy_file
522
+ policy_file_path=app.config.security.authz.policy_file
523
+ if app.config.security.authz.policy_file
480
524
  else None
481
525
  )
482
526
  logger.info(
planar/cli.py CHANGED
@@ -13,28 +13,6 @@ from planar.config import Environment
13
13
  app = typer.Typer(help="Planar CLI tool")
14
14
 
15
15
 
16
- class PlanarServer(uvicorn.Server):
17
- """Intercept SIGINT/SIGTERM to trigger early shutdown on the app."""
18
-
19
- def __init__(self, config: uvicorn.Config, app_import_string: str):
20
- super().__init__(config)
21
- self.app_import_string = app_import_string
22
-
23
- def handle_exit(self, sig, frame):
24
- # Import the PlanarApp instance and fire its early-shutdown hook
25
- import asyncio
26
- import importlib
27
-
28
- module_name, var_name = self.app_import_string.split(":")
29
- app_module = importlib.import_module(module_name)
30
- planar_app = getattr(app_module, var_name, None)
31
- if planar_app and hasattr(planar_app, "graceful_shutdown"):
32
- asyncio.create_task(planar_app.graceful_shutdown())
33
-
34
- # Continue with Uvicorn's normal shutdown procedure
35
- super().handle_exit(sig, frame)
36
-
37
-
38
16
  def find_default_app_path() -> Path:
39
17
  """Checks for default app file paths (app.py, then main.py)."""
40
18
  for filename in ["app.py", "main.py"]:
@@ -94,9 +72,25 @@ def dev_command(
94
72
  "--script",
95
73
  help="Run as a script with 'uv run' instead of starting a server",
96
74
  ),
75
+ ssl_keyfile: str | None = typer.Option(
76
+ None, "--ssl-keyfile", help="Path to SSL key file"
77
+ ),
78
+ ssl_certfile: str | None = typer.Option(
79
+ None, "--ssl-certfile", help="Path to SSL cert file"
80
+ ),
97
81
  ):
98
82
  """Run Planar in development mode"""
99
- run_command(Environment.DEV, port, host, config, path, app_name, script)
83
+ run_command(
84
+ Environment.DEV,
85
+ port,
86
+ host,
87
+ config,
88
+ path,
89
+ app_name,
90
+ script,
91
+ ssl_keyfile,
92
+ ssl_certfile,
93
+ )
100
94
 
101
95
 
102
96
  @app.command("prod")
@@ -119,9 +113,25 @@ def prod_command(
119
113
  "--script",
120
114
  help="Run as a script with 'uv run' instead of starting a server",
121
115
  ),
116
+ ssl_keyfile: str | None = typer.Option(
117
+ None, "--ssl-keyfile", help="Path to SSL key file"
118
+ ),
119
+ ssl_certfile: str | None = typer.Option(
120
+ None, "--ssl-certfile", help="Path to SSL cert file"
121
+ ),
122
122
  ):
123
123
  """Run Planar in production mode"""
124
- run_command(Environment.PROD, port, host, config, path, app_name, script)
124
+ run_command(
125
+ Environment.PROD,
126
+ port,
127
+ host,
128
+ config,
129
+ path,
130
+ app_name,
131
+ script,
132
+ ssl_keyfile,
133
+ ssl_certfile,
134
+ )
125
135
 
126
136
 
127
137
  def run_command(
@@ -132,6 +142,8 @@ def run_command(
132
142
  path: Path | None,
133
143
  app_name: str,
134
144
  script: bool = False,
145
+ ssl_keyfile: str | None = None,
146
+ ssl_certfile: str | None = None,
135
147
  ):
136
148
  """Common logic for both dev and prod commands"""
137
149
  os.environ["PLANAR_ENV"] = env.value
@@ -188,15 +200,15 @@ def run_command(
188
200
  typer.echo(f"Starting Planar in {env.value} mode")
189
201
 
190
202
  try:
191
- config = uvicorn.Config(
203
+ uvicorn.run(
192
204
  app_import_string,
193
205
  host=host or ("127.0.0.1" if env == Environment.DEV else "0.0.0.0"),
194
206
  port=port or 8000,
195
207
  reload=True if env == Environment.DEV else False,
196
208
  timeout_graceful_shutdown=4,
209
+ ssl_keyfile=ssl_keyfile,
210
+ ssl_certfile=ssl_certfile,
197
211
  )
198
-
199
- PlanarServer(config, app_import_string).run()
200
212
  except Exception as e:
201
213
  # Provide more context on import errors
202
214
  if isinstance(e, (ImportError, AttributeError)):