nvidia-nat 1.3.0a20250822__py3-none-any.whl → 1.3.0a20250823__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +0 -1
  3. nat/agent/react_agent/agent.py +21 -3
  4. nat/agent/react_agent/register.py +1 -1
  5. nat/agent/register.py +0 -1
  6. nat/agent/rewoo_agent/agent.py +0 -1
  7. nat/agent/rewoo_agent/register.py +1 -1
  8. nat/agent/tool_calling_agent/agent.py +0 -1
  9. nat/agent/tool_calling_agent/register.py +1 -1
  10. nat/authentication/api_key/api_key_auth_provider.py +1 -1
  11. nat/authentication/register.py +0 -1
  12. nat/builder/builder.py +1 -1
  13. nat/builder/context.py +9 -1
  14. nat/builder/function_base.py +3 -3
  15. nat/builder/function_info.py +5 -7
  16. nat/builder/workflow_builder.py +0 -1
  17. nat/cli/commands/evaluate.py +1 -1
  18. nat/cli/commands/info/list_components.py +7 -8
  19. nat/cli/commands/info/list_mcp.py +3 -4
  20. nat/cli/commands/registry/search.py +14 -16
  21. nat/cli/commands/start.py +0 -1
  22. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  23. nat/cli/commands/workflow/workflow_commands.py +0 -1
  24. nat/cli/type_registry.py +3 -5
  25. nat/data_models/config.py +1 -1
  26. nat/data_models/evaluate.py +1 -1
  27. nat/data_models/function_dependencies.py +6 -6
  28. nat/data_models/intermediate_step.py +3 -3
  29. nat/data_models/model_gated_field_mixin.py +125 -0
  30. nat/data_models/swe_bench_model.py +1 -1
  31. nat/data_models/temperature_mixin.py +36 -0
  32. nat/data_models/top_p_mixin.py +36 -0
  33. nat/embedder/register.py +0 -1
  34. nat/eval/dataset_handler/dataset_handler.py +5 -6
  35. nat/eval/evaluate.py +7 -8
  36. nat/eval/rag_evaluator/register.py +2 -2
  37. nat/eval/register.py +0 -1
  38. nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
  39. nat/eval/utils/weave_eval.py +3 -3
  40. nat/experimental/test_time_compute/models/strategy_base.py +3 -2
  41. nat/experimental/test_time_compute/register.py +0 -1
  42. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
  43. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
  44. nat/front_ends/fastapi/message_handler.py +13 -14
  45. nat/front_ends/fastapi/message_validator.py +4 -4
  46. nat/front_ends/fastapi/step_adaptor.py +1 -1
  47. nat/front_ends/register.py +0 -1
  48. nat/llm/aws_bedrock_llm.py +3 -3
  49. nat/llm/azure_openai_llm.py +3 -4
  50. nat/llm/nim_llm.py +4 -4
  51. nat/llm/openai_llm.py +4 -4
  52. nat/llm/register.py +0 -1
  53. nat/llm/utils/env_config_value.py +2 -3
  54. nat/object_store/register.py +0 -1
  55. nat/observability/exporter/file_exporter.py +1 -1
  56. nat/observability/register.py +3 -3
  57. nat/profiler/callbacks/langchain_callback_handler.py +1 -1
  58. nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
  59. nat/profiler/data_frame_row.py +1 -1
  60. nat/profiler/decorators/framework_wrapper.py +1 -4
  61. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  62. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  63. nat/profiler/inference_optimization/data_models.py +3 -3
  64. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  65. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  66. nat/profiler/profile_runner.py +13 -8
  67. nat/registry_handlers/package_utils.py +0 -1
  68. nat/registry_handlers/pypi/pypi_handler.py +20 -23
  69. nat/registry_handlers/register.py +3 -4
  70. nat/registry_handlers/rest/rest_handler.py +8 -9
  71. nat/retriever/register.py +0 -1
  72. nat/runtime/session.py +23 -8
  73. nat/settings/global_settings.py +0 -1
  74. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  75. nat/tool/document_search.py +1 -1
  76. nat/tool/mcp/mcp_tool.py +1 -1
  77. nat/tool/register.py +0 -1
  78. nat/utils/data_models/schema_validator.py +2 -2
  79. nat/utils/exception_handlers/automatic_retries.py +0 -2
  80. nat/utils/exception_handlers/schemas.py +1 -1
  81. nat/utils/reactive/base/observable_base.py +2 -2
  82. nat/utils/reactive/base/observer_base.py +1 -1
  83. nat/utils/reactive/observable.py +2 -2
  84. nat/utils/reactive/observer.py +2 -2
  85. nat/utils/reactive/subscription.py +1 -1
  86. nat/utils/settings/global_settings.py +4 -6
  87. nat/utils/type_utils.py +4 -4
  88. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/METADATA +1 -1
  89. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/RECORD +94 -91
  90. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/WHEEL +0 -0
  91. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/entry_points.txt +0 -0
  92. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  93. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE.md +0 -0
  94. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,36 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+
