nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
nat/llm/litellm_llm.py ADDED
@@ -0,0 +1,80 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import AsyncIterator
17
+
18
+ from pydantic import AliasChoices
19
+ from pydantic import ConfigDict
20
+ from pydantic import Field
21
+
22
+ from nat.builder.builder import Builder
23
+ from nat.builder.llm import LLMProviderInfo
24
+ from nat.cli.register_workflow import register_llm_provider
25
+ from nat.data_models.common import OptionalSecretStr
26
+ from nat.data_models.llm import LLMBaseConfig
27
+ from nat.data_models.optimizable import OptimizableField
28
+ from nat.data_models.optimizable import OptimizableMixin
29
+ from nat.data_models.optimizable import SearchSpace
30
+ from nat.data_models.retry_mixin import RetryMixin
31
+ from nat.data_models.thinking_mixin import ThinkingMixin
32
+
33
+
34
+ class LiteLlmModelConfig(
35
+ LLMBaseConfig,
36
+ OptimizableMixin,
37
+ RetryMixin,
38
+ ThinkingMixin,
39
+ name="litellm",
40
+ ):
41
+ """A LiteLlm provider to be used with an LLM client."""
42
+
43
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
44
+
45
+ api_key: OptionalSecretStr = Field(default=None, description="API key to interact with hosted model.")
46
+ base_url: str | None = Field(default=None,
47
+ description="Base url to the hosted model.",
48
+ validation_alias=AliasChoices("base_url", "api_base"),
49
+ serialization_alias="api_base")
50
+ model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
51
+ serialization_alias="model",
52
+ description="The LiteLlm hosted model name.")
53
+ seed: int | None = Field(default=None, description="Random seed to set for generation.")
54
+ temperature: float | None = OptimizableField(
55
+ default=None,
56
+ ge=0.0,
57
+ description="Sampling temperature to control randomness in the output.",
58
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
59
+ top_p: float | None = OptimizableField(default=None,
60
+ ge=0.0,
61
+ le=1.0,
62
+ description="Top-p for distribution sampling.",
63
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
64
+
65
+
66
+ @register_llm_provider(config_type=LiteLlmModelConfig)
67
+ async def litellm_model(
68
+ config: LiteLlmModelConfig,
69
+ _builder: Builder,
70
+ ) -> AsyncIterator[LLMProviderInfo]:
71
+ """Litellm model provider.
72
+
73
+ Args:
74
+ config (LiteLlmModelConfig): The LiteLlm model configuration.
75
+ _builder (Builder): The NAT builder instance.
76
+
77
+ Returns:
78
+ AsyncIterator[LLMProviderInfo]: An async iterator that yields an LLMProviderInfo object.
79
+ """
80
+ yield LLMProviderInfo(config=config, description="A LiteLlm model for use with an LLM client.")
nat/llm/nim_llm.py CHANGED
@@ -21,24 +21,38 @@ from pydantic import PositiveInt
21
21
  from nat.builder.builder import Builder
22
22
  from nat.builder.llm import LLMProviderInfo
23
23
  from nat.cli.register_workflow import register_llm_provider
24
+ from nat.data_models.common import OptionalSecretStr
24
25
  from nat.data_models.llm import LLMBaseConfig
26
+ from nat.data_models.optimizable import OptimizableField
27
+ from nat.data_models.optimizable import OptimizableMixin
28
+ from nat.data_models.optimizable import SearchSpace
25
29
  from nat.data_models.retry_mixin import RetryMixin
26
- from nat.data_models.temperature_mixin import TemperatureMixin
27
30
  from nat.data_models.thinking_mixin import ThinkingMixin
28
- from nat.data_models.top_p_mixin import TopPMixin
29
31
 
30
32
 
31
- class NIMModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, ThinkingMixin, name="nim"):
33
+ class NIMModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="nim"):
32
34
  """An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""
33
35
 
34
- model_config = ConfigDict(protected_namespaces=())
36
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
35
37
 
36
- api_key: str | None = Field(default=None, description="NVIDIA API key to interact with hosted NIM.")
38
+ api_key: OptionalSecretStr = Field(default=None, description="NVIDIA API key to interact with hosted NIM.")
37
39
  base_url: str | None = Field(default=None, description="Base url to the hosted NIM.")
38
- model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
39
- serialization_alias="model",
40
- description="The model name for the hosted NIM.")
41
- max_tokens: PositiveInt = Field(default=300, description="Maximum number of tokens to generate.")
40
+ model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
41
+ serialization_alias="model",
42
+ description="The model name for the hosted NIM.")
43
+ max_tokens: PositiveInt = OptimizableField(default=300,
44
+ description="Maximum number of tokens to generate.",
45
+ space=SearchSpace(high=2176, low=128, step=512))
46
+ temperature: float | None = OptimizableField(
47
+ default=None,
48
+ ge=0.0,
49
+ description="Sampling temperature to control randomness in the output.",
50
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
51
+ top_p: float | None = OptimizableField(default=None,
52
+ ge=0.0,
53
+ le=1.0,
54
+ description="Top-p for distribution sampling.",
55
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
42
56
 
43
57
 
44
58
  @register_llm_provider(config_type=NIMModelConfig)
nat/llm/openai_llm.py CHANGED
@@ -20,25 +20,37 @@ from pydantic import Field
20
20
  from nat.builder.builder import Builder
21
21
  from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
+ from nat.data_models.common import OptionalSecretStr
23
24
  from nat.data_models.llm import LLMBaseConfig
25
+ from nat.data_models.optimizable import OptimizableField
26
+ from nat.data_models.optimizable import OptimizableMixin
27
+ from nat.data_models.optimizable import SearchSpace
24
28
  from nat.data_models.retry_mixin import RetryMixin
25
- from nat.data_models.temperature_mixin import TemperatureMixin
26
29
  from nat.data_models.thinking_mixin import ThinkingMixin
27
- from nat.data_models.top_p_mixin import TopPMixin
28
30
 
29
31
 
30
- class OpenAIModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, ThinkingMixin, name="openai"):
32
+ class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="openai"):
31
33
  """An OpenAI LLM provider to be used with an LLM client."""
32
34
 
33
35
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
34
36
 
35
- api_key: str | None = Field(default=None, description="OpenAI API key to interact with hosted model.")
37
+ api_key: OptionalSecretStr = Field(default=None, description="OpenAI API key to interact with hosted model.")
36
38
  base_url: str | None = Field(default=None, description="Base url to the hosted model.")
37
- model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
38
- serialization_alias="model",
39
- description="The OpenAI hosted model name.")
39
+ model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
40
+ serialization_alias="model",
41
+ description="The OpenAI hosted model name.")
40
42
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
41
43
  max_retries: int = Field(default=10, description="The max number of retries for the request.")
44
+ temperature: float | None = OptimizableField(
45
+ default=None,
46
+ ge=0.0,
47
+ description="Sampling temperature to control randomness in the output.",
48
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
49
+ top_p: float | None = OptimizableField(default=None,
50
+ ge=0.0,
51
+ le=1.0,
52
+ description="Top-p for distribution sampling.",
53
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
42
54
 
43
55
 
44
56
  @register_llm_provider(config_type=OpenAIModelConfig)
nat/llm/register.py CHANGED
@@ -15,9 +15,13 @@
15
15
 
16
16
  # flake8: noqa
17
17
  # isort:skip_file
18
+ """Register LLM providers via import side effects.
18
19
 
