nvidia-nat 1.4.0a20251102__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. nat/builder/builder.py +52 -0
  2. nat/builder/component_utils.py +7 -1
  3. nat/builder/context.py +17 -0
  4. nat/builder/framework_enum.py +1 -0
  5. nat/builder/function.py +74 -3
  6. nat/builder/workflow.py +4 -2
  7. nat/builder/workflow_builder.py +129 -0
  8. nat/cli/commands/workflow/workflow_commands.py +3 -2
  9. nat/cli/register_workflow.py +50 -0
  10. nat/cli/type_registry.py +68 -0
  11. nat/data_models/component.py +2 -0
  12. nat/data_models/component_ref.py +11 -0
  13. nat/data_models/config.py +16 -0
  14. nat/data_models/function.py +14 -1
  15. nat/data_models/middleware.py +35 -0
  16. nat/data_models/runtime_enum.py +26 -0
  17. nat/eval/dataset_handler/dataset_filter.py +34 -2
  18. nat/eval/evaluate.py +11 -3
  19. nat/eval/utils/weave_eval.py +17 -3
  20. nat/front_ends/fastapi/fastapi_front_end_config.py +29 -0
  21. nat/front_ends/fastapi/fastapi_front_end_plugin.py +13 -7
  22. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +144 -14
  23. nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
  24. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
  25. nat/llm/aws_bedrock_llm.py +11 -9
  26. nat/llm/azure_openai_llm.py +12 -4
  27. nat/llm/litellm_llm.py +11 -4
  28. nat/llm/nim_llm.py +11 -9
  29. nat/llm/openai_llm.py +12 -9
  30. nat/middleware/__init__.py +35 -0
  31. nat/middleware/cache_middleware.py +256 -0
  32. nat/middleware/function_middleware.py +186 -0
  33. nat/middleware/middleware.py +184 -0
  34. nat/middleware/register.py +35 -0
  35. nat/profiler/decorators/framework_wrapper.py +16 -0
  36. nat/retriever/milvus/register.py +11 -3
  37. nat/retriever/milvus/retriever.py +102 -40
  38. nat/runtime/runner.py +12 -1
  39. nat/runtime/session.py +10 -3
  40. nat/tool/code_execution/code_sandbox.py +4 -7
  41. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  42. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +5 -0
  43. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  44. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  45. nat/tool/server_tools.py +15 -2
  46. nat/utils/__init__.py +8 -4
  47. nat/utils/io/yaml_tools.py +73 -3
  48. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +11 -3
  49. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +54 -50
  50. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
  51. nat/data_models/temperature_mixin.py +0 -44
  52. nat/data_models/top_p_mixin.py +0 -44
  53. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  54. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
  55. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  56. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
  57. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,8 @@ from pydantic import BaseModel
39
39
  from pydantic import Field
40
40
  from starlette.websockets import WebSocket
41
41
 
42
+ from nat.builder.eval_builder import WorkflowEvalBuilder
43
+ from nat.builder.evaluator import EvaluatorInfo
42
44
  from nat.builder.function import Function
43
45
  from nat.builder.workflow_builder import WorkflowBuilder
44
46
  from nat.data_models.api_server import ChatRequest
@@ -51,11 +53,14 @@ from nat.data_models.object_store import NoSuchKeyError
51
53
  from nat.eval.config import EvaluationRunOutput
52
54
  from nat.eval.evaluate import EvaluationRun
53
55
  from nat.eval.evaluate import EvaluationRunConfig
56
+ from nat.eval.evaluator.evaluator_model import EvalInput
54
57
  from nat.front_ends.fastapi.auth_flow_handlers.http_flow_handler import HTTPAuthenticationFlowHandler
55
58
  from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import FlowState
56
59
  from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import WebSocketAuthenticationFlowHandler
57
60
  from nat.front_ends.fastapi.fastapi_front_end_config import AsyncGenerateResponse
58
61
  from nat.front_ends.fastapi.fastapi_front_end_config import AsyncGenerationStatusResponse
62
+ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateItemRequest
63
+ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateItemResponse
59
64
  from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateRequest
60
65
  from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateResponse
61
66
  from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateStatusResponse
@@ -227,6 +232,54 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
227
232
  self._outstanding_flows: dict[str, FlowState] = {}
