nvidia-nat 1.3.0a20250827__py3-none-any.whl → 1.3.0a20250829__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. nat/agent/base.py +12 -7
  2. nat/agent/dual_node.py +7 -2
  3. nat/agent/react_agent/agent.py +15 -14
  4. nat/agent/react_agent/register.py +5 -1
  5. nat/agent/rewoo_agent/agent.py +23 -32
  6. nat/agent/rewoo_agent/register.py +8 -4
  7. nat/agent/tool_calling_agent/agent.py +15 -20
  8. nat/agent/tool_calling_agent/register.py +6 -2
  9. nat/builder/context.py +7 -2
  10. nat/builder/eval_builder.py +2 -2
  11. nat/builder/function.py +8 -8
  12. nat/builder/workflow_builder.py +21 -24
  13. nat/cli/cli_utils/config_override.py +1 -1
  14. nat/cli/commands/info/list_channels.py +1 -1
  15. nat/cli/commands/object_store/__init__.py +14 -0
  16. nat/cli/commands/object_store/object_store.py +227 -0
  17. nat/cli/commands/registry/publish.py +2 -2
  18. nat/cli/commands/registry/pull.py +2 -2
  19. nat/cli/commands/registry/remove.py +2 -2
  20. nat/cli/commands/registry/search.py +1 -1
  21. nat/cli/commands/start.py +1 -1
  22. nat/cli/commands/uninstall.py +1 -1
  23. nat/cli/commands/workflow/workflow_commands.py +4 -4
  24. nat/cli/entrypoint.py +3 -1
  25. nat/data_models/discovery_metadata.py +4 -4
  26. nat/data_models/gated_field_mixin.py +12 -14
  27. nat/data_models/temperature_mixin.py +1 -1
  28. nat/data_models/thinking_mixin.py +68 -0
  29. nat/data_models/top_p_mixin.py +1 -1
  30. nat/eval/evaluate.py +6 -6
  31. nat/eval/intermediate_step_adapter.py +1 -1
  32. nat/eval/rag_evaluator/evaluate.py +2 -2
  33. nat/eval/rag_evaluator/register.py +1 -1
  34. nat/eval/remote_workflow.py +3 -3
  35. nat/eval/swe_bench_evaluator/evaluate.py +5 -5
  36. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  37. nat/eval/tunable_rag_evaluator/evaluate.py +3 -3
  38. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +2 -2
  39. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  40. nat/front_ends/fastapi/fastapi_front_end_plugin.py +1 -1
  41. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
  42. nat/front_ends/fastapi/message_handler.py +2 -2
  43. nat/front_ends/fastapi/message_validator.py +8 -10
  44. nat/front_ends/fastapi/response_helpers.py +4 -4
  45. nat/front_ends/fastapi/step_adaptor.py +1 -1
  46. nat/llm/aws_bedrock_llm.py +10 -9
  47. nat/llm/azure_openai_llm.py +9 -1
  48. nat/llm/nim_llm.py +2 -1
  49. nat/llm/openai_llm.py +2 -1
  50. nat/llm/utils/thinking.py +215 -0
  51. nat/observability/exporter/base_exporter.py +1 -1
  52. nat/observability/exporter/processing_exporter.py +8 -9
  53. nat/observability/exporter_manager.py +5 -5
  54. nat/observability/mixin/file_mixin.py +7 -7
  55. nat/observability/processor/batching_processor.py +4 -6
  56. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  57. nat/observability/processor/processor_factory.py +70 -0
  58. nat/profiler/calc/calc_runner.py +3 -4
  59. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  60. nat/profiler/callbacks/langchain_callback_handler.py +5 -5
  61. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  62. nat/profiler/callbacks/semantic_kernel_callback_handler.py +2 -2
  63. nat/profiler/decorators/function_tracking.py +125 -0
  64. nat/profiler/profile_runner.py +1 -1
  65. nat/profiler/utils.py +1 -1
  66. nat/registry_handlers/local/local_handler.py +2 -2
  67. nat/registry_handlers/package_utils.py +1 -1
  68. nat/registry_handlers/pypi/pypi_handler.py +3 -3
  69. nat/registry_handlers/rest/rest_handler.py +4 -4
  70. nat/retriever/milvus/retriever.py +1 -1
  71. nat/retriever/nemo_retriever/retriever.py +1 -1
  72. nat/runtime/loader.py +1 -1
  73. nat/runtime/runner.py +2 -2
  74. nat/settings/global_settings.py +1 -1
  75. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  76. nat/tool/nvidia_rag.py +1 -1
  77. nat/tool/retriever.py +3 -2
  78. nat/utils/io/yaml_tools.py +1 -1
  79. nat/utils/reactive/observer.py +2 -2
  80. nat/utils/settings/global_settings.py +2 -2
  81. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/METADATA +3 -1
  82. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/RECORD +87 -81
  83. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/WHEEL +0 -0
  84. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/entry_points.txt +0 -0
  85. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  86. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/licenses/LICENSE.md +0 -0
  87. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,68 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+