20
+ This module is imported by the NeMo Agent Toolkit runtime to ensure providers are registered and discoverable.
21
+ """
19
22
  # Import any providers which need to be automatically registered here
20
23
  from . import aws_bedrock_llm
21
24
  from . import azure_openai_llm
25
+ from . import litellm_llm
22
26
  from . import nim_llm
23
27
  from . import openai_llm
nat/llm/utils/thinking.py CHANGED
@@ -19,10 +19,10 @@ import logging
19
19
  import types
20
20
  from abc import abstractmethod
21
21
  from collections.abc import AsyncGenerator
22
+ from collections.abc import Callable
22
23
  from collections.abc import Iterable
23
24
  from dataclasses import dataclass
24
25
  from typing import Any
25
- from typing import Callable
26
26
  from typing import TypeVar
27
27
 
28
28
  ModelType = TypeVar("ModelType")
@@ -372,7 +372,7 @@ class BaseExporter(Exporter):
372
372
  try:
373
373
  # Wait for all tasks to complete with a timeout
374
374
  await asyncio.wait_for(asyncio.gather(*self._tasks, return_exceptions=True), timeout=timeout)
375
- except asyncio.TimeoutError:
375
+ except TimeoutError:
376
376
  logger.warning("%s: Some tasks did not complete within %s seconds", self.name, timeout)
377
377
  except Exception as e:
378
378
  logger.exception("%s: Error while waiting for tasks: %s", self.name, e)
@@ -53,6 +53,8 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
53
53
  - Configurable None filtering: processors returning None can drop items from pipeline
54
54
  - Automatic type validation before export
55
55
  """