228
233
  self._outstanding_flows_lock = asyncio.Lock()
229
234
 
235
+ # Evaluator storage for single-item evaluation
236
+ self._evaluators: dict[str, EvaluatorInfo] = {}
237
+ self._eval_builder: WorkflowEvalBuilder | None = None
238
+
239
+ async def initialize_evaluators(self, config: Config):
240
+ """Initialize and store evaluators from config for single-item evaluation."""
241
+ if not config.eval or not config.eval.evaluators:
242
+ logger.info("No evaluators configured, skipping evaluator initialization")
243
+ return
244
+
245
+ try:
246
+ # Build evaluators using WorkflowEvalBuilder (same pattern as nat eval)
247
+ # Start with registry=None and let populate_builder set everything up
248
+ self._eval_builder = WorkflowEvalBuilder(general_config=config.general,
249
+ eval_general_config=config.eval.general,
250
+ registry=None)
251
+
252
+ # Enter the async context and keep it alive
253
+ await self._eval_builder.__aenter__()
254
+
255
+ # Populate builder with config (this sets up LLMs, functions, etc.)
256
+ # Skip workflow build since we already have it from the main builder
257
+ await self._eval_builder.populate_builder(config, skip_workflow=True)
258
+
259
+ # Now evaluators should be populated by populate_builder
260
+ for name in config.eval.evaluators.keys():
261
+ self._evaluators[name] = self._eval_builder.get_evaluator(name)
262
+ logger.info(f"Initialized evaluator: {name}")
263
+
264
+ logger.info(f"Successfully initialized {len(self._evaluators)} evaluators")
265
+
266
+ except Exception as e:
267
+ logger.error(f"Failed to initialize evaluators: {e}")
268
+ # Don't fail startup, just log the error
269
+ self._evaluators = {}
270
+
271
+ async def cleanup_evaluators(self):
272
+ """Clean up evaluator resources on shutdown."""
273
+ if self._eval_builder:
274
+ try:
275
+ await self._eval_builder.__aexit__(None, None, None)
276
+ logger.info("Evaluator builder context cleaned up")
277
+ except Exception as e:
278
+ logger.error(f"Error cleaning up evaluator builder: {e}")
279
+ finally:
280
+ self._eval_builder = None
281
+ self._evaluators.clear()
282
+
230
283
  def get_step_adaptor(self) -> StepAdaptor:
231
284
 
232
285
  return StepAdaptor(self.front_end_config.step_adaptor)
@@ -236,12 +289,20 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
236
289
  # Do things like setting the base URL and global configuration options
237
290
  app.root_path = self.front_end_config.root_path
238
291
 
292
+ # Initialize evaluators for single-item evaluation
293
+ # TODO: we need config control over this as it's not always needed
294
+ await self.initialize_evaluators(self._config)
295
+
296
+ # Ensure evaluator resources are cleaned up when the app shuts down
297
+ app.add_event_handler("shutdown", self.cleanup_evaluators)
298
+
239
299
  await self.add_routes(app, builder)
240
300
 
241
301
  async def add_routes(self, app: FastAPI, builder: WorkflowBuilder):
242
302
 
243
303
  await self.add_default_route(app, SessionManager(await builder.build()))
244
304
  await self.add_evaluate_route(app, SessionManager(await builder.build()))
305
+ await self.add_evaluate_item_route(app, SessionManager(await builder.build()))
245
306
  await self.add_static_files_route(app, builder)
246
307
  await self.add_authorization_route(app)
247
308
  await self.add_mcp_client_tool_list_route(app, builder)
@@ -439,6 +500,69 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
439
500
  else:
440
501
  logger.warning("Dask is not available, evaluation endpoints will not be added.")
441
502
 