21
+ from nat.data_models.model_gated_field_mixin import ModelGatedFieldMixin
22
+
23
+ _UNSUPPORTED_TEMPERATURE_MODELS = (re.compile(r"gpt-?5", re.IGNORECASE), )
24
+
25
+
26
+ class TemperatureMixin(
27
+ BaseModel,
28
+ ModelGatedFieldMixin[float],
29
+ field_name="temperature",
30
+ default_if_supported=0.0,
31
+ unsupported_models=_UNSUPPORTED_TEMPERATURE_MODELS,
32
+ ):
33
+ """
34
+ Mixin class for temperature configuration.
35
+ """
36
+ temperature: float | None = Field(default=None, ge=0.0, le=1.0, description="Sampling temperature in [0, 1].")
@@ -0,0 +1,36 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+
21
+ from nat.data_models.model_gated_field_mixin import ModelGatedFieldMixin
22
+
23
+ _UNSUPPORTED_TOP_P_MODELS = (re.compile(r"gpt-?5", re.IGNORECASE), )
24
+
25
+
26
+ class TopPMixin(
27
+ BaseModel,
28
+ ModelGatedFieldMixin[float],
29
+ field_name="top_p",
30
+ default_if_supported=1.0,
31
+ unsupported_models=_UNSUPPORTED_TOP_P_MODELS,
32
+ ):
33
+ """
34
+ Mixin class for top-p configuration.
35
+ """
36
+ top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="Top-p for distribution sampling.")
nat/embedder/register.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
@@ -146,13 +146,12 @@ class DatasetHandler:
146
146
  # When num_passes is specified, always use concurrency * num_passes
147
147
  # This respects the user's intent for exact number of passes
148
148
  target_size = self.concurrency * self.num_passes