56
+ # All ProcessingExporter instances automatically use this for signature checking
57
+ _signature_method = '_process_pipeline'
56
58
 
57
59
  def __init__(self, context_state: ContextState | None = None, drop_nones: bool = True):
58
60
  """Initialize the processing exporter.
@@ -294,8 +296,6 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
294
296
  self._check_processor_compatibility(predecessor,
295
297
  processor,
296
298
  "predecessor",
297
- predecessor.output_class,
298
- processor.input_class,
299
299
  str(predecessor.output_type),
300
300
  str(processor.input_type))
301
301
 
@@ -304,8 +304,6 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
304
304
  self._check_processor_compatibility(processor,
305
305
  successor,
306
306
  "successor",
307
- processor.output_class,
308
- successor.input_class,
309
307
  str(processor.output_type),
310
308
  str(successor.input_type))
311
309
 
@@ -313,34 +311,22 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
313
311
  source_processor: Processor,
314
312
  target_processor: Processor,
315
313
  relationship: str,
316
- source_class: type,
317
- target_class: type,
318
314
  source_type: str,
319
315
  target_type: str) -> None:
320
- """Check type compatibility between two processors.
316
+ """Check type compatibility between two processors using Pydantic validation.
321
317
 
322
318
  Args:
323
319
  source_processor (Processor): The processor providing output
324
320
  target_processor (Processor): The processor receiving input
325
321
  relationship (str): Description of relationship ("predecessor" or "successor")
326
- source_class (type): The output class of source processor
327
- target_class (type): The input class of target processor
328
322
  source_type (str): String representation of source type
329
323
  target_type (str): String representation of target type
330
324
  """
331
- try:
332
- if not issubclass(source_class, target_class):
333
- raise ValueError(f"Processor {target_processor.__class__.__name__} input type {target_type} "
334
- f"is not compatible with {relationship} {source_processor.__class__.__name__} "
335
- f"output type {source_type}")
336
- except TypeError:
337
- logger.warning(
338
- "Cannot use issubclass() for type compatibility check between "
339
- "%s (%s) and %s (%s). Skipping compatibility check.",
340
- source_processor.__class__.__name__,
341
- source_type,
342
- target_processor.__class__.__name__,
343
- target_type)
325
+ # Use Pydantic-based type compatibility checking
326
+ if not source_processor.is_output_compatible_with(target_processor.input_type):
327
+ raise ValueError(f"Processor {target_processor.__class__.__name__} input type {target_type} "
328
+ f"is not compatible with {relationship} {source_processor.__class__.__name__} "
329
+ f"output type {source_type}")
344
330
 
345
331
  async def _pre_start(self) -> None:
346
332
 
@@ -350,36 +336,21 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
350
336
  last_processor = self._processors[-1]
351
337
 
352
338
  # validate that the first processor's input type is compatible with the exporter's input type
353
- try:
354
- if not issubclass(self.input_class, first_processor.input_class):
355
- raise ValueError(f"Processor {first_processor.__class__.__name__} input type "
356
- f"{first_processor.input_type} is not compatible with the "
357
- f"{self.input_type} input type")
358
- except TypeError as e:
359
- # Handle cases where classes are generic types that can't be used with issubclass
360
- logger.warning(
361
- "Cannot validate type compatibility between %s (%s) "
362
- "and exporter (%s): %s. Skipping validation.",
363
- first_processor.__class__.__name__,
364
- first_processor.input_type,
365
- self.input_type,
366
- e)
367
-
339
+ if not first_processor.is_compatible_with_input(self.input_type):
340
+ logger.error("First processor %s input=%s incompatible with exporter input=%s",
341
+ first_processor.__class__.__name__,
342
+ first_processor.input_type,
343
+ self.input_type)
344
+ raise ValueError("First processor incompatible with exporter input")
368
345
  # Validate that the last processor's output type is compatible with the exporter's output type
