nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +41 -21
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +46 -26
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +46 -11
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  53. nat/cli/commands/workflow/workflow_commands.py +9 -13
  54. nat/cli/entrypoint.py +8 -10
  55. nat/cli/register_workflow.py +38 -4
  56. nat/cli/type_registry.py +75 -6
  57. nat/control_flow/__init__.py +0 -0
  58. nat/control_flow/register.py +20 -0
  59. nat/control_flow/router_agent/__init__.py +0 -0
  60. nat/control_flow/router_agent/agent.py +329 -0
  61. nat/control_flow/router_agent/prompt.py +48 -0
  62. nat/control_flow/router_agent/register.py +91 -0
  63. nat/control_flow/sequential_executor.py +166 -0
  64. nat/data_models/agent.py +34 -0
  65. nat/data_models/api_server.py +10 -10
  66. nat/data_models/authentication.py +23 -9
  67. nat/data_models/common.py +1 -1
  68. nat/data_models/component.py +2 -0
  69. nat/data_models/component_ref.py +11 -0
  70. nat/data_models/config.py +41 -17
  71. nat/data_models/dataset_handler.py +1 -1
  72. nat/data_models/discovery_metadata.py +4 -4
  73. nat/data_models/evaluate.py +4 -1
  74. nat/data_models/function.py +34 -0
  75. nat/data_models/function_dependencies.py +14 -6
  76. nat/data_models/gated_field_mixin.py +242 -0
  77. nat/data_models/intermediate_step.py +3 -3
  78. nat/data_models/optimizable.py +119 -0
  79. nat/data_models/optimizer.py +149 -0
  80. nat/data_models/swe_bench_model.py +1 -1
  81. nat/data_models/temperature_mixin.py +44 -0
  82. nat/data_models/thinking_mixin.py +86 -0
  83. nat/data_models/top_p_mixin.py +44 -0
  84. nat/embedder/nim_embedder.py +1 -1
  85. nat/embedder/openai_embedder.py +1 -1
  86. nat/embedder/register.py +0 -1
  87. nat/eval/config.py +3 -1
  88. nat/eval/dataset_handler/dataset_handler.py +71 -7
  89. nat/eval/evaluate.py +86 -31
  90. nat/eval/evaluator/base_evaluator.py +1 -1
  91. nat/eval/evaluator/evaluator_model.py +13 -0
  92. nat/eval/intermediate_step_adapter.py +1 -1
  93. nat/eval/rag_evaluator/evaluate.py +2 -2
  94. nat/eval/rag_evaluator/register.py +3 -3
  95. nat/eval/register.py +4 -1
  96. nat/eval/remote_workflow.py +3 -3
  97. nat/eval/runtime_evaluator/__init__.py +14 -0
  98. nat/eval/runtime_evaluator/evaluate.py +123 -0
  99. nat/eval/runtime_evaluator/register.py +100 -0
  100. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  101. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  102. nat/eval/trajectory_evaluator/register.py +1 -1
  103. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  104. nat/eval/utils/eval_trace_ctx.py +89 -0
  105. nat/eval/utils/weave_eval.py +18 -9
  106. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  107. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  108. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  109. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  110. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  111. nat/experimental/test_time_compute/register.py +0 -1
  112. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  113. nat/front_ends/console/authentication_flow_handler.py +82 -30
  114. nat/front_ends/console/console_front_end_plugin.py +8 -5
  115. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  116. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  117. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  118. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  119. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  120. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
  121. nat/front_ends/fastapi/job_store.py +518 -99
  122. nat/front_ends/fastapi/main.py +11 -19
  123. nat/front_ends/fastapi/message_handler.py +13 -14
  124. nat/front_ends/fastapi/message_validator.py +17 -19
  125. nat/front_ends/fastapi/response_helpers.py +4 -4
  126. nat/front_ends/fastapi/step_adaptor.py +2 -2
  127. nat/front_ends/fastapi/utils.py +57 -0
  128. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  129. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  130. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  131. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  132. nat/front_ends/mcp/tool_converter.py +44 -14
  133. nat/front_ends/register.py +0 -1
  134. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  135. nat/llm/aws_bedrock_llm.py +24 -12
  136. nat/llm/azure_openai_llm.py +13 -6
  137. nat/llm/litellm_llm.py +69 -0
  138. nat/llm/nim_llm.py +20 -8
  139. nat/llm/openai_llm.py +14 -6
  140. nat/llm/register.py +4 -1
  141. nat/llm/utils/env_config_value.py +2 -3
  142. nat/llm/utils/thinking.py +215 -0
  143. nat/meta/pypi.md +9 -9
  144. nat/object_store/register.py +0 -1
  145. nat/observability/exporter/base_exporter.py +3 -3
  146. nat/observability/exporter/file_exporter.py +1 -1
  147. nat/observability/exporter/processing_exporter.py +309 -81
  148. nat/observability/exporter/span_exporter.py +1 -1
  149. nat/observability/exporter_manager.py +7 -7
  150. nat/observability/mixin/file_mixin.py +7 -7
  151. nat/observability/mixin/redaction_config_mixin.py +42 -0
  152. nat/observability/mixin/tagging_config_mixin.py +62 -0
  153. nat/observability/mixin/type_introspection_mixin.py +420 -107
  154. nat/observability/processor/batching_processor.py +5 -7
  155. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  156. nat/observability/processor/processor.py +3 -0
  157. nat/observability/processor/processor_factory.py +70 -0
  158. nat/observability/processor/redaction/__init__.py +24 -0
  159. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  160. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  161. nat/observability/processor/redaction/redaction_processor.py +177 -0
  162. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  163. nat/observability/processor/span_tagging_processor.py +68 -0
  164. nat/observability/register.py +6 -4
  165. nat/profiler/calc/calc_runner.py +3 -4
  166. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  167. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  168. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  169. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  170. nat/profiler/data_frame_row.py +1 -1
  171. nat/profiler/decorators/framework_wrapper.py +62 -13
  172. nat/profiler/decorators/function_tracking.py +160 -3
  173. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  174. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  175. nat/profiler/inference_optimization/data_models.py +3 -3
  176. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  177. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  178. nat/profiler/parameter_optimization/__init__.py +0 -0
  179. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  180. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  181. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  182. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  183. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  184. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  185. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  186. nat/profiler/profile_runner.py +14 -9
  187. nat/profiler/utils.py +4 -2
  188. nat/registry_handlers/local/local_handler.py +2 -2
  189. nat/registry_handlers/package_utils.py +1 -2
  190. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  191. nat/registry_handlers/register.py +3 -4
  192. nat/registry_handlers/rest/rest_handler.py +12 -13
  193. nat/retriever/milvus/retriever.py +2 -2
  194. nat/retriever/nemo_retriever/retriever.py +1 -1
  195. nat/retriever/register.py +0 -1
  196. nat/runtime/loader.py +2 -2
  197. nat/runtime/runner.py +3 -2
  198. nat/runtime/session.py +43 -8
  199. nat/settings/global_settings.py +16 -5
  200. nat/tool/chat_completion.py +5 -2
  201. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  202. nat/tool/datetime_tools.py +49 -9
  203. nat/tool/document_search.py +2 -2
  204. nat/tool/github_tools.py +450 -0
  205. nat/tool/nvidia_rag.py +1 -1
  206. nat/tool/register.py +2 -9
  207. nat/tool/retriever.py +3 -2
  208. nat/utils/callable_utils.py +70 -0
  209. nat/utils/data_models/schema_validator.py +3 -3
  210. nat/utils/exception_handlers/automatic_retries.py +104 -51
  211. nat/utils/exception_handlers/schemas.py +1 -1
  212. nat/utils/io/yaml_tools.py +2 -2
  213. nat/utils/log_levels.py +25 -0
  214. nat/utils/reactive/base/observable_base.py +2 -2
  215. nat/utils/reactive/base/observer_base.py +1 -1
  216. nat/utils/reactive/observable.py +2 -2
  217. nat/utils/reactive/observer.py +4 -4
  218. nat/utils/reactive/subscription.py +1 -1
  219. nat/utils/settings/global_settings.py +6 -8
  220. nat/utils/type_converter.py +4 -3
  221. nat/utils/type_utils.py +9 -5
  222. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
  223. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
  224. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
  225. nat/cli/commands/info/list_mcp.py +0 -304
  226. nat/tool/github_tools/create_github_commit.py +0 -133
  227. nat/tool/github_tools/create_github_issue.py +0 -87
  228. nat/tool/github_tools/create_github_pr.py +0 -106
  229. nat/tool/github_tools/get_github_file.py +0 -106
  230. nat/tool/github_tools/get_github_issue.py +0 -166
  231. nat/tool/github_tools/get_github_pr.py +0 -256
  232. nat/tool/github_tools/update_github_issue.py +0 -100
  233. nat/tool/mcp/exceptions.py +0 -142
  234. nat/tool/mcp/mcp_client.py +0 -255
  235. nat/tool/mcp/mcp_tool.py +0 -96
  236. nat/utils/exception_handlers/mcp.py +0 -211
  237. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  238. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  239. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
  240. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  241. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
  242. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,8 @@ from nat.cli.type_registry import EvaluatorRegisteredCallableT