149
+ # When num_passes = 0, use the largest multiple of concurrency <= original_size
150
+ # If original_size < concurrency, we need at least concurrency rows
151
+ elif original_size >= self.concurrency:
152
+ target_size = (original_size // self.concurrency) * self.concurrency
149
153
  else:
150
- # When num_passes = 0, use the largest multiple of concurrency <= original_size
151
- # If original_size < concurrency, we need at least concurrency rows
152
- if original_size >= self.concurrency:
153
- target_size = (original_size // self.concurrency) * self.concurrency
154
- else:
155
- target_size = self.concurrency
154
+ target_size = self.concurrency
156
155
 
157
156
  if target_size == 0:
158
157
  raise ValueError("Input dataset too small for even one batch at given concurrency.")
nat/eval/evaluate.py CHANGED
@@ -42,7 +42,7 @@ from nat.runtime.session import SessionManager
42
42
  logger = logging.getLogger(__name__)
43
43
 
44
44
 
45
- class EvaluationRun: # pylint: disable=too-many-public-methods
45
+ class EvaluationRun:
46
46
  """
47
47
  Instantiated for each evaluation run and used to store data for that single run.
48
48
 
@@ -319,7 +319,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
319
319
  except Exception as e:
320
320
  logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e, exc_info=True)
321
321
 
322
- def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults): # pylint: disable=unused-argument # noqa: E501
322
+ def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
323
323
  workflow_output_file = self.eval_config.general.output_dir / "workflow_output.json"
324
324
  workflow_output_file.parent.mkdir(parents=True, exist_ok=True)
325
325
 
@@ -511,12 +511,11 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
511
511
  with self.eval_trace_context.evaluation_context():
512
512
  if self.config.endpoint:
513
513
  await self.run_workflow_remote()
514
- else:
515
- if not self.config.skip_workflow:
516
- if session_manager is None:
517
- session_manager = SessionManager(eval_workflow.build(),
518
- max_concurrency=self.eval_config.general.max_concurrency)
519
- await self.run_workflow_local(session_manager)
514
+ elif not self.config.skip_workflow:
515
+ if session_manager is None:
516
+ session_manager = SessionManager(eval_workflow.build(),
517
+ max_concurrency=self.eval_config.general.max_concurrency)
518
+ await self.run_workflow_local(session_manager)
520
519
 
521
520
  # Evaluate
522
521
  evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
@@ -73,7 +73,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
73
73
  if isinstance(self.metric, str):
74
74
  return self.metric
75
75
  if isinstance(self.metric, dict) and self.metric:
76
- return next(iter(self.metric.keys())) # pylint: disable=no-member
76
+ return next(iter(self.metric.keys()))
77
77
  return ""
78
78
 
79
79
  @property
@@ -82,7 +82,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
82
82
  if isinstance(self.metric, str):
83
83
  return RagasMetricConfig() # Default config when only a metric name is given
84
84
  if isinstance(self.metric, dict) and self.metric:
85
- return next(iter(self.metric.values())) # pylint: disable=no-member
85
+ return next(iter(self.metric.values()))
86
86
  return RagasMetricConfig() # Default config when an invalid type is provided
87
87
 
88
88
 
nat/eval/register.py CHANGED
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
 
16
16
  # flake8: noqa
17
- # pylint: disable=unused-import
18
17
 
19
18
  # Import evaluators which need to be automatically registered here
20
19
  from .rag_evaluator.register import register_ragas_evaluator
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import asyncio
17
16
  import logging
18
17
  from typing import Callable
19
18
 
@@ -23,7 +22,6 @@ from langchain.schema import HumanMessage
23
22
  from langchain.schema import SystemMessage
24
23
  from langchain_core.language_models import BaseChatModel
25
24
  from langchain_core.runnables import RunnableLambda
26
- from tqdm import tqdm
27
25
 
28
26
  from nat.eval.evaluator.base_evaluator import BaseEvaluator
29
27
  from nat.eval.evaluator.evaluator_model import EvalInputItem
@@ -31,7 +29,6 @@ from nat.eval.evaluator.evaluator_model import EvalOutputItem
31
29
 
32
30
  logger = logging.getLogger(__name__)
33
31
 
34
- # pylint: disable=line-too-long
35
32
  # flake8: noqa: E501
36
33
 
37
34
 
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
31
31
  logger = logging.getLogger(__name__)
32
32
 
33
33
 
34
- class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
34
+ class WeaveEvaluationIntegration:
35
35
  """
36
36
  Class to handle all Weave integration functionality.
37
37
  """
@@ -47,8 +47,8 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
47
47
  from weave.flow.eval_imperative import EvaluationLogger
48
48
  from weave.flow.eval_imperative import ScoreLogger
49
49
  from weave.trace.context import weave_client_context
50
- self.evaluation_logger_cls = EvaluationLogger # pylint: disable=invalid-name
51
- self.score_logger_cls = ScoreLogger # pylint: disable=invalid-name
50
+ self.evaluation_logger_cls = EvaluationLogger
51
+ self.score_logger_cls = ScoreLogger
52
52
  self.weave_client_context = weave_client_context
53
53
  self.available = True
54
54
  except ImportError:
@@ -17,9 +17,10 @@ from abc import ABC
17
17
  from abc import abstractmethod
18
18
 
19
19
  from nat.builder.builder import Builder
20
- from nat.experimental.test_time_compute.models.ttc_item import TTCItem
21
- from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum, PipelineTypeEnum
22
20
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
21
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
22
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
23
+ from nat.experimental.test_time_compute.models.ttc_item import TTCItem
23
24
 
24
25
 
25
26
  class StrategyBase(ABC):
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  from .editing import iterative_plan_refinement_editor
@@ -135,8 +135,6 @@ class LLMBasedOutputMergingSelector(StrategyBase):
135
135
  except Exception as e:
136
136
  logger.error(f"Error parsing merged output: {e}")
137
137
  raise ValueError("Failed to parse merged output.")
138
- else:
139
- merged_output = merged_output
140
138
 
141
139
  logger.info("Merged output: %s", str(merged_output))
142
140
 
@@ -307,7 +307,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
307
307
  async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
308
308
  """Handle evaluation requests."""
309
309
 
310
- async with session_manager.session(request=http_request):
310
+ async with session_manager.session(http_connection=http_request):
311
311
 
312
312
  # if job_id is present and already exists return the job info
313
313
  if request.job_id:
@@ -336,7 +336,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
336
336
  """Get the status of an evaluation job."""
337
337
  logger.info("Getting status for job %s", job_id)
338
338
 
339
- async with session_manager.session(request=http_request):
339
+ async with session_manager.session(http_connection=http_request):
340
340
 
341
341
  job = job_store.get_job(job_id)
342
342
  if not job:
@@ -349,7 +349,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
349
349
  """Get the status of the last created evaluation job."""
350
350
  logger.info("Getting last job status")
351
351
 
352
- async with session_manager.session(request=http_request):
352
+ async with session_manager.session(http_connection=http_request):
353
353
 
354
354
  job = job_store.get_last_job()
355
355
  if not job:
@@ -361,7 +361,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
361
361
  async def get_jobs(http_request: Request, status: str | None = None) -> list[EvaluateStatusResponse]:
362
362
  """Get all jobs, optionally filtered by status."""
363
363
 
364
- async with session_manager.session(request=http_request):
364
+ async with session_manager.session(http_connection=http_request):
365
365
 
366
366
  if status is None:
367
367
  logger.info("Getting all jobs")
@@ -522,9 +522,9 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
522
522
 
523
523
  workflow = session_manager.workflow
524
524
 
525
- GenerateBodyType = workflow.input_schema # pylint: disable=invalid-name
526
- GenerateStreamResponseType = workflow.streaming_output_schema # pylint: disable=invalid-name
527
- GenerateSingleResponseType = workflow.single_output_schema # pylint: disable=invalid-name
525
+ GenerateBodyType = workflow.input_schema
526
+ GenerateStreamResponseType = workflow.streaming_output_schema
527
+ GenerateSingleResponseType = workflow.single_output_schema
528
528
 
529
529
  # Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
530
530
  # Consider prefixing these with "nat_" to avoid conflicts
@@ -572,7 +572,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
572
572
 
573
573
  response.headers["Content-Type"] = "application/json"
574
574
 
575
- async with session_manager.session(request=request,
575
+ async with session_manager.session(http_connection=request,
576
576
  user_authentication_callback=self._http_flow_handler.authenticate):
577
577
 
578
578
  return await generate_single_response(None, session_manager, result_type=result_type)
@@ -583,7 +583,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
583
583
 
584
584
  async def get_stream(request: Request):
585
585
 
586
- async with session_manager.session(request=request,
586
+ async with session_manager.session(http_connection=request,
587
587
  user_authentication_callback=self._http_flow_handler.authenticate):
588
588
 
589
589
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -618,7 +618,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
618
618
 
619
619
  response.headers["Content-Type"] = "application/json"
620
620
 
621
- async with session_manager.session(request=request,
621
+ async with session_manager.session(http_connection=request,
622
622
  user_authentication_callback=self._http_flow_handler.authenticate):
623
623
 
624
624
  return await generate_single_response(payload, session_manager, result_type=result_type)
@@ -632,7 +632,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
632
632
 
633
633
  async def post_stream(request: Request, payload: request_type):
634
634
 
635
- async with session_manager.session(request=request,
635
+ async with session_manager.session(http_connection=request,
636
636
  user_authentication_callback=self._http_flow_handler.authenticate):
637
637
 
638
638
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -677,7 +677,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
677
677
  # Check if streaming is requested
678
678
  stream_requested = getattr(payload, 'stream', False)
679
679
 
680
- async with session_manager.session(request=request):
680
+ async with session_manager.session(http_connection=request):
681
681
  if stream_requested:
682
682
  # Return streaming response
683
683
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -688,42 +688,41 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
688
688
  step_adaptor=self.get_step_adaptor(),
689
689
  result_type=ChatResponseChunk,
690
690
  output_type=ChatResponseChunk))
691
- else:
692
- # Return single response - check if workflow supports non-streaming
693
- try:
691
+
692
+ # Return single response - check if workflow supports non-streaming
693
+ try:
694
+ response.headers["Content-Type"] = "application/json"
695
+ return await generate_single_response(payload, session_manager, result_type=ChatResponse)
696
+ except ValueError as e:
697
+ if "Cannot get a single output value for streaming workflows" in str(e):
698
+ # Workflow only supports streaming, but client requested non-streaming
699
+ # Fall back to streaming and collect the result
700
+ chunks = []
701
+ async for chunk_str in generate_streaming_response_as_str(
702
+ payload,
703
+ session_manager=session_manager,
704
+ streaming=True,
705
+ step_adaptor=self.get_step_adaptor(),
706
+ result_type=ChatResponseChunk,
707
+ output_type=ChatResponseChunk):
708
+ if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
709
+ chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
710
+ if chunk_data:
711
+ try:
712
+ chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
713
+ if (chunk_json.choices and len(chunk_json.choices) > 0
714
+ and chunk_json.choices[0].delta
715
+ and chunk_json.choices[0].delta.content is not None):
716
+ chunks.append(chunk_json.choices[0].delta.content)
717
+ except Exception:
718
+ continue
719
+
720
+ # Create a single response from collected chunks
721
+ content = "".join(chunks)
722
+ single_response = ChatResponse.from_string(content)
694
723
  response.headers["Content-Type"] = "application/json"
695
- return await generate_single_response(payload, session_manager, result_type=ChatResponse)
696
- except ValueError as e:
697
- if "Cannot get a single output value for streaming workflows" in str(e):
698
- # Workflow only supports streaming, but client requested non-streaming
699
- # Fall back to streaming and collect the result
700
- chunks = []
701
- async for chunk_str in generate_streaming_response_as_str(
702
- payload,
703
- session_manager=session_manager,
704
- streaming=True,
705
- step_adaptor=self.get_step_adaptor(),
706
- result_type=ChatResponseChunk,
707
- output_type=ChatResponseChunk):
708
- if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
709
- chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
710
- if chunk_data:
711
- try:
712
- chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
713
- if (chunk_json.choices and len(chunk_json.choices) > 0
714
- and chunk_json.choices[0].delta
715
- and chunk_json.choices[0].delta.content is not None):
716
- chunks.append(chunk_json.choices[0].delta.content)
717
- except Exception:
718
- continue
719
-
720
- # Create a single response from collected chunks
721
- content = "".join(chunks)
722
- single_response = ChatResponse.from_string(content)
723
- response.headers["Content-Type"] = "application/json"
724
- return single_response
725
- else:
726
- raise
724
+ return single_response
725
+ raise
727
726
 
728
727
  return post_openai_api_compatible
729
728
 
@@ -758,7 +757,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
758
757
  http_request: Request) -> AsyncGenerateResponse | AsyncGenerationStatusResponse:
759
758
  """Handle async generation requests."""
760
759
 
761
- async with session_manager.session(request=http_request):
760
+ async with session_manager.session(http_connection=http_request):
762
761
 
763
762
  # if job_id is present and already exists return the job info
764
763
  if request.job_id:
@@ -804,7 +803,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
804
803
  """Get the status of an async job."""
805
804
  logger.info("Getting status for job %s", job_id)
806
805
 
807
- async with session_manager.session(request=http_request):
806
+ async with session_manager.session(http_connection=http_request):
808
807
 
809
808
  job = job_store.get_job(job_id)
810
809
  if not job:
@@ -86,7 +86,7 @@ class WebSocketMessageHandler:
86
86
 
87
87
  async def __aexit__(self, exc_type, exc_value, traceback) -> None:
88
88
 
89
- # TODO: Handle the exit # pylint: disable=fixme
89
+ # TODO: Handle the exit
90
90
  pass
91
91
 
92
92
  async def run(self) -> None:
@@ -105,12 +105,10 @@ class WebSocketMessageHandler:
105
105
  if (isinstance(validated_message, WebSocketUserMessage)):
106
106
  await self.process_workflow_request(validated_message)
107
107
 
108
- elif isinstance(
109
- validated_message,
110
- ( # noqa: E131
111
- WebSocketSystemResponseTokenMessage,
112
- WebSocketSystemIntermediateStepMessage,
113
- WebSocketSystemInteractionMessage)):
108
+ elif isinstance(validated_message,
109
+ (WebSocketSystemResponseTokenMessage,
110
+ WebSocketSystemIntermediateStepMessage,
111
+ WebSocketSystemInteractionMessage)):
114
112
  # These messages are already handled by self.create_websocket_message(data_model=value, …)
115
113
  # No further processing is needed here.
116
114
  pass
@@ -119,11 +117,9 @@ class WebSocketMessageHandler:
119
117
  user_content = await self.process_user_message_content(validated_message)
120
118
  self._user_interaction_response.set_result(user_content)
121
119
  except (asyncio.CancelledError, WebSocketDisconnect):
122
- # TODO: Handle the disconnect # pylint: disable=fixme
120
+ # TODO: Handle the disconnect
123
121
  break
124
122
 
125
- return None
126
-
127
123
  async def process_user_message_content(
128
124
  self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
129
125
  """
@@ -162,12 +158,13 @@ class WebSocketMessageHandler:
162
158
 
163
159
  if isinstance(content, TextContent) and (self._running_workflow_task is None):
164
160
 
165
- def _done_callback(task: asyncio.Task): # pylint: disable=unused-argument
161
+ def _done_callback(task: asyncio.Task):
166
162
  self._running_workflow_task = None
167
163
 
168
164
  self._running_workflow_task = asyncio.create_task(
169
- self._run_workflow(content.text,
170
- self._conversation_id,
165
+ self._run_workflow(payload=content.text,
166
+ user_message_id=self._message_parent_id,
167
+ conversation_id=self._conversation_id,
171
168
  result_type=self._schema_output_mapping[self._workflow_schema_type],
172
169
  output_type=self._schema_output_mapping[
173
170
  self._workflow_schema_type])).add_done_callback(_done_callback)
@@ -290,14 +287,16 @@ class WebSocketMessageHandler:
290
287
 
291
288
  async def _run_workflow(self,
292
289
  payload: typing.Any,
290
+ user_message_id: str | None = None,
293
291
  conversation_id: str | None = None,
294
292
  result_type: type | None = None,
295
293
  output_type: type | None = None) -> None:
296
294
 
297
295
  try:
298
296
  async with self._session_manager.session(
297
+ user_message_id=user_message_id,
299
298
  conversation_id=conversation_id,
300
- request=self._socket,
299
+ http_connection=self._socket,
301
300
  user_input_callback=self.human_interaction_callback,
302
301
  user_authentication_callback=(self._flow_handler.authenticate
303
302
  if self._flow_handler else None)) as session:
@@ -232,7 +232,7 @@ class MessageValidator:
232
232
  """
233
233
  return data_model.parent_id or "root"
234
234
 
235
- async def create_system_response_token_message( # pylint: disable=R0917:too-many-positional-arguments
235
+ async def create_system_response_token_message(
236
236
  self,
237
237
  message_type: Literal[WebSocketMessageType.RESPONSE_MESSAGE,
238
238
  WebSocketMessageType.ERROR_MESSAGE] = WebSocketMessageType.RESPONSE_MESSAGE,
@@ -272,7 +272,7 @@ class MessageValidator:
272
272
  logger.error("Error creating system response token message: %s", str(e), exc_info=True)
273
273
  return None
274
274
 
275
- async def create_system_intermediate_step_message( # pylint: disable=R0917:too-many-positional-arguments
275
+ async def create_system_intermediate_step_message(
276
276
  self,
277
277
  message_type: Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE] = (
278
278
  WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE),
@@ -311,7 +311,7 @@ class MessageValidator:
311
311
  logger.error("Error creating system intermediate step message: %s", str(e), exc_info=True)
312
312
  return None
313
313
 
314
- async def create_system_interaction_message( # pylint: disable=R0917:too-many-positional-arguments
314
+ async def create_system_interaction_message(
315
315
  self,
316
316
  *,
317
317
  message_type: Literal[WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = (
@@ -323,7 +323,7 @@ class MessageValidator:
323
323
  content: HumanPrompt,
324
324
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
325
325
  timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
326
- ) -> WebSocketSystemInteractionMessage | None: # noqa: E125 continuation line with same indent as next logical line
326
+ ) -> WebSocketSystemInteractionMessage | None:
327
327
  """
328
328
  Creates a system interaction message with default values.
329
329
 
@@ -289,7 +289,7 @@ class StepAdaptor:
289
289
 
290
290
  return event
291
291
 
292
- def process(self, step: IntermediateStep) -> ResponseSerializable | None: # pylint: disable=R1710
292
+ def process(self, step: IntermediateStep) -> ResponseSerializable | None:
293
293
 
294
294
  # Track the chunk
295
295
  self._history.append(step)
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
@@ -22,9 +22,10 @@ from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
24
  from nat.data_models.retry_mixin import RetryMixin
25
+ from nat.data_models.temperature_mixin import TemperatureMixin
25
26
 
26
27
 
27
- class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
28
+ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, name="aws_bedrock"):
28
29
  """An AWS Bedrock llm provider to be used with an LLM client."""
29
30
 
30
31
  model_config = ConfigDict(protected_namespaces=())
@@ -33,7 +34,6 @@ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
33
34
  model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
34
35
  serialization_alias="model",
35
36
  description="The model name for the hosted AWS Bedrock.")
36
- temperature: float = Field(default=0.0, ge=0.0, le=1.0, description="Sampling temperature in [0, 1].")
37
37
  max_tokens: int | None = Field(default=1024,
38
38
  gt=0,
39
39
  description="Maximum number of tokens to generate."
@@ -52,6 +52,6 @@ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
52
52
 
53
53
 
54
54
  @register_llm_provider(config_type=AWSBedrockModelConfig)
55
- async def aws_bedrock_model(llm_config: AWSBedrockModelConfig, builder: Builder):
55
+ async def aws_bedrock_model(llm_config: AWSBedrockModelConfig, _builder: Builder):
56
56
 
57
57
  yield LLMProviderInfo(config=llm_config, description="A AWS Bedrock model for use with an LLM client.")
@@ -22,9 +22,11 @@ from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
24
  from nat.data_models.retry_mixin import RetryMixin
25
+ from nat.data_models.temperature_mixin import TemperatureMixin
26
+ from nat.data_models.top_p_mixin import TopPMixin
25
27
 
26
28
 
27
- class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, name="azure_openai"):
29
+ class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="azure_openai"):
28
30
  """An Azure OpenAI LLM provider to be used with an LLM client."""