21
+ from nat.data_models.gated_field_mixin import GatedFieldMixin
22
+
23
+ # The system prompt format for thinking is different for these, so we need to distinguish them here with two separate
24
+ # regex patterns
25
+ _NVIDIA_NEMOTRON_REGEX = re.compile(r"^nvidia/nvidia.*nemotron", re.IGNORECASE)
26
+ _LLAMA_NEMOTRON_REGEX = re.compile(r"^nvidia/llama.*nemotron", re.IGNORECASE)
27
+ _MODEL_KEYS = ("model_name", "model", "azure_deployment")
28
+
29
+
30
+ class ThinkingMixin(
31
+ BaseModel,
32
+ GatedFieldMixin,
33
+ field_name="thinking",
34
+ default_if_supported=None,
35
+ keys=_MODEL_KEYS,
36
+ supported=(_NVIDIA_NEMOTRON_REGEX, _LLAMA_NEMOTRON_REGEX),
37
+ ):
38
+ """
39
+ Mixin class for thinking configuration. Only supported on Nemotron models.
40
+
41
+ Attributes:
42
+ thinking: Whether to enable thinking. Defaults to None when supported on the model.
43
+ """
44
+ thinking: bool | None = Field(
45
+ default=None,
46
+ description="Whether to enable thinking. Defaults to None when supported on the model.",
47
+ exclude=True,
48
+ )
49
+
50
+ @property
51
+ def thinking_system_prompt(self) -> str | None:
52
+ """
53
+ Returns the system prompt to use for thinking.
54
+ For NVIDIA Nemotron, returns "/think" if enabled, else "/no_think".
55
+ For Llama Nemotron, returns "detailed thinking on" if enabled, else "detailed thinking off".
56
+ If thinking is not supported on the model, returns None.
57
+
58
+ Returns:
59
+ str | None: The system prompt to use for thinking.
60
+ """
61
+ if self.thinking is None:
62
+ return None
63
+ for key in _MODEL_KEYS:
64
+ if hasattr(self, key):
65
+ if _NVIDIA_NEMOTRON_REGEX.match(getattr(self, key)):
66
+ return "/think" if self.thinking else "/no_think"
67
+ elif _LLAMA_NEMOTRON_REGEX.match(getattr(self, key)):
68
+ return f"detailed thinking {'on' if self.thinking else 'off'}"
@@ -23,7 +23,7 @@ from nat.data_models.gated_field_mixin import GatedFieldMixin
23
23
 
24
24
  class TopPMixin(
25
25
  BaseModel,
26
- GatedFieldMixin[float],
26
+ GatedFieldMixin,
27
27
  field_name="top_p",
28
28
  default_if_supported=1.0,
29
29
  keys=("model_name", "model", "azure_deployment"),
nat/eval/evaluate.py CHANGED
@@ -168,17 +168,17 @@ class EvaluationRun:
168
168
  intermediate_future = None
169
169
 
170
170
  try:
171
-
172
171
  # Start usage stats and intermediate steps collection in parallel
173
172
  intermediate_future = pull_intermediate()
174
173
  runner_result = runner.result()
175
174
  base_output = await runner_result
176
175
  intermediate_steps = await intermediate_future
177
176
  except NotImplementedError as e:
177
+ logger.error("Failed to run the workflow: %s", e)
178
178
  # raise original error
179
- raise e
179
+ raise
180
180
  except Exception as e:
181
- logger.exception("Failed to run the workflow: %s", e, exc_info=True)
181
+ logger.exception("Failed to run the workflow: %s", e)
182
182
  # stop processing if a workflow error occurs
183
183
  self.workflow_interrupted = True
184
184
 
@@ -317,7 +317,7 @@ class EvaluationRun:
317
317
  logger.info("Deleting old job directory: %s", dir_to_delete)
318
318
  shutil.rmtree(dir_to_delete)
319
319
  except Exception as e:
320
- logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e, exc_info=True)
320
+ logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e)
321
321
 