27
27
  from nat.cli.type_registry import FrontEndBuildCallableT
28
28
  from nat.cli.type_registry import FrontEndRegisteredCallableT
29
29
  from nat.cli.type_registry import FunctionBuildCallableT
30
+ from nat.cli.type_registry import FunctionGroupBuildCallableT
31
+ from nat.cli.type_registry import FunctionGroupRegisteredCallableT
30
32
  from nat.cli.type_registry import FunctionRegisteredCallableT
31
33
  from nat.cli.type_registry import LLMClientBuildCallableT
32
34
  from nat.cli.type_registry import LLMClientRegisteredCallableT
@@ -60,6 +62,7 @@ from nat.data_models.embedder import EmbedderBaseConfigT
60
62
  from nat.data_models.evaluator import EvaluatorBaseConfigT
61
63
  from nat.data_models.front_end import FrontEndConfigT
62
64
  from nat.data_models.function import FunctionConfigT
65
+ from nat.data_models.function import FunctionGroupConfigT
63
66
  from nat.data_models.llm import LLMBaseConfigT
64
67
  from nat.data_models.memory import MemoryBaseConfigT
65
68
  from nat.data_models.object_store import ObjectStoreBaseConfigT
@@ -155,10 +158,7 @@ def register_function(config_type: type[FunctionConfigT],
155
158
 
156
159
  context_manager_fn = asynccontextmanager(fn)
157
160
 
158
- if framework_wrappers is None:
159
- framework_wrappers_list: list[str] = []
160
- else:
161
- framework_wrappers_list = list(framework_wrappers)
161
+ framework_wrappers_list = list(framework_wrappers or [])
162
162
 
163
163
  discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
164
164
  component_type=ComponentEnum.FUNCTION)
