nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
nat/builder/context.py CHANGED
@@ -19,6 +19,7 @@ from collections.abc import Awaitable
19
19
  from collections.abc import Callable
20
20
  from contextlib import contextmanager
21
21
  from contextvars import ContextVar
22
+ from functools import cached_property
22
23
 
23
24
  from nat.builder.intermediate_step_manager import IntermediateStepManager
24
25
  from nat.builder.user_interaction_manager import UserInteractionManager
@@ -40,12 +41,12 @@ from nat.utils.reactive.subject import Subject
40
41
  class Singleton(type):
41
42
 
42
43
  def __init__(cls, name, bases, dict):
43
- super(Singleton, cls).__init__(name, bases, dict)
44
+ super().__init__(name, bases, dict)
44
45
  cls.instance = None
45
46
 
46
47
  def __call__(cls, *args, **kw):
47
48
  if cls.instance is None:
48
- cls.instance = super(Singleton, cls).__call__(*args, **kw)
49
+ cls.instance = super().__call__(*args, **kw)
49
50
  return cls.instance
50
51
 
51
52
 
@@ -67,14 +68,14 @@ class ContextState(metaclass=Singleton):
67
68
  def __init__(self):
68
69
  self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
69
70
  self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
71
+ self.workflow_run_id: ContextVar[str | None] = ContextVar("workflow_run_id", default=None)
72
+ self.workflow_trace_id: ContextVar[int | None] = ContextVar("workflow_trace_id", default=None)
70
73
  self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
71
74
  self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
72
- self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes())
73
- self.event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=Subject())
74
- self.active_function: ContextVar[InvocationNode] = ContextVar("active_function",
75
- default=InvocationNode(function_id="root",
76
- function_name="root"))
77
- self.active_span_id_stack: ContextVar[list[str]] = ContextVar("active_span_id_stack", default=["root"])
75
+ self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
76
+ self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None)
77
+ self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None)
78
+ self._active_span_id_stack: ContextVar[list[str] | None] = ContextVar("active_span_id_stack", default=None)
78
79
 
79
80
  # Default is a lambda no-op which returns NoneType
80
81
  self.user_input_callback: ContextVar[Callable[[InteractionPrompt], Awaitable[HumanResponse | None]]
@@ -85,6 +86,30 @@ class ContextState(metaclass=Singleton):
85
86
  Awaitable[AuthenticatedContext]]
86
87
  | None] = ContextVar("user_auth_callback", default=None)
87
88
 
89
+ @property
90
+ def metadata(self) -> ContextVar[RequestAttributes]:
91
+ if self._metadata.get() is None:
92
+ self._metadata.set(RequestAttributes())
93
+ return typing.cast(ContextVar[RequestAttributes], self._metadata)
94
+
95
+ @property
96
+ def active_function(self) -> ContextVar[InvocationNode]:
97
+ if self._active_function.get() is None:
98
+ self._active_function.set(InvocationNode(function_id="root", function_name="root"))
99
+ return typing.cast(ContextVar[InvocationNode], self._active_function)
100
+
101
+ @property
102
+ def event_stream(self) -> ContextVar[Subject[IntermediateStep]]:
103
+ if self._event_stream.get() is None:
104
+ self._event_stream.set(Subject())
105
+ return typing.cast(ContextVar[Subject[IntermediateStep]], self._event_stream)
106
+
107
+ @property
108
+ def active_span_id_stack(self) -> ContextVar[list[str]]:
109
+ if self._active_span_id_stack.get() is None:
110
+ self._active_span_id_stack.set(["root"])
111
+ return typing.cast(ContextVar[list[str]], self._active_span_id_stack)
112
+
88
113
  @staticmethod
89
114
  def get() -> "ContextState":
90
115
  return ContextState()
@@ -98,14 +123,14 @@ class Context:
98
123
  @property
99
124
  def input_message(self):
100
125
  """
101
- Retrieves the input message from the context state.
126
+ Retrieves the input message from the context state.
102
127
 
103
- The input_message property is used to access the message stored in the
104
- context state. This property returns the message as it is currently
105
- maintained in the context.
128
+ The input_message property is used to access the message stored in the
129
+ context state. This property returns the message as it is currently
130
+ maintained in the context.
106
131
 
107
- Returns:
108
- str: The input message retrieved from the context state.
132
+ Returns:
133
+ str: The input message retrieved from the context state.
109
134
  """