503
+ async def add_evaluate_item_route(self, app: FastAPI, session_manager: SessionManager):
504
+ """Add the single-item evaluation endpoint to the FastAPI app."""
505
+
506
+ async def evaluate_single_item(request: EvaluateItemRequest, http_request: Request) -> EvaluateItemResponse:
507
+ """Handle single-item evaluation requests."""
508
+
509
+ async with session_manager.session(http_connection=http_request):
510
+
511
+ # Check if evaluator exists
512
+ if request.evaluator_name not in self._evaluators:
513
+ raise HTTPException(status_code=404,
514
+ detail=f"Evaluator '{request.evaluator_name}' not found. "
515
+ f"Available evaluators: {list(self._evaluators.keys())}")
516
+
517
+ try:
518
+ # Get the evaluator
519
+ evaluator = self._evaluators[request.evaluator_name]
520
+
521
+ # Run evaluation on single item
522
+ result = await evaluator.evaluate_fn(EvalInput(eval_input_items=[request.item]))
523
+
524
+ # Extract the single output item
525
+ if result.eval_output_items:
526
+ output_item = result.eval_output_items[0]
527
+ return EvaluateItemResponse(success=True, result=output_item, error=None)
528
+ else:
529
+ return EvaluateItemResponse(success=False, result=None, error="Evaluator returned no results")
530
+
531
+ except Exception as e:
532
+ logger.exception(f"Error evaluating item with {request.evaluator_name}")
533
+ return EvaluateItemResponse(success=False, result=None, error=f"Evaluation failed: {str(e)}")
534
+
535
+ # Register the route
536
+ if self.front_end_config.evaluate_item.path:
537
+ app.add_api_route(path=self.front_end_config.evaluate_item.path,
538
+ endpoint=evaluate_single_item,
539
+ methods=[self.front_end_config.evaluate_item.method],
540
+ response_model=EvaluateItemResponse,
541
+ description=self.front_end_config.evaluate_item.description,
542
+ responses={
543
+ 404: {
544
+ "description": "Evaluator not found",
545
+ "content": {
546
+ "application/json": {
547
+ "example": {
548
+ "detail": "Evaluator 'unknown' not found"
549
+ }
550
+ }
551
+ }
552
+ },
553
+ 500: {
554
+ "description": "Internal Server Error",
555
+ "content": {
556
+ "application/json": {
557
+ "example": {
558
+ "detail": "Internal server error occurred"
559
+ }
560
+ }
561
+ }
562
+ }
563
+ })
564
+ logger.info(f"Added evaluate_item route at {self.front_end_config.evaluate_item.path}")
565
+
442
566
  async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
443
567
 
444
568
  if not self.front_end_config.object_store:
@@ -544,7 +668,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
544
668
  GenerateStreamResponseType = workflow.streaming_output_schema
545
669
  GenerateSingleResponseType = workflow.single_output_schema
546
670
 
547
- if self._dask_available:
671
+ # Skip async generation for custom routes (those with function_name)
672
+ if self._dask_available and not hasattr(endpoint, 'function_name'):
548
673
  # Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
549
674
  # Consider prefixing these with "nat_" to avoid conflicts
550
675
 
@@ -562,6 +687,10 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
562
687
  description="Optional time (in seconds) before the job expires. "
563
688
  "Clamped between 600 (10 min) and 86400 (24h).")
564
689
 
690
+ def validate_model(self):
691
+ # Override to ensure that the parent class validator is not called
692
+ return self
693
+
565
694
  # Ensure that the input is in the body. POD types are treated as query parameters
566
695
  if (not issubclass(GenerateBodyType, BaseModel)):
567
696
  GenerateBodyType = typing.Annotated[GenerateBodyType, Body()]
@@ -760,17 +889,18 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
760
889
  return AsyncGenerateResponse(job_id=job.job_id, status=job.status)
761
890
 
762
891
  job_id = self._job_store.ensure_job_id(request.job_id)
763
- (_, job) = await self._job_store.submit_job(job_id=job_id,
764
- expiry_seconds=request.expiry_seconds,
765
- job_fn=run_generation,
766
- sync_timeout=request.sync_timeout,
767
- job_args=[
768
- self._scheduler_address,
769
- self._db_url,
770
- self._config_file_path,
771
- job_id,
772
- request.model_dump(mode="json")
773
- ])
892
+ (_, job) = await self._job_store.submit_job(
893
+ job_id=job_id,
894
+ expiry_seconds=request.expiry_seconds,
895
+ job_fn=run_generation,
896
+ sync_timeout=request.sync_timeout,
897
+ job_args=[
898
+ self._scheduler_address,
899
+ self._db_url,
900
+ self._config_file_path,
901
+ job_id,
902
+ request.model_dump(mode="json", exclude=["job_id", "sync_timeout", "expiry_seconds"])
903
+ ])
774
904
 