@@ -177,6 +177,40 @@ def register_function(config_type: type[FunctionConfigT],
177
177
  return register_function_inner
178
178
 
179
179
 
180
+ def register_function_group(config_type: type[FunctionGroupConfigT],
181
+ framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
182
+ """
183
+ Register a function group with optional framework_wrappers for automatic profiler hooking.
184
+ Function groups share configuration/resources across multiple functions.
185
+ """
186
+
187
+ def register_function_group_inner(
188
+ fn: FunctionGroupBuildCallableT[FunctionGroupConfigT]
189
+ ) -> FunctionGroupRegisteredCallableT[FunctionGroupConfigT]:
190
+ from .type_registry import GlobalTypeRegistry
191
+ from .type_registry import RegisteredFunctionGroupInfo
192
+
193
+ context_manager_fn = asynccontextmanager(fn)
194
+
195
+ framework_wrappers_list = list(framework_wrappers or [])
196
+
197
+ discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
198
+ component_type=ComponentEnum.FUNCTION_GROUP)
199
+
200
+ GlobalTypeRegistry.get().register_function_group(
201
+ RegisteredFunctionGroupInfo(
202
+ full_type=config_type.full_type,
203
+ config_type=config_type,
204
+ build_fn=context_manager_fn,
205
+ framework_wrappers=framework_wrappers_list,
206
+ discovery_metadata=discovery_metadata,
207
+ ))
208
+
209
+ return context_manager_fn
210
+
211
+ return register_function_group_inner
212
+
213
+
180
214
  def register_llm_provider(config_type: type[LLMBaseConfigT]):
181
215
 
182
216
  def register_llm_provider_inner(
nat/cli/type_registry.py CHANGED
@@ -37,6 +37,7 @@ from nat.builder.embedder import EmbedderProviderInfo
37
37
  from nat.builder.evaluator import EvaluatorInfo
38
38
  from nat.builder.front_end import FrontEndBase
39
39
  from nat.builder.function import Function
40
+ from nat.builder.function import FunctionGroup
40
41
  from nat.builder.function_base import FunctionBase
41
42
  from nat.builder.function_info import FunctionInfo
42
43
  from nat.builder.llm import LLMProviderInfo
@@ -55,6 +56,8 @@ from nat.data_models.front_end import FrontEndBaseConfig
55
56
  from nat.data_models.front_end import FrontEndConfigT
56
57
  from nat.data_models.function import FunctionBaseConfig
57
58
  from nat.data_models.function import FunctionConfigT
59
+ from nat.data_models.function import FunctionGroupBaseConfig
60
+ from nat.data_models.function import FunctionGroupConfigT
58
61
  from nat.data_models.llm import LLMBaseConfig
59
62
  from nat.data_models.llm import LLMBaseConfigT
60
63
  from nat.data_models.logging import LoggingBaseConfig
@@ -85,6 +88,7 @@ EmbedderProviderBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncI
85
88
  EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIterator[EvaluatorInfo]]
86
89
  FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
87
90
  FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
91
+ FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
88
92
  TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
89
93
  LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
90
94
  LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
@@ -106,6 +110,7 @@ EvaluatorRegisteredCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], Abs
106
110
  FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncContextManager[FrontEndBase]]