369
- try:
370
- if not DecomposedType.is_type_compatible(last_processor.output_type, self.output_type):
371
- raise ValueError(f"Processor {last_processor.__class__.__name__} output type "
372
- f"{last_processor.output_type} is not compatible with the "
373
- f"{self.output_type} output type")
374
- except TypeError as e:
375
- # Handle cases where classes are generic types that can't be used with issubclass
376
- logger.warning(
377
- "Cannot validate type compatibility between %s (%s) "
378
- "and exporter (%s): %s. Skipping validation.",
379
- last_processor.__class__.__name__,
380
- last_processor.output_type,
381
- self.output_type,
382
- e)
346
+ # Use DecomposedType.is_type_compatible for the final export stage to allow batch compatibility
347
+ # This enables BatchingProcessor[T] -> Exporter[T] patterns where the exporter handles both T and list[T]
348
+ if not DecomposedType.is_type_compatible(last_processor.output_type, self.output_type):
349
+ logger.error("Last processor %s output=%s incompatible with exporter output=%s",
350
+ last_processor.__class__.__name__,
351
+ last_processor.output_type,
352
+ self.output_type)
353
+ raise ValueError("Last processor incompatible with exporter output")
383
354
 
384
355
  # Lock the pipeline to prevent further modifications
385
356
  self._pipeline_locked = True
@@ -432,12 +403,15 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
432
403
  await self.export_processed(processed_item)
433
404
  else:
434
405
  logger.debug("Skipping export of empty batch")
435
- elif isinstance(processed_item, self.output_class):
406
+ elif self.validate_output_type(processed_item):
436
407
  await self.export_processed(processed_item)
437
408
  else:
438
409
  if raise_on_invalid:
439
- raise ValueError(f"Processed item {processed_item} is not a valid output type. "
440
- f"Expected {self.output_class} or list[{self.output_class}]")
410
+ logger.error("Invalid processed item type for export: %s (expected %s or list[%s])",
411
+ type(processed_item),
412
+ self.output_type,
413
+ self.output_type)
414
+ raise ValueError("Invalid processed item type for export")
441
415
  logger.warning("Processed item %s is not a valid output type for export", processed_item)
442
416
 
443
417
  async def _continue_pipeline_after(self, source_processor: Processor, item: Any) -> None:
@@ -512,7 +486,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
512
486
  event (IntermediateStep): The event to be exported.
