nvidia-nat 1.1.0a20251020__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (480) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +265 -0
  4. nat/agent/dual_node.py +72 -0
  5. nat/agent/prompt_optimizer/__init__.py +0 -0
  6. nat/agent/prompt_optimizer/prompt.py +68 -0
  7. nat/agent/prompt_optimizer/register.py +149 -0
  8. nat/agent/react_agent/__init__.py +0 -0
  9. nat/agent/react_agent/agent.py +394 -0
  10. nat/agent/react_agent/output_parser.py +104 -0
  11. nat/agent/react_agent/prompt.py +44 -0
  12. nat/agent/react_agent/register.py +168 -0
  13. nat/agent/reasoning_agent/__init__.py +0 -0
  14. nat/agent/reasoning_agent/reasoning_agent.py +227 -0
  15. nat/agent/register.py +23 -0
  16. nat/agent/rewoo_agent/__init__.py +0 -0
  17. nat/agent/rewoo_agent/agent.py +593 -0
  18. nat/agent/rewoo_agent/prompt.py +107 -0
  19. nat/agent/rewoo_agent/register.py +175 -0
  20. nat/agent/tool_calling_agent/__init__.py +0 -0
  21. nat/agent/tool_calling_agent/agent.py +246 -0
  22. nat/agent/tool_calling_agent/register.py +129 -0
  23. nat/authentication/__init__.py +14 -0
  24. nat/authentication/api_key/__init__.py +14 -0
  25. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  26. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  27. nat/authentication/api_key/register.py +26 -0
  28. nat/authentication/credential_validator/__init__.py +14 -0
  29. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  30. nat/authentication/exceptions/__init__.py +14 -0
  31. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  32. nat/authentication/http_basic_auth/__init__.py +0 -0
  33. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  34. nat/authentication/http_basic_auth/register.py +30 -0
  35. nat/authentication/interfaces.py +96 -0
  36. nat/authentication/oauth2/__init__.py +14 -0
  37. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
  38. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  39. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  40. nat/authentication/oauth2/register.py +25 -0
  41. nat/authentication/register.py +20 -0
  42. nat/builder/__init__.py +0 -0
  43. nat/builder/builder.py +317 -0
  44. nat/builder/component_utils.py +320 -0
  45. nat/builder/context.py +321 -0
  46. nat/builder/embedder.py +24 -0
  47. nat/builder/eval_builder.py +166 -0
  48. nat/builder/evaluator.py +29 -0
  49. nat/builder/framework_enum.py +25 -0
  50. nat/builder/front_end.py +73 -0
  51. nat/builder/function.py +714 -0
  52. nat/builder/function_base.py +380 -0
  53. nat/builder/function_info.py +625 -0
  54. nat/builder/intermediate_step_manager.py +206 -0
  55. nat/builder/llm.py +25 -0
  56. nat/builder/retriever.py +25 -0
  57. nat/builder/user_interaction_manager.py +78 -0
  58. nat/builder/workflow.py +160 -0
  59. nat/builder/workflow_builder.py +1365 -0
  60. nat/cli/__init__.py +14 -0
  61. nat/cli/cli_utils/__init__.py +0 -0
  62. nat/cli/cli_utils/config_override.py +231 -0
  63. nat/cli/cli_utils/validation.py +37 -0
  64. nat/cli/commands/__init__.py +0 -0
  65. nat/cli/commands/configure/__init__.py +0 -0
  66. nat/cli/commands/configure/channel/__init__.py +0 -0
  67. nat/cli/commands/configure/channel/add.py +28 -0
  68. nat/cli/commands/configure/channel/channel.py +34 -0
  69. nat/cli/commands/configure/channel/remove.py +30 -0
  70. nat/cli/commands/configure/channel/update.py +30 -0
  71. nat/cli/commands/configure/configure.py +33 -0
  72. nat/cli/commands/evaluate.py +139 -0
  73. nat/cli/commands/info/__init__.py +14 -0
  74. nat/cli/commands/info/info.py +47 -0
  75. nat/cli/commands/info/list_channels.py +32 -0
  76. nat/cli/commands/info/list_components.py +128 -0
  77. nat/cli/commands/mcp/__init__.py +14 -0
  78. nat/cli/commands/mcp/mcp.py +986 -0
  79. nat/cli/commands/object_store/__init__.py +14 -0
  80. nat/cli/commands/object_store/object_store.py +227 -0
  81. nat/cli/commands/optimize.py +90 -0
  82. nat/cli/commands/registry/__init__.py +14 -0
  83. nat/cli/commands/registry/publish.py +88 -0
  84. nat/cli/commands/registry/pull.py +118 -0
  85. nat/cli/commands/registry/registry.py +36 -0
  86. nat/cli/commands/registry/remove.py +108 -0
  87. nat/cli/commands/registry/search.py +153 -0
  88. nat/cli/commands/sizing/__init__.py +14 -0
  89. nat/cli/commands/sizing/calc.py +297 -0
  90. nat/cli/commands/sizing/sizing.py +27 -0
  91. nat/cli/commands/start.py +257 -0
  92. nat/cli/commands/uninstall.py +81 -0
  93. nat/cli/commands/validate.py +47 -0
  94. nat/cli/commands/workflow/__init__.py +14 -0
  95. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  96. nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
  97. nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
  98. nat/cli/commands/workflow/templates/register.py.j2 +4 -0
  99. nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
  100. nat/cli/commands/workflow/workflow.py +37 -0
  101. nat/cli/commands/workflow/workflow_commands.py +403 -0
  102. nat/cli/entrypoint.py +141 -0
  103. nat/cli/main.py +60 -0
  104. nat/cli/register_workflow.py +522 -0
  105. nat/cli/type_registry.py +1069 -0
  106. nat/control_flow/__init__.py +0 -0
  107. nat/control_flow/register.py +20 -0
  108. nat/control_flow/router_agent/__init__.py +0 -0
  109. nat/control_flow/router_agent/agent.py +329 -0
  110. nat/control_flow/router_agent/prompt.py +48 -0
  111. nat/control_flow/router_agent/register.py +91 -0
  112. nat/control_flow/sequential_executor.py +166 -0
  113. nat/data_models/__init__.py +14 -0
  114. nat/data_models/agent.py +34 -0
  115. nat/data_models/api_server.py +843 -0
  116. nat/data_models/authentication.py +245 -0
  117. nat/data_models/common.py +171 -0
  118. nat/data_models/component.py +60 -0
  119. nat/data_models/component_ref.py +179 -0
  120. nat/data_models/config.py +434 -0
  121. nat/data_models/dataset_handler.py +169 -0
  122. nat/data_models/discovery_metadata.py +305 -0
  123. nat/data_models/embedder.py +27 -0
  124. nat/data_models/evaluate.py +130 -0
  125. nat/data_models/evaluator.py +26 -0
  126. nat/data_models/front_end.py +26 -0
  127. nat/data_models/function.py +64 -0
  128. nat/data_models/function_dependencies.py +80 -0
  129. nat/data_models/gated_field_mixin.py +242 -0
  130. nat/data_models/interactive.py +246 -0
  131. nat/data_models/intermediate_step.py +302 -0
  132. nat/data_models/invocation_node.py +38 -0
  133. nat/data_models/llm.py +27 -0
  134. nat/data_models/logging.py +26 -0
  135. nat/data_models/memory.py +27 -0
  136. nat/data_models/object_store.py +44 -0
  137. nat/data_models/optimizable.py +119 -0
  138. nat/data_models/optimizer.py +149 -0
  139. nat/data_models/profiler.py +54 -0
  140. nat/data_models/registry_handler.py +26 -0
  141. nat/data_models/retriever.py +30 -0
  142. nat/data_models/retry_mixin.py +35 -0
  143. nat/data_models/span.py +228 -0
  144. nat/data_models/step_adaptor.py +64 -0
  145. nat/data_models/streaming.py +33 -0
  146. nat/data_models/swe_bench_model.py +54 -0
  147. nat/data_models/telemetry_exporter.py +26 -0
  148. nat/data_models/temperature_mixin.py +44 -0
  149. nat/data_models/thinking_mixin.py +86 -0
  150. nat/data_models/top_p_mixin.py +44 -0
  151. nat/data_models/ttc_strategy.py +30 -0
  152. nat/embedder/__init__.py +0 -0
  153. nat/embedder/azure_openai_embedder.py +46 -0
  154. nat/embedder/nim_embedder.py +59 -0
  155. nat/embedder/openai_embedder.py +42 -0
  156. nat/embedder/register.py +22 -0
  157. nat/eval/__init__.py +14 -0
  158. nat/eval/config.py +62 -0
  159. nat/eval/dataset_handler/__init__.py +0 -0
  160. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  161. nat/eval/dataset_handler/dataset_filter.py +52 -0
  162. nat/eval/dataset_handler/dataset_handler.py +431 -0
  163. nat/eval/evaluate.py +565 -0
  164. nat/eval/evaluator/__init__.py +14 -0
  165. nat/eval/evaluator/base_evaluator.py +77 -0
  166. nat/eval/evaluator/evaluator_model.py +58 -0
  167. nat/eval/intermediate_step_adapter.py +99 -0
  168. nat/eval/rag_evaluator/__init__.py +0 -0
  169. nat/eval/rag_evaluator/evaluate.py +178 -0
  170. nat/eval/rag_evaluator/register.py +143 -0
  171. nat/eval/register.py +26 -0
  172. nat/eval/remote_workflow.py +133 -0
  173. nat/eval/runners/__init__.py +14 -0
  174. nat/eval/runners/config.py +39 -0
  175. nat/eval/runners/multi_eval_runner.py +54 -0
  176. nat/eval/runtime_evaluator/__init__.py +14 -0
  177. nat/eval/runtime_evaluator/evaluate.py +123 -0
  178. nat/eval/runtime_evaluator/register.py +100 -0
  179. nat/eval/runtime_event_subscriber.py +52 -0
  180. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  181. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  182. nat/eval/swe_bench_evaluator/register.py +36 -0
  183. nat/eval/trajectory_evaluator/__init__.py +0 -0
  184. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  185. nat/eval/trajectory_evaluator/register.py +40 -0
  186. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  187. nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
  188. nat/eval/tunable_rag_evaluator/register.py +52 -0
  189. nat/eval/usage_stats.py +41 -0
  190. nat/eval/utils/__init__.py +0 -0
  191. nat/eval/utils/eval_trace_ctx.py +89 -0
  192. nat/eval/utils/output_uploader.py +140 -0
  193. nat/eval/utils/tqdm_position_registry.py +40 -0
  194. nat/eval/utils/weave_eval.py +193 -0
  195. nat/experimental/__init__.py +0 -0
  196. nat/experimental/decorators/__init__.py +0 -0
  197. nat/experimental/decorators/experimental_warning_decorator.py +154 -0
  198. nat/experimental/test_time_compute/__init__.py +0 -0
  199. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  200. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  201. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  202. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  203. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  204. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  205. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
  206. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  207. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  208. nat/experimental/test_time_compute/models/__init__.py +0 -0
  209. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  210. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  211. nat/experimental/test_time_compute/models/search_config.py +120 -0
  212. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  213. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  214. nat/experimental/test_time_compute/models/strategy_base.py +67 -0
  215. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  216. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  217. nat/experimental/test_time_compute/register.py +35 -0
  218. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  219. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  220. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  221. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  222. nat/experimental/test_time_compute/search/__init__.py +0 -0
  223. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  224. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  225. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  226. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  227. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  228. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  229. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
  230. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  231. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  232. nat/front_ends/__init__.py +14 -0
  233. nat/front_ends/console/__init__.py +14 -0
  234. nat/front_ends/console/authentication_flow_handler.py +285 -0
  235. nat/front_ends/console/console_front_end_config.py +32 -0
  236. nat/front_ends/console/console_front_end_plugin.py +108 -0
  237. nat/front_ends/console/register.py +25 -0
  238. nat/front_ends/cron/__init__.py +14 -0
  239. nat/front_ends/fastapi/__init__.py +14 -0
  240. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  241. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  242. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
  243. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  244. nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
  245. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  246. nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
  247. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
  248. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  249. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  250. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  251. nat/front_ends/fastapi/job_store.py +602 -0
  252. nat/front_ends/fastapi/main.py +64 -0
  253. nat/front_ends/fastapi/message_handler.py +344 -0
  254. nat/front_ends/fastapi/message_validator.py +351 -0
  255. nat/front_ends/fastapi/register.py +25 -0
  256. nat/front_ends/fastapi/response_helpers.py +195 -0
  257. nat/front_ends/fastapi/step_adaptor.py +319 -0
  258. nat/front_ends/fastapi/utils.py +57 -0
  259. nat/front_ends/mcp/__init__.py +14 -0
  260. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  261. nat/front_ends/mcp/mcp_front_end_config.py +90 -0
  262. nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
  263. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
  264. nat/front_ends/mcp/memory_profiler.py +320 -0
  265. nat/front_ends/mcp/register.py +27 -0
  266. nat/front_ends/mcp/tool_converter.py +290 -0
  267. nat/front_ends/register.py +21 -0
  268. nat/front_ends/simple_base/__init__.py +14 -0
  269. nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
  270. nat/llm/__init__.py +0 -0
  271. nat/llm/aws_bedrock_llm.py +69 -0
  272. nat/llm/azure_openai_llm.py +57 -0
  273. nat/llm/litellm_llm.py +69 -0
  274. nat/llm/nim_llm.py +58 -0
  275. nat/llm/openai_llm.py +54 -0
  276. nat/llm/register.py +27 -0
  277. nat/llm/utils/__init__.py +14 -0
  278. nat/llm/utils/env_config_value.py +93 -0
  279. nat/llm/utils/error.py +17 -0
  280. nat/llm/utils/thinking.py +215 -0
  281. nat/memory/__init__.py +20 -0
  282. nat/memory/interfaces.py +183 -0
  283. nat/memory/models.py +112 -0
  284. nat/meta/pypi.md +58 -0
  285. nat/object_store/__init__.py +20 -0
  286. nat/object_store/in_memory_object_store.py +76 -0
  287. nat/object_store/interfaces.py +84 -0
  288. nat/object_store/models.py +38 -0
  289. nat/object_store/register.py +19 -0
  290. nat/observability/__init__.py +14 -0
  291. nat/observability/exporter/__init__.py +14 -0
  292. nat/observability/exporter/base_exporter.py +449 -0
  293. nat/observability/exporter/exporter.py +78 -0
  294. nat/observability/exporter/file_exporter.py +33 -0
  295. nat/observability/exporter/processing_exporter.py +550 -0
  296. nat/observability/exporter/raw_exporter.py +52 -0
  297. nat/observability/exporter/span_exporter.py +308 -0
  298. nat/observability/exporter_manager.py +335 -0
  299. nat/observability/mixin/__init__.py +14 -0
  300. nat/observability/mixin/batch_config_mixin.py +26 -0
  301. nat/observability/mixin/collector_config_mixin.py +23 -0
  302. nat/observability/mixin/file_mixin.py +288 -0
  303. nat/observability/mixin/file_mode.py +23 -0
  304. nat/observability/mixin/redaction_config_mixin.py +42 -0
  305. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  306. nat/observability/mixin/serialize_mixin.py +61 -0
  307. nat/observability/mixin/tagging_config_mixin.py +62 -0
  308. nat/observability/mixin/type_introspection_mixin.py +496 -0
  309. nat/observability/processor/__init__.py +14 -0
  310. nat/observability/processor/batching_processor.py +308 -0
  311. nat/observability/processor/callback_processor.py +42 -0
  312. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  313. nat/observability/processor/intermediate_step_serializer.py +28 -0
  314. nat/observability/processor/processor.py +74 -0
  315. nat/observability/processor/processor_factory.py +70 -0
  316. nat/observability/processor/redaction/__init__.py +24 -0
  317. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  318. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  319. nat/observability/processor/redaction/redaction_processor.py +177 -0
  320. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  321. nat/observability/processor/span_tagging_processor.py +68 -0
  322. nat/observability/register.py +114 -0
  323. nat/observability/utils/__init__.py +14 -0
  324. nat/observability/utils/dict_utils.py +236 -0
  325. nat/observability/utils/time_utils.py +31 -0
  326. nat/plugins/.namespace +1 -0
  327. nat/profiler/__init__.py +0 -0
  328. nat/profiler/calc/__init__.py +14 -0
  329. nat/profiler/calc/calc_runner.py +626 -0
  330. nat/profiler/calc/calculations.py +288 -0
  331. nat/profiler/calc/data_models.py +188 -0
  332. nat/profiler/calc/plot.py +345 -0
  333. nat/profiler/callbacks/__init__.py +0 -0
  334. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  335. nat/profiler/callbacks/base_callback_class.py +20 -0
  336. nat/profiler/callbacks/langchain_callback_handler.py +297 -0
  337. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  338. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  339. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  340. nat/profiler/data_frame_row.py +51 -0
  341. nat/profiler/data_models.py +24 -0
  342. nat/profiler/decorators/__init__.py +0 -0
  343. nat/profiler/decorators/framework_wrapper.py +180 -0
  344. nat/profiler/decorators/function_tracking.py +411 -0
  345. nat/profiler/forecasting/__init__.py +0 -0
  346. nat/profiler/forecasting/config.py +18 -0
  347. nat/profiler/forecasting/model_trainer.py +75 -0
  348. nat/profiler/forecasting/models/__init__.py +22 -0
  349. nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
  350. nat/profiler/forecasting/models/linear_model.py +197 -0
  351. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  352. nat/profiler/inference_metrics_model.py +28 -0
  353. nat/profiler/inference_optimization/__init__.py +0 -0
  354. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  355. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  356. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  357. nat/profiler/inference_optimization/data_models.py +386 -0
  358. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  359. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  360. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
  361. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  362. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  363. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  364. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  365. nat/profiler/intermediate_property_adapter.py +102 -0
  366. nat/profiler/parameter_optimization/__init__.py +0 -0
  367. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  368. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  369. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  370. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  371. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  372. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  373. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  374. nat/profiler/profile_runner.py +478 -0
  375. nat/profiler/utils.py +186 -0
  376. nat/registry_handlers/__init__.py +0 -0
  377. nat/registry_handlers/local/__init__.py +0 -0
  378. nat/registry_handlers/local/local_handler.py +176 -0
  379. nat/registry_handlers/local/register_local.py +37 -0
  380. nat/registry_handlers/metadata_factory.py +60 -0
  381. nat/registry_handlers/package_utils.py +570 -0
  382. nat/registry_handlers/pypi/__init__.py +0 -0
  383. nat/registry_handlers/pypi/pypi_handler.py +248 -0
  384. nat/registry_handlers/pypi/register_pypi.py +40 -0
  385. nat/registry_handlers/register.py +20 -0
  386. nat/registry_handlers/registry_handler_base.py +157 -0
  387. nat/registry_handlers/rest/__init__.py +0 -0
  388. nat/registry_handlers/rest/register_rest.py +56 -0
  389. nat/registry_handlers/rest/rest_handler.py +236 -0
  390. nat/registry_handlers/schemas/__init__.py +0 -0
  391. nat/registry_handlers/schemas/headers.py +42 -0
  392. nat/registry_handlers/schemas/package.py +68 -0
  393. nat/registry_handlers/schemas/publish.py +68 -0
  394. nat/registry_handlers/schemas/pull.py +82 -0
  395. nat/registry_handlers/schemas/remove.py +36 -0
  396. nat/registry_handlers/schemas/search.py +91 -0
  397. nat/registry_handlers/schemas/status.py +47 -0
  398. nat/retriever/__init__.py +0 -0
  399. nat/retriever/interface.py +41 -0
  400. nat/retriever/milvus/__init__.py +14 -0
  401. nat/retriever/milvus/register.py +81 -0
  402. nat/retriever/milvus/retriever.py +228 -0
  403. nat/retriever/models.py +77 -0
  404. nat/retriever/nemo_retriever/__init__.py +14 -0
  405. nat/retriever/nemo_retriever/register.py +60 -0
  406. nat/retriever/nemo_retriever/retriever.py +190 -0
  407. nat/retriever/register.py +21 -0
  408. nat/runtime/__init__.py +14 -0
  409. nat/runtime/loader.py +220 -0
  410. nat/runtime/runner.py +292 -0
  411. nat/runtime/session.py +223 -0
  412. nat/runtime/user_metadata.py +130 -0
  413. nat/settings/__init__.py +0 -0
  414. nat/settings/global_settings.py +329 -0
  415. nat/test/.namespace +1 -0
  416. nat/tool/__init__.py +0 -0
  417. nat/tool/chat_completion.py +77 -0
  418. nat/tool/code_execution/README.md +151 -0
  419. nat/tool/code_execution/__init__.py +0 -0
  420. nat/tool/code_execution/code_sandbox.py +267 -0
  421. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  422. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  423. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  424. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  425. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  426. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  427. nat/tool/code_execution/register.py +74 -0
  428. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  429. nat/tool/code_execution/utils.py +100 -0
  430. nat/tool/datetime_tools.py +82 -0
  431. nat/tool/document_search.py +141 -0
  432. nat/tool/github_tools.py +450 -0
  433. nat/tool/memory_tools/__init__.py +0 -0
  434. nat/tool/memory_tools/add_memory_tool.py +79 -0
  435. nat/tool/memory_tools/delete_memory_tool.py +66 -0
  436. nat/tool/memory_tools/get_memory_tool.py +72 -0
  437. nat/tool/nvidia_rag.py +95 -0
  438. nat/tool/register.py +31 -0
  439. nat/tool/retriever.py +95 -0
  440. nat/tool/server_tools.py +66 -0
  441. nat/utils/__init__.py +0 -0
  442. nat/utils/callable_utils.py +70 -0
  443. nat/utils/data_models/__init__.py +0 -0
  444. nat/utils/data_models/schema_validator.py +58 -0
  445. nat/utils/debugging_utils.py +43 -0
  446. nat/utils/decorators.py +210 -0
  447. nat/utils/dump_distro_mapping.py +32 -0
  448. nat/utils/exception_handlers/__init__.py +0 -0
  449. nat/utils/exception_handlers/automatic_retries.py +342 -0
  450. nat/utils/exception_handlers/schemas.py +114 -0
  451. nat/utils/io/__init__.py +0 -0
  452. nat/utils/io/model_processing.py +28 -0
  453. nat/utils/io/yaml_tools.py +119 -0
  454. nat/utils/log_levels.py +25 -0
  455. nat/utils/log_utils.py +37 -0
  456. nat/utils/metadata_utils.py +74 -0
  457. nat/utils/optional_imports.py +142 -0
  458. nat/utils/producer_consumer_queue.py +178 -0
  459. nat/utils/reactive/__init__.py +0 -0
  460. nat/utils/reactive/base/__init__.py +0 -0
  461. nat/utils/reactive/base/observable_base.py +65 -0
  462. nat/utils/reactive/base/observer_base.py +55 -0
  463. nat/utils/reactive/base/subject_base.py +79 -0
  464. nat/utils/reactive/observable.py +59 -0
  465. nat/utils/reactive/observer.py +76 -0
  466. nat/utils/reactive/subject.py +131 -0
  467. nat/utils/reactive/subscription.py +49 -0
  468. nat/utils/settings/__init__.py +0 -0
  469. nat/utils/settings/global_settings.py +195 -0
  470. nat/utils/string_utils.py +38 -0
  471. nat/utils/type_converter.py +299 -0
  472. nat/utils/type_utils.py +488 -0
  473. nat/utils/url_utils.py +27 -0
  474. nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
  475. nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
  476. nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
  477. nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
  478. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  479. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
  480. nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1257 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import json