110
135
  return self._context_state.input_message.get()
111
136
 
@@ -143,7 +168,7 @@ class Context:
143
168
  """
144
169
  return UserInteractionManager(self._context_state)
145
170
 
146
- @property
171
+ @cached_property
147
172
  def intermediate_step_manager(self) -> IntermediateStepManager:
148
173
  """
149
174
  Retrieves the intermediate step manager instance from the current context state.
@@ -174,6 +199,20 @@ class Context:
174
199
  """
175
200
  return self._context_state.user_message_id.get()
176
201
 
202
+ @property
203
+ def workflow_run_id(self) -> str | None:
204
+ """
205
+ Returns a stable identifier for the current workflow/agent invocation (UUID string).
206
+ """
207
+ return self._context_state.workflow_run_id.get()
208
+
209
+ @property
210
+ def workflow_trace_id(self) -> int | None:
211
+ """
212
+ Returns the 128-bit trace identifier for the current run, used as the OpenTelemetry trace_id.
213
+ """
214
+ return self._context_state.workflow_trace_id.get()
215
+
177
216
  @contextmanager
178
217
  def push_active_function(self,
179
218
  function_name: str,
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
16
17
  import dataclasses
17
18
  import logging
18
19
  from contextlib import asynccontextmanager
@@ -90,17 +91,20 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
90
91
  return self.eval_general_config.output_dir
91
92
 
92
93
  @override
93
- def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
94
- tools = []
94
+ async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
95
95
  tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
96
- for fn_name in self._functions:
97
- fn = self.get_function(fn_name)
96
+
97
+ async def get_tool(fn_name: str):
98
+ fn = await self.get_function(fn_name)
98
99
  try:
99
- tools.append(tool_wrapper_reg.build_fn(fn_name, fn, self))
100
+ return tool_wrapper_reg.build_fn(fn_name, fn, self)
100
101
  except Exception:
101
102
  logger.exception("Error fetching tool `%s`", fn_name)
103
+ return None
102
104
 
103
- return tools
105
+ tasks = [get_tool(fn_name) for fn_name in self._functions]
106
+ tools = await asyncio.gather(*tasks, return_exceptions=False)
107
+ return [tool for tool in tools if tool is not None]
104
108
 
105
109
  def _log_build_failure_evaluator(self,
106
110
  failing_evaluator_name: str,
@@ -127,11 +131,12 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
127
131
  remaining_components,
128
132
  original_error)
129
133
 
130
- async def populate_builder(self, config: Config):
134
+ @override
135
+ async def populate_builder(self, config: Config, skip_workflow: bool = False):
131
136
  # Skip setting workflow if workflow config is EmptyFunctionConfig
132
- skip_workflow = isinstance(config.workflow, EmptyFunctionConfig)
137
+ skip_workflow = skip_workflow or isinstance(config.workflow, EmptyFunctionConfig)
133
138
 
134
- await super().populate_builder(config, skip_workflow)
139
+ await super().populate_builder(config, skip_workflow=skip_workflow)
135
140
 
136
141
  # Initialize progress tracking for evaluators
137
142
  completed_evaluators = []
@@ -22,3 +22,4 @@ class LLMFrameworkEnum(str, Enum):
22
22
  CREWAI = "crewai"
23
23
  SEMANTIC_KERNEL = "semantic_kernel"
24
24
  AGNO = "agno"
25
+ ADK = "adk"
nat/builder/front_end.py CHANGED
@@ -37,7 +37,7 @@ class FrontEndBase(typing.Generic[FrontEndConfigT], ABC):
37
37
 
38
38
  super().__init__()
39
39
 
40
- self._full_config: "Config" = full_config
40
+ self._full_config: Config = full_config
41
41
  self._front_end_config: FrontEndConfigT = typing.cast(FrontEndConfigT, full_config.general.front_end)
42
42
 
43
43
  @property
nat/builder/function.py CHANGED
@@ -14,12 +14,14 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
+ import re
17
18
  import typing
18
19
  from abc import ABC
19
20
  from abc import abstractmethod
20
21
  from collections.abc import AsyncGenerator
21
22
  from collections.abc import Awaitable
22
23
  from collections.abc import Callable
24
+ from collections.abc import Sequence
23
25
 
24
26
  from pydantic import BaseModel
25
27
 
@@ -29,7 +31,9 @@ from nat.builder.function_base import InputT
29
31
  from nat.builder.function_base import SingleOutputT
30
32
  from nat.builder.function_base import StreamingOutputT
31
33
  from nat.builder.function_info import FunctionInfo
34
+ from nat.data_models.function import EmptyFunctionConfig
32
35
  from nat.data_models.function import FunctionBaseConfig
36
+ from nat.data_models.function import FunctionGroupBaseConfig
33
37
 
34
38
  _InvokeFnT = Callable[[InputT], Awaitable[SingleOutputT]]
35
39
  _StreamFnT = Callable[[InputT], AsyncGenerator[StreamingOutputT]]
@@ -342,3 +346,369 @@ class LambdaFunction(Function[InputT, StreamingOutputT, SingleOutputT]):
342
346
  pass
343
347
 
344
348
  return FunctionImpl(config=config, info=info, instance_name=instance_name)
349
+
350
+
351
+ class FunctionGroup:
352
+ """
353
+ A group of functions that can be used together, sharing the same configuration, context, and resources.
354
+ """
355
+
356
+ def __init__(self,
357
+ *,
358
+ config: FunctionGroupBaseConfig,
359
+ instance_name: str | None = None,
360
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None):
361
+ """
362
+ Creates a new function group.
363
+
364
+ Parameters
365
+ ----------
366
+ config : FunctionGroupBaseConfig
367
+ The configuration for the function group.
368
+ instance_name : str | None, optional
369
+ The name of the function group. If not provided, the type of the function group will be used.
370
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
371
+ A callback function to additionally filter the functions in the function group dynamically when
372
+ the functions are accessed via any accessor method.
373
+ """
374
+ self._config = config
375
+ self._instance_name = instance_name or config.type
376
+ self._functions: dict[str, Function] = dict()
377
+ self._filter_fn = filter_fn
378
+ self._per_function_filter_fn: dict[str, Callable[[str], Awaitable[bool]]] = dict()
379
+
380
+ def add_function(self,
381
+ name: str,
382
+ fn: Callable,
383
+ *,
384
+ input_schema: type[BaseModel] | None = None,
385
+ description: str | None = None,
386
+ converters: list[Callable] | None = None,
387
+ filter_fn: Callable[[str], Awaitable[bool]] | None = None):
388
+ """
389
+ Adds a function to the function group.
390
+
391
+ Parameters
392
+ ----------
393
+ name : str
394
+ The name of the function.
395
+ fn : Callable
396
+ The function to add to the function group.
397
+ input_schema : type[BaseModel] | None, optional
398
+ The input schema for the function.
399
+ description : str | None, optional
400
+ The description of the function.
401
+ converters : list[Callable] | None, optional
402
+ The converters to use for the function.
403
+ filter_fn : Callable[[str], Awaitable[bool]] | None, optional
404
+ A callback to determine if the function should be included in the function group. The
405
+ callback will be called with the function name. The callback is invoked dynamically when
406
+ the functions are accessed via any accessor method such as `get_accessible_functions`,
407
+ `get_included_functions`, `get_excluded_functions`, `get_all_functions`.
408
+
409
+ Raises
410
+ ------
411
+ ValueError
412
+ When the function name is empty or blank.
413
+ When the function name contains invalid characters.
414
+ When the function already exists in the function group.
415
+ """
416
+ if not name.strip():
417
+ raise ValueError("Function name cannot be empty or blank")
418
+ if not re.match(r"^[a-zA-Z0-9_.-]+$", name):
419
+ raise ValueError(
420
+ f"Function name can only contain letters, numbers, underscores, periods, and hyphens: {name}")
421
+ if name in self._functions:
422
+ raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
423
+
424
+ info = FunctionInfo.from_fn(fn, input_schema=input_schema, description=description, converters=converters)
425
+ full_name = self._get_fn_name(name)
426
+ lambda_fn = LambdaFunction.from_info(config=EmptyFunctionConfig(), info=info, instance_name=full_name)
427
+ self._functions[name] = lambda_fn
428
+ if filter_fn:
429
+ self._per_function_filter_fn[name] = filter_fn
430
+
431
+ def get_config(self) -> FunctionGroupBaseConfig:
432
+ """
433
+ Returns the configuration for the function group.
434
+
435
+ Returns
436
+ -------
437
+ FunctionGroupBaseConfig
438
+ The configuration for the function group.
439
+ """
440
+ return self._config
441
+
442
+ def _get_fn_name(self, name: str) -> str:
443
+ return f"{self._instance_name}.{name}"
444
+
445
+ async def _fn_should_be_included(self, name: str) -> bool:
446
+ if name not in self._per_function_filter_fn:
447
+ return True
448
+ return await self._per_function_filter_fn[name](name)
449
+
450
+ async def _get_all_but_excluded_functions(
451
+ self,
452
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
453
+ ) -> dict[str, Function]:
454
+ """
455
+ Returns a dictionary of all functions in the function group except the excluded functions.
456
+ """
457
+ missing = set(self._config.exclude) - set(self._functions.keys())
458
+ if missing:
459
+ raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
460
+
461
+ if filter_fn is None:
462
+ if self._filter_fn is None:
463
+
464
+ async def identity_filter(x: Sequence[str]) -> Sequence[str]:
465
+ return x
466
+
467
+ filter_fn = identity_filter
468
+ else:
469
+ filter_fn = self._filter_fn
470
+
471
+ excluded = set(self._config.exclude)
472
+ included = set(await filter_fn(list(self._functions.keys())))
473
+
474
+ result = {}
475
+ for name in self._functions:
476
+ if name in excluded:
477
+ continue
478
+ if not await self._fn_should_be_included(name):
479
+ continue
480
+ if name not in included:
481
+ continue
482
+ result[self._get_fn_name(name)] = self._functions[name]
483
+
484
+ return result
485
+
486
+ async def get_accessible_functions(
487
+ self,
488
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
489
+ ) -> dict[str, Function]:
490
+ """
491
+ Returns a dictionary of all accessible functions in the function group.
492
+
493
+ First, the functions are filtered by the function group's configuration.
494
+ If the function group is configured to:
495
+ - include some functions, this will return only the included functions.
496
+ - not include or exclude any function, this will return all functions in the group.
497
+ - exclude some functions, this will return all functions in the group except the excluded functions.
498
+
499
+ Then, the functions are filtered by filter function and per-function filter functions.
500
+
501
+ Parameters
502
+ ----------
503
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
504
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
505
+ then fall back to the function group's filter function. If no filter function is set for the function group
506
+ all functions will be returned.
507
+
508
+ Returns
509
+ -------
510
+ dict[str, Function]
511
+ A dictionary of all accessible functions in the function group.
512
+
513
+ Raises
514
+ ------
515
+ ValueError
516
+ When the function group is configured to include functions that are not found in the group.
517
+ """
518
+ if self._config.include:
519
+ return await self.get_included_functions(filter_fn=filter_fn)
520
+ if self._config.exclude:
521
+ return await self._get_all_but_excluded_functions(filter_fn=filter_fn)
522
+ return await self.get_all_functions(filter_fn=filter_fn)
523
+
524
+ async def get_excluded_functions(
525
+ self,
526
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
527
+ ) -> dict[str, Function]:
528
+ """
529
+ Returns a dictionary of all functions in the function group which are configured to be excluded or filtered
530
+ out by a filter function or per-function filter function.
531
+
532
+ Parameters
533
+ ----------
534
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
535
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
536
+ then fall back to the function group's filter function. If no filter function is set for the function group
537
+ then no functions will be added to the returned dictionary.
538
+
539
+ Returns
540
+ -------
541
+ dict[str, Function]
542
+ A dictionary of all excluded functions in the function group.
543
+
544
+ Raises
545
+ ------
546
+ ValueError
547
+ When the function group is configured to exclude functions that are not found in the group.
548
+ """
549
+ missing = set(self._config.exclude) - set(self._functions.keys())
550
+ if missing:
551
+ raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
552
+
553
+ if filter_fn is None:
554
+ if self._filter_fn is None:
555
+
556
+ async def identity_filter(x: Sequence[str]) -> Sequence[str]:
557
+ return x
558
+
559
+ filter_fn = identity_filter
560
+ else:
561
+ filter_fn = self._filter_fn
562
+
563
+ excluded = set(self._config.exclude)
564
+ included = set(await filter_fn(list(self._functions.keys())))
565
+
566
+ result = {}
567
+ for name in self._functions:
568
+ is_excluded = False
569
+ if name in excluded:
570
+ is_excluded = True
571
+ elif not await self._fn_should_be_included(name):
572
+ is_excluded = True
573
+ elif name not in included:
574
+ is_excluded = True
575
+
576
+ if is_excluded:
577
+ result[self._get_fn_name(name)] = self._functions[name]
578
+
579
+ return result
580
+
581
+ async def get_included_functions(
582
+ self,
583
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
584
+ ) -> dict[str, Function]:
585
+ """
586
+ Returns a dictionary of all functions in the function group which are:
587
+ - configured to be included and added to the global function registry
588
+ - not configured to be excluded.
589
+ - not filtered out by a filter function.
590
+
591
+ Parameters
592
+ ----------
593
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
594
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
595
+ then fall back to the function group's filter function. If no filter function is set for the function group
596
+ all functions will be returned.
597
+
598
+ Returns
599
+ -------
600
+ dict[str, Function]
601
+ A dictionary of all included functions in the function group.
602
+
603
+ Raises
604
+ ------
605
+ ValueError
606
+ When the function group is configured to include functions that are not found in the group.
607
+ """
608
+ missing = set(self._config.include) - set(self._functions.keys())
609
+ if missing:
610
+ raise ValueError(f"Unknown included functions: {sorted(missing)}")
611
+
612
+ if filter_fn is None:
613
+ if self._filter_fn is None:
614
+
615
+ async def identity_filter(x: Sequence[str]) -> Sequence[str]:
616
+ return x
617
+
618
+ filter_fn = identity_filter
619
+ else:
620
+ filter_fn = self._filter_fn
621
+
622
+ included = set(await filter_fn(list(self._config.include)))
623
+ result = {}
624
+ for name in included:
625
+ if await self._fn_should_be_included(name):
626
+ result[self._get_fn_name(name)] = self._functions[name]
627
+ return result
628
+
629
+ async def get_all_functions(
630
+ self,
631
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
632
+ ) -> dict[str, Function]:
633
+ """
634
+ Returns a dictionary of all functions in the function group, regardless if they are included or excluded.
635
+
636
+ If a filter function has been set, the returned functions will additionally be filtered by the callback.
637
+
638
+ Parameters
639
+ ----------
640
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
641
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
642
+ then fall back to the function group's filter function. If no filter function is set for the function group
643
+ all functions will be returned.
644
+
645
+ Returns
646
+ -------
647
+ dict[str, Function]
648
+ A dictionary of all functions in the function group.
649
+ """
650
+ if filter_fn is None:
651
+ if self._filter_fn is None:
652
+
653
+ async def identity_filter(x: Sequence[str]) -> Sequence[str]:
654
+ return x
655
+
656
+ filter_fn = identity_filter
657
+ else:
658
+ filter_fn = self._filter_fn
659
+
660
+ included = set(await filter_fn(list(self._functions.keys())))
661
+ result = {}
662
+ for name in included:
663
+ if await self._fn_should_be_included(name):
664
+ result[self._get_fn_name(name)] = self._functions[name]
665
+ return result
666
+
667
+ def set_filter_fn(self, filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]]):
668
+ """
669
+ Sets the filter function for the function group.
670
+
671
+ Parameters
672
+ ----------
673
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]]
674
+ The filter function to set for the function group.
675
+ """
676
+ self._filter_fn = filter_fn
677
+
678
+ def set_per_function_filter_fn(self, name: str, filter_fn: Callable[[str], Awaitable[bool]]):
679
+ """
680
+ Sets the a per-function filter function for the a function within the function group.
681
+
682
+ Parameters
683
+ ----------
684
+ name : str
685
+ The name of the function.
686
+ filter_fn : Callable[[str], Awaitable[bool]]
687
+ The per-function filter function to set for the function group.
688
+
689
+ Raises
690
+ ------
691
+ ValueError
692
+ When the function is not found in the function group.
693
+ """
694
+ if name not in self._functions:
695
+ raise ValueError(f"Function {name} not found in function group {self._instance_name}")
696
+ self._per_function_filter_fn[name] = filter_fn
697
+
698
+ def set_instance_name(self, instance_name: str):
699
+ """
700
+ Sets the instance name for the function group.
701
+
702
+ Parameters
703
+ ----------
704
+ instance_name : str
705
+ The instance name to set for the function group.
706
+ """
707
+ self._instance_name = instance_name
708
+
709
+ @property
710
+ def instance_name(self) -> str:
711
+ """
712
+ Returns the instance name for the function group.
713
+ """
714
+ return self._instance_name
@@ -233,7 +233,7 @@ class FunctionDescriptor:
233
233
 
234
234
  is_input_typed = all([a != sig.empty for a in annotations])
235
235
 
236
- input_type = tuple[*annotations] if is_input_typed else None # noqa: syntax-error
236
+ input_type = tuple[*annotations] if is_input_typed else None
237
237
 
238
238
  # Get the base type here removing all annotations and async generators
239
239
  output_annotation_decomp = DecomposedType(sig.return_annotation).get_base_type()
@@ -16,6 +16,8 @@
16
16
  import dataclasses
17
17
  import logging
18
18
  import typing
19
+ import weakref
20
+ from typing import ClassVar
19
21
 
20
22
  from nat.data_models.intermediate_step import IntermediateStep
21
23
  from nat.data_models.intermediate_step import IntermediateStepPayload
@@ -46,11 +48,19 @@ class IntermediateStepManager:
46
48
  Manages updates to the NAT Event Stream for intermediate steps
47
49
  """
