nvidia-nat 1.4.0a20251102__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/builder/builder.py +52 -0
- nat/builder/component_utils.py +7 -1
- nat/builder/context.py +17 -0
- nat/builder/framework_enum.py +1 -0
- nat/builder/function.py +74 -3
- nat/builder/workflow.py +4 -2
- nat/builder/workflow_builder.py +129 -0
- nat/cli/commands/workflow/workflow_commands.py +3 -2
- nat/cli/register_workflow.py +50 -0
- nat/cli/type_registry.py +68 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +16 -0
- nat/data_models/function.py +14 -1
- nat/data_models/middleware.py +35 -0
- nat/data_models/runtime_enum.py +26 -0
- nat/eval/dataset_handler/dataset_filter.py +34 -2
- nat/eval/evaluate.py +11 -3
- nat/eval/utils/weave_eval.py +17 -3
- nat/front_ends/fastapi/fastapi_front_end_config.py +29 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +13 -7
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +144 -14
- nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
- nat/llm/aws_bedrock_llm.py +11 -9
- nat/llm/azure_openai_llm.py +12 -4
- nat/llm/litellm_llm.py +11 -4
- nat/llm/nim_llm.py +11 -9
- nat/llm/openai_llm.py +12 -9
- nat/middleware/__init__.py +35 -0
- nat/middleware/cache_middleware.py +256 -0
- nat/middleware/function_middleware.py +186 -0
- nat/middleware/middleware.py +184 -0
- nat/middleware/register.py +35 -0
- nat/profiler/decorators/framework_wrapper.py +16 -0
- nat/retriever/milvus/register.py +11 -3
- nat/retriever/milvus/retriever.py +102 -40
- nat/runtime/runner.py +12 -1
- nat/runtime/session.py +10 -3
- nat/tool/code_execution/code_sandbox.py +4 -7
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +5 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
- nat/tool/server_tools.py +15 -2
- nat/utils/__init__.py +8 -4
- nat/utils/io/yaml_tools.py +73 -3
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +11 -3
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +54 -50
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
- nat/data_models/temperature_mixin.py +0 -44
- nat/data_models/top_p_mixin.py +0 -44
- nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
|
@@ -39,6 +39,8 @@ from pydantic import BaseModel
|
|
|
39
39
|
from pydantic import Field
|
|
40
40
|
from starlette.websockets import WebSocket
|
|
41
41
|
|
|
42
|
+
from nat.builder.eval_builder import WorkflowEvalBuilder
|
|
43
|
+
from nat.builder.evaluator import EvaluatorInfo
|
|
42
44
|
from nat.builder.function import Function
|
|
43
45
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
44
46
|
from nat.data_models.api_server import ChatRequest
|
|
@@ -51,11 +53,14 @@ from nat.data_models.object_store import NoSuchKeyError
|
|
|
51
53
|
from nat.eval.config import EvaluationRunOutput
|
|
52
54
|
from nat.eval.evaluate import EvaluationRun
|
|
53
55
|
from nat.eval.evaluate import EvaluationRunConfig
|
|
56
|
+
from nat.eval.evaluator.evaluator_model import EvalInput
|
|
54
57
|
from nat.front_ends.fastapi.auth_flow_handlers.http_flow_handler import HTTPAuthenticationFlowHandler
|
|
55
58
|
from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import FlowState
|
|
56
59
|
from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import WebSocketAuthenticationFlowHandler
|
|
57
60
|
from nat.front_ends.fastapi.fastapi_front_end_config import AsyncGenerateResponse
|
|
58
61
|
from nat.front_ends.fastapi.fastapi_front_end_config import AsyncGenerationStatusResponse
|
|
62
|
+
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateItemRequest
|
|
63
|
+
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateItemResponse
|
|
59
64
|
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateRequest
|
|
60
65
|
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateResponse
|
|
61
66
|
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateStatusResponse
|
|
@@ -227,6 +232,54 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
227
232
|
self._outstanding_flows: dict[str, FlowState] = {}
|
|
228
233
|
self._outstanding_flows_lock = asyncio.Lock()
|
|
229
234
|
|
|
235
|
+
# Evaluator storage for single-item evaluation
|
|
236
|
+
self._evaluators: dict[str, EvaluatorInfo] = {}
|
|
237
|
+
self._eval_builder: WorkflowEvalBuilder | None = None
|
|
238
|
+
|
|
239
|
+
async def initialize_evaluators(self, config: Config):
|
|
240
|
+
"""Initialize and store evaluators from config for single-item evaluation."""
|
|
241
|
+
if not config.eval or not config.eval.evaluators:
|
|
242
|
+
logger.info("No evaluators configured, skipping evaluator initialization")
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
# Build evaluators using WorkflowEvalBuilder (same pattern as nat eval)
|
|
247
|
+
# Start with registry=None and let populate_builder set everything up
|
|
248
|
+
self._eval_builder = WorkflowEvalBuilder(general_config=config.general,
|
|
249
|
+
eval_general_config=config.eval.general,
|
|
250
|
+
registry=None)
|
|
251
|
+
|
|
252
|
+
# Enter the async context and keep it alive
|
|
253
|
+
await self._eval_builder.__aenter__()
|
|
254
|
+
|
|
255
|
+
# Populate builder with config (this sets up LLMs, functions, etc.)
|
|
256
|
+
# Skip workflow build since we already have it from the main builder
|
|
257
|
+
await self._eval_builder.populate_builder(config, skip_workflow=True)
|
|
258
|
+
|
|
259
|
+
# Now evaluators should be populated by populate_builder
|
|
260
|
+
for name in config.eval.evaluators.keys():
|
|
261
|
+
self._evaluators[name] = self._eval_builder.get_evaluator(name)
|
|
262
|
+
logger.info(f"Initialized evaluator: {name}")
|
|
263
|
+
|
|
264
|
+
logger.info(f"Successfully initialized {len(self._evaluators)} evaluators")
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
logger.error(f"Failed to initialize evaluators: {e}")
|
|
268
|
+
# Don't fail startup, just log the error
|
|
269
|
+
self._evaluators = {}
|
|
270
|
+
|
|
271
|
+
async def cleanup_evaluators(self):
|
|
272
|
+
"""Clean up evaluator resources on shutdown."""
|
|
273
|
+
if self._eval_builder:
|
|
274
|
+
try:
|
|
275
|
+
await self._eval_builder.__aexit__(None, None, None)
|
|
276
|
+
logger.info("Evaluator builder context cleaned up")
|
|
277
|
+
except Exception as e:
|
|
278
|
+
logger.error(f"Error cleaning up evaluator builder: {e}")
|
|
279
|
+
finally:
|
|
280
|
+
self._eval_builder = None
|
|
281
|
+
self._evaluators.clear()
|
|
282
|
+
|
|
230
283
|
def get_step_adaptor(self) -> StepAdaptor:
|
|
231
284
|
|
|
232
285
|
return StepAdaptor(self.front_end_config.step_adaptor)
|
|
@@ -236,12 +289,20 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
236
289
|
# Do things like setting the base URL and global configuration options
|
|
237
290
|
app.root_path = self.front_end_config.root_path
|
|
238
291
|
|
|
292
|
+
# Initialize evaluators for single-item evaluation
|
|
293
|
+
# TODO: we need config control over this as it's not always needed
|
|
294
|
+
await self.initialize_evaluators(self._config)
|
|
295
|
+
|
|
296
|
+
# Ensure evaluator resources are cleaned up when the app shuts down
|
|
297
|
+
app.add_event_handler("shutdown", self.cleanup_evaluators)
|
|
298
|
+
|
|
239
299
|
await self.add_routes(app, builder)
|
|
240
300
|
|
|
241
301
|
async def add_routes(self, app: FastAPI, builder: WorkflowBuilder):
|
|
242
302
|
|
|
243
303
|
await self.add_default_route(app, SessionManager(await builder.build()))
|
|
244
304
|
await self.add_evaluate_route(app, SessionManager(await builder.build()))
|
|
305
|
+
await self.add_evaluate_item_route(app, SessionManager(await builder.build()))
|
|
245
306
|
await self.add_static_files_route(app, builder)
|
|
246
307
|
await self.add_authorization_route(app)
|
|
247
308
|
await self.add_mcp_client_tool_list_route(app, builder)
|
|
@@ -439,6 +500,69 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
439
500
|
else:
|
|
440
501
|
logger.warning("Dask is not available, evaluation endpoints will not be added.")
|
|
441
502
|
|
|
503
|
+
async def add_evaluate_item_route(self, app: FastAPI, session_manager: SessionManager):
|
|
504
|
+
"""Add the single-item evaluation endpoint to the FastAPI app."""
|
|
505
|
+
|
|
506
|
+
async def evaluate_single_item(request: EvaluateItemRequest, http_request: Request) -> EvaluateItemResponse:
|
|
507
|
+
"""Handle single-item evaluation requests."""
|
|
508
|
+
|
|
509
|
+
async with session_manager.session(http_connection=http_request):
|
|
510
|
+
|
|
511
|
+
# Check if evaluator exists
|
|
512
|
+
if request.evaluator_name not in self._evaluators:
|
|
513
|
+
raise HTTPException(status_code=404,
|
|
514
|
+
detail=f"Evaluator '{request.evaluator_name}' not found. "
|
|
515
|
+
f"Available evaluators: {list(self._evaluators.keys())}")
|
|
516
|
+
|
|
517
|
+
try:
|
|
518
|
+
# Get the evaluator
|
|
519
|
+
evaluator = self._evaluators[request.evaluator_name]
|
|
520
|
+
|
|
521
|
+
# Run evaluation on single item
|
|
522
|
+
result = await evaluator.evaluate_fn(EvalInput(eval_input_items=[request.item]))
|
|
523
|
+
|
|
524
|
+
# Extract the single output item
|
|
525
|
+
if result.eval_output_items:
|
|
526
|
+
output_item = result.eval_output_items[0]
|
|
527
|
+
return EvaluateItemResponse(success=True, result=output_item, error=None)
|
|
528
|
+
else:
|
|
529
|
+
return EvaluateItemResponse(success=False, result=None, error="Evaluator returned no results")
|
|
530
|
+
|
|
531
|
+
except Exception as e:
|
|
532
|
+
logger.exception(f"Error evaluating item with {request.evaluator_name}")
|
|
533
|
+
return EvaluateItemResponse(success=False, result=None, error=f"Evaluation failed: {str(e)}")
|
|
534
|
+
|
|
535
|
+
# Register the route
|
|
536
|
+
if self.front_end_config.evaluate_item.path:
|
|
537
|
+
app.add_api_route(path=self.front_end_config.evaluate_item.path,
|
|
538
|
+
endpoint=evaluate_single_item,
|
|
539
|
+
methods=[self.front_end_config.evaluate_item.method],
|
|
540
|
+
response_model=EvaluateItemResponse,
|
|
541
|
+
description=self.front_end_config.evaluate_item.description,
|
|
542
|
+
responses={
|
|
543
|
+
404: {
|
|
544
|
+
"description": "Evaluator not found",
|
|
545
|
+
"content": {
|
|
546
|
+
"application/json": {
|
|
547
|
+
"example": {
|
|
548
|
+
"detail": "Evaluator 'unknown' not found"
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
}
|
|
552
|
+
},
|
|
553
|
+
500: {
|
|
554
|
+
"description": "Internal Server Error",
|
|
555
|
+
"content": {
|
|
556
|
+
"application/json": {
|
|
557
|
+
"example": {
|
|
558
|
+
"detail": "Internal server error occurred"
|
|
559
|
+
}
|
|
560
|
+
}
|
|
561
|
+
}
|
|
562
|
+
}
|
|
563
|
+
})
|
|
564
|
+
logger.info(f"Added evaluate_item route at {self.front_end_config.evaluate_item.path}")
|
|
565
|
+
|
|
442
566
|
async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
|
|
443
567
|
|
|
444
568
|
if not self.front_end_config.object_store:
|
|
@@ -544,7 +668,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
544
668
|
GenerateStreamResponseType = workflow.streaming_output_schema
|
|
545
669
|
GenerateSingleResponseType = workflow.single_output_schema
|
|
546
670
|
|
|
547
|
-
|
|
671
|
+
# Skip async generation for custom routes (those with function_name)
|
|
672
|
+
if self._dask_available and not hasattr(endpoint, 'function_name'):
|
|
548
673
|
# Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
|
|
549
674
|
# Consider prefixing these with "nat_" to avoid conflicts
|
|
550
675
|
|
|
@@ -562,6 +687,10 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
562
687
|
description="Optional time (in seconds) before the job expires. "
|
|
563
688
|
"Clamped between 600 (10 min) and 86400 (24h).")
|
|
564
689
|
|
|
690
|
+
def validate_model(self):
|
|
691
|
+
# Override to ensure that the parent class validator is not called
|
|
692
|
+
return self
|
|
693
|
+
|
|
565
694
|
# Ensure that the input is in the body. POD types are treated as query parameters
|
|
566
695
|
if (not issubclass(GenerateBodyType, BaseModel)):
|
|
567
696
|
GenerateBodyType = typing.Annotated[GenerateBodyType, Body()]
|
|
@@ -760,17 +889,18 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
760
889
|
return AsyncGenerateResponse(job_id=job.job_id, status=job.status)
|
|
761
890
|
|
|
762
891
|
job_id = self._job_store.ensure_job_id(request.job_id)
|
|
763
|
-
(_, job) = await self._job_store.submit_job(
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
892
|
+
(_, job) = await self._job_store.submit_job(
|
|
893
|
+
job_id=job_id,
|
|
894
|
+
expiry_seconds=request.expiry_seconds,
|
|
895
|
+
job_fn=run_generation,
|
|
896
|
+
sync_timeout=request.sync_timeout,
|
|
897
|
+
job_args=[
|
|
898
|
+
self._scheduler_address,
|
|
899
|
+
self._db_url,
|
|
900
|
+
self._config_file_path,
|
|
901
|
+
job_id,
|
|
902
|
+
request.model_dump(mode="json", exclude=["job_id", "sync_timeout", "expiry_seconds"])
|
|
903
|
+
])
|
|
774
904
|
|
|
775
905
|
if job is not None:
|
|
776
906
|
response.status_code = 200
|
|
@@ -916,7 +1046,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
916
1046
|
responses={500: response_500},
|
|
917
1047
|
)
|
|
918
1048
|
|
|
919
|
-
if self._dask_available:
|
|
1049
|
+
if self._dask_available and not hasattr(endpoint, 'function_name'):
|
|
920
1050
|
app.add_api_route(
|
|
921
1051
|
path=f"{endpoint.path}/async",
|
|
922
1052
|
endpoint=post_async_generation(request_type=AsyncGenerateRequest),
|
|
@@ -930,7 +1060,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
930
1060
|
else:
|
|
931
1061
|
raise ValueError(f"Unsupported method {endpoint.method}")
|
|
932
1062
|
|
|
933
|
-
if self._dask_available:
|
|
1063
|
+
if self._dask_available and not hasattr(endpoint, 'function_name'):
|
|
934
1064
|
app.add_api_route(
|
|
935
1065
|
path=f"{endpoint.path}/async/job/{{job_id}}",
|
|
936
1066
|
endpoint=get_async_job_status,
|
|
@@ -140,6 +140,10 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
|
|
|
140
140
|
# Mount the MCP server's ASGI app at the configured base_path
|
|
141
141
|
app.mount(self.front_end_config.base_path, mcp.streamable_http_app())
|
|
142
142
|
|
|
143
|
+
# Allow plugins to add routes to the wrapper app (e.g., OAuth discovery endpoints)
|
|
144
|
+
worker = self._get_worker_instance()
|
|
145
|
+
await worker.add_root_level_routes(app, mcp)
|
|
146
|
+
|
|
143
147
|
# Configure and start uvicorn server
|
|
144
148
|
config = uvicorn.Config(
|
|
145
149
|
app,
|
|
@@ -17,12 +17,16 @@ import logging
|
|
|
17
17
|
from abc import ABC
|
|
18
18
|
from abc import abstractmethod
|
|
19
19
|
from collections.abc import Mapping
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
20
21
|
from typing import Any
|
|
21
22
|
|
|
22
23
|
from mcp.server.fastmcp import FastMCP
|
|
23
24
|
from starlette.exceptions import HTTPException
|
|
24
25
|
from starlette.requests import Request
|
|
25
26
|
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from fastapi import FastAPI
|
|
29
|
+
|
|
26
30
|
from nat.builder.function import Function
|
|
27
31
|
from nat.builder.function_base import FunctionBase
|
|
28
32
|
from nat.builder.workflow import Workflow
|
|
@@ -192,6 +196,28 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
192
196
|
|
|
193
197
|
return functions
|
|
194
198
|
|
|
199
|
+
async def add_root_level_routes(self, wrapper_app: "FastAPI", mcp: FastMCP) -> None:
|
|
200
|
+
"""Add routes to the wrapper FastAPI app (optional extension point).
|
|
201
|
+
|
|
202
|
+
This method is called when base_path is configured and a wrapper
|
|
203
|
+
FastAPI app is created to mount the MCP server. Plugins can override
|
|
204
|
+
this to add routes to the wrapper app at the root level, outside the
|
|
205
|
+
mounted MCP server path.
|
|
206
|
+
|
|
207
|
+
Common use cases:
|
|
208
|
+
- OAuth discovery endpoints (e.g., /.well-known/oauth-protected-resource)
|
|
209
|
+
- Health checks at root level
|
|
210
|
+
- Static file serving
|
|
211
|
+
- Custom authentication/authorization endpoints
|
|
212
|
+
|
|
213
|
+
Default implementation does nothing, making this an optional extension point.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
wrapper_app: The FastAPI wrapper application that mounts the MCP server
|
|
217
|
+
mcp: The FastMCP server instance (already mounted at base_path)
|
|
218
|
+
"""
|
|
219
|
+
pass # Default: no additional root-level routes
|
|
220
|
+
|
|
195
221
|
def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
|
|
196
222
|
"""Set up HTTP debug endpoints for introspecting tools and schemas.
|
|
197
223
|
|
nat/llm/aws_bedrock_llm.py
CHANGED
|
@@ -25,18 +25,10 @@ from nat.data_models.optimizable import OptimizableField
|
|
|
25
25
|
from nat.data_models.optimizable import OptimizableMixin
|
|
26
26
|
from nat.data_models.optimizable import SearchSpace
|
|
27
27
|
from nat.data_models.retry_mixin import RetryMixin
|
|
28
|
-
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
29
28
|
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
30
|
-
from nat.data_models.top_p_mixin import TopPMixin
|
|
31
29
|
|
|
32
30
|
|
|
33
|
-
class AWSBedrockModelConfig(LLMBaseConfig,
|
|
34
|
-
RetryMixin,
|
|
35
|
-
OptimizableMixin,
|
|
36
|
-
TemperatureMixin,
|
|
37
|
-
TopPMixin,
|
|
38
|
-
ThinkingMixin,
|
|
39
|
-
name="aws_bedrock"):
|
|
31
|
+
class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="aws_bedrock"):
|
|
40
32
|
"""An AWS Bedrock llm provider to be used with an LLM client."""
|
|
41
33
|
|
|
42
34
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
@@ -61,6 +53,16 @@ class AWSBedrockModelConfig(LLMBaseConfig,
|
|
|
61
53
|
default=None, description="Bedrock endpoint to use. Needed if you don't want to default to us-east-1 endpoint.")
|
|
62
54
|
credentials_profile_name: str | None = Field(
|
|
63
55
|
default=None, description="The name of the profile in the ~/.aws/credentials or ~/.aws/config files.")
|
|
56
|
+
temperature: float | None = OptimizableField(
|
|
57
|
+
default=None,
|
|
58
|
+
ge=0.0,
|
|
59
|
+
description="Sampling temperature to control randomness in the output.",
|
|
60
|
+
space=SearchSpace(high=0.9, low=0.1, step=0.2))
|
|
61
|
+
top_p: float | None = OptimizableField(default=None,
|
|
62
|
+
ge=0.0,
|
|
63
|
+
le=1.0,
|
|
64
|
+
description="Top-p for distribution sampling.",
|
|
65
|
+
space=SearchSpace(high=1.0, low=0.5, step=0.1))
|
|
64
66
|
|
|
65
67
|
|
|
66
68
|
@register_llm_provider(config_type=AWSBedrockModelConfig)
|
nat/llm/azure_openai_llm.py
CHANGED
|
@@ -22,17 +22,15 @@ from nat.builder.llm import LLMProviderInfo
|
|
|
22
22
|
from nat.cli.register_workflow import register_llm_provider
|
|
23
23
|
from nat.data_models.common import OptionalSecretStr
|
|
24
24
|
from nat.data_models.llm import LLMBaseConfig
|
|
25
|
+
from nat.data_models.optimizable import OptimizableField
|
|
26
|
+
from nat.data_models.optimizable import SearchSpace
|
|
25
27
|
from nat.data_models.retry_mixin import RetryMixin
|
|
26
|
-
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
27
28
|
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
28
|
-
from nat.data_models.top_p_mixin import TopPMixin
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class AzureOpenAIModelConfig(
|
|
32
32
|
LLMBaseConfig,
|
|
33
33
|
RetryMixin,
|
|
34
|
-
TemperatureMixin,
|
|
35
|
-
TopPMixin,
|
|
36
34
|
ThinkingMixin,
|
|
37
35
|
name="azure_openai",
|
|
38
36
|
):
|
|
@@ -50,6 +48,16 @@ class AzureOpenAIModelConfig(
|
|
|
50
48
|
serialization_alias="azure_deployment",
|
|
51
49
|
description="The Azure OpenAI hosted model/deployment name.")
|
|
52
50
|
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
51
|
+
temperature: float | None = OptimizableField(
|
|
52
|
+
default=None,
|
|
53
|
+
ge=0.0,
|
|
54
|
+
description="Sampling temperature to control randomness in the output.",
|
|
55
|
+
space=SearchSpace(high=0.9, low=0.1, step=0.2))
|
|
56
|
+
top_p: float | None = OptimizableField(default=None,
|
|
57
|
+
ge=0.0,
|
|
58
|
+
le=1.0,
|
|
59
|
+
description="Top-p for distribution sampling.",
|
|
60
|
+
space=SearchSpace(high=1.0, low=0.5, step=0.1))
|
|
53
61
|
|
|
54
62
|
|
|
55
63
|
@register_llm_provider(config_type=AzureOpenAIModelConfig)
|
nat/llm/litellm_llm.py
CHANGED
|
@@ -26,18 +26,15 @@ from nat.data_models.common import OptionalSecretStr
|
|
|
26
26
|
from nat.data_models.llm import LLMBaseConfig
|
|
27
27
|
from nat.data_models.optimizable import OptimizableField
|
|
28
28
|
from nat.data_models.optimizable import OptimizableMixin
|
|
29
|
+
from nat.data_models.optimizable import SearchSpace
|
|
29
30
|
from nat.data_models.retry_mixin import RetryMixin
|
|
30
|
-
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
31
31
|
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
32
|
-
from nat.data_models.top_p_mixin import TopPMixin
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
class LiteLlmModelConfig(
|
|
36
35
|
LLMBaseConfig,
|
|
37
36
|
OptimizableMixin,
|
|
38
37
|
RetryMixin,
|
|
39
|
-
TemperatureMixin,
|
|
40
|
-
TopPMixin,
|
|
41
38
|
ThinkingMixin,
|
|
42
39
|
name="litellm",
|
|
43
40
|
):
|
|
@@ -54,6 +51,16 @@ class LiteLlmModelConfig(
|
|
|
54
51
|
serialization_alias="model",
|
|
55
52
|
description="The LiteLlm hosted model name.")
|
|
56
53
|
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
54
|
+
temperature: float | None = OptimizableField(
|
|
55
|
+
default=None,
|
|
56
|
+
ge=0.0,
|
|
57
|
+
description="Sampling temperature to control randomness in the output.",
|
|
58
|
+
space=SearchSpace(high=0.9, low=0.1, step=0.2))
|
|
59
|
+
top_p: float | None = OptimizableField(default=None,
|
|
60
|
+
ge=0.0,
|
|
61
|
+
le=1.0,
|
|
62
|
+
description="Top-p for distribution sampling.",
|
|
63
|
+
space=SearchSpace(high=1.0, low=0.5, step=0.1))
|
|
57
64
|
|
|
58
65
|
|
|
59
66
|
@register_llm_provider(config_type=LiteLlmModelConfig)
|
nat/llm/nim_llm.py
CHANGED
|
@@ -27,18 +27,10 @@ from nat.data_models.optimizable import OptimizableField
|
|
|
27
27
|
from nat.data_models.optimizable import OptimizableMixin
|
|
28
28
|
from nat.data_models.optimizable import SearchSpace
|
|
29
29
|
from nat.data_models.retry_mixin import RetryMixin
|
|
30
|
-
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
31
30
|
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
32
|
-
from nat.data_models.top_p_mixin import TopPMixin
|
|
33
31
|
|
|
34
32
|
|
|
35
|
-
class NIMModelConfig(LLMBaseConfig,
|
|
36
|
-
RetryMixin,
|
|
37
|
-
OptimizableMixin,
|
|
38
|
-
TemperatureMixin,
|
|
39
|
-
TopPMixin,
|
|
40
|
-
ThinkingMixin,
|
|
41
|
-
name="nim"):
|
|
33
|
+
class NIMModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="nim"):
|
|
42
34
|
"""An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""
|
|
43
35
|
|
|
44
36
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
@@ -51,6 +43,16 @@ class NIMModelConfig(LLMBaseConfig,
|
|
|
51
43
|
max_tokens: PositiveInt = OptimizableField(default=300,
|
|
52
44
|
description="Maximum number of tokens to generate.",
|
|
53
45
|
space=SearchSpace(high=2176, low=128, step=512))
|
|
46
|
+
temperature: float | None = OptimizableField(
|
|
47
|
+
default=None,
|
|
48
|
+
ge=0.0,
|
|
49
|
+
description="Sampling temperature to control randomness in the output.",
|
|
50
|
+
space=SearchSpace(high=0.9, low=0.1, step=0.2))
|
|
51
|
+
top_p: float | None = OptimizableField(default=None,
|
|
52
|
+
ge=0.0,
|
|
53
|
+
le=1.0,
|
|
54
|
+
description="Top-p for distribution sampling.",
|
|
55
|
+
space=SearchSpace(high=1.0, low=0.5, step=0.1))
|
|
54
56
|
|
|
55
57
|
|
|
56
58
|
@register_llm_provider(config_type=NIMModelConfig)
|
nat/llm/openai_llm.py
CHANGED
|
@@ -24,19 +24,12 @@ from nat.data_models.common import OptionalSecretStr
|
|
|
24
24
|
from nat.data_models.llm import LLMBaseConfig
|
|
25
25
|
from nat.data_models.optimizable import OptimizableField
|
|
26
26
|
from nat.data_models.optimizable import OptimizableMixin
|
|
27
|
+
from nat.data_models.optimizable import SearchSpace
|
|
27
28
|
from nat.data_models.retry_mixin import RetryMixin
|
|
28
|
-
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
29
29
|
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
30
|
-
from nat.data_models.top_p_mixin import TopPMixin
|
|
31
30
|
|
|
32
31
|
|
|
33
|
-
class OpenAIModelConfig(LLMBaseConfig,
|
|
34
|
-
RetryMixin,
|
|
35
|
-
OptimizableMixin,
|
|
36
|
-
TemperatureMixin,
|
|
37
|
-
TopPMixin,
|
|
38
|
-
ThinkingMixin,
|
|
39
|
-
name="openai"):
|
|
32
|
+
class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="openai"):
|
|
40
33
|
"""An OpenAI LLM provider to be used with an LLM client."""
|
|
41
34
|
|
|
42
35
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
@@ -48,6 +41,16 @@ class OpenAIModelConfig(LLMBaseConfig,
|
|
|
48
41
|
description="The OpenAI hosted model name.")
|
|
49
42
|
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
50
43
|
max_retries: int = Field(default=10, description="The max number of retries for the request.")
|
|
44
|
+
temperature: float | None = OptimizableField(
|
|
45
|
+
default=None,
|
|
46
|
+
ge=0.0,
|
|
47
|
+
description="Sampling temperature to control randomness in the output.",
|
|
48
|
+
space=SearchSpace(high=0.9, low=0.1, step=0.2))
|
|
49
|
+
top_p: float | None = OptimizableField(default=None,
|
|
50
|
+
ge=0.0,
|
|
51
|
+
le=1.0,
|
|
52
|
+
description="Top-p for distribution sampling.",
|
|
53
|
+
space=SearchSpace(high=1.0, low=0.5, step=0.1))
|
|
51
54
|
|
|
52
55
|
|
|
53
56
|
@register_llm_provider(config_type=OpenAIModelConfig)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Middleware implementations for NeMo Agent Toolkit."""
|
|
16
|
+
|
|
17
|
+
from nat.middleware.cache_middleware import CacheMiddleware
|
|
18
|
+
from nat.middleware.function_middleware import FunctionMiddleware
|
|
19
|
+
from nat.middleware.function_middleware import FunctionMiddlewareChain
|
|
20
|
+
from nat.middleware.function_middleware import validate_middleware
|
|
21
|
+
from nat.middleware.middleware import CallNext
|
|
22
|
+
from nat.middleware.middleware import CallNextStream
|
|
23
|
+
from nat.middleware.middleware import FunctionMiddlewareContext
|
|
24
|
+
from nat.middleware.middleware import Middleware
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"CacheMiddleware",
|
|
28
|
+
"CallNext",
|
|
29
|
+
"CallNextStream",
|
|
30
|
+
"FunctionMiddlewareContext",
|
|
31
|
+
"Middleware",
|
|
32
|
+
"FunctionMiddleware",
|
|
33
|
+
"FunctionMiddlewareChain",
|
|
34
|
+
"validate_middleware",
|
|
35
|
+
]
|