322
322
  def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
323
323
  workflow_output_file = self.eval_config.general.output_dir / "workflow_output.json"
@@ -367,7 +367,7 @@ class EvaluationRun:
367
367
 
368
368
  await self.weave_eval.alog_score(eval_output, evaluator_name)
369
369
  except Exception as e:
370
- logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e, exc_info=True)
370
+ logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e)
371
371
 
372
372
  async def run_evaluators(self, evaluators: dict[str, Any]):
373
373
  """Run all configured evaluators asynchronously."""
@@ -380,7 +380,7 @@ class EvaluationRun:
380
380
  try:
381
381
  await asyncio.gather(*tasks)
382
382
  except Exception as e:
383
- logger.exception("An error occurred while running evaluators: %s", e, exc_info=True)
383
+ logger.error("An error occurred while running evaluators: %s", e)
384
384
  raise
385
385
  finally:
386
386
  # Finish prediction loggers in Weave
@@ -40,7 +40,7 @@ class IntermediateStepAdapter:
40
40
  try:
41
41
  validated_steps.append(IntermediateStep.model_validate(step_data))
42
42
  except Exception as e:
43
- logger.exception("Validation failed for step: %r, Error: %s", step_data, e, exc_info=True)
43
+ logger.exception("Validation failed for step: %r, Error: %s", step_data, e)
44
44
  return validated_steps
45
45
 
46
46
  def serialize_intermediate_steps(self, intermediate_steps: list[IntermediateStep]) -> list[dict]:
@@ -102,7 +102,7 @@ class RAGEvaluator:
102
102
  """Converts the ragas EvaluationResult to nat EvalOutput"""
103
103
 
104
104
  if not results_dataset:
105
- logger.error("Ragas evaluation failed with no results")
105
+ logger.error("Ragas evaluation failed with no results", exc_info=True)
106
106
  return EvalOutput(average_score=0.0, eval_output_items=[])
107
107
 
108
108
  scores: list[dict[str, float]] = results_dataset.scores
@@ -169,7 +169,7 @@ class RAGEvaluator:
169
169
  _pbar=pbar)
170
170
  except Exception as e:
171
171
  # On exception we still continue with other evaluators. Log and return an avg_score of 0.0
172
- logger.exception("Error evaluating ragas metric, Error: %s", e, exc_info=True)
172
+ logger.exception("Error evaluating ragas metric, Error: %s", e)
173
173
  results_dataset = None
174
174
  finally:
175
175
  pbar.close()
@@ -104,7 +104,7 @@ async def register_ragas_evaluator(config: RagasEvaluatorConfig, builder: EvalBu
104
104
  raise ValueError(message) from e
105
105
  except AttributeError as e:
106
106
  message = f"Ragas metric {metric_name} not found {e}."
107
- logger.error(message)
107
+ logger.exception(message)
108
108
  return None
109
109
 
110
110
  async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
@@ -74,7 +74,7 @@ class EvaluationRemoteWorkflowHandler:
74
74
  if chunk_data.get("value"):
75
75
  final_response = chunk_data.get("value")
76
76
  except json.JSONDecodeError as e:
77
- logger.error("Failed to parse generate response chunk: %s", e)
77
+ logger.exception("Failed to parse generate response chunk: %s", e)
78
78
  continue
79
79
  elif line.startswith(INTERMEDIATE_DATA_PREFIX):
80
80
  # This is an intermediate step
@@ -90,12 +90,12 @@ class EvaluationRemoteWorkflowHandler:
90
90
  payload=payload)
91
91
  intermediate_steps.append(intermediate_step)
92
92
  except (json.JSONDecodeError, ValidationError) as e:
93
- logger.error("Failed to parse intermediate step: %s", e)
93
+ logger.exception("Failed to parse intermediate step: %s", e)
94
94
  continue
95
95
 
96
96
  except aiohttp.ClientError as e:
97
97
  # Handle connection or HTTP-related errors
98
- logger.error("Request failed for question %s: %s", question, e)
98
+ logger.exception("Request failed for question %s: %s", question, e)
99
99
  item.output_obj = None