29
31
 
30
32
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
@@ -38,10 +40,7 @@ class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, name="azure_openai"):
38
40
  azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
39
41
  serialization_alias="azure_deployment",
40
42
  description="The Azure OpenAI hosted model/deployment name.")
41
- temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
42
- top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
43
43
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
44
- max_retries: int = Field(default=10, description="The max number of retries for the request.")
45
44
 
46
45
 
47
46
  @register_llm_provider(config_type=AzureOpenAIModelConfig)
nat/llm/nim_llm.py CHANGED
@@ -23,9 +23,11 @@ from nat.builder.llm import LLMProviderInfo
23
23
  from nat.cli.register_workflow import register_llm_provider
24
24
  from nat.data_models.llm import LLMBaseConfig
25
25
  from nat.data_models.retry_mixin import RetryMixin
26
+ from nat.data_models.temperature_mixin import TemperatureMixin
27
+ from nat.data_models.top_p_mixin import TopPMixin
26
28
 
27
29
 
28
- class NIMModelConfig(LLMBaseConfig, RetryMixin, name="nim"):
30
+ class NIMModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="nim"):
29
31
  """An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""
30
32
 
31
33
  model_config = ConfigDict(protected_namespaces=())
@@ -35,12 +37,10 @@ class NIMModelConfig(LLMBaseConfig, RetryMixin, name="nim"):
35
37
  model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
36
38
  serialization_alias="model",
37
39
  description="The model name for the hosted NIM.")
38
- temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
39
- top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
40
40
  max_tokens: PositiveInt = Field(default=300, description="Maximum number of tokens to generate.")
41
41
 
42
42
 
43
43
  @register_llm_provider(config_type=NIMModelConfig)
44
- async def nim_model(llm_config: NIMModelConfig, builder: Builder):
44
+ async def nim_model(llm_config: NIMModelConfig, _builder: Builder):
45
45
 
46
46
  yield LLMProviderInfo(config=llm_config, description="A NIM model for use with an LLM client.")