18
+ import logging
19
+ import os
20
+ import typing
21
+ from abc import ABC
22
+ from abc import abstractmethod
23
+ from collections.abc import Awaitable
24
+ from collections.abc import Callable
25
+ from contextlib import asynccontextmanager
26
+ from pathlib import Path
27
+
28
+ import httpx
29
+ from authlib.common.errors import AuthlibBaseError as OAuthError
30
+ from fastapi import Body
31
+ from fastapi import FastAPI
32
+ from fastapi import HTTPException
33
+ from fastapi import Request
34
+ from fastapi import Response
35
+ from fastapi import UploadFile
36
+ from fastapi.middleware.cors import CORSMiddleware
37
+ from fastapi.responses import StreamingResponse
38
+ from pydantic import BaseModel
39
+ from pydantic import Field
40
+ from starlette.websockets import WebSocket
41
+
42
+ from nat.builder.function import Function
43
+ from nat.builder.workflow_builder import WorkflowBuilder
44
+ from nat.data_models.api_server import ChatRequest
45
+ from nat.data_models.api_server import ChatResponse
46
+ from nat.data_models.api_server import ChatResponseChunk
47
+ from nat.data_models.api_server import ResponseIntermediateStep
48
+ from nat.data_models.config import Config
49
+ from nat.data_models.object_store import KeyAlreadyExistsError
50
+ from nat.data_models.object_store import NoSuchKeyError
51
+ from nat.eval.config import EvaluationRunOutput
52
+ from nat.eval.evaluate import EvaluationRun
53
+ from nat.eval.evaluate import EvaluationRunConfig
54
+ from nat.front_ends.fastapi.auth_flow_handlers.http_flow_handler import HTTPAuthenticationFlowHandler
55
+ from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import FlowState
56
+ from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import WebSocketAuthenticationFlowHandler
57
+ from nat.front_ends.fastapi.fastapi_front_end_config import AsyncGenerateResponse
58
+ from nat.front_ends.fastapi.fastapi_front_end_config import AsyncGenerationStatusResponse
59
+ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateRequest
60
+ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateResponse
61
+ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateStatusResponse
62
+ from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
63
+ from nat.front_ends.fastapi.message_handler import WebSocketMessageHandler
64
+ from nat.front_ends.fastapi.response_helpers import generate_single_response
65
+ from nat.front_ends.fastapi.response_helpers import generate_streaming_response_as_str
66
+ from nat.front_ends.fastapi.response_helpers import generate_streaming_response_full_as_str
67
+ from nat.front_ends.fastapi.step_adaptor import StepAdaptor
68
+ from nat.front_ends.fastapi.utils import get_config_file_path
69
+ from nat.object_store.models import ObjectStoreItem
70
+ from nat.runtime.loader import load_workflow
71
+ from nat.runtime.session import SessionManager
72
+
73
+ logger = logging.getLogger(__name__)
74
+
75
+ _DASK_AVAILABLE = False
76
+
77
+ try:
78
+ from nat.front_ends.fastapi.job_store import JobInfo
79
+ from nat.front_ends.fastapi.job_store import JobStatus
80
+ from nat.front_ends.fastapi.job_store import JobStore
81
+ _DASK_AVAILABLE = True
82
+ except ImportError:
83
+ JobInfo = None
84
+ JobStatus = None
85
+ JobStore = None
86
+
87
+
88
+ class FastApiFrontEndPluginWorkerBase(ABC):
89
+
90
+ def __init__(self, config: Config):
91
+ self._config = config
92
+
93
+ assert isinstance(config.general.front_end,
94
+ FastApiFrontEndConfig), ("Front end config is not FastApiFrontEndConfig")
95
+
96
+ self._front_end_config = config.general.front_end
97
+ self._dask_available = False
98
+ self._job_store = None
99
+ self._http_flow_handler: HTTPAuthenticationFlowHandler | None = HTTPAuthenticationFlowHandler()
100
+ self._scheduler_address = os.environ.get("NAT_DASK_SCHEDULER_ADDRESS")
101
+ self._db_url = os.environ.get("NAT_JOB_STORE_DB_URL")
102
+ self._config_file_path = get_config_file_path()
103
+
104
+ if self._scheduler_address is not None:
105
+ if not _DASK_AVAILABLE:
106
+ raise RuntimeError("Dask is not available, please install it to use the FastAPI front end with Dask.")
107
+
108
+ if self._db_url is None:
109
+ raise RuntimeError(
110
+ "NAT_JOB_STORE_DB_URL must be set when using Dask (configure a persistent JobStore database).")
111
+
112
+ try:
113
+ self._job_store = JobStore(scheduler_address=self._scheduler_address, db_url=self._db_url)
114
+ self._dask_available = True
115
+ logger.debug("Connected to Dask scheduler at %s", self._scheduler_address)
116
+ except Exception as e:
117
+ raise RuntimeError(f"Failed to connect to Dask scheduler at {self._scheduler_address}: {e}") from e
118
+ else:
119
+ logger.debug("No Dask scheduler address provided, running without Dask support.")
120
+
121
+ @property
122
+ def config(self) -> Config:
123
+ return self._config
124
+
125
+ @property
126
+ def front_end_config(self) -> FastApiFrontEndConfig:
127
+ return self._front_end_config
128
+
129
+ def build_app(self) -> FastAPI:
130
+
131
+ # Create the FastAPI app and configure it
132
+ @asynccontextmanager
133
+ async def lifespan(starting_app: FastAPI):
134
+
135
+ logger.debug("Starting NAT server from process %s", os.getpid())
136
+
137
+ async with WorkflowBuilder.from_config(self.config) as builder:
138
+
139
+ await self.configure(starting_app, builder)
140
+
141
+ yield
142
+
143
+ logger.debug("Closing NAT server from process %s", os.getpid())
144
+
145
+ nat_app = FastAPI(lifespan=lifespan)
146
+
147
+ # Configure app CORS.
148
+ self.set_cors_config(nat_app)
149
+
150
+ @nat_app.middleware("http")
151
+ async def authentication_log_filter(request: Request, call_next: Callable[[Request], Awaitable[Response]]):
152
+ return await self._suppress_authentication_logs(request, call_next)
153
+
154
+ return nat_app
155
+
156
+ def set_cors_config(self, nat_app: FastAPI) -> None:
157
+ """
158
+ Set the cross origin resource sharing configuration.
159
+ """
160
+ cors_kwargs = {}
161
+
162
+ if self.front_end_config.cors.allow_origins is not None:
163
+ cors_kwargs["allow_origins"] = self.front_end_config.cors.allow_origins
164
+
165
+ if self.front_end_config.cors.allow_origin_regex is not None:
166
+ cors_kwargs["allow_origin_regex"] = self.front_end_config.cors.allow_origin_regex
167
+
168
+ if self.front_end_config.cors.allow_methods is not None:
169
+ cors_kwargs["allow_methods"] = self.front_end_config.cors.allow_methods
170
+
171
+ if self.front_end_config.cors.allow_headers is not None:
172
+ cors_kwargs["allow_headers"] = self.front_end_config.cors.allow_headers
173
+
174
+ if self.front_end_config.cors.allow_credentials is not None:
175
+ cors_kwargs["allow_credentials"] = self.front_end_config.cors.allow_credentials
176
+
177
+ if self.front_end_config.cors.expose_headers is not None:
178
+ cors_kwargs["expose_headers"] = self.front_end_config.cors.expose_headers
179
+
180
+ if self.front_end_config.cors.max_age is not None:
181
+ cors_kwargs["max_age"] = self.front_end_config.cors.max_age
182
+
183
+ nat_app.add_middleware(
184
+ CORSMiddleware,
185
+ **cors_kwargs,
186
+ )
187
+
188
+ async def _suppress_authentication_logs(self, request: Request,
189
+ call_next: Callable[[Request], Awaitable[Response]]) -> Response:
190
+ """
191
+ Intercepts authentication request and supreses logs that contain sensitive data.
192
+ """
193
+ from nat.utils.log_utils import LogFilter
194
+
195
+ logs_to_suppress: list[str] = []
196
+
197
+ if (self.front_end_config.oauth2_callback_path):
198
+ logs_to_suppress.append(self.front_end_config.oauth2_callback_path)
199
+
200
+ logging.getLogger("uvicorn.access").addFilter(LogFilter(logs_to_suppress))
201
+ try:
202
+ response = await call_next(request)
203
+ finally:
204
+ logging.getLogger("uvicorn.access").removeFilter(LogFilter(logs_to_suppress))
205
+
206
+ return response
207
+
208
+ @abstractmethod
209
+ async def configure(self, app: FastAPI, builder: WorkflowBuilder):
210
+ pass
211
+
212
+ @abstractmethod
213
+ def get_step_adaptor(self) -> StepAdaptor:
214
+ pass
215
+
216
+
217
+ class RouteInfo(BaseModel):
218
+
219
+ function_name: str | None
220
+
221
+
222
+ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
223
+
224
+ def __init__(self, config: Config):
225
+ super().__init__(config)
226
+
227
+ self._outstanding_flows: dict[str, FlowState] = {}
228
+ self._outstanding_flows_lock = asyncio.Lock()
229
+
230
+ def get_step_adaptor(self) -> StepAdaptor:
231
+
232
+ return StepAdaptor(self.front_end_config.step_adaptor)
233
+
234
+ async def configure(self, app: FastAPI, builder: WorkflowBuilder):
235
+
236
+ # Do things like setting the base URL and global configuration options
237
+ app.root_path = self.front_end_config.root_path
238
+
239
+ await self.add_routes(app, builder)
240
+
241
+ async def add_routes(self, app: FastAPI, builder: WorkflowBuilder):
242
+
243
+ await self.add_default_route(app, SessionManager(await builder.build()))
244
+ await self.add_evaluate_route(app, SessionManager(await builder.build()))
245
+ await self.add_static_files_route(app, builder)
246
+ await self.add_authorization_route(app)
247
+ await self.add_mcp_client_tool_list_route(app, builder)
248
+
249
+ for ep in self.front_end_config.endpoints:
250
+
251
+ entry_workflow = await builder.build(entry_function=ep.function_name)
252
+
253
+ await self.add_route(app, endpoint=ep, session_manager=SessionManager(entry_workflow))
254
+
255
+ async def add_default_route(self, app: FastAPI, session_manager: SessionManager):
256
+
257
+ await self.add_route(app, self.front_end_config.workflow, session_manager)
258
+
259
+ async def add_evaluate_route(self, app: FastAPI, session_manager: SessionManager):
260
+ """Add the evaluate endpoint to the FastAPI app."""
261
+
262
+ response_500 = {
263
+ "description": "Internal Server Error",
264
+ "content": {
265
+ "application/json": {
266
+ "example": {
267
+ "detail": "Internal server error occurred"
268
+ }
269
+ }
270
+ },
271
+ }
272
+
273
+ # TODO: Find another way to limit the number of concurrent evaluations
274
+ async def run_evaluation(scheduler_address: str,
275
+ db_url: str,
276
+ workflow_config_file_path: str,
277
+ job_id: str,
278
+ eval_config_file: str,
279
+ reps: int):
280
+ """Background task to run the evaluation."""
281
+ job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
282
+
283
+ try:
284
+ # We have two config files, one for the workflow and one for the evaluation
285
+ # Create EvaluationRunConfig using the CLI defaults
286
+ eval_config = EvaluationRunConfig(config_file=Path(eval_config_file), dataset=None, reps=reps)
287
+
288
+ # Create a new EvaluationRun with the evaluation-specific config
289
+ await job_store.update_status(job_id, JobStatus.RUNNING)
290
+ eval_runner = EvaluationRun(eval_config)
291
+
292
+ async with load_workflow(workflow_config_file_path) as local_session_manager:
293
+ output: EvaluationRunOutput = await eval_runner.run_and_evaluate(
294
+ session_manager=local_session_manager, job_id=job_id)
295
+
296
+ if output.workflow_interrupted:
297
+ await job_store.update_status(job_id, JobStatus.INTERRUPTED)
298
+ else:
299
+ parent_dir = os.path.dirname(output.workflow_output_file) if output.workflow_output_file else None
300
+
301
+ await job_store.update_status(job_id, JobStatus.SUCCESS, output_path=str(parent_dir))
302
+ except Exception as e:
303
+ logger.exception("Error in evaluation job %s", job_id)
304
+ await job_store.update_status(job_id, JobStatus.FAILURE, error=str(e))
305
+
306
+ async def start_evaluation(request: EvaluateRequest, http_request: Request):
307
+ """Handle evaluation requests."""
308
+
309
+ async with session_manager.session(http_connection=http_request):
310
+
311
+ # if job_id is present and already exists return the job info
312
+ # There is a race condition between this check and the actual job submission, however if the client is
313
+ # supplying their own job_ids, then it is their responsibility to ensure that the job_id is unique.
314
+ if request.job_id:
315
+ job_status = await self._job_store.get_status(request.job_id)
316
+ if job_status != JobStatus.NOT_FOUND:
317
+ return EvaluateResponse(job_id=request.job_id, status=job_status)
318
+
319
+ job_id = self._job_store.ensure_job_id(request.job_id)
320
+
321
+ await self._job_store.submit_job(job_id=job_id,
322
+ config_file=request.config_file,
323
+ expiry_seconds=request.expiry_seconds,
324
+ job_fn=run_evaluation,
325
+ job_args=[
326
+ self._scheduler_address,
327
+ self._db_url,
328
+ self._config_file_path,
329
+ job_id,
330
+ request.config_file,
331
+ request.reps
332
+ ])
333
+
334
+ logger.info("Submitted evaluation job %s with config %s", job_id, request.config_file)
335
+
336
+ return EvaluateResponse(job_id=job_id, status=JobStatus.SUBMITTED)
337
+
338
+ def translate_job_to_response(job: "JobInfo") -> EvaluateStatusResponse:
339
+ """Translate a JobInfo object to an EvaluateStatusResponse."""
340
+ return EvaluateStatusResponse(job_id=job.job_id,
341
+ status=job.status,
342
+ config_file=str(job.config_file),
343
+ error=job.error,
344
+ output_path=str(job.output_path),
345
+ created_at=job.created_at,
346
+ updated_at=job.updated_at,
347
+ expires_at=self._job_store.get_expires_at(job))
348
+
349
+ async def get_job_status(job_id: str, http_request: Request) -> EvaluateStatusResponse:
350
+ """Get the status of an evaluation job."""
351
+ logger.info("Getting status for job %s", job_id)
352
+
353
+ async with session_manager.session(http_connection=http_request):
354
+
355
+ job = await self._job_store.get_job(job_id)
356
+ if not job:
357
+ logger.warning("Job %s not found", job_id)
358
+ raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
359
+ logger.info("Found job %s with status %s", job_id, job.status)
360
+ return translate_job_to_response(job)
361
+
362
+ async def get_last_job_status(http_request: Request) -> EvaluateStatusResponse:
363
+ """Get the status of the last created evaluation job."""
364
+ logger.info("Getting last job status")
365
+
366
+ async with session_manager.session(http_connection=http_request):
367
+
368
+ job = await self._job_store.get_last_job()
369
+ if not job:
370
+ logger.warning("No jobs found when requesting last job status")
371
+ raise HTTPException(status_code=404, detail="No jobs found")
372
+ logger.info("Found last job %s with status %s", job.job_id, job.status)
373
+ return translate_job_to_response(job)
374
+
375
+ async def get_jobs(http_request: Request, status: str | None = None) -> list[EvaluateStatusResponse]:
376
+ """Get all jobs, optionally filtered by status."""
377
+
378
+ async with session_manager.session(http_connection=http_request):
379
+
380
+ if status is None:
381
+ logger.info("Getting all jobs")
382
+ jobs = await self._job_store.get_all_jobs()
383
+ else:
384
+ logger.info("Getting jobs with status %s", status)
385
+ jobs = await self._job_store.get_jobs_by_status(JobStatus(status))
386
+
387
+ logger.info("Found %d jobs", len(jobs))
388
+ return [translate_job_to_response(job) for job in jobs]
389
+
390
+ if self.front_end_config.evaluate.path:
391
+ if self._dask_available:
392
+ # Add last job endpoint first (most specific)
393
+ app.add_api_route(
394
+ path=f"{self.front_end_config.evaluate.path}/job/last",
395
+ endpoint=get_last_job_status,
396
+ methods=["GET"],
397
+ response_model=EvaluateStatusResponse,
398
+ description="Get the status of the last created evaluation job",
399
+ responses={
400
+ 404: {
401
+ "description": "No jobs found"
402
+ }, 500: response_500
403
+ },
404
+ )
405
+
406
+ # Add specific job endpoint (least specific)
407
+ app.add_api_route(
408
+ path=f"{self.front_end_config.evaluate.path}/job/{{job_id}}",
409
+ endpoint=get_job_status,
410
+ methods=["GET"],
411
+ response_model=EvaluateStatusResponse,
412
+ description="Get the status of an evaluation job",
413
+ responses={
414
+ 404: {
415
+ "description": "Job not found"
416
+ }, 500: response_500
417
+ },
418
+ )
419
+
420
+ # Add jobs endpoint with optional status query parameter
421
+ app.add_api_route(
422
+ path=f"{self.front_end_config.evaluate.path}/jobs",
423
+ endpoint=get_jobs,
424
+ methods=["GET"],
425
+ response_model=list[EvaluateStatusResponse],
426
+ description="Get all jobs, optionally filtered by status",
427
+ responses={500: response_500},
428
+ )
429
+
430
+ # Add HTTP endpoint for evaluation
431
+ app.add_api_route(
432
+ path=self.front_end_config.evaluate.path,
433
+ endpoint=start_evaluation,
434
+ methods=[self.front_end_config.evaluate.method],
435
+ response_model=EvaluateResponse,
436
+ description=self.front_end_config.evaluate.description,
437
+ responses={500: response_500},
438
+ )
439
+ else:
440
+ logger.warning("Dask is not available, evaluation endpoints will not be added.")
441
+
442
+ async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
443
+
444
+ if not self.front_end_config.object_store:
445
+ logger.debug("No object store configured, skipping static files route")
446
+ return
447
+
448
+ object_store_client = await builder.get_object_store_client(self.front_end_config.object_store)
449
+
450
+ def sanitize_path(path: str) -> str:
451
+ sanitized_path = os.path.normpath(path.strip("/"))
452
+ if sanitized_path == ".":
453
+ raise HTTPException(status_code=400, detail="Invalid file path.")
454
+ filename = os.path.basename(sanitized_path)
455
+ if not filename:
456
+ raise HTTPException(status_code=400, detail="Filename cannot be empty.")
457
+ return sanitized_path
458
+
459
+ # Upload static files to the object store; if key is present, it will fail with 409 Conflict
460
+ async def add_static_file(file_path: str, file: UploadFile):
461
+ sanitized_file_path = sanitize_path(file_path)
462
+ file_data = await file.read()
463
+
464
+ try:
465
+ await object_store_client.put_object(sanitized_file_path,
466
+ ObjectStoreItem(data=file_data, content_type=file.content_type))
467
+ except KeyAlreadyExistsError as e:
468
+ raise HTTPException(status_code=409, detail=str(e)) from e
469
+
470
+ return {"filename": sanitized_file_path}
471
+
472
+ # Upsert static files to the object store; if key is present, it will overwrite the file
473
+ async def upsert_static_file(file_path: str, file: UploadFile):
474
+ sanitized_file_path = sanitize_path(file_path)
475
+ file_data = await file.read()
476
+
477
+ await object_store_client.upsert_object(sanitized_file_path,
478
+ ObjectStoreItem(data=file_data, content_type=file.content_type))
479
+
480
+ return {"filename": sanitized_file_path}
481
+
482
+ # Get static files from the object store
483
+ async def get_static_file(file_path: str):
484
+
485
+ try:
486
+ file_data = await object_store_client.get_object(file_path)
487
+ except NoSuchKeyError as e:
488
+ raise HTTPException(status_code=404, detail=str(e)) from e
489
+
490
+ filename = file_path.split("/")[-1]
491
+
492
+ async def reader():
493
+ yield file_data.data
494
+
495
+ return StreamingResponse(reader(),
496
+ media_type=file_data.content_type,
497
+ headers={"Content-Disposition": f"attachment; filename={filename}"})
498
+
499
+ async def delete_static_file(file_path: str):
500
+ try:
501
+ await object_store_client.delete_object(file_path)
502
+ except NoSuchKeyError as e:
503
+ raise HTTPException(status_code=404, detail=str(e)) from e
504
+
505
+ return Response(status_code=204)
506
+
507
+ # Add the static files route to the FastAPI app
508
+ app.add_api_route(
509
+ path="/static/{file_path:path}",
510
+ endpoint=add_static_file,
511
+ methods=["POST"],
512
+ description="Upload a static file to the object store",
513
+ )
514
+
515
+ app.add_api_route(
516
+ path="/static/{file_path:path}",
517
+ endpoint=upsert_static_file,
518
+ methods=["PUT"],
519
+ description="Upsert a static file to the object store",
520
+ )
521
+
522
+ app.add_api_route(
523
+ path="/static/{file_path:path}",
524
+ endpoint=get_static_file,
525
+ methods=["GET"],
526
+ description="Get a static file from the object store",
527
+ )
528
+
529
+ app.add_api_route(
530
+ path="/static/{file_path:path}",
531
+ endpoint=delete_static_file,
532
+ methods=["DELETE"],
533
+ description="Delete a static file from the object store",
534
+ )
535
+
536
+ async def add_route(self,
537
+ app: FastAPI,
538
+ endpoint: FastApiFrontEndConfig.EndpointBase,
539
+ session_manager: SessionManager):
540
+
541
+ workflow = session_manager.workflow
542
+
543
+ GenerateBodyType = workflow.input_schema
544
+ GenerateStreamResponseType = workflow.streaming_output_schema
545
+ GenerateSingleResponseType = workflow.single_output_schema
546
+
547
+ if self._dask_available:
548
+ # Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
549
+ # Consider prefixing these with "nat_" to avoid conflicts
550
+
551
+ class AsyncGenerateRequest(GenerateBodyType):
552
+ job_id: str | None = Field(default=None, description="Unique identifier for the evaluation job")
553
+ sync_timeout: int = Field(
554
+ default=0,
555
+ ge=0,
556
+ le=300,
557
+ description="Attempt to perform the job synchronously up until `sync_timeout` sectonds, "
558
+ "if the job hasn't been completed by then a job_id will be returned with a status code of 202.")
559
+ expiry_seconds: int = Field(default=JobStore.DEFAULT_EXPIRY,
560
+ ge=JobStore.MIN_EXPIRY,
561
+ le=JobStore.MAX_EXPIRY,
562
+ description="Optional time (in seconds) before the job expires. "
563
+ "Clamped between 600 (10 min) and 86400 (24h).")
564
+
565
+ # Ensure that the input is in the body. POD types are treated as query parameters
566
+ if (not issubclass(GenerateBodyType, BaseModel)):
567
+ GenerateBodyType = typing.Annotated[GenerateBodyType, Body()]
568
+ else:
569
+ logger.info("Expecting generate request payloads in the following format: %s",
570
+ GenerateBodyType.model_fields)
571
+
572
+ response_500 = {
573
+ "description": "Internal Server Error",
574
+ "content": {
575
+ "application/json": {
576
+ "example": {
577
+ "detail": "Internal server error occurred"
578
+ }
579
+ }
580
+ },
581
+ }
582
+
583
+ def get_single_endpoint(result_type: type | None):
584
+
585
+ async def get_single(response: Response, request: Request):
586
+
587
+ response.headers["Content-Type"] = "application/json"
588
+
589
+ async with session_manager.session(http_connection=request,
590
+ user_authentication_callback=self._http_flow_handler.authenticate):
591
+
592
+ return await generate_single_response(None, session_manager, result_type=result_type)
593
+
594
+ return get_single
595
+
596
+ def get_streaming_endpoint(streaming: bool, result_type: type | None, output_type: type | None):
597
+
598
+ async def get_stream(request: Request):
599
+
600
+ async with session_manager.session(http_connection=request,
601
+ user_authentication_callback=self._http_flow_handler.authenticate):
602
+
603
+ return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
604
+ content=generate_streaming_response_as_str(
605
+ None,
606
+ session_manager=session_manager,
607
+ streaming=streaming,
608
+ step_adaptor=self.get_step_adaptor(),
609
+ result_type=result_type,
610
+ output_type=output_type))
611
+
612
+ return get_stream
613
+
614
+ def get_streaming_raw_endpoint(streaming: bool, result_type: type | None, output_type: type | None):
615
+
616
+ async def get_stream(filter_steps: str | None = None):
617
+
618
+ return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
619
+ content=generate_streaming_response_full_as_str(
620
+ None,
621
+ session_manager=session_manager,
622
+ streaming=streaming,
623
+ result_type=result_type,
624
+ output_type=output_type,
625
+ filter_steps=filter_steps))
626
+
627
+ return get_stream
628
+
629
+ def post_single_endpoint(request_type: type, result_type: type | None):
630
+
631
+ async def post_single(response: Response, request: Request, payload: request_type):
632
+
633
+ response.headers["Content-Type"] = "application/json"
634
+
635
+ async with session_manager.session(http_connection=request,
636
+ user_authentication_callback=self._http_flow_handler.authenticate):
637
+
638
+ return await generate_single_response(payload, session_manager, result_type=result_type)
639
+
640
+ return post_single
641
+
642
+ def post_streaming_endpoint(request_type: type,
643
+ streaming: bool,
644
+ result_type: type | None,
645
+ output_type: type | None):
646
+
647
+ async def post_stream(request: Request, payload: request_type):
648
+
649
+ async with session_manager.session(http_connection=request,
650
+ user_authentication_callback=self._http_flow_handler.authenticate):
651
+
652
+ return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
653
+ content=generate_streaming_response_as_str(
654
+ payload,
655
+ session_manager=session_manager,
656
+ streaming=streaming,
657
+ step_adaptor=self.get_step_adaptor(),
658
+ result_type=result_type,
659
+ output_type=output_type))
660
+
661
+ return post_stream
662
+
663
+ def post_streaming_raw_endpoint(request_type: type,
664
+ streaming: bool,
665
+ result_type: type | None,
666
+ output_type: type | None):
667
+ """
668
+ Stream raw intermediate steps without any step adaptor translations.
669
+ """
670
+
671
+ async def post_stream(payload: request_type, filter_steps: str | None = None):
672
+
673
+ return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
674
+ content=generate_streaming_response_full_as_str(
675
+ payload,
676
+ session_manager=session_manager,
677
+ streaming=streaming,
678
+ result_type=result_type,
679
+ output_type=output_type,
680
+ filter_steps=filter_steps))
681
+
682
+ return post_stream
683
+
684
+ def post_openai_api_compatible_endpoint(request_type: type):
685
+ """
686
+ OpenAI-compatible endpoint that handles both streaming and non-streaming
687
+ based on the 'stream' parameter in the request.
688
+ """
689
+
690
+ async def post_openai_api_compatible(response: Response, request: Request, payload: request_type):
691
+ # Check if streaming is requested
692
+
693
+ response.headers["Content-Type"] = "application/json"
694
+ stream_requested = getattr(payload, 'stream', False)
695
+
696
+ async with session_manager.session(http_connection=request):
697
+ if stream_requested:
698
+
699
+ # Return streaming response
700
+ return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
701
+ content=generate_streaming_response_as_str(
702
+ payload,
703
+ session_manager=session_manager,
704
+ streaming=True,
705
+ step_adaptor=self.get_step_adaptor(),
706
+ result_type=ChatResponseChunk,
707
+ output_type=ChatResponseChunk))
708
+
709
+ return await generate_single_response(payload, session_manager, result_type=ChatResponse)
710
+
711
+ return post_openai_api_compatible
712
+
713
+ def _job_status_to_response(job: "JobInfo") -> AsyncGenerationStatusResponse:
714
+ job_output = job.output
715
+ if job_output is not None:
716
+ try:
717
+ job_output = json.loads(job_output)
718
+ except json.JSONDecodeError:
719
+ logger.error("Failed to parse job output as JSON: %s", job_output)
720
+ job_output = {"error": "Output parsing failed"}
721
+
722
+ return AsyncGenerationStatusResponse(job_id=job.job_id,
723
+ status=job.status,
724
+ error=job.error,
725
+ output=job_output,
726
+ created_at=job.created_at,
727
+ updated_at=job.updated_at,
728
+ expires_at=self._job_store.get_expires_at(job))
729
+
730
+ async def run_generation(scheduler_address: str,
731
+ db_url: str,
732
+ config_file_path: str,
733
+ job_id: str,
734
+ payload: typing.Any):
735
+ """Background task to run the workflow."""
736
+ job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
737
+ try:
738
+ async with load_workflow(config_file_path) as local_session_manager:
739
+ result = await generate_single_response(
740
+ payload, local_session_manager, result_type=local_session_manager.workflow.single_output_schema)
741
+
742
+ await job_store.update_status(job_id, JobStatus.SUCCESS, output=result)
743
+ except Exception as e:
744
+ logger.exception("Error in async job %s", job_id)
745
+ await job_store.update_status(job_id, JobStatus.FAILURE, error=str(e))
746
+
747
+ def post_async_generation(request_type: type):
748
+
749
+ async def start_async_generation(
750
+ request: request_type, response: Response,
751
+ http_request: Request) -> AsyncGenerateResponse | AsyncGenerationStatusResponse:
752
+ """Handle async generation requests."""
753
+
754
+ async with session_manager.session(http_connection=http_request):
755
+
756
+ # if job_id is present and already exists return the job info
757
+ if request.job_id:
758
+ job = await self._job_store.get_job(request.job_id)
759
+ if job:
760
+ return AsyncGenerateResponse(job_id=job.job_id, status=job.status)
761
+
762
+ job_id = self._job_store.ensure_job_id(request.job_id)
763
+ (_, job) = await self._job_store.submit_job(job_id=job_id,
764
+ expiry_seconds=request.expiry_seconds,
765
+ job_fn=run_generation,
766
+ sync_timeout=request.sync_timeout,
767
+ job_args=[
768
+ self._scheduler_address,
769
+ self._db_url,
770
+ self._config_file_path,
771
+ job_id,
772
+ request.model_dump(mode="json")
773
+ ])
774
+
775
+ if job is not None:
776
+ response.status_code = 200
777
+ return _job_status_to_response(job)
778
+
779
+ response.status_code = 202
780
+ return AsyncGenerateResponse(job_id=job_id, status=JobStatus.SUBMITTED)
781
+
782
+ return start_async_generation
783
+
784
+ async def get_async_job_status(job_id: str, http_request: Request) -> AsyncGenerationStatusResponse:
785
+ """Get the status of an async job."""
786
+ logger.info("Getting status for job %s", job_id)
787
+
788
+ async with session_manager.session(http_connection=http_request):
789
+
790
+ job = await self._job_store.get_job(job_id)
791
+ if job is None:
792
+ logger.warning("Job %s not found", job_id)
793
+ raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
794
+
795
+ logger.info("Found job %s with status %s", job_id, job.status)
796
+ return _job_status_to_response(job)
797
+
798
+ async def websocket_endpoint(websocket: WebSocket):
799
+
800
+ # Universal cookie handling: works for both cross-origin and same-origin connections
801
+ session_id = websocket.query_params.get("session")
802
+ if session_id:
803
+ headers = list(websocket.scope.get("headers", []))
804
+ cookie_header = f"nat-session={session_id}"
805
+
806
+ # Check if the session cookie already exists to avoid duplicates
807
+ cookie_exists = False
808
+ existing_session_cookie = False
809
+
810
+ for i, (name, value) in enumerate(headers):
811
+ if name == b"cookie":
812
+ cookie_exists = True
813
+ cookie_str = value.decode()
814
+
815
+ # Check if nat-session already exists in cookies
816
+ if "nat-session=" in cookie_str:
817
+ existing_session_cookie = True
818
+ logger.info("WebSocket: Session cookie already present in headers (same-origin)")
819
+ else:
820
+ # Append to existing cookie header (cross-origin case)
821
+ headers[i] = (name, f"{cookie_str}; {cookie_header}".encode())
822
+ logger.info("WebSocket: Added session cookie to existing cookie header: %s",
823
+ session_id[:10] + "...")
824
+ break
825
+
826
+ # Add new cookie header only if no cookies exist and no session cookie found
827
+ if not cookie_exists and not existing_session_cookie:
828
+ headers.append((b"cookie", cookie_header.encode()))
829
+ logger.info("WebSocket: Added new session cookie header: %s", session_id[:10] + "...")
830
+
831
+ # Update the websocket scope with the modified headers
832
+ websocket.scope["headers"] = headers
833
+
834
+ async with WebSocketMessageHandler(websocket, session_manager, self.get_step_adaptor()) as handler:
835
+
836
+ flow_handler = WebSocketAuthenticationFlowHandler(self._add_flow, self._remove_flow, handler)
837
+
838
+ # Ugly hack to set the flow handler on the message handler. Both need eachother to be set.
839
+ handler.set_flow_handler(flow_handler)
840
+
841
+ await handler.run()
842
+
843
+ if (endpoint.websocket_path):
844
+ app.add_websocket_route(endpoint.websocket_path, websocket_endpoint)
845
+
846
+ if (endpoint.path):
847
+
848
+ if (endpoint.method == "GET"):
849
+
850
+ app.add_api_route(
851
+ path=endpoint.path,
852
+ endpoint=get_single_endpoint(result_type=GenerateSingleResponseType),
853
+ methods=[endpoint.method],
854
+ response_model=GenerateSingleResponseType,
855
+ description=endpoint.description,
856
+ responses={500: response_500},
857
+ )
858
+
859
+ app.add_api_route(
860
+ path=f"{endpoint.path}/stream",
861
+ endpoint=get_streaming_endpoint(streaming=True,
862
+ result_type=GenerateStreamResponseType,
863
+ output_type=GenerateStreamResponseType),
864
+ methods=[endpoint.method],
865
+ response_model=GenerateStreamResponseType,
866
+ description=endpoint.description,
867
+ responses={500: response_500},
868
+ )
869
+
870
+ app.add_api_route(
871
+ path=f"{endpoint.path}/full",
872
+ endpoint=get_streaming_raw_endpoint(streaming=True,
873
+ result_type=GenerateStreamResponseType,
874
+ output_type=GenerateStreamResponseType),
875
+ methods=[endpoint.method],
876
+ description="Stream raw intermediate steps without any step adaptor translations.\n"
877
+ "Use filter_steps query parameter to filter steps by type (comma-separated list) or\
878
+ set to 'none' to suppress all intermediate steps.",
879
+ )
880
+
881
+ elif (endpoint.method == "POST"):
882
+
883
+ app.add_api_route(
884
+ path=endpoint.path,
885
+ endpoint=post_single_endpoint(request_type=GenerateBodyType,
886
+ result_type=GenerateSingleResponseType),
887
+ methods=[endpoint.method],
888
+ response_model=GenerateSingleResponseType,
889
+ description=endpoint.description,
890
+ responses={500: response_500},
891
+ )
892
+
893
+ app.add_api_route(
894
+ path=f"{endpoint.path}/stream",
895
+ endpoint=post_streaming_endpoint(request_type=GenerateBodyType,
896
+ streaming=True,
897
+ result_type=GenerateStreamResponseType,
898
+ output_type=GenerateStreamResponseType),
899
+ methods=[endpoint.method],
900
+ response_model=GenerateStreamResponseType,
901
+ description=endpoint.description,
902
+ responses={500: response_500},
903
+ )
904
+
905
+ app.add_api_route(
906
+ path=f"{endpoint.path}/full",
907
+ endpoint=post_streaming_raw_endpoint(request_type=GenerateBodyType,
908
+ streaming=True,
909
+ result_type=GenerateStreamResponseType,
910
+ output_type=GenerateStreamResponseType),
911
+ methods=[endpoint.method],
912
+ response_model=GenerateStreamResponseType,
913
+ description="Stream raw intermediate steps without any step adaptor translations.\n"
914
+ "Use filter_steps query parameter to filter steps by type (comma-separated list) or \
915
+ set to 'none' to suppress all intermediate steps.",
916
+ responses={500: response_500},
917
+ )
918
+
919
+ if self._dask_available:
920
+ app.add_api_route(
921
+ path=f"{endpoint.path}/async",
922
+ endpoint=post_async_generation(request_type=AsyncGenerateRequest),
923
+ methods=[endpoint.method],
924
+ response_model=AsyncGenerateResponse | AsyncGenerationStatusResponse,
925
+ description="Start an async generate job",
926
+ responses={500: response_500},
927
+ )
928
+ else:
929
+ logger.warning("Dask is not available, async generation endpoints will not be added.")
930
+ else:
931
+ raise ValueError(f"Unsupported method {endpoint.method}")
932
+
933
+ if self._dask_available:
934
+ app.add_api_route(
935
+ path=f"{endpoint.path}/async/job/{{job_id}}",
936
+ endpoint=get_async_job_status,
937
+ methods=["GET"],
938
+ response_model=AsyncGenerationStatusResponse,
939
+ description="Get the status of an async job",
940
+ responses={
941
+ 404: {
942
+ "description": "Job not found"
943
+ }, 500: response_500
944
+ },
945
+ )
946
+
947
+ if (endpoint.openai_api_path):
948
+ if (endpoint.method == "GET"):
949
+
950
+ app.add_api_route(
951
+ path=endpoint.openai_api_path,
952
+ endpoint=get_single_endpoint(result_type=ChatResponse),
953
+ methods=[endpoint.method],
954
+ response_model=ChatResponse,
955
+ description=endpoint.description,
956
+ responses={500: response_500},
957
+ )
958
+
959
+ app.add_api_route(
960
+ path=f"{endpoint.openai_api_path}/stream",
961
+ endpoint=get_streaming_endpoint(streaming=True,
962
+ result_type=ChatResponseChunk,
963
+ output_type=ChatResponseChunk),
964
+ methods=[endpoint.method],
965
+ response_model=ChatResponseChunk,
966
+ description=endpoint.description,
967
+ responses={500: response_500},
968
+ )
969
+
970
+ elif (endpoint.method == "POST"):
971
+
972
+ # Check if OpenAI v1 compatible endpoint is configured
973
+ openai_v1_path = getattr(endpoint, 'openai_api_v1_path', None)
974
+
975
+ # Always create legacy endpoints for backward compatibility (unless they conflict with v1 path)
976
+ if not openai_v1_path or openai_v1_path != endpoint.openai_api_path:
977
+ # <openai_api_path> = non-streaming (legacy behavior)
978
+ app.add_api_route(
979
+ path=endpoint.openai_api_path,
980
+ endpoint=post_single_endpoint(request_type=ChatRequest, result_type=ChatResponse),
981
+ methods=[endpoint.method],
982
+ response_model=ChatResponse,
983
+ description=endpoint.description,
984
+ responses={500: response_500},
985
+ )
986
+
987
+ # <openai_api_path>/stream = streaming (legacy behavior)
988
+ app.add_api_route(
989
+ path=f"{endpoint.openai_api_path}/stream",
990
+ endpoint=post_streaming_endpoint(request_type=ChatRequest,
991
+ streaming=True,
992
+ result_type=ChatResponseChunk,
993
+ output_type=ChatResponseChunk),
994
+ methods=[endpoint.method],
995
+ response_model=ChatResponseChunk | ResponseIntermediateStep,
996
+ description=endpoint.description,
997
+ responses={500: response_500},
998
+ )
999
+
1000
+ # Create OpenAI v1 compatible endpoint if configured
1001
+ if openai_v1_path:
1002
+ # OpenAI v1 Compatible Mode: Create single endpoint that handles both streaming and non-streaming
1003
+ app.add_api_route(
1004
+ path=openai_v1_path,
1005
+ endpoint=post_openai_api_compatible_endpoint(request_type=ChatRequest),
1006
+ methods=[endpoint.method],
1007
+ response_model=ChatResponse | ChatResponseChunk,
1008
+ description=f"{endpoint.description} (OpenAI Chat Completions API compatible)",
1009
+ responses={500: response_500},
1010
+ )
1011
+
1012
+ else:
1013
+ raise ValueError(f"Unsupported method {endpoint.method}")
1014
+
1015
+ async def add_authorization_route(self, app: FastAPI):
1016
+
1017
+ from fastapi.responses import HTMLResponse
1018
+
1019
+ from nat.front_ends.fastapi.html_snippets.auth_code_grant_success import AUTH_REDIRECT_SUCCESS_HTML
1020
+
1021
+ async def redirect_uri(request: Request):
1022
+ """
1023
+ Handle the redirect URI for OAuth2 authentication.
1024
+ Args:
1025
+ request: The FastAPI request object containing query parameters.
1026
+
1027
+ Returns:
1028
+ HTMLResponse: A response indicating the success of the authentication flow.
1029
+ """
1030
+ state = request.query_params.get("state")
1031
+
1032
+ async with self._outstanding_flows_lock:
1033
+ if not state or state not in self._outstanding_flows:
1034
+ return "Invalid state. Please restart the authentication process."
1035
+
1036
+ flow_state = self._outstanding_flows[state]
1037
+
1038
+ config = flow_state.config
1039
+ verifier = flow_state.verifier
1040
+ client = flow_state.client
1041
+
1042
+ try:
1043
+ res = await client.fetch_token(url=config.token_url,
1044
+ authorization_response=str(request.url),
1045
+ code_verifier=verifier,
1046
+ state=state)
1047
+ flow_state.future.set_result(res)
1048
+ except OAuthError as e:
1049
+ flow_state.future.set_exception(
1050
+ RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
1051
+ except httpx.HTTPError as e:
1052
+ flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
1053
+ except Exception as e:
1054
+ flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
1055
+
1056
+ return HTMLResponse(content=AUTH_REDIRECT_SUCCESS_HTML,
1057
+ status_code=200,
1058
+ headers={
1059
+ "Content-Type": "text/html; charset=utf-8", "Cache-Control": "no-cache"
1060
+ })
1061
+
1062
+ if (self.front_end_config.oauth2_callback_path):
1063
+ # Add the redirect URI route
1064
+ app.add_api_route(
1065
+ path=self.front_end_config.oauth2_callback_path,
1066
+ endpoint=redirect_uri,
1067
+ methods=["GET"],
1068
+ description="Handles the authorization code and state returned from the Authorization Code Grant Flow.")
1069
+
1070
+ async def add_mcp_client_tool_list_route(self, app: FastAPI, builder: WorkflowBuilder):
1071
+ """Add the MCP client tool list endpoint to the FastAPI app."""
1072
+ from typing import Any
1073
+
1074
+ from pydantic import BaseModel
1075
+
1076
+ class MCPToolInfo(BaseModel):
1077
+ name: str
1078
+ description: str
1079
+ server: str
1080
+ available: bool
1081
+
1082
+ class MCPClientToolListResponse(BaseModel):
1083
+ mcp_clients: list[dict[str, Any]]
1084
+
1085
+ async def get_mcp_client_tool_list() -> MCPClientToolListResponse:
1086
+ """
1087
+ Get the list of MCP tools from all MCP clients in the workflow configuration.
1088
+ Checks session health and compares with workflow function group configuration.
1089
+ """
1090
+ mcp_clients_info = []
1091
+
1092
+ try:
1093
+ # Get all function groups from the builder
1094
+ function_groups = builder._function_groups
1095
+
1096
+ # Find MCP client function groups
1097
+ for group_name, configured_group in function_groups.items():
1098
+ if configured_group.config.type != "mcp_client":
1099
+ continue
1100
+
1101
+ from nat.plugins.mcp.client_config import MCPClientConfig
1102
+
1103
+ config = configured_group.config
1104
+ assert isinstance(config, MCPClientConfig)
1105
+
1106
+ # Reuse the existing MCP client session stored on the function group instance
1107
+ group_instance = configured_group.instance
1108
+
1109
+ client = group_instance.mcp_client
1110
+ if client is None:
1111
+ raise RuntimeError(f"MCP client not found for group {group_name}")
1112
+
1113
+ try:
1114
+ session_healthy = False
1115
+ server_tools: dict[str, Any] = {}
1116
+
1117
+ try:
1118
+ server_tools = await client.get_tools()
1119
+ session_healthy = True
1120
+ except Exception as e:
1121
+ logger.exception(f"Failed to connect to MCP server {client.server_name}: {e}")
1122
+ session_healthy = False
1123
+
1124
+ # Get workflow function group configuration (configured client-side tools)
1125
+ configured_short_names: set[str] = set()
1126
+ configured_full_to_fn: dict[str, Function] = {}
1127
+ try:
1128
+ # Pass a no-op filter function to bypass any default filtering that might check
1129
+ # health status, preventing potential infinite recursion during health status checks.
1130
+ async def pass_through_filter(fn):
1131
+ return fn
1132
+
1133
+ accessible_functions = await group_instance.get_accessible_functions(
1134
+ filter_fn=pass_through_filter)
1135
+ configured_full_to_fn = accessible_functions
1136
+ configured_short_names = {name.split('.', 1)[1] for name in accessible_functions.keys()}
1137
+ except Exception as e:
1138
+ logger.exception(f"Failed to get accessible functions for group {group_name}: {e}")
1139
+
1140
+ # Build alias->original mapping and override configs from overrides
1141
+ alias_to_original: dict[str, str] = {}
1142
+ override_configs: dict[str, Any] = {}
1143
+ try:
1144
+ if config.tool_overrides is not None:
1145
+ for orig_name, override in config.tool_overrides.items():
1146
+ if override.alias is not None:
1147
+ alias_to_original[override.alias] = orig_name
1148
+ override_configs[override.alias] = override
1149
+ else:
1150
+ override_configs[orig_name] = override
1151
+ except Exception:
1152
+ pass
1153
+
1154
+ # Create tool info list (always return configured tools; mark availability)
1155
+ tools_info: list[dict[str, Any]] = []
1156
+ available_count = 0
1157
+ for wf_fn, fn_short in zip(configured_full_to_fn.values(), configured_short_names):
1158
+ orig_name = alias_to_original.get(fn_short, fn_short)
1159
+ available = session_healthy and (orig_name in server_tools)
1160
+ if available:
1161
+ available_count += 1
1162
+
1163
+ # Prefer tool override description, then workflow function description,
1164
+ # then server description
1165
+ description = ""
1166
+ if fn_short in override_configs and override_configs[fn_short].description:
1167
+ description = override_configs[fn_short].description
1168
+ elif wf_fn.description:
1169
+ description = wf_fn.description
1170
+ elif available and orig_name in server_tools:
1171
+ description = server_tools[orig_name].description or ""
1172
+
1173
+ tools_info.append(
1174
+ MCPToolInfo(name=fn_short,
1175
+ description=description or "",
1176
+ server=client.server_name,
1177
+ available=available).model_dump())
1178
+
1179
+ # Sort tools_info by name to maintain consistent ordering
1180
+ tools_info.sort(key=lambda x: x['name'])
1181
+
1182
+ mcp_clients_info.append({
1183
+ "function_group": group_name,
1184
+ "server": client.server_name,
1185
+ "transport": config.server.transport,
1186
+ "session_healthy": session_healthy,
1187
+ "tools": tools_info,
1188
+ "total_tools": len(configured_short_names),
1189
+ "available_tools": available_count
1190
+ })
1191
+
1192
+ except Exception as e:
1193
+ logger.error(f"Error processing MCP client {group_name}: {e}")
1194
+ mcp_clients_info.append({
1195
+ "function_group": group_name,
1196
+ "server": "unknown",
1197
+ "transport": config.server.transport if config.server else "unknown",
1198
+ "session_healthy": False,
1199
+ "error": str(e),
1200
+ "tools": [],
1201
+ "total_tools": 0,
1202
+ "workflow_tools": 0
1203
+ })
1204
+
1205
+ return MCPClientToolListResponse(mcp_clients=mcp_clients_info)
1206
+
1207
+ except Exception as e:
1208
+ logger.error(f"Error in MCP client tool list endpoint: {e}")
1209
+ raise HTTPException(status_code=500, detail=f"Failed to retrieve MCP client information: {str(e)}")
1210
+
1211
+ # Add the route to the FastAPI app
1212
+ app.add_api_route(
1213
+ path="/mcp/client/tool/list",
1214
+ endpoint=get_mcp_client_tool_list,
1215
+ methods=["GET"],
1216
+ response_model=MCPClientToolListResponse,
1217
+ description="Get list of MCP client tools with session health and workflow configuration comparison",
1218
+ responses={
1219
+ 200: {
1220
+ "description": "Successfully retrieved MCP client tool information",
1221
+ "content": {
1222
+ "application/json": {
1223
+ "example": {
1224
+ "mcp_clients": [{
1225
+ "function_group": "mcp_tools",
1226
+ "server": "streamable-http:http://localhost:9901/mcp",
1227
+ "transport": "streamable-http",
1228
+ "session_healthy": True,
1229
+ "tools": [{
1230
+ "name": "tool_a",
1231
+ "description": "Tool A description",
1232
+ "server": "streamable-http:http://localhost:9901/mcp",
1233
+ "available": True
1234
+ }],
1235
+ "total_tools": 1,
1236
+ "available_tools": 1
1237
+ }]
1238
+ }
1239
+ }
1240
+ }
1241
+ },
1242
+ 500: {
1243
+ "description": "Internal Server Error"
1244
+ }
1245
+ })
1246
+
1247
+ async def _add_flow(self, state: str, flow_state: FlowState):
1248
+ async with self._outstanding_flows_lock:
1249
+ self._outstanding_flows[state] = flow_state
1250
+
1251
+ async def _remove_flow(self, state: str):
1252
+ async with self._outstanding_flows_lock:
1253
+ del self._outstanding_flows[state]
1254
+
1255
+
1256
+ # Prevent Sphinx from documenting items not a part of the public API
1257
+ __all__ = ["FastApiFrontEndPluginWorkerBase", "FastApiFrontEndPluginWorker", "RouteInfo"]