513
487
  """
514
488
  # Convert IntermediateStep to PipelineInputT and create export task
515
- if isinstance(event, self.input_class):
489
+ if self.validate_input_type(event):
516
490
  input_item: PipelineInputT = event # type: ignore
517
491
  coro = self._export_with_processing(input_item)
518
492
  self._create_export_task(coro)
@@ -126,6 +126,7 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
126
126
 
127
127
  parent_span = None
128
128
  span_ctx = None
129
+ workflow_trace_id = self._context_state.workflow_trace_id.get()
129
130
 
130
131
  # Look up the parent span to establish hierarchy
131
132
  # event.parent_id is the UUID of the last START step with a different UUID from current step
@@ -141,6 +142,9 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
141
142
  parent_span = parent_span.model_copy() if isinstance(parent_span, Span) else None
142
143
  if parent_span and parent_span.context:
143
144
  span_ctx = SpanContext(trace_id=parent_span.context.trace_id)
145
+ # No parent: adopt workflow trace id if available to keep all spans in the same trace
146
+ if span_ctx is None and workflow_trace_id:
147
+ span_ctx = SpanContext(trace_id=workflow_trace_id)
144
148
 
145
149
  # Extract start/end times from the step
146
150
  # By convention, `span_event_timestamp` is the time we started, `event_timestamp` is the time we ended.
@@ -154,28 +158,52 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
154
158
  else:
155
159
  sub_span_name = f"{event.payload.event_type}"
156
160
 
161
+ # Prefer parent/context trace id for attribute, else workflow trace id
162
+ _attr_trace_id = None
163
+ if span_ctx is not None:
164
+ _attr_trace_id = span_ctx.trace_id
165
+ elif parent_span and parent_span.context:
166
+ _attr_trace_id = parent_span.context.trace_id
167
+ elif workflow_trace_id:
168
+ _attr_trace_id = workflow_trace_id
169
+
170
+ attributes = {
171
+ f"{self._span_prefix}.event_type":
172
+ event.payload.event_type.value,
173
+ f"{self._span_prefix}.function.id":
174
+ event.function_ancestry.function_id if event.function_ancestry else "unknown",
175
+ f"{self._span_prefix}.function.name":
176
+ event.function_ancestry.function_name if event.function_ancestry else "unknown",
177
+ f"{self._span_prefix}.subspan.name":
178
+ event.payload.name or "",
179
+ f"{self._span_prefix}.event_timestamp":
180
+ event.event_timestamp,
181
+ f"{self._span_prefix}.framework":
182
+ event.payload.framework.value if event.payload.framework else "unknown",
183
+ f"{self._span_prefix}.conversation.id":
184
+ self._context_state.conversation_id.get() or "unknown",
185
+ f"{self._span_prefix}.workflow.run_id":
186
+ self._context_state.workflow_run_id.get() or "unknown",
187
+ f"{self._span_prefix}.workflow.trace_id": (f"{_attr_trace_id:032x}" if _attr_trace_id else "unknown"),
188
+ }
189
+
157
190
  sub_span = Span(name=sub_span_name,
158
191
  parent=parent_span,
159
192
  context=span_ctx,
160
- attributes={
161
- f"{self._span_prefix}.event_type":
162
- event.payload.event_type.value,
163
- f"{self._span_prefix}.function.id":
164
- event.function_ancestry.function_id if event.function_ancestry else "unknown",
165
- f"{self._span_prefix}.function.name":
166
- event.function_ancestry.function_name if event.function_ancestry else "unknown",
167
- f"{self._span_prefix}.subspan.name":
168
- event.payload.name or "",
169
- f"{self._span_prefix}.event_timestamp":
170
- event.event_timestamp,
171
- f"{self._span_prefix}.framework":
172
- event.payload.framework.value if event.payload.framework else "unknown",
173
- },
193
+ attributes=attributes,
174
194
  start_time=start_ns)
175
195
 
176
196
  span_kind = event_type_to_span_kind(event.event_type)
177
197
  sub_span.set_attribute(f"{self._span_prefix}.span.kind", span_kind.value)
178
198
 
199
+ # Enable session grouping by setting session.id from conversation_id
200
+ try:
201
+ conversation_id = self._context_state.conversation_id.get()
202
+ if conversation_id:
203
+ sub_span.set_attribute("session.id", conversation_id)
204
+ except (AttributeError, LookupError):
205
+ pass
206
+
179
207
  if event.payload.data and event.payload.data.input:
180
208
  match = re.search(r"Human:\s*Question:\s*(.*)", str(event.payload.data.input))
181
209
  if match:
@@ -252,7 +280,7 @@ class SpanExporter(ProcessingExporter[InputSpanT, OutputSpanT], SerializeMixin):
252
280
 
253
281
  end_metadata = event.payload.metadata or {}
254
282
 
255
- if not isinstance(end_metadata, (dict, TraceMetadata)):
283
+ if not isinstance(end_metadata, dict | TraceMetadata):
256
284
  logger.warning("Invalid metadata type for step %s", event.UUID)
257
285
  return
258
286
 
@@ -184,7 +184,7 @@ class ExporterManager:
184
184
  try:
185
185
  await asyncio.wait_for(asyncio.gather(*cleanup_tasks, return_exceptions=True),
186
186
  timeout=self._shutdown_timeout)
187
- except asyncio.TimeoutError:
187
+ except TimeoutError:
188
188
  logger.warning("Some isolated exporters did not clean up within timeout")
189
189
 
190
190
  self._active_isolated_exporters.clear()
@@ -301,7 +301,7 @@ class ExporterManager:
301
301
  try:
302
302
  task.cancel()
303
303
  await asyncio.wait_for(task, timeout=self._shutdown_timeout)
304
- except asyncio.TimeoutError:
304
+ except TimeoutError:
305
305
  logger.warning("Exporter '%s' task did not shut down in time and may be stuck.", name)
306
306
  stuck_tasks.append(name)
307
307
  except asyncio.CancelledError:
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -25,9 +25,10 @@ class RedactionConfigMixin(BaseModel):
25
25
  """