48
50
 
51
+ # Class-level tracking for debugging and monitoring
52
+ _instance_count: ClassVar[int] = 0
53
+ _active_instances: ClassVar[set[weakref.ref]] = set()
54
+
49
55
  def __init__(self, context_state: "ContextState"): # noqa: F821
50
56
  self._context_state = context_state
51
57
 
52
58
  self._outstanding_start_steps: dict[str, OpenStep] = {}
53
59
 
60
+ # Track instance creation
61
+ IntermediateStepManager._instance_count += 1
62
+ IntermediateStepManager._active_instances.add(weakref.ref(self, self._cleanup_instance_tracking))
63
+
54
64
  def push_intermediate_step(self, payload: IntermediateStepPayload) -> None:
55
65
  """
56
66
  Pushes an intermediate step to the NAT Event Stream
@@ -91,7 +101,10 @@ class IntermediateStepManager:
91
101
  open_step = self._outstanding_start_steps.pop(payload.UUID, None)
92
102
 
93
103
  if (open_step is None):
94
- logger.warning("Step id %s not found in outstanding start steps", payload.UUID)
104
+ logger.warning(
105
+ "Step id %s not found in outstanding start steps. "
106
+ "This may occur if the step was started in a different context or already completed.",
107
+ payload.UUID)
95
108
  return
96
109
 
97
110
  parent_step_id = open_step.step_parent_id
@@ -147,7 +160,8 @@ class IntermediateStepManager:
147
160
  if (open_step is None):
148
161
  logger.warning(
149
162
  "Created a chunk for step %s, but no matching start step was found. "
150
- "Chunks must be created with the same ID as the start step.",
163
+ "Chunks must be created with the same ID as the start step. "
164
+ "This may occur if the step was started in a different context.",
151
165
  payload.UUID)
152
166
  return
153
167
 
@@ -172,3 +186,25 @@ class IntermediateStepManager:
172
186
  """
173
187
 
174
188
  return self._context_state.event_stream.get().subscribe(on_next, on_error, on_complete)
189
+
190
+ @classmethod
191
+ def _cleanup_instance_tracking(cls, ref: weakref.ref) -> None:
192
+ """Cleanup callback for weakref when instance is garbage collected."""
193
+ cls._active_instances.discard(ref)
194
+
195
+ @classmethod
196
+ def get_active_instance_count(cls) -> int:
197
+ """Get the number of active IntermediateStepManager instances.
198
+
199
+ Returns:
200
+ int: Number of active instances (cleaned up automatically via weakref)
201
+ """
202
+ return len(cls._active_instances)
203
+
204
+ def get_outstanding_step_count(self) -> int:
205
+ """Get the number of outstanding (started but not ended) steps.
206
+
207
+ Returns:
208
+ int: Number of steps that have been started but not yet ended
209
+ """
210
+ return len(self._outstanding_start_steps)