nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -16,11 +16,15 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
from abc import ABC
|
|
18
18
|
from abc import abstractmethod
|
|
19
|
+
from collections.abc import Mapping
|
|
20
|
+
from typing import Any
|
|
19
21
|
|
|
20
22
|
from mcp.server.fastmcp import FastMCP
|
|
23
|
+
from starlette.exceptions import HTTPException
|
|
21
24
|
from starlette.requests import Request
|
|
22
25
|
|
|
23
26
|
from nat.builder.function import Function
|
|
27
|
+
from nat.builder.function_base import FunctionBase
|
|
24
28
|
from nat.builder.workflow import Workflow
|
|
25
29
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
26
30
|
from nat.data_models.config import Config
|
|
@@ -82,7 +86,7 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
82
86
|
"""
|
|
83
87
|
pass
|
|
84
88
|
|
|
85
|
-
def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
|
|
89
|
+
async def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
|
|
86
90
|
"""Get all functions from the workflow.
|
|
87
91
|
|
|
88
92
|
Args:
|
|
@@ -94,13 +98,114 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
94
98
|
functions: dict[str, Function] = {}
|
|
95
99
|
|
|
96
100
|
# Extract all functions from the workflow
|
|
97
|
-
|
|
98
|
-
|
|
101
|
+
functions.update(workflow.functions)
|
|
102
|
+
for function_group in workflow.function_groups.values():
|
|
103
|
+
functions.update(await function_group.get_accessible_functions())
|
|
99
104
|
|
|
100
|
-
|
|
105
|
+
if workflow.config.workflow.workflow_alias:
|
|
106
|
+
functions[workflow.config.workflow.workflow_alias] = workflow
|
|
107
|
+
else:
|
|
108
|
+
functions[workflow.config.workflow.type] = workflow
|
|
101
109
|
|
|
102
110
|
return functions
|
|
103
111
|
|
|
112
|
+
def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
|
|
113
|
+
"""Set up HTTP debug endpoints for introspecting tools and schemas.
|
|
114
|
+
|
|
115
|
+
Exposes:
|
|
116
|
+
- GET /debug/tools/list: List tools. Optional query param `name` (one or more, repeatable or comma separated)
|
|
117
|
+
selects a subset and returns details for those tools.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
@mcp.custom_route("/debug/tools/list", methods=["GET"])
|
|
121
|
+
async def list_tools(request: Request):
|
|
122
|
+
"""HTTP list tools endpoint."""
|
|
123
|
+
|
|
124
|
+
from starlette.responses import JSONResponse
|
|
125
|
+
|
|
126
|
+
from nat.front_ends.mcp.tool_converter import get_function_description
|
|
127
|
+
|
|
128
|
+
# Query params
|
|
129
|
+
# Support repeated names and comma-separated lists
|
|
130
|
+
names_param_list = set(request.query_params.getlist("name"))
|
|
131
|
+
names: list[str] = []
|
|
132
|
+
for raw in names_param_list:
|
|
133
|
+
# if p.strip() is empty, it won't be included in the list!
|
|
134
|
+
parts = [p.strip() for p in raw.split(",") if p.strip()]
|
|
135
|
+
names.extend(parts)
|
|
136
|
+
detail_raw = request.query_params.get("detail")
|
|
137
|
+
|
|
138
|
+
def _parse_detail_param(detail_param: str | None, has_names: bool) -> bool:
|
|
139
|
+
if detail_param is None:
|
|
140
|
+
if has_names:
|
|
141
|
+
return True
|
|
142
|
+
return False
|
|
143
|
+
v = detail_param.strip().lower()
|
|
144
|
+
if v in ("0", "false", "no", "off"):
|
|
145
|
+
return False
|
|
146
|
+
if v in ("1", "true", "yes", "on"):
|
|
147
|
+
return True
|
|
148
|
+
# For invalid values, default based on whether names are present
|
|
149
|
+
return has_names
|
|
150
|
+
|
|
151
|
+
# Helper function to build the input schema info
|
|
152
|
+
def _build_schema_info(fn: FunctionBase) -> dict[str, Any] | None:
|
|
153
|
+
schema = getattr(fn, "input_schema", None)
|
|
154
|
+
if schema is None:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
# check if schema is a ChatRequest
|
|
158
|
+
schema_name = getattr(schema, "__name__", "")
|
|
159
|
+
schema_qualname = getattr(schema, "__qualname__", "")
|
|
160
|
+
if "ChatRequest" in schema_name or "ChatRequest" in schema_qualname:
|
|
161
|
+
# Simplified interface used by MCP wrapper for ChatRequest
|
|
162
|
+
return {
|
|
163
|
+
"type": "object",
|
|
164
|
+
"properties": {
|
|
165
|
+
"query": {
|
|
166
|
+
"type": "string", "description": "User query string"
|
|
167
|
+
}
|
|
168
|
+
},
|
|
169
|
+
"required": ["query"],
|
|
170
|
+
"title": "ChatRequestQuery",
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
# Pydantic models provide model_json_schema
|
|
174
|
+
if schema is not None and hasattr(schema, "model_json_schema"):
|
|
175
|
+
return schema.model_json_schema()
|
|
176
|
+
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
def _build_final_json(functions_to_include: Mapping[str, FunctionBase],
|
|
180
|
+
include_schemas: bool = False) -> dict[str, Any]:
|
|
181
|
+
tools = []
|
|
182
|
+
for name, fn in functions_to_include.items():
|
|
183
|
+
list_entry: dict[str, Any] = {
|
|
184
|
+
"name": name, "description": get_function_description(fn), "is_workflow": hasattr(fn, "run")
|
|
185
|
+
}
|
|
186
|
+
if include_schemas:
|
|
187
|
+
list_entry["schema"] = _build_schema_info(fn)
|
|
188
|
+
tools.append(list_entry)
|
|
189
|
+
|
|
190
|
+
return {
|
|
191
|
+
"count": len(tools),
|
|
192
|
+
"tools": tools,
|
|
193
|
+
"server_name": mcp.name,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
if names:
|
|
197
|
+
# Return selected tools
|
|
198
|
+
try:
|
|
199
|
+
functions_to_include = {n: functions[n] for n in names}
|
|
200
|
+
except KeyError as e:
|
|
201
|
+
raise HTTPException(status_code=404, detail=f"Tool \"{e.args[0]}\" not found.") from e
|
|
202
|
+
else:
|
|
203
|
+
functions_to_include = functions
|
|
204
|
+
|
|
205
|
+
# Default for listing all: detail defaults to False unless explicitly set true
|
|
206
|
+
return JSONResponse(
|
|
207
|
+
_build_final_json(functions_to_include, _parse_detail_param(detail_raw, has_names=bool(names))))
|
|
208
|
+
|
|
104
209
|
|
|
105
210
|
class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
106
211
|
"""Default MCP front end plugin worker implementation."""
|
|
@@ -118,10 +223,10 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
|
118
223
|
self._setup_health_endpoint(mcp)
|
|
119
224
|
|
|
120
225
|
# Build the workflow and register all functions with MCP
|
|
121
|
-
workflow = builder.build()
|
|
226
|
+
workflow = await builder.build()
|
|
122
227
|
|
|
123
228
|
# Get all functions from the workflow
|
|
124
|
-
functions = self._get_all_functions(workflow)
|
|
229
|
+
functions = await self._get_all_functions(workflow)
|
|
125
230
|
|
|
126
231
|
# Filter functions based on tool_names if provided
|
|
127
232
|
if self.front_end_config.tool_names:
|
|
@@ -134,10 +239,13 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
|
134
239
|
logger.debug("Skipping function %s as it's not in tool_names", function_name)
|
|
135
240
|
functions = filtered_functions
|
|
136
241
|
|
|
137
|
-
# Register each function with MCP
|
|
242
|
+
# Register each function with MCP, passing workflow context for observability
|
|
138
243
|
for function_name, function in functions.items():
|
|
139
|
-
register_function_with_mcp(mcp, function_name, function)
|
|
244
|
+
register_function_with_mcp(mcp, function_name, function, workflow)
|
|
140
245
|
|
|
141
246
|
# Add a simple fallback function if no functions were found
|
|
142
247
|
if not functions:
|
|
143
248
|
raise RuntimeError("No functions found in workflow. Please check your configuration.")
|
|
249
|
+
|
|
250
|
+
# After registration, expose debug endpoints for tool/schema inspection
|
|
251
|
+
self._setup_debug_endpoints(mcp, functions)
|
|
@@ -17,13 +17,17 @@ import json
|
|
|
17
17
|
import logging
|
|
18
18
|
from inspect import Parameter
|
|
19
19
|
from inspect import Signature
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
20
21
|
|
|
21
22
|
from mcp.server.fastmcp import FastMCP
|
|
22
23
|
from pydantic import BaseModel
|
|
23
24
|
|
|
25
|
+
from nat.builder.context import ContextState
|
|
24
26
|
from nat.builder.function import Function
|
|
25
27
|
from nat.builder.function_base import FunctionBase
|
|
26
|
-
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from nat.builder.workflow import Workflow
|
|
27
31
|
|
|
28
32
|
logger = logging.getLogger(__name__)
|
|
29
33
|
|
|
@@ -33,14 +37,16 @@ def create_function_wrapper(
|
|
|
33
37
|
function: FunctionBase,
|
|
34
38
|
schema: type[BaseModel],
|
|
35
39
|
is_workflow: bool = False,
|
|
40
|
+
workflow: 'Workflow | None' = None,
|
|
36
41
|
):
|
|
37
42
|
"""Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
|
|
38
43
|
|
|
39
44
|
Args:
|
|
40
|
-
function_name: The name of the function/tool
|
|
41
|
-
function: The NAT Function object
|
|
42
|
-
schema: The input schema of the function
|
|
43
|
-
is_workflow: Whether the function is a Workflow
|
|
45
|
+
function_name (str): The name of the function/tool
|
|
46
|
+
function (FunctionBase): The NAT Function object
|
|
47
|
+
schema (type[BaseModel]): The input schema of the function
|
|
48
|
+
is_workflow (bool): Whether the function is a Workflow
|
|
49
|
+
workflow (Workflow | None): The parent workflow for observability context
|
|
44
50
|
|
|
45
51
|
Returns:
|
|
46
52
|
A wrapper function suitable for registration with MCP
|
|
@@ -101,6 +107,19 @@ def create_function_wrapper(
|
|
|
101
107
|
await ctx.report_progress(0, 100)
|
|
102
108
|
|
|
103
109
|
try:
|
|
110
|
+
# Helper function to wrap function calls with observability
|
|
111
|
+
async def call_with_observability(func_call):
|
|
112
|
+
# Use workflow's observability context (workflow should always be available)
|
|
113
|
+
if not workflow:
|
|
114
|
+
logger.error("Missing workflow context for function %s - observability will not be available",
|
|
115
|
+
function_name)
|
|
116
|
+
raise RuntimeError("Workflow context is required for observability")
|
|
117
|
+
|
|
118
|
+
logger.debug("Starting observability context for function %s", function_name)
|
|
119
|
+
context_state = ContextState.get()
|
|
120
|
+
async with workflow.exporter_manager.start(context_state=context_state):
|
|
121
|
+
return await func_call()
|
|
122
|
+
|
|
104
123
|
# Special handling for ChatRequest
|
|
105
124
|
if is_chat_request:
|
|
106
125
|
from nat.data_models.api_server import ChatRequest
|
|
@@ -118,7 +137,7 @@ def create_function_wrapper(
|
|
|
118
137
|
result = await runner.result(to_type=str)
|
|
119
138
|
else:
|
|
120
139
|
# Regular functions use ainvoke
|
|
121
|
-
result = await function.ainvoke(chat_request, to_type=str)
|
|
140
|
+
result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str))
|
|
122
141
|
else:
|
|
123
142
|
# Regular handling
|
|
124
143
|
# Handle complex input schema - if we extracted fields from a nested schema,
|
|
@@ -129,7 +148,7 @@ def create_function_wrapper(
|
|
|
129
148
|
field_type = schema.model_fields[field_name].annotation
|
|
130
149
|
|
|
131
150
|
# If it's a pydantic model, we need to create an instance
|
|
132
|
-
if hasattr(field_type, "model_validate"):
|
|
151
|
+
if field_type and hasattr(field_type, "model_validate"):
|
|
133
152
|
# Create the nested object
|
|
134
153
|
nested_obj = field_type.model_validate(kwargs)
|
|
135
154
|
# Call with the nested object
|
|
@@ -147,7 +166,7 @@ def create_function_wrapper(
|
|
|
147
166
|
result = await runner.result(to_type=str)
|
|
148
167
|
else:
|
|
149
168
|
# Regular function call
|
|
150
|
-
result = await function.acall_invoke(**kwargs)
|
|
169
|
+
result = await call_with_observability(lambda: function.acall_invoke(**kwargs))
|
|
151
170
|
|
|
152
171
|
# Report completion
|
|
153
172
|
if ctx:
|
|
@@ -156,7 +175,7 @@ def create_function_wrapper(
|
|
|
156
175
|
# Handle different result types for proper formatting
|
|
157
176
|
if isinstance(result, str):
|
|
158
177
|
return result
|
|
159
|
-
if isinstance(result,
|
|
178
|
+
if isinstance(result, dict | list):
|
|
160
179
|
return json.dumps(result, default=str)
|
|
161
180
|
return str(result)
|
|
162
181
|
except Exception as e:
|
|
@@ -170,7 +189,7 @@ def create_function_wrapper(
|
|
|
170
189
|
wrapper = create_wrapper()
|
|
171
190
|
|
|
172
191
|
# Set the signature on the wrapper function (WITHOUT ctx)
|
|
173
|
-
wrapper.__signature__ = sig
|
|
192
|
+
wrapper.__signature__ = sig # type: ignore
|
|
174
193
|
wrapper.__name__ = function_name
|
|
175
194
|
|
|
176
195
|
# Return the wrapper with proper signature
|
|
@@ -183,8 +202,8 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
183
202
|
|
|
184
203
|
The description is determined using the following precedence:
|
|
185
204
|
1. If the function is a Workflow and has a 'description' attribute, use it.
|
|
186
|
-
2. If the Workflow's config has a '
|
|
187
|
-
3. If the Workflow's config has a '
|
|
205
|
+
2. If the Workflow's config has a 'description', use it.
|
|
206
|
+
3. If the Workflow's config has a 'topic', use it.
|
|
188
207
|
4. If the function is a regular Function, use its 'description' attribute.
|
|
189
208
|
|
|
190
209
|
Args:
|
|
@@ -195,6 +214,9 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
195
214
|
"""
|
|
196
215
|
function_description = ""
|
|
197
216
|
|
|
217
|
+
# Import here to avoid circular imports
|
|
218
|
+
from nat.builder.workflow import Workflow
|
|
219
|
+
|
|
198
220
|
if isinstance(function, Workflow):
|
|
199
221
|
config = function.config
|
|
200
222
|
|
|
@@ -207,6 +229,9 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
207
229
|
# Try to get anything that might be a description
|
|
208
230
|
elif hasattr(config, "topic") and config.topic:
|
|
209
231
|
function_description = config.topic
|
|
232
|
+
# Try to get description from the workflow config
|
|
233
|
+
elif hasattr(config, "workflow") and hasattr(config.workflow, "description") and config.workflow.description:
|
|
234
|
+
function_description = config.workflow.description
|
|
210
235
|
|
|
211
236
|
elif isinstance(function, Function):
|
|
212
237
|
function_description = function.description
|
|
@@ -214,13 +239,17 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
214
239
|
return function_description
|
|
215
240
|
|
|
216
241
|
|
|
217
|
-
def register_function_with_mcp(mcp: FastMCP,
|
|
242
|
+
def register_function_with_mcp(mcp: FastMCP,
|
|
243
|
+
function_name: str,
|
|
244
|
+
function: FunctionBase,
|
|
245
|
+
workflow: 'Workflow | None' = None) -> None:
|
|
218
246
|
"""Register a NAT Function as an MCP tool.
|
|
219
247
|
|
|
220
248
|
Args:
|
|
221
249
|
mcp: The FastMCP instance
|
|
222
250
|
function_name: The name to register the function under
|
|
223
251
|
function: The NAT Function to register
|
|
252
|
+
workflow: The parent workflow for observability context (if available)
|
|
224
253
|
"""
|
|
225
254
|
logger.info("Registering function %s with MCP", function_name)
|
|
226
255
|
|
|
@@ -229,6 +258,7 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
|
|
|
229
258
|
logger.info("Function %s has input schema: %s", function_name, input_schema)
|
|
230
259
|
|
|
231
260
|
# Check if we're dealing with a Workflow
|
|
261
|
+
from nat.builder.workflow import Workflow
|
|
232
262
|
is_workflow = isinstance(function, Workflow)
|
|
233
263
|
if is_workflow:
|
|
234
264
|
logger.info("Function %s is a Workflow", function_name)
|
|
@@ -237,5 +267,5 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
|
|
|
237
267
|
function_description = get_function_description(function)
|
|
238
268
|
|
|
239
269
|
# Create and register the wrapper function with MCP
|
|
240
|
-
wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow)
|
|
270
|
+
wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow, workflow)
|
|
241
271
|
mcp.tool(name=function_name, description=function_description)(wrapper_func)
|
nat/front_ends/register.py
CHANGED
|
@@ -35,6 +35,8 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
|
|
|
35
35
|
|
|
36
36
|
async def run(self):
|
|
37
37
|
|
|
38
|
+
await self.pre_run()
|
|
39
|
+
|
|
38
40
|
# Must yield the workflow function otherwise it cleans up
|
|
39
41
|
async with WorkflowBuilder.from_config(config=self.full_config) as builder:
|
|
40
42
|
|
|
@@ -45,7 +47,7 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
|
|
|
45
47
|
|
|
46
48
|
click.echo(stream.getvalue())
|
|
47
49
|
|
|
48
|
-
workflow = builder.build()
|
|
50
|
+
workflow = await builder.build()
|
|
49
51
|
session_manager = SessionManager(workflow)
|
|
50
52
|
await self.run_workflow(session_manager)
|
|
51
53
|
|
nat/llm/aws_bedrock_llm.py
CHANGED
|
@@ -21,27 +21,39 @@ from nat.builder.builder import Builder
|
|
|
21
21
|
from nat.builder.llm import LLMProviderInfo
|
|
22
22
|
from nat.cli.register_workflow import register_llm_provider
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
|
+
from nat.data_models.optimizable import OptimizableField
|
|
25
|
+
from nat.data_models.optimizable import OptimizableMixin
|
|
26
|
+
from nat.data_models.optimizable import SearchSpace
|
|
24
27
|
from nat.data_models.retry_mixin import RetryMixin
|
|
28
|
+
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
29
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
30
|
+
from nat.data_models.top_p_mixin import TopPMixin
|
|
25
31
|
|
|
26
32
|
|
|
27
|
-
class AWSBedrockModelConfig(LLMBaseConfig,
|
|
33
|
+
class AWSBedrockModelConfig(LLMBaseConfig,
|
|
34
|
+
RetryMixin,
|
|
35
|
+
OptimizableMixin,
|
|
36
|
+
TemperatureMixin,
|
|
37
|
+
TopPMixin,
|
|
38
|
+
ThinkingMixin,
|
|
39
|
+
name="aws_bedrock"):
|
|
28
40
|
"""An AWS Bedrock llm provider to be used with an LLM client."""
|
|
29
41
|
|
|
30
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
42
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
31
43
|
|
|
32
44
|
# Completion parameters
|
|
33
45
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
34
46
|
serialization_alias="model",
|
|
35
47
|
description="The model name for the hosted AWS Bedrock.")
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
48
|
+
max_tokens: int = OptimizableField(default=300,
|
|
49
|
+
description="Maximum number of tokens to generate.",
|
|
50
|
+
space=SearchSpace(high=2176, low=128, step=512))
|
|
51
|
+
context_size: int | None = Field(
|
|
52
|
+
default=1024,
|
|
53
|
+
gt=0,
|
|
54
|
+
description="The maximum number of tokens available for input. This is only required for LlamaIndex. "
|
|
55
|
+
"This field is ignored for LangChain/LangGraph.",
|
|
56
|
+
)
|
|
45
57
|
|
|
46
58
|
# Client parameters
|
|
47
59
|
region_name: str | None = Field(default="None", description="AWS region to use.")
|
|
@@ -52,6 +64,6 @@ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
|
|
|
52
64
|
|
|
53
65
|
|
|
54
66
|
@register_llm_provider(config_type=AWSBedrockModelConfig)
|
|
55
|
-
async def aws_bedrock_model(llm_config: AWSBedrockModelConfig,
|
|
67
|
+
async def aws_bedrock_model(llm_config: AWSBedrockModelConfig, _builder: Builder):
|
|
56
68
|
|
|
57
69
|
yield LLMProviderInfo(config=llm_config, description="A AWS Bedrock model for use with an LLM client.")
|
nat/llm/azure_openai_llm.py
CHANGED
|
@@ -22,9 +22,19 @@ from nat.builder.llm import LLMProviderInfo
|
|
|
22
22
|
from nat.cli.register_workflow import register_llm_provider
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
24
|
from nat.data_models.retry_mixin import RetryMixin
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
25
|
+
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
26
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
27
|
+
from nat.data_models.top_p_mixin import TopPMixin
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AzureOpenAIModelConfig(
|
|
31
|
+
LLMBaseConfig,
|
|
32
|
+
RetryMixin,
|
|
33
|
+
TemperatureMixin,
|
|
34
|
+
TopPMixin,
|
|
35
|
+
ThinkingMixin,
|
|
36
|
+
name="azure_openai",
|
|
37
|
+
):
|
|
28
38
|
"""An Azure OpenAI LLM provider to be used with an LLM client."""
|
|
29
39
|
|
|
30
40
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
@@ -38,10 +48,7 @@ class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, name="azure_openai"):
|
|
|
38
48
|
azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
|
|
39
49
|
serialization_alias="azure_deployment",
|
|
40
50
|
description="The Azure OpenAI hosted model/deployment name.")
|
|
41
|
-
temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
|
|
42
|
-
top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
|
|
43
51
|
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
44
|
-
max_retries: int = Field(default=10, description="The max number of retries for the request.")
|
|
45
52
|
|
|
46
53
|
|
|
47
54
|
@register_llm_provider(config_type=AzureOpenAIModelConfig)
|
nat/llm/litellm_llm.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections.abc import AsyncIterator
|
|
17
|
+
|
|
18
|
+
from pydantic import AliasChoices
|
|
19
|
+
from pydantic import ConfigDict
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
|
|
22
|
+
from nat.builder.builder import Builder
|
|
23
|
+
from nat.builder.llm import LLMProviderInfo
|
|
24
|
+
from nat.cli.register_workflow import register_llm_provider
|
|
25
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
26
|
+
from nat.data_models.retry_mixin import RetryMixin
|
|
27
|
+
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
28
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
29
|
+
from nat.data_models.top_p_mixin import TopPMixin
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class LiteLlmModelConfig(
|
|
33
|
+
LLMBaseConfig,
|
|
34
|
+
RetryMixin,
|
|
35
|
+
TemperatureMixin,
|
|
36
|
+
TopPMixin,
|
|
37
|
+
ThinkingMixin,
|
|
38
|
+
name="litellm",
|
|
39
|
+
):
|
|
40
|
+
"""A LiteLlm provider to be used with an LLM client."""
|
|
41
|
+
|
|
42
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
43
|
+
|
|
44
|
+
api_key: str | None = Field(default=None, description="API key to interact with hosted model.")
|
|
45
|
+
base_url: str | None = Field(default=None,
|
|
46
|
+
description="Base url to the hosted model.",
|
|
47
|
+
validation_alias=AliasChoices("base_url", "api_base"),
|
|
48
|
+
serialization_alias="api_base")
|
|
49
|
+
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
50
|
+
serialization_alias="model",
|
|
51
|
+
description="The LiteLlm hosted model name.")
|
|
52
|
+
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@register_llm_provider(config_type=LiteLlmModelConfig)
|
|
56
|
+
async def litellm_model(
|
|
57
|
+
config: LiteLlmModelConfig,
|
|
58
|
+
_builder: Builder,
|
|
59
|
+
) -> AsyncIterator[LLMProviderInfo]:
|
|
60
|
+
"""Litellm model provider.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
config (LiteLlmModelConfig): The LiteLlm model configuration.
|
|
64
|
+
_builder (Builder): The NAT builder instance.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
AsyncIterator[LLMProviderInfo]: An async iterator that yields an LLMProviderInfo object.
|
|
68
|
+
"""
|
|
69
|
+
yield LLMProviderInfo(config=config, description="A LiteLlm model for use with an LLM client.")
|
nat/llm/nim_llm.py
CHANGED
|
@@ -22,25 +22,37 @@ from nat.builder.builder import Builder
|
|
|
22
22
|
from nat.builder.llm import LLMProviderInfo
|
|
23
23
|
from nat.cli.register_workflow import register_llm_provider
|
|
24
24
|
from nat.data_models.llm import LLMBaseConfig
|
|
25
|
+
from nat.data_models.optimizable import OptimizableField
|
|
26
|
+
from nat.data_models.optimizable import OptimizableMixin
|
|
27
|
+
from nat.data_models.optimizable import SearchSpace
|
|
25
28
|
from nat.data_models.retry_mixin import RetryMixin
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
+
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
30
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
31
|
+
from nat.data_models.top_p_mixin import TopPMixin
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class NIMModelConfig(LLMBaseConfig,
|
|
35
|
+
RetryMixin,
|
|
36
|
+
OptimizableMixin,
|
|
37
|
+
TemperatureMixin,
|
|
38
|
+
TopPMixin,
|
|
39
|
+
ThinkingMixin,
|
|
40
|
+
name="nim"):
|
|
29
41
|
"""An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""
|
|
30
42
|
|
|
31
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
43
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
32
44
|
|
|
33
45
|
api_key: str | None = Field(default=None, description="NVIDIA API key to interact with hosted NIM.")
|
|
34
46
|
base_url: str | None = Field(default=None, description="Base url to the hosted NIM.")
|
|
35
47
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
36
48
|
serialization_alias="model",
|
|
37
49
|
description="The model name for the hosted NIM.")
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
50
|
+
max_tokens: PositiveInt = OptimizableField(default=300,
|
|
51
|
+
description="Maximum number of tokens to generate.",
|
|
52
|
+
space=SearchSpace(high=2176, low=128, step=512))
|
|
41
53
|
|
|
42
54
|
|
|
43
55
|
@register_llm_provider(config_type=NIMModelConfig)
|
|
44
|
-
async def nim_model(llm_config: NIMModelConfig,
|
|
56
|
+
async def nim_model(llm_config: NIMModelConfig, _builder: Builder):
|
|
45
57
|
|
|
46
58
|
yield LLMProviderInfo(config=llm_config, description="A NIM model for use with an LLM client.")
|
nat/llm/openai_llm.py
CHANGED
|
@@ -21,10 +21,20 @@ from nat.builder.builder import Builder
|
|
|
21
21
|
from nat.builder.llm import LLMProviderInfo
|
|
22
22
|
from nat.cli.register_workflow import register_llm_provider
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
|
+
from nat.data_models.optimizable import OptimizableMixin
|
|
24
25
|
from nat.data_models.retry_mixin import RetryMixin
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
26
|
+
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
27
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
28
|
+
from nat.data_models.top_p_mixin import TopPMixin
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OpenAIModelConfig(LLMBaseConfig,
|
|
32
|
+
RetryMixin,
|
|
33
|
+
OptimizableMixin,
|
|
34
|
+
TemperatureMixin,
|
|
35
|
+
TopPMixin,
|
|
36
|
+
ThinkingMixin,
|
|
37
|
+
name="openai"):
|
|
28
38
|
"""An OpenAI LLM provider to be used with an LLM client."""
|
|
29
39
|
|
|
30
40
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
@@ -34,13 +44,11 @@ class OpenAIModelConfig(LLMBaseConfig, RetryMixin, name="openai"):
|
|
|
34
44
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
35
45
|
serialization_alias="model",
|
|
36
46
|
description="The OpenAI hosted model name.")
|
|
37
|
-
temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
|
|
38
|
-
top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
|
|
39
47
|
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
40
48
|
max_retries: int = Field(default=10, description="The max number of retries for the request.")
|
|
41
49
|
|
|
42
50
|
|
|
43
51
|
@register_llm_provider(config_type=OpenAIModelConfig)
|
|
44
|
-
async def openai_llm(config: OpenAIModelConfig,
|
|
52
|
+
async def openai_llm(config: OpenAIModelConfig, _builder: Builder):
|
|
45
53
|
|
|
46
54
|
yield LLMProviderInfo(config=config, description="An OpenAI model for use with an LLM client.")
|
nat/llm/register.py
CHANGED
|
@@ -13,12 +13,15 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint: disable=unused-import
|
|
17
16
|
# flake8: noqa
|
|
18
17
|
# isort:skip_file
|
|
18
|
+
"""Register LLM providers via import side effects.
|
|
19
19
|
|
|
20
|
+
This module is imported by the NeMo Agent Toolkit runtime to ensure providers are registered and discoverable.
|
|
21
|
+
"""
|
|
20
22
|
# Import any providers which need to be automatically registered here
|
|
21
23
|
from . import aws_bedrock_llm
|
|
22
24
|
from . import azure_openai_llm
|
|
25
|
+
from . import litellm_llm
|
|
23
26
|
from . import nim_llm
|
|
24
27
|
from . import openai_llm
|
|
@@ -72,9 +72,8 @@ class EnvConfigValue(ABC):
|
|
|
72
72
|
f"{message} Try passing a value to the constructor, or setting the `{self.__class__._ENV_KEY}` "
|
|
73
73
|
"environment variable.")
|
|
74
74
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
raise ValueError("value must not be none")
|
|
75
|
+
elif not self.__class__._ALLOW_NONE and value is None:
|
|
76
|
+
raise ValueError("value must not be none")
|
|
78
77
|
|
|
79
78
|
assert isinstance(value, str) or value is None
|
|
80
79
|
|