100
100
  item.trajectory = []
101
101
  return
@@ -69,13 +69,13 @@ class SweBenchEvaluator:
69
69
  try:
70
70
  shutil.move(swe_bench_report_file, report_dir)
71
71
  except Exception as e:
72
- logger.exception("Error moving report file: %s", e, exc_info=True)
72
+ logger.exception("Error moving report file: %s", e)
73
73
 
74
74
  try:
75
75
  dest_logs_dir = os.path.join(report_dir, 'logs')
76
76
  shutil.move(logs_dir, dest_logs_dir)
77
77
  except Exception as e:
78
- logger.exception("Error moving logs directory: %s", e, exc_info=True)
78
+ logger.exception("Error moving logs directory: %s", e)
79
79
 
80
80
  def is_repo_supported(self, repo: str, version: str) -> bool:
81
81
  """Check if the repo is supported by swebench"""
@@ -106,7 +106,7 @@ class SweBenchEvaluator:
106
106
  self._model_name_or_path = swebench_output.model_name_or_path
107
107
 
108
108
  except Exception as e:
109
- logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e, exc_info=True)
109
+ logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e)
110
110
 
111
111
  # Filter out repos/version not supported by SWEBench
112
112
  supported_inputs = [
@@ -114,7 +114,7 @@ class SweBenchEvaluator:
114
114
  ]
115
115
 
116
116
  if not supported_inputs:
117
- logger.error("No supported instances; nothing to evaluate")
117
+ logger.exception("No supported instances; nothing to evaluate")
118
118
  return None, None
119
119
 
120
120
  if len(supported_inputs) < len(swebench_inputs):
@@ -135,7 +135,7 @@ class SweBenchEvaluator:
135
135
  filtered_outputs = [output for output in swebench_outputs if output.instance_id in valid_instance_ids]
136
136
 
137
137
  if not filtered_outputs:
138
- logger.error("No supported outputs; nothing to evaluate")
138
+ logger.error("No supported outputs; nothing to evaluate", exc_info=True)
139
139
  return None, None
140
140
 
141
141
  # Write SWEBenchOutput to file
@@ -65,7 +65,7 @@ class TrajectoryEvaluator(BaseEvaluator):
65
65
  prediction=generated_answer,
66
66
  )
67
67
  except Exception as e:
68
- logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e, exc_info=True)
68
+ logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e)
69
69
  return EvalOutputItem(id=item.id, score=0.0, reasoning=f"Error evaluating trajectory: {e}")
70
70
 