775
905
  if job is not None:
776
906
  response.status_code = 200
@@ -916,7 +1046,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
916
1046
  responses={500: response_500},
917
1047
  )
918
1048
 
919
- if self._dask_available:
1049
+ if self._dask_available and not hasattr(endpoint, 'function_name'):
920
1050
  app.add_api_route(
921
1051
  path=f"{endpoint.path}/async",
922
1052
  endpoint=post_async_generation(request_type=AsyncGenerateRequest),
@@ -930,7 +1060,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
930
1060
  else:
931
1061
  raise ValueError(f"Unsupported method {endpoint.method}")
932
1062
 
933
- if self._dask_available:
1063
+ if self._dask_available and not hasattr(endpoint, 'function_name'):
934
1064
  app.add_api_route(
935
1065
  path=f"{endpoint.path}/async/job/{{job_id}}",
936
1066
  endpoint=get_async_job_status,
@@ -140,6 +140,10 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
140
140
  # Mount the MCP server's ASGI app at the configured base_path
141
141
  app.mount(self.front_end_config.base_path, mcp.streamable_http_app())
142
142
 
143
+ # Allow plugins to add routes to the wrapper app (e.g., OAuth discovery endpoints)
144
+ worker = self._get_worker_instance()
145
+ await worker.add_root_level_routes(app, mcp)
146
+
143
147
  # Configure and start uvicorn server
144
148
  config = uvicorn.Config(
145
149
  app,
@@ -17,12 +17,16 @@ import logging
17
17
  from abc import ABC
18
18
  from abc import abstractmethod
19
19
  from collections.abc import Mapping
20
+ from typing import TYPE_CHECKING
20
21
  from typing import Any
21
22
 
22
23
  from mcp.server.fastmcp import FastMCP
23
24
  from starlette.exceptions import HTTPException
24
25
  from starlette.requests import Request
25
26
 
27
+ if TYPE_CHECKING:
28
+ from fastapi import FastAPI
29
+
26
30
  from nat.builder.function import Function
27
31
  from nat.builder.function_base import FunctionBase
28
32
  from nat.builder.workflow import Workflow
@@ -192,6 +196,28 @@ class MCPFrontEndPluginWorkerBase(ABC):
192
196
 
193
197
  return functions
194
198
 
199
+ async def add_root_level_routes(self, wrapper_app: "FastAPI", mcp: FastMCP) -> None:
200
+ """Add routes to the wrapper FastAPI app (optional extension point).
201
+
202
+ This method is called when base_path is configured and a wrapper
203
+ FastAPI app is created to mount the MCP server. Plugins can override
204
+ this to add routes to the wrapper app at the root level, outside the
205
+ mounted MCP server path.
206
+
207
+ Common use cases:
208
+ - OAuth discovery endpoints (e.g., /.well-known/oauth-protected-resource)
209
+ - Health checks at root level
210
+ - Static file serving
211
+ - Custom authentication/authorization endpoints
212
+
213
+ Default implementation does nothing, making this an optional extension point.
214
+
215
+ Args:
216
+ wrapper_app: The FastAPI wrapper application that mounts the MCP server
217
+ mcp: The FastMCP server instance (already mounted at base_path)
218
+ """
219
+ pass # Default: no additional root-level routes
220
+
195
221
  def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
196
222
  """Set up HTTP debug endpoints for introspecting tools and schemas.
197
223
 
@@ -25,18 +25,10 @@ from nat.data_models.optimizable import OptimizableField
25
25
  from nat.data_models.optimizable import OptimizableMixin
26
26
  from nat.data_models.optimizable import SearchSpace
27
27
  from nat.data_models.retry_mixin import RetryMixin
28
- from nat.data_models.temperature_mixin import TemperatureMixin
29
28
  from nat.data_models.thinking_mixin import ThinkingMixin
30
- from nat.data_models.top_p_mixin import TopPMixin
31
29
 
32
30
 
33
- class AWSBedrockModelConfig(LLMBaseConfig,
34
- RetryMixin,
35
- OptimizableMixin,
36
- TemperatureMixin,
37
- TopPMixin,
38
- ThinkingMixin,
39
- name="aws_bedrock"):
31
+ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="aws_bedrock"):
40
32
  """An AWS Bedrock llm provider to be used with an LLM client."""
41
33
 
42
34
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
@@ -61,6 +53,16 @@ class AWSBedrockModelConfig(LLMBaseConfig,
61
53
  default=None, description="Bedrock endpoint to use. Needed if you don't want to default to us-east-1 endpoint.")
62
54
  credentials_profile_name: str | None = Field(
63
55
  default=None, description="The name of the profile in the ~/.aws/credentials or ~/.aws/config files.")
56
+ temperature: float | None = OptimizableField(
57
+ default=None,
58
+ ge=0.0,
59
+ description="Sampling temperature to control randomness in the output.",
60
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
61
+ top_p: float | None = OptimizableField(default=None,
62
+ ge=0.0,
63
+ le=1.0,
64
+ description="Top-p for distribution sampling.",
65
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
64
66
 
65
67
 
66
68
  @register_llm_provider(config_type=AWSBedrockModelConfig)
@@ -22,17 +22,15 @@ from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.common import OptionalSecretStr
24
24
  from nat.data_models.llm import LLMBaseConfig
25
+ from nat.data_models.optimizable import OptimizableField
26
+ from nat.data_models.optimizable import SearchSpace
25
27
  from nat.data_models.retry_mixin import RetryMixin
26
- from nat.data_models.temperature_mixin import TemperatureMixin
27
28
  from nat.data_models.thinking_mixin import ThinkingMixin
28
- from nat.data_models.top_p_mixin import TopPMixin
29
29
 
30
30
 
31
31
  class AzureOpenAIModelConfig(
32
32
  LLMBaseConfig,
33
33
  RetryMixin,
34
- TemperatureMixin,
35
- TopPMixin,
36
34
  ThinkingMixin,
37
35
  name="azure_openai",
38
36
  ):
@@ -50,6 +48,16 @@ class AzureOpenAIModelConfig(
50
48
  serialization_alias="azure_deployment",
51
49
  description="The Azure OpenAI hosted model/deployment name.")
52
50
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
51
+ temperature: float | None = OptimizableField(
52
+ default=None,
53
+ ge=0.0,
54
+ description="Sampling temperature to control randomness in the output.",
55
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
56
+ top_p: float | None = OptimizableField(default=None,
57
+ ge=0.0,
58
+ le=1.0,
59
+ description="Top-p for distribution sampling.",
60
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
53
61
 
54
62
 
55
63
  @register_llm_provider(config_type=AzureOpenAIModelConfig)
nat/llm/litellm_llm.py CHANGED
@@ -26,18 +26,15 @@ from nat.data_models.common import OptionalSecretStr
26
26
  from nat.data_models.llm import LLMBaseConfig
27
27
  from nat.data_models.optimizable import OptimizableField
28
28
  from nat.data_models.optimizable import OptimizableMixin
29
+ from nat.data_models.optimizable import SearchSpace
29
30
  from nat.data_models.retry_mixin import RetryMixin
30
- from nat.data_models.temperature_mixin import TemperatureMixin
31
31
  from nat.data_models.thinking_mixin import ThinkingMixin
32
- from nat.data_models.top_p_mixin import TopPMixin
33
32
 
34
33
 
35
34
  class LiteLlmModelConfig(
36
35
  LLMBaseConfig,
37
36
  OptimizableMixin,
38
37
  RetryMixin,
39
- TemperatureMixin,
40
- TopPMixin,
41
38
  ThinkingMixin,
42
39
  name="litellm",
43
40
  ):
@@ -54,6 +51,16 @@ class LiteLlmModelConfig(
54
51
  serialization_alias="model",
55
52
  description="The LiteLlm hosted model name.")
56
53
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
54
+ temperature: float | None = OptimizableField(
55
+ default=None,
56
+ ge=0.0,
57
+ description="Sampling temperature to control randomness in the output.",
58
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
59
+ top_p: float | None = OptimizableField(default=None,
60
+ ge=0.0,
61
+ le=1.0,
62
+ description="Top-p for distribution sampling.",
63
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
57
64
 
58
65
 
59
66
  @register_llm_provider(config_type=LiteLlmModelConfig)
nat/llm/nim_llm.py CHANGED
@@ -27,18 +27,10 @@ from nat.data_models.optimizable import OptimizableField
27
27
  from nat.data_models.optimizable import OptimizableMixin
28
28
  from nat.data_models.optimizable import SearchSpace
29
29
  from nat.data_models.retry_mixin import RetryMixin
30
- from nat.data_models.temperature_mixin import TemperatureMixin
31
30
  from nat.data_models.thinking_mixin import ThinkingMixin
32
- from nat.data_models.top_p_mixin import TopPMixin
33
31
 
34
32
 
35
- class NIMModelConfig(LLMBaseConfig,
36
- RetryMixin,
37
- OptimizableMixin,
38
- TemperatureMixin,
39
- TopPMixin,
40
- ThinkingMixin,
41
- name="nim"):
33
+ class NIMModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="nim"):
42
34
  """An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""
43
35
 
44
36
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
@@ -51,6 +43,16 @@ class NIMModelConfig(LLMBaseConfig,
51
43
  max_tokens: PositiveInt = OptimizableField(default=300,
52
44
  description="Maximum number of tokens to generate.",
53
45
  space=SearchSpace(high=2176, low=128, step=512))
46
+ temperature: float | None = OptimizableField(
47
+ default=None,
48
+ ge=0.0,
49
+ description="Sampling temperature to control randomness in the output.",
50
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
51
+ top_p: float | None = OptimizableField(default=None,
52
+ ge=0.0,
53
+ le=1.0,
54
+ description="Top-p for distribution sampling.",
55
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
54
56
 
55
57
 
56
58
  @register_llm_provider(config_type=NIMModelConfig)
nat/llm/openai_llm.py CHANGED
@@ -24,19 +24,12 @@ from nat.data_models.common import OptionalSecretStr
24
24
  from nat.data_models.llm import LLMBaseConfig
25
25
  from nat.data_models.optimizable import OptimizableField
26
26
  from nat.data_models.optimizable import OptimizableMixin
27
+ from nat.data_models.optimizable import SearchSpace
27
28
  from nat.data_models.retry_mixin import RetryMixin
28
- from nat.data_models.temperature_mixin import TemperatureMixin
29
29
  from nat.data_models.thinking_mixin import ThinkingMixin
30
- from nat.data_models.top_p_mixin import TopPMixin
31
30
 
32
31
 
33
- class OpenAIModelConfig(LLMBaseConfig,
34
- RetryMixin,
35
- OptimizableMixin,
36
- TemperatureMixin,
37
- TopPMixin,
38
- ThinkingMixin,
39
- name="openai"):
32
+ class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="openai"):
40
33
  """An OpenAI LLM provider to be used with an LLM client."""
41
34
 
42
35
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
@@ -48,6 +41,16 @@ class OpenAIModelConfig(LLMBaseConfig,
48
41
  description="The OpenAI hosted model name.")
49
42
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
50
43
  max_retries: int = Field(default=10, description="The max number of retries for the request.")
44
+ temperature: float | None = OptimizableField(
45
+ default=None,
46
+ ge=0.0,
47
+ description="Sampling temperature to control randomness in the output.",
48
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
49
+ top_p: float | None = OptimizableField(default=None,
50
+ ge=0.0,
51
+ le=1.0,
52
+ description="Top-p for distribution sampling.",
53
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
51
54
 
52
55
 
53
56
  @register_llm_provider(config_type=OpenAIModelConfig)
@@ -0,0 +1,35 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Middleware implementations for NeMo Agent Toolkit."""
16
+
17
+ from nat.middleware.cache_middleware import CacheMiddleware
18
+ from nat.middleware.function_middleware import FunctionMiddleware
19
+ from nat.middleware.function_middleware import FunctionMiddlewareChain
20
+ from nat.middleware.function_middleware import validate_middleware
21
+ from nat.middleware.middleware import CallNext
22
+ from nat.middleware.middleware import CallNextStream
23
+ from nat.middleware.middleware import FunctionMiddlewareContext
24
+ from nat.middleware.middleware import Middleware
25
+
26
+ __all__ = [
27
+ "CacheMiddleware",
28
+ "CallNext",
29
+ "CallNextStream",
30
+ "FunctionMiddlewareContext",
31
+ "Middleware",
32
+ "FunctionMiddleware",
33
+ "FunctionMiddlewareChain",
34
+ "validate_middleware",
35
+ ]