107
111
  FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
108
112
  AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
113
+ FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
109
114
  TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
110
115
  LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
111
116
  LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
@@ -178,6 +183,16 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
178
183
  framework_wrappers: list[str] = Field(default_factory=list)
179
184
 
180
185
 
186
+ class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
187
+ """
188
+ Represents a registered function group. Function groups are collections of functions that share configuration
189
+ and resources.
190
+ """
191
+
192
+ build_fn: FunctionGroupRegisteredCallableT = Field(repr=False)
193
+ framework_wrappers: list[str] = Field(default_factory=list)
194
+
195
+
181
196
  class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
182
197
  """
183
198
  Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
@@ -298,7 +313,7 @@ class RegisteredPackage(BaseModel):
298
313
  discovery_metadata: DiscoveryMetadata
299
314
 
300
315
 
301
- class TypeRegistry: # pylint: disable=too-many-public-methods
316
+ class TypeRegistry:
302
317
 
303
318
  def __init__(self) -> None:
304
319
  # Telemetry Exporters
@@ -313,6 +328,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
313
328
  # Functions
314
329
  self._registered_functions: dict[type[FunctionBaseConfig], RegisteredFunctionInfo] = {}
315
330
 
331
+ # Function Groups
332
+ self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
333
+
316
334
  # LLMs
317
335
  self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
318
336
  self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
@@ -478,6 +496,50 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
478
496
 
479
497
  return list(self._registered_functions.values())
480
498
 
499
+ def register_function_group(self, registration: RegisteredFunctionGroupInfo):
500
+ """Register a function group with the type registry.
501
+
502
+ Args:
503
+ registration: The function group registration information
504
+
505
+ Raises:
506
+ ValueError: If a function group with the same config type is already registered
507
+ """
508
+ if (registration.config_type in self._registered_function_groups):
509
+ raise ValueError(
510
+ f"A function group with the same config type `{registration.config_type}` has already been "
511
+ "registered.")
512
+
513
+ self._registered_function_groups[registration.config_type] = registration
514
+
515
+ self._registration_changed()
516
+
517
+ def get_function_group(self, config_type: type[FunctionGroupBaseConfig]) -> RegisteredFunctionGroupInfo:
518
+ """Get a registered function group by its config type.
519
+
520
+ Args:
521
+ config_type: The function group configuration type
522
+
523
+ Returns:
524
+ RegisteredFunctionGroupInfo: The registered function group information
525
+
526
+ Raises:
527
+ KeyError: If no function group is registered for the given config type
528
+ """
529
+ try:
530
+ return self._registered_function_groups[config_type]
531
+ except KeyError as err:
532
+ raise KeyError(f"Could not find a registered function group for config `{config_type}`. "
533
+ f"Registered configs: {set(self._registered_function_groups.keys())}") from err
534
+
535
+ def get_registered_function_groups(self) -> list[RegisteredInfo[FunctionGroupBaseConfig]]:
536
+ """Get all registered function groups.
537
+
538
+ Returns:
539
+ list[RegisteredInfo[FunctionGroupBaseConfig]]: List of all registered function groups
540
+ """
541
+ return list(self._registered_function_groups.values())
542
+
481
543
  def register_llm_provider(self, info: RegisteredLLMProviderInfo):
482
544
 
483
545
  if (info.config_type in self._registered_llm_provider_infos):
@@ -779,7 +841,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
779
841
 
780
842
  self._registration_changed()
781
843
 
782
- def get_infos_by_type(self, component_type: ComponentEnum) -> dict: # pylint: disable=R0911
844
+ def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
783
845
 
784
846
  if component_type == ComponentEnum.FRONT_END:
785
847
  return self._registered_front_end_infos
@@ -790,6 +852,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
790
852
  if component_type == ComponentEnum.FUNCTION:
791
853
  return self._registered_functions
792
854
 
855
+ if component_type == ComponentEnum.FUNCTION_GROUP:
856
+ return self._registered_function_groups
857
+
793
858
  if component_type == ComponentEnum.TOOL_WRAPPER:
794
859
  return self._registered_tool_wrappers
795
860
 
@@ -849,12 +914,14 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
849
914
 
850
915
  raise ValueError(f"Supplied an unsupported component type {component_type}")
851
916
 
852
- def get_registered_types_by_component_type( # pylint: disable=R0911
853
- self, component_type: ComponentEnum) -> list[str]:
917
+ def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
854
918
 
855
919
  if component_type == ComponentEnum.FUNCTION:
856
920
  return [i.static_type() for i in self._registered_functions]
857
921
 
922
+ if component_type == ComponentEnum.FUNCTION_GROUP:
923
+ return [i.static_type() for i in self._registered_function_groups]
924
+
858
925
  if component_type == ComponentEnum.TOOL_WRAPPER:
859
926
  return list(self._registered_tool_wrappers)
860
927
 
@@ -925,8 +992,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
925
992
  if (short_names[key.local_name] == 1):
926
993
  type_list.append((key.local_name, key.config_type))
927
994
 
928
- # pylint: disable=consider-alternative-union-syntax
929
- return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
995
+ return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
930
996
 
931
997
  def compute_annotation(self, cls: type[TypedBaseModelT]):
932
998
 
@@ -945,6 +1011,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
945
1011
  if issubclass(cls, FunctionBaseConfig):
946
1012
  return self._do_compute_annotation(cls, self.get_registered_functions())
947
1013
 
1014
+ if issubclass(cls, FunctionGroupBaseConfig):
1015
+ return self._do_compute_annotation(cls, self.get_registered_function_groups())
1016
+
948
1017
  if issubclass(cls, LLMBaseConfig):
949
1018
  return self._do_compute_annotation(cls, self.get_registered_llm_providers())
950
1019
 
File without changes
@@ -0,0 +1,20 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # flake8: noqa
17
+
18
+ # Import any control flows which need to be automatically registered here
19
+ from . import sequential_executor
20
+ from .router_agent import register
File without changes
@@ -0,0 +1,329 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import typing
18
+
19
+ from langchain_core.callbacks.base import AsyncCallbackHandler
20
+ from langchain_core.language_models import BaseChatModel
21
+ from langchain_core.messages.base import BaseMessage
22
+ from langchain_core.messages.human import HumanMessage
23
+ from langchain_core.prompts.chat import ChatPromptTemplate
24
+ from langchain_core.tools import BaseTool
25
+ from langgraph.graph import StateGraph
26
+ from pydantic import BaseModel
27
+ from pydantic import Field
28
+
29
+ from nat.agent.base import AGENT_CALL_LOG_MESSAGE
30
+ from nat.agent.base import AGENT_LOG_PREFIX
31
+ from nat.agent.base import BaseAgent
32
+
33
+ if typing.TYPE_CHECKING:
34
+ from nat.control_flow.router_agent.register import RouterAgentWorkflowConfig
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class RouterAgentGraphState(BaseModel):
40
+ """State schema for the Router Agent Graph.
41
+
42
+ This class defines the state structure used throughout the Router Agent's
43
+ execution graph, containing messages, routing information, and branch selection.
44
+
45
+ Attributes:
46
+ messages: A list of messages representing the conversation history.
47
+ forward_message: The message to be forwarded to the chosen branch.
48
+ chosen_branch: The name of the branch selected by the router agent.
49
+ """
50
+ messages: list[BaseMessage] = Field(default_factory=list)
51
+ forward_message: BaseMessage = Field(default_factory=lambda: HumanMessage(content=""))
52
+ chosen_branch: str = Field(default="")
53
+
54
+
55
+ class RouterAgentGraph(BaseAgent):
56
+ """Configurable Router Agent for routing requests to different branches.
57
+
58
+ A Router Agent analyzes incoming requests and routes them to one of the
59
+ configured branches based on the conte nt and context. It makes a single
60
+ routing decision and executes only the selected branch before returning.
61
+
62
+ This agent is useful for creating multi-path workflows where different
63
+ types of requests need to be handled by specialized sub-agents or tools.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ llm: BaseChatModel,
69
+ branches: list[BaseTool],
70
+ prompt: ChatPromptTemplate,
71
+ max_router_retries: int = 3,
72
+ callbacks: list[AsyncCallbackHandler] | None = None,
73
+ detailed_logs: bool = False,
74
+ log_response_max_chars: int = 1000,
75
+ ):
76
+ """Initialize the Router Agent.
77
+
78
+ Args:
79
+ llm: The language model to use for routing decisions.
80
+ branches: List of tools/branches that the agent can route to.
81
+ prompt: The chat prompt template for the routing agent.
82
+ max_router_retries: Maximum number of retries if branch selection fails.
83
+ callbacks: Optional list of async callback handlers.
84
+ detailed_logs: Whether to enable detailed logging.
85
+ log_response_max_chars: Maximum characters to log in responses.
86
+ """
87
+ super().__init__(llm=llm,
88
+ tools=branches,
89
+ callbacks=callbacks,
90
+ detailed_logs=detailed_logs,
91
+ log_response_max_chars=log_response_max_chars)
92
+
93
+ self._branches = branches
94
+ self._branches_dict = {branch.name: branch for branch in branches}
95
+ branch_names = ",".join([branch.name for branch in branches])
96
+ branch_names_and_descriptions = "\n".join([f"{branch.name}: {branch.description}" for branch in branches])
97
+
98
+ prompt = prompt.partial(branches=branch_names_and_descriptions, branch_names=branch_names)
99
+ self.agent = prompt | self.llm
100
+
101
+ self.max_router_retries = max_router_retries
102
+
103
+ def _get_branch(self, branch_name: str) -> BaseTool | None:
104
+ return self._branches_dict.get(branch_name, None)
105
+
106
+ async def agent_node(self, state: RouterAgentGraphState):
107
+ """Execute the agent node to select a branch for routing.
108
+
109
+ This method processes the incoming request and determines which branch
110
+ should handle it. It uses the configured LLM to analyze the request
111
+ and select the most appropriate branch.
112
+
113
+ Args:
114
+ state: The current state of the router agent graph.
115
+
116
+ Returns:
117
+ RouterAgentGraphState: Updated state with the chosen branch.
118
+
119
+ Raises:
120
+ RuntimeError: If the agent fails to choose a branch after max retries.
121
+ """
122
+ logger.debug("%s Starting the Router Agent Node", AGENT_LOG_PREFIX)
123
+ chat_history = self._get_chat_history(state.messages)
124
+ request = state.forward_message.content
125
+ for attempt in range(1, self.max_router_retries + 1):
126
+ try:
127
+ agent_response = await self._call_llm(self.agent, {"request": request, "chat_history": chat_history})
128
+ if self.detailed_logs:
129
+ logger.info(AGENT_CALL_LOG_MESSAGE, request, agent_response)
130
+
131
+ state.messages += [agent_response]
132
+
133
+ # Determine chosen branch based on agent response
134
+ if state.chosen_branch == "":
135
+ for branch in self._branches:
136
+ if branch.name.lower() in str(agent_response.content).lower():
137
+ state.chosen_branch = branch.name
138
+ if self.detailed_logs:
139
+ logger.debug("%s Router Agent has chosen branch: %s", AGENT_LOG_PREFIX, branch.name)
140
+ return state
141
+
142
+ # The agent failed to choose a branch
143
+ if state.chosen_branch == "":
144
+ if attempt == self.max_router_retries:
145
+ logger.error("%s Router Agent has empty chosen branch", AGENT_LOG_PREFIX)
146
+ raise RuntimeError("Router Agent failed to choose a branch")
147
+ logger.warning("%s Router Agent failed to choose a branch, retrying %d out of %d",
148
+ AGENT_LOG_PREFIX,
149
+ attempt,
150
+ self.max_router_retries)
151
+
152
+ except Exception as ex:
153
+ logger.error("%s Router Agent failed to call agent_node: %s", AGENT_LOG_PREFIX, ex)
154
+ raise
155
+
156
+ return state
157
+
158
+ async def branch_node(self, state: RouterAgentGraphState):
159
+ """Execute the selected branch with the forwarded message.
160
+
161
+ This method calls the tool/branch that was selected by the agent node
162
+ and processes the response.
163
+
164
+ Args:
165
+ state: The current state containing the chosen branch and message.
166
+
167
+ Returns:
168
+ RouterAgentGraphState: Updated state with the branch response.
169
+
170
+ Raises:
171
+ RuntimeError: If no branch was chosen or branch execution fails.
172
+ ValueError: If the requested tool is not found in the configuration.
173
+ """
174
+ logger.debug("%s Starting Router Agent Tool Node", AGENT_LOG_PREFIX)
175
+ try:
176
+ if state.chosen_branch == "":
177
+ logger.error("%s Router Agent has empty chosen branch", AGENT_LOG_PREFIX)
178
+ raise RuntimeError("Router Agent failed to choose a branch")
179
+ requested_branch = self._get_branch(state.chosen_branch)
180
+ if not requested_branch:
181
+ logger.error("%s Router Agent wants to call tool %s but it is not in the config file",
182
+ AGENT_LOG_PREFIX,
183
+ state.chosen_branch)
184
+ raise ValueError("Tool not found in config file")
185
+
186
+ branch_input = state.forward_message.content
187
+ branch_response = await self._call_tool(requested_branch, branch_input)
188
+ state.messages += [branch_response]
189
+ if self.detailed_logs:
190
+ self._log_tool_response(requested_branch.name, branch_input, branch_response.content)
191
+
192
+ return state
193
+
194
+ except Exception as ex:
195
+ logger.error("%s Router Agent throws exception during branch node execution: %s", AGENT_LOG_PREFIX, ex)
196
+ raise
197
+
198
+ async def _build_graph(self, state_schema):
199
+ logger.debug("%s Building and compiling the Router Agent Graph", AGENT_LOG_PREFIX)
200
+
201
+ graph = StateGraph(state_schema)
202
+ graph.add_node("agent", self.agent_node)
203
+ graph.add_node("branch", self.branch_node)
204
+ graph.add_edge("agent", "branch")
205
+ graph.set_entry_point("agent")
206
+
207
+ self.graph = graph.compile()
208
+ logger.debug("%s Router Agent Graph built and compiled successfully", AGENT_LOG_PREFIX)
209
+
210
+ return self.graph
211
+
212
+ async def build_graph(self):
213
+ """Build and compile the router agent execution graph.
214
+
215
+ Creates a state graph with agent and branch nodes, configures the
216
+ execution flow, and compiles the graph for execution.
217
+
218
+ Returns:
219
+ The compiled execution graph.
220
+
221
+ Raises:
222
+ Exception: If graph building or compilation fails.
223
+ """
224
+ try:
225
+ await self._build_graph(state_schema=RouterAgentGraphState)
226
+ return self.graph
227
+ except Exception as ex:
228
+ logger.error("%s Router Agent failed to build graph: %s", AGENT_LOG_PREFIX, ex)
229
+ raise
230
+
231
+ @staticmethod
232
+ def validate_system_prompt(system_prompt: str) -> bool:
233
+ """Validate that the system prompt contains required variables.
234
+
235
+ Checks that the system prompt includes necessary template variables
236
+ for branch information that the router agent needs.
237
+
238
+ Args:
239
+ system_prompt: The system prompt string to validate.
240
+
241
+ Returns:
242
+ True if the prompt is valid, False otherwise.
243
+ """
244
+ errors = []
245
+ required_prompt_variables = {
246
+ "{branches}": "The system prompt must contain {branches} so the agent knows about configured branches.",
247
+ "{branch_names}": "The system prompt must contain {branch_names} so the agent knows branch names."
248
+ }
249
+ for variable_name, error_message in required_prompt_variables.items():
250
+ if variable_name not in system_prompt:
251
+ errors.append(error_message)
252
+ if errors:
253
+ error_text = "\n".join(errors)
254
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
255
+ return False
256
+ return True
257
+
258
+ @staticmethod
259
+ def validate_user_prompt(user_prompt: str) -> bool:
260
+ """Validate that the user prompt contains required variables.
261
+
262
+ Checks that the user prompt includes necessary template variables
263
+ for chat history and other required information.
264
+
265
+ Args:
266
+ user_prompt: The user prompt string to validate.
267
+
268
+ Returns:
269
+ True if the prompt is valid, False otherwise.
270
+ """
271
+ errors = []
272
+ if not user_prompt:
273
+ errors.append("The user prompt cannot be empty.")
274
+ else:
275
+ required_prompt_variables = {
276
+ "{chat_history}":
277
+ "The user prompt must contain {chat_history} so the agent knows about the conversation history.",
278
+ "{request}":
279
+ "The user prompt must contain {request} so the agent sees the current request.",
280
+ }
281
+ for variable_name, error_message in required_prompt_variables.items():
282
+ if variable_name not in user_prompt:
283
+ errors.append(error_message)
284
+ if errors:
285
+ error_text = "\n".join(errors)
286
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
287
+ return False
288
+ return True
289
+
290
+
291
+ def create_router_agent_prompt(config: "RouterAgentWorkflowConfig") -> ChatPromptTemplate:
292
+ """Create a Router Agent prompt from the configuration.
293
+
294
+ Builds a ChatPromptTemplate using either custom prompts from the config
295
+ or default system and user prompts. Validates the prompts to ensure they
296
+ contain required template variables.
297
+
298
+ Args:
299
+ config: The router agent workflow configuration containing prompt settings.
300
+
301
+ Returns:
302
+ A configured ChatPromptTemplate for the router agent.
303
+
304
+ Raises:
305
+ ValueError: If the system_prompt or user_prompt validation fails.
306
+ """
307
+ from nat.control_flow.router_agent.prompt import SYSTEM_PROMPT
308
+ from nat.control_flow.router_agent.prompt import USER_PROMPT
309
+ # the Router Agent prompt can be customized via config option system_prompt and user_prompt.
310
+
311
+ if config.system_prompt:
312
+ system_prompt = config.system_prompt
313
+ else:
314
+ system_prompt = SYSTEM_PROMPT
315
+
316
+ if config.user_prompt:
317
+ user_prompt = config.user_prompt
318
+ else:
319
+ user_prompt = USER_PROMPT
320
+
321
+ if not RouterAgentGraph.validate_system_prompt(system_prompt):
322
+ logger.error("%s Invalid system_prompt", AGENT_LOG_PREFIX)
323
+ raise ValueError("Invalid system_prompt")
324
+
325
+ if not RouterAgentGraph.validate_user_prompt(user_prompt):
326
+ logger.error("%s Invalid user_prompt", AGENT_LOG_PREFIX)
327
+ raise ValueError("Invalid user_prompt")
328
+
329
+ return ChatPromptTemplate([("system", system_prompt), ("user", user_prompt)])
@@ -0,0 +1,48 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ SYSTEM_PROMPT = """
17
+ You are a Router Agent responsible for analyzing incoming requests and routing them to the most appropriate branch.
18
+
19
+ Available branches:
20
+ {branches}
21
+
22
+ CRITICAL INSTRUCTIONS:
23
+ - Analyze the user's request carefully
24
+ - Select exactly ONE branch that best handles the request from: [{branch_names}]
25
+ - Respond with ONLY the exact branch name, nothing else
26
+ - Be decisive - choose the single best match, if the request could fit multiple branches,
27
+ choose the most specific/specialized one
28
+ - If no branch perfectly fits, choose the closest match
29
+
30
+ Your response MUST contain ONLY the branch name. Do not include any explanations, reasoning, or additional text.
31
+
32
+ Examples:
33
+ User: "How do I calculate 15 + 25?"
34
+ Response: calculator_tool
35
+
36
+ User: "What's the weather like today?"
37
+ Response: weather_service
38
+
39
+ User: "Send an email to John"
40
+ Response: email_tool"""
41
+
42
+ USER_PROMPT = """
43
+ Previous conversation history:
44
+ {chat_history}
45
+
46
+ To respond to the request: {request}, which branch should be chosen?
47
+
48
+ Respond with only the branch name."""