71
71
  reasoning = {
@@ -182,8 +182,8 @@ class TunableRagEvaluator(BaseEvaluator):
182
182
  relevance_score = parsed_response["relevance_score"]
183
183
  reasoning = parsed_response["reasoning"]
184
184
  except KeyError as e:
185
- logger.error("Missing required keys in default scoring response: %s",
186
- ", ".join(str(arg) for arg in e.args))
185
+ logger.exception("Missing required keys in default scoring response: %s",
186
+ ", ".join(str(arg) for arg in e.args))
187
187
  reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
188
188
 
189
189
  coverage_weight = self.default_score_weights.get("coverage", 1 / 3)
@@ -215,7 +215,7 @@ class TunableRagEvaluator(BaseEvaluator):
215
215
  reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
216
216
  raise
217
217
  except (KeyError, ValueError) as e:
218
- logger.error("Error parsing judge LLM response: %s", e)
218
+ logger.exception("Error parsing judge LLM response: %s", e)
219
219
  score = 0.0
220
220
  reasoning = "Error in evaluator from parsing judge LLM response."
221
221
 
@@ -148,13 +148,13 @@ async def register_ttc_tool_orchestration_function(
148
148
  result = await fn.acall_invoke(item.output)
149
149
  return item, result, None
150
150
  except Exception as e:
151
- logger.error(f"Error invoking function '{item.name}': {e}")
151
+ logger.exception(f"Error invoking function '{item.name}': {e}")
152
152
  return item, None, str(e)
153
153
 
154
154
  tasks = []
155
155
  for item in ttc_items:
156
156
  if item.name not in function_map:
157
- logger.error(f"Function '{item.name}' not found in function map.")
157
+ logger.error(f"Function '{item.name}' not found in function map.", exc_info=True)
158
158
  item.output = f"Error: Function '{item.name}' not found in function map. Check your input"
159
159
  else:
160
160
  fn = function_map[item.name]
@@ -47,11 +47,11 @@ class _FastApiFrontEndController:
47
47
  self._server_background_task = asyncio.create_task(self._server.serve())
48
48
  except asyncio.CancelledError as e:
49
49
  error_message = f"Task error occurred while starting API server: {str(e)}"
50
- logger.error(error_message, exc_info=True)
50
+ logger.error(error_message)
51
51
  raise RuntimeError(error_message) from e
52
52
  except Exception as e:
53
53
  error_message = f"Unexpected error occurred while starting API server: {str(e)}"
54
- logger.error(error_message, exc_info=True)
54
+ logger.exception(error_message)
55
55
  raise RuntimeError(error_message) from e
56
56
 
57
57
  async def stop_server(self) -> None:
@@ -63,6 +63,6 @@ class _FastApiFrontEndController:
63
63
  self._server.should_exit = True
64
64
  await self._server_background_task
65
65
  except asyncio.CancelledError as e:
66
- logger.error("Server shutdown failed: %s", str(e), exc_info=True)
66
+ logger.exception("Server shutdown failed: %s", str(e))
67
67
  except Exception as e:
68
- logger.error("Unexpected error occurred: %s", str(e), exc_info=True)
68
+ logger.exception("Unexpected error occurred: %s", str(e))
@@ -113,4 +113,4 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
113
113
  try:
114
114
  os.remove(config_file_name)
115
115
  except OSError as e:
116
- logger.error(f"Warning: Failed to delete temp file {config_file_name}: {e}")
116
+ logger.exception(f"Warning: Failed to delete temp file {config_file_name}: {e}")
@@ -215,7 +215,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
215
215
  job_store.cleanup_expired_jobs()
216
216
  logger.debug("Expired %s jobs cleaned up", name)
217
217
  except Exception as e:
218
- logger.error("Error during %s job cleanup: %s", name, e)
218
+ logger.exception("Error during %s job cleanup: %s", name, e)
219
219
  await asyncio.sleep(sleep_time_sec)
220
220
 
221
221
  async def create_cleanup_task(self, app: FastAPI, name: str, job_store: JobStore, sleep_time_sec: int = 300):
@@ -301,7 +301,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
301
301
 
302
302
  job_store.update_status(job_id, "success", output_path=str(parent_dir))
303
303
  except Exception as e:
304
- logger.error("Error in evaluation job %s: %s", job_id, str(e))
304
+ logger.exception("Error in evaluation job %s: %s", job_id, str(e))
305
305
  job_store.update_status(job_id, "failure", error=str(e))
306
306
 
307
307
  async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
@@ -735,7 +735,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
735
735
  result_type=result_type)
736
736
  job_store.update_status(job_id, "success", output=result)
737
737
  except Exception as e:
738
- logger.error("Error in evaluation job %s: %s", job_id, e)
738
+ logger.exception("Error in evaluation job %s: %s", job_id, e)
739
739
  job_store.update_status(job_id, "failure", error=str(e))
740
740
 
741
741
  def _job_status_to_response(job: JobInfo) -> AsyncGenerationStatusResponse:
@@ -170,7 +170,7 @@ class WebSocketMessageHandler:
170
170
  self._workflow_schema_type])).add_done_callback(_done_callback)
171
171
 
172
172
  except ValueError as e:
173
- logger.error("User message content not found: %s", str(e), exc_info=True)
173
+ logger.exception("User message content not found: %s", str(e))
174
174
  await self.create_websocket_message(data_model=Error(code=ErrorTypes.INVALID_USER_MESSAGE_CONTENT,
175
175
  message="User message content could not be found",
176
176
  details=str(e)),
@@ -238,7 +238,7 @@ class WebSocketMessageHandler:
238
238
  f"Message type could not be resolved by input data model: {data_model.model_dump_json()}")
239
239
 
240
240
  except (ValidationError, TypeError, ValueError) as e:
241
- logger.error("A data vaidation error ocurred creating websocket message: %s", str(e), exc_info=True)
241
+ logger.exception("A data vaidation error ocurred creating websocket message: %s", str(e))
242
242
  message = await self._message_validator.create_system_response_token_message(
243
243
  message_type=WebSocketMessageType.ERROR_MESSAGE,
244
244
  conversation_id=self._conversation_id,
@@ -97,7 +97,7 @@ class MessageValidator:
97
97
  return validated_message
98
98
 
99
99
  except (ValidationError, TypeError, ValueError) as e:
100
- logger.error("A data validation error %s occurred for message: %s", str(e), str(message), exc_info=True)
100
+ logger.exception("A data validation error %s occurred for message: %s", str(e), str(message))
101
101
  return await self.create_system_response_token_message(message_type=WebSocketMessageType.ERROR_MESSAGE,
102
102
  content=Error(code=ErrorTypes.INVALID_MESSAGE,
103
103
  message="Error validating message.",
@@ -119,7 +119,7 @@ class MessageValidator:
119
119
  return schema
120
120
 
121
121
  except (TypeError, ValueError) as e:
122
- logger.error("Error retrieving schema for message type '%s': %s", message_type, str(e), exc_info=True)
122
+ logger.exception("Error retrieving schema for message type '%s': %s", message_type, str(e))
123
123
  return Error
124
124
 
125
125
  async def convert_data_to_message_content(self, data_model: BaseModel) -> BaseModel:
@@ -156,7 +156,7 @@ class MessageValidator:
156
156
  return validated_message_content
157
157
 
158
158
  except ValueError as e:
159
- logger.error("Input data could not be converted to validated message content: %s", str(e), exc_info=True)
159
+ logger.exception("Input data could not be converted to validated message content: %s", str(e))
160
160
  return Error(code=ErrorTypes.INVALID_DATA_CONTENT, message="Input data not supported.", details=str(e))
161
161
 
162
162
  async def convert_text_content_to_human_response(self, text_content: TextContent,
@@ -191,7 +191,7 @@ class MessageValidator:
191
191
  return human_response
192
192
 
193
193
  except ValueError as e:
194
- logger.error("Error human response content not found: %s", str(e), exc_info=True)
194
+ logger.exception("Error human response content not found: %s", str(e))
195
195
  return HumanResponseText(text=str(e))
196
196
 
197
197
  async def resolve_message_type_by_data(self, data_model: BaseModel) -> str:
@@ -218,9 +218,7 @@ class MessageValidator:
218
218
  return validated_message_type
219
219
 
220
220
  except ValueError as e:
221
- logger.error("Error type not found converting data to validated websocket message content: %s",
222
- str(e),
223
- exc_info=True)
221
+ logger.exception("Error type not found converting data to validated websocket message content: %s", str(e))
224
222
  return WebSocketMessageType.ERROR_MESSAGE
225
223
 
226
224
  async def get_intermediate_step_parent_id(self, data_model: ResponseIntermediateStep) -> str:
@@ -269,7 +267,7 @@ class MessageValidator:
269
267
  timestamp=timestamp)
270
268
 
271
269
  except Exception as e:
272
- logger.error("Error creating system response token message: %s", str(e), exc_info=True)
270
+ logger.exception("Error creating system response token message: %s", str(e))
273
271
  return None
274
272
 
275
273
  async def create_system_intermediate_step_message(
@@ -308,7 +306,7 @@ class MessageValidator:
308
306
  timestamp=timestamp)
309
307
 
310
308
  except Exception as e:
311
- logger.error("Error creating system intermediate step message: %s", str(e), exc_info=True)
309
+ logger.exception("Error creating system intermediate step message: %s", str(e))
312
310
  return None
313
311
 
314
312
  async def create_system_interaction_message(
@@ -348,5 +346,5 @@ class MessageValidator:
348
346
  timestamp=timestamp)
349
347
 
350
348
  except Exception as e:
351
- logger.error("Error creating system interaction message: %s", str(e), exc_info=True)
349
+ logger.exception("Error creating system interaction message: %s", str(e))
352
350
  return None
@@ -98,9 +98,9 @@ async def generate_streaming_response(payload: typing.Any,
98
98
  yield item
99
99
  else:
100
100
  yield ResponsePayloadOutput(payload=item)
101
- except Exception as e:
101
+ except Exception:
102
102
  # Handle exceptions here
103
- raise e
103
+ raise
104
104
  finally:
105
105
  await q.close()
106
106
 
@@ -165,9 +165,9 @@ async def generate_streaming_response_full(payload: typing.Any,
165
165
  yield item
166
166
  else:
167
167
  yield ResponsePayloadOutput(payload=item)
168
- except Exception as e:
168
+ except Exception:
169
169
  # Handle exceptions here
170
- raise e
170
+ raise
171
171
  finally:
172
172
  await q.close()
173
173
 
@@ -314,6 +314,6 @@ class StepAdaptor:
314
314
  return self._handle_custom(payload, ancestry)
315
315
 
316
316
  except Exception as e:
317
- logger.error("Error processing intermediate step: %s", e, exc_info=True)
317
+ logger.exception("Error processing intermediate step: %s", e)
318
318
 
319
319
  return None
@@ -23,9 +23,11 @@ from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
24
  from nat.data_models.retry_mixin import RetryMixin
25
25
  from nat.data_models.temperature_mixin import TemperatureMixin
26
+ from nat.data_models.thinking_mixin import ThinkingMixin
27
+ from nat.data_models.top_p_mixin import TopPMixin
26
28
 
27
29
 
28
- class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, name="aws_bedrock"):
30
+ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, ThinkingMixin, name="aws_bedrock"):
29
31
  """An AWS Bedrock llm provider to be used with an LLM client."""
30
32
 
31
33
  model_config = ConfigDict(protected_namespaces=())
@@ -34,14 +36,13 @@ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, name="a
34
36
  model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
35
37
  serialization_alias="model",
36
38
  description="The model name for the hosted AWS Bedrock.")
37
- max_tokens: int | None = Field(default=1024,
38
- gt=0,
39
- description="Maximum number of tokens to generate."
40
- "This field is ONLY required when using AWS Bedrock with Langchain.")
41
- context_size: int | None = Field(default=1024,
42
- gt=0,
43
- description="Maximum number of tokens to generate."
44
- "This field is ONLY required when using AWS Bedrock with LlamaIndex.")
39
+ max_tokens: int | None = Field(default=1024, gt=0, description="Maximum number of tokens to generate.")
40
+ context_size: int | None = Field(
41
+ default=1024,
42
+ gt=0,
43
+ description="The maximum number of tokens available for input. This is only required for LlamaIndex. "
44
+ "This field is ignored for Langchain.",
45
+ )
45
46
 
46
47
  # Client parameters
47
48
  region_name: str | None = Field(default="None", description="AWS region to use.")
@@ -23,10 +23,18 @@ from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
24
  from nat.data_models.retry_mixin import RetryMixin
25
25
  from nat.data_models.temperature_mixin import TemperatureMixin
26
+ from nat.data_models.thinking_mixin import ThinkingMixin
26
27
  from nat.data_models.top_p_mixin import TopPMixin
27
28
 
28
29
 
29
- class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="azure_openai"):
30
+ class AzureOpenAIModelConfig(
31
+ LLMBaseConfig,
32
+ RetryMixin,
33
+ TemperatureMixin,
34
+ TopPMixin,
35
+ ThinkingMixin,
36
+ name="azure_openai",
37
+ ):
30
38
  """An Azure OpenAI LLM provider to be used with an LLM client."""
31
39
 
32
40
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
nat/llm/nim_llm.py CHANGED
@@ -24,10 +24,11 @@ from nat.cli.register_workflow import register_llm_provider
24
24
  from nat.data_models.llm import LLMBaseConfig
25
25
  from nat.data_models.retry_mixin import RetryMixin
26
26
  from nat.data_models.temperature_mixin import TemperatureMixin
27
+ from nat.data_models.thinking_mixin import ThinkingMixin
27
28
  from nat.data_models.top_p_mixin import TopPMixin
28
29
 
29
30
 
30
- class NIMModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="nim"):
31
+ class NIMModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, ThinkingMixin, name="nim"):
31
32
  """An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""
32
33
 
33
34
  model_config = ConfigDict(protected_namespaces=())
nat/llm/openai_llm.py CHANGED
@@ -23,10 +23,11 @@ from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
24
  from nat.data_models.retry_mixin import RetryMixin
25
25
  from nat.data_models.temperature_mixin import TemperatureMixin
26
+ from nat.data_models.thinking_mixin import ThinkingMixin
26
27
  from nat.data_models.top_p_mixin import TopPMixin
27
28
 
28
29
 
29
- class OpenAIModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="openai"):
30
+ class OpenAIModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, ThinkingMixin, name="openai"):
30
31
  """An OpenAI LLM provider to be used with an LLM client."""
31
32
 
32
33
  model_config = ConfigDict(protected_namespaces=(), extra="allow")