26
26
  redaction_enabled: bool = Field(default=False, description="Whether to enable redaction processing.")
27
27
  redaction_value: str = Field(default="[REDACTED]", description="Value to replace redacted attributes with.")
28
- redaction_attributes: list[str] = Field(default_factory=lambda: ["input.value", "output.value", "metadata"],
29
- description="Span attributes to redact when redaction is triggered.")
28
+ redaction_attributes: list[str] = Field(default_factory=lambda: ["input.value", "output.value", "nat.metadata"],
29
+ description="Attributes to redact when redaction is triggered.")
30
30
  force_redaction: bool = Field(default=False, description="Always redact regardless of other conditions.")
31
+ redaction_tag: str | None = Field(default=None, description="Tag to add to spans when redaction is triggered.")
31
32
 
32
33
 
33
34
  class HeaderRedactionConfigMixin(RedactionConfigMixin):
@@ -38,4 +39,4 @@ class HeaderRedactionConfigMixin(RedactionConfigMixin):
38
39
 
39
40
  Note: The callback function must be provided directly to the processor at runtime.
40
41
  """
41
- redaction_header: str = Field(default="x-redaction-key", description="Header to check for redaction decisions.")
42
+ redaction_headers: list[str] = Field(default_factory=list, description="Headers to check for redaction decisions.")
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import sys
17
+ from collections.abc import Mapping
16
18
  from enum import Enum
17
19
  from typing import Generic
18
20
  from typing import TypeVar
@@ -20,7 +22,17 @@ from typing import TypeVar
20
22
  from pydantic import BaseModel
21
23
  from pydantic import Field
22
24
 
23
- TagValueT = TypeVar("TagValueT")
25
+ if sys.version_info >= (3, 12):
26
+ from typing import TypedDict
27
+ else:
28
+ from typing_extensions import TypedDict
29
+
30
+ TagMappingT = TypeVar("TagMappingT", bound=Mapping)
31
+
32
+
33
+ class BaseTaggingConfigMixin(BaseModel, Generic[TagMappingT]):
34
+ """Base mixin for tagging spans."""
35
+ tags: TagMappingT | None = Field(default=None, description="Tags to add to the span.")
24
36
 
25
37
 
26
38
  class PrivacyLevel(str, Enum):
@@ -31,20 +43,20 @@ class PrivacyLevel(str, Enum):
31
43
  HIGH = "high"
32
44
 
33
45
 
34
- class TaggingConfigMixin(BaseModel, Generic[TagValueT]):
35
- """Generic mixin for tagging spans with typed values.
46
+ PrivacyTagSchema = TypedDict(
47
+ "PrivacyTagSchema",
48
+ {
49
+ "privacy.level": PrivacyLevel,
50
+ },
51
+ total=True,
52
+ )
36
53
 
37
- This mixin provides a flexible tagging system where both the tag key
38
- and value type can be customized for different use cases.
39
- """
40
- tag_key: str | None = Field(default=None, description="Key to use when tagging traces.")
41
- tag_value: TagValueT | None = Field(default=None, description="Value to tag the traces with.")
42
54
 
55
+ class PrivacyTaggingConfigMixin(BaseTaggingConfigMixin[PrivacyTagSchema]):
56
+ """Mixin for privacy level tagging on spans."""
57
+ pass
43
58
 
44
- class PrivacyTaggingConfigMixin(TaggingConfigMixin[PrivacyLevel]):
45
- """Mixin for privacy level tagging on spans.
46
59
 
47
- Specializes TaggingConfigMixin to work with PrivacyLevel enum values,
48
- providing a typed interface for privacy-related span tagging.
49
- """
60
+ class CustomTaggingConfigMixin(BaseTaggingConfigMixin[dict[str, str]]):
61
+ """Mixin for string key-value tagging on spans."""
50
62
  pass