nvidia-nat 1.3.0a20250822__py3-none-any.whl → 1.3.0a20250824__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +0 -1
  3. nat/agent/react_agent/agent.py +21 -3
  4. nat/agent/react_agent/register.py +1 -1
  5. nat/agent/register.py +0 -1
  6. nat/agent/rewoo_agent/agent.py +0 -1
  7. nat/agent/rewoo_agent/register.py +1 -1
  8. nat/agent/tool_calling_agent/agent.py +0 -1
  9. nat/agent/tool_calling_agent/register.py +1 -1
  10. nat/authentication/api_key/api_key_auth_provider.py +1 -1
  11. nat/authentication/register.py +0 -1
  12. nat/builder/builder.py +1 -1
  13. nat/builder/context.py +9 -1
  14. nat/builder/function_base.py +3 -3
  15. nat/builder/function_info.py +5 -7
  16. nat/builder/workflow_builder.py +0 -1
  17. nat/cli/commands/evaluate.py +1 -1
  18. nat/cli/commands/info/list_components.py +7 -8
  19. nat/cli/commands/info/list_mcp.py +3 -4
  20. nat/cli/commands/registry/search.py +14 -16
  21. nat/cli/commands/start.py +0 -1
  22. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  23. nat/cli/commands/workflow/workflow_commands.py +0 -1
  24. nat/cli/type_registry.py +3 -5
  25. nat/data_models/config.py +1 -1
  26. nat/data_models/evaluate.py +1 -1
  27. nat/data_models/function_dependencies.py +6 -6
  28. nat/data_models/intermediate_step.py +3 -3
  29. nat/data_models/model_gated_field_mixin.py +125 -0
  30. nat/data_models/swe_bench_model.py +1 -1
  31. nat/data_models/temperature_mixin.py +36 -0
  32. nat/data_models/top_p_mixin.py +36 -0
  33. nat/embedder/register.py +0 -1
  34. nat/eval/dataset_handler/dataset_handler.py +5 -6
  35. nat/eval/evaluate.py +7 -8
  36. nat/eval/rag_evaluator/register.py +2 -2
  37. nat/eval/register.py +0 -1
  38. nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
  39. nat/eval/utils/weave_eval.py +3 -3
  40. nat/experimental/test_time_compute/models/strategy_base.py +3 -2
  41. nat/experimental/test_time_compute/register.py +0 -1
  42. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
  43. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
  44. nat/front_ends/fastapi/message_handler.py +13 -14
  45. nat/front_ends/fastapi/message_validator.py +4 -4
  46. nat/front_ends/fastapi/step_adaptor.py +1 -1
  47. nat/front_ends/register.py +0 -1
  48. nat/llm/aws_bedrock_llm.py +3 -3
  49. nat/llm/azure_openai_llm.py +3 -4
  50. nat/llm/nim_llm.py +4 -4
  51. nat/llm/openai_llm.py +4 -4
  52. nat/llm/register.py +0 -1
  53. nat/llm/utils/env_config_value.py +2 -3
  54. nat/object_store/register.py +0 -1
  55. nat/observability/exporter/file_exporter.py +1 -1
  56. nat/observability/register.py +3 -3
  57. nat/profiler/callbacks/langchain_callback_handler.py +1 -1
  58. nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
  59. nat/profiler/data_frame_row.py +1 -1
  60. nat/profiler/decorators/framework_wrapper.py +1 -4
  61. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  62. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  63. nat/profiler/inference_optimization/data_models.py +3 -3
  64. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  65. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  66. nat/profiler/profile_runner.py +13 -8
  67. nat/registry_handlers/package_utils.py +0 -1
  68. nat/registry_handlers/pypi/pypi_handler.py +20 -23
  69. nat/registry_handlers/register.py +3 -4
  70. nat/registry_handlers/rest/rest_handler.py +8 -9
  71. nat/retriever/register.py +0 -1
  72. nat/runtime/session.py +23 -8
  73. nat/settings/global_settings.py +0 -1
  74. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  75. nat/tool/document_search.py +1 -1
  76. nat/tool/mcp/mcp_tool.py +1 -1
  77. nat/tool/register.py +0 -1
  78. nat/utils/data_models/schema_validator.py +2 -2
  79. nat/utils/exception_handlers/automatic_retries.py +0 -2
  80. nat/utils/exception_handlers/schemas.py +1 -1
  81. nat/utils/reactive/base/observable_base.py +2 -2
  82. nat/utils/reactive/base/observer_base.py +1 -1
  83. nat/utils/reactive/observable.py +2 -2
  84. nat/utils/reactive/observer.py +2 -2
  85. nat/utils/reactive/subscription.py +1 -1
  86. nat/utils/settings/global_settings.py +4 -6
  87. nat/utils/type_utils.py +4 -4
  88. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/METADATA +1 -1
  89. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/RECORD +94 -91
  90. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/WHEEL +0 -0
  91. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/entry_points.txt +0 -0
  92. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  93. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/licenses/LICENSE.md +0 -0
  94. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/top_level.txt +0 -0
aiq/__init__.py CHANGED
@@ -13,10 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import sys
17
16
  import importlib
18
17
  import importlib.abc
19
18
  import importlib.util
19
+ import sys
20
20
  import warnings
21
21
 
22
22
 
@@ -26,7 +26,7 @@ class CompatFinder(importlib.abc.MetaPathFinder):
26
26
  self.alias_prefix = alias_prefix
27
27
  self.target_prefix = target_prefix
28
28
 
29
- def find_spec(self, fullname, path, target=None): # pylint: disable=unused-argument
29
+ def find_spec(self, fullname, path, target=None):
30
30
  if fullname == self.alias_prefix or fullname.startswith(self.alias_prefix + "."):
31
31
  # Map aiq.something -> nat.something
32
32
  target_name = self.target_prefix + fullname[len(self.alias_prefix):]
nat/agent/base.py CHANGED
@@ -179,7 +179,6 @@ class BaseAgent(ABC):
179
179
  logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
180
180
  await asyncio.sleep(sleep_time)
181
181
 
182
- # pylint: disable=C0209
183
182
  # All retries exhausted, return error message
184
183
  error_content = "Tool call failed after all retry attempts. Last error: %s" % str(last_exception)
185
184
  logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
@@ -14,8 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import json
17
- # pylint: disable=R0917
18
17
  import logging
18
+ import re
19
19
  import typing
20
20
  from json import JSONDecodeError
21
21
 
@@ -23,12 +23,14 @@ from langchain_core.agents import AgentAction
23
23
  from langchain_core.agents import AgentFinish
24
24
  from langchain_core.callbacks.base import AsyncCallbackHandler
25
25
  from langchain_core.language_models import BaseChatModel
26
+ from langchain_core.language_models import LanguageModelInput
26
27
  from langchain_core.messages.ai import AIMessage
27
28
  from langchain_core.messages.base import BaseMessage
28
29
  from langchain_core.messages.human import HumanMessage
29
30
  from langchain_core.messages.tool import ToolMessage
30
31
  from langchain_core.prompts import ChatPromptTemplate
31
32
  from langchain_core.prompts import MessagesPlaceholder
33
+ from langchain_core.runnables import Runnable
32
34
  from langchain_core.runnables.config import RunnableConfig
33
35
  from langchain_core.tools import BaseTool
34
36
  from pydantic import BaseModel
@@ -97,11 +99,27 @@ class ReActAgentGraph(DualNodeAgent):
97
99
  f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
98
100
  prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
99
101
  # construct the ReAct Agent
100
- bound_llm = llm.bind(stop=["Observation:"]) # type: ignore
101
- self.agent = prompt | bound_llm
102
+ self.agent = prompt | self._maybe_bind_llm_and_yield()
102
103
  self.tools_dict = {tool.name: tool for tool in tools}
103
104
  logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
104
105
 
106
+ def _maybe_bind_llm_and_yield(self) -> Runnable[LanguageModelInput, BaseMessage]:
107
+ """
108
+ Bind additional parameters to the LLM if needed
109
+ - if the LLM is a smart model, no need to bind any additional parameters
110
+ - if the LLM is a non-smart model, bind a stop sequence to the LLM
111
+
112
+ Returns:
113
+ Runnable[LanguageModelInput, BaseMessage]: The LLM with any additional parameters bound.
114
+ """
115
+ # models that don't need (or don't support)a stop sequence
116
+ smart_models = re.compile(r"gpt-?5", re.IGNORECASE)
117
+ if any(smart_models.search(getattr(self.llm, model, "")) for model in ["model", "model_name"]):
118
+ # no need to bind any additional parameters to the LLM
119
+ return self.llm
120
+ # add a stop sequence to the LLM
121
+ return self.llm.bind(stop=["Observation:"])
122
+
105
123
  def _get_tool(self, tool_name: str):
106
124
  try:
107
125
  return self.tools_dict.get(tool_name)
@@ -125,7 +125,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
125
125
 
126
126
  # get and return the output from the state
127
127
  state = ReActGraphState(**state)
128
- output_message = state.messages[-1] # pylint: disable=E1136
128
+ output_message = state.messages[-1]
129
129
  return ChatResponse.from_string(str(output_message.content))
130
130
 
131
131
  except Exception as ex:
nat/agent/register.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  # Import any workflows which need to be automatically registered here
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import json
17
- # pylint: disable=R0917
18
17
  import logging
19
18
  from json import JSONDecodeError
20
19
 
@@ -133,7 +133,7 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
133
133
 
134
134
  # get and return the output from the state
135
135
  state = ReWOOGraphState(**state)
136
- output_message = state.result.content # pylint: disable=E1101
136
+ output_message = state.result.content
137
137
  return ChatResponse.from_string(output_message)
138
138
 
139
139
  except Exception as ex:
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=R0917
17
16
  import logging
18
17
  import typing
19
18
 
@@ -86,7 +86,7 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
86
86
 
87
87
  # get and return the output from the state
88
88
  state = ToolCallAgentGraphState(**state)
89
- output_message = state.messages[-1] # pylint: disable=E1136
89
+ output_message = state.messages[-1]
90
90
  return output_message.content
91
91
  except Exception as ex:
92
92
  logger.exception("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
@@ -31,7 +31,7 @@ class APIKeyAuthProvider(AuthProviderBase[APIKeyAuthProviderConfig]):
31
31
  # fmt: off
32
32
  def __init__(self,
33
33
  config: APIKeyAuthProviderConfig,
34
- config_name: str | None = None) -> None: # pylint: disable=unused-argument
34
+ config_name: str | None = None) -> None:
35
35
  assert isinstance(config, APIKeyAuthProviderConfig), ("Config is not APIKeyAuthProviderConfig")
36
36
  super().__init__(config)
37
37
  # fmt: on
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  from nat.authentication.api_key import register as register_api_key
nat/builder/builder.py CHANGED
@@ -58,7 +58,7 @@ class UserManagerHolder():
58
58
  return self._context.user_manager.get_id()
59
59
 
60
60
 
61
- class Builder(ABC): # pylint: disable=too-many-public-methods
61
+ class Builder(ABC):
62
62
 
63
63
  @abstractmethod
64
64
  async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
nat/builder/context.py CHANGED
@@ -38,7 +38,7 @@ from nat.utils.reactive.subject import Subject
38
38
 
39
39
  class Singleton(type):
40
40
 
41
- def __init__(cls, name, bases, dict): # pylint: disable=W0622
41
+ def __init__(cls, name, bases, dict):
42
42
  super(Singleton, cls).__init__(name, bases, dict)
43
43
  cls.instance = None
44
44
 
@@ -65,6 +65,7 @@ class ContextState(metaclass=Singleton):
65
65
 
66
66
  def __init__(self):
67
67
  self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
68
+ self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
68
69
  self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
69
70
  self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
70
71
  self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes())
@@ -165,6 +166,13 @@ class Context:
165
166
  """
166
167
  return self._context_state.conversation_id.get()
167
168
 
169
+ @property
170
+ def user_message_id(self) -> str | None:
171
+ """
172
+ This property retrieves the user message ID which is the unique identifier for the current user message.
173
+ """
174
+ return self._context_state.user_message_id.get()
175
+
168
176
  @contextmanager
169
177
  def push_active_function(self, function_name: str, input_data: typing.Any | None):
170
178
  """
@@ -111,7 +111,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
111
111
  ValueError
112
112
  If the input type cannot be determined from the class definition
113
113
  """
114
- for base_cls in self.__class__.__orig_bases__: # pylint: disable=no-member # type: ignore
114
+ for base_cls in self.__class__.__orig_bases__:
115
115
 
116
116
  base_cls_args = typing.get_args(base_cls)
117
117
 
@@ -196,7 +196,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
196
196
  ValueError
197
197
  If the streaming output type cannot be determined from the class definition
198
198
  """
199
- for base_cls in self.__class__.__orig_bases__: # pylint: disable=no-member # type: ignore
199
+ for base_cls in self.__class__.__orig_bases__:
200
200
 
201
201
  base_cls_args = typing.get_args(base_cls)
202
202
 
@@ -269,7 +269,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
269
269
  ValueError
270
270
  If the single output type cannot be determined from the class definition
271
271
  """
272
- for base_cls in self.__class__.__orig_bases__: # pylint: disable=no-member # type: ignore
272
+ for base_cls in self.__class__.__orig_bases__:
273
273
 
274
274
  base_cls_args = typing.get_args(base_cls)
275
275
 
@@ -231,7 +231,7 @@ class FunctionDescriptor:
231
231
  else:
232
232
  annotations = [param.annotation for param in sig.parameters.values()]
233
233
 
234
- is_input_typed = all([a != sig.empty for a in annotations]) # pylint: disable=use-a-generator
234
+ is_input_typed = all([a != sig.empty for a in annotations])
235
235
 
236
236
  input_type = tuple[*annotations] if is_input_typed else None # noqa: syntax-error
237
237
 
@@ -372,18 +372,16 @@ class FunctionInfo:
372
372
 
373
373
  if (stream_to_single_fn is not None):
374
374
  raise ValueError("Cannot provide both single_fn and stream_to_single_fn")
375
- else:
376
- if (stream_to_single_fn is not None and stream_fn is None):
377
- raise ValueError("stream_fn must be provided if stream_to_single_fn is provided")
375
+ elif (stream_to_single_fn is not None and stream_fn is None):
376
+ raise ValueError("stream_fn must be provided if stream_to_single_fn is provided")
378
377
 
379
378
  if (stream_fn is not None):
380
379
  final_stream_fn = stream_fn
381
380
 
382
381
  if (single_to_stream_fn is not None):
383
382
  raise ValueError("Cannot provide both stream_fn and single_to_stream_fn")
384
- else:
385
- if (single_to_stream_fn is not None and single_fn is None):
386
- raise ValueError("single_fn must be provided if single_to_stream_fn is provided")
383
+ elif (single_to_stream_fn is not None and single_fn is None):
384
+ raise ValueError("single_fn must be provided if single_to_stream_fn is provided")
387
385
 
388
386
  if (single_fn is None and stream_fn is None):
389
387
  raise ValueError("At least one of single_fn or stream_fn must be provided")
@@ -127,7 +127,6 @@ class ConfiguredTTCStrategy:
127
127
  instance: StrategyBase
128
128
 
129
129
 
130
- # pylint: disable=too-many-public-methods
131
130
  class WorkflowBuilder(Builder, AbstractAsyncContextManager):
132
131
 
133
132
  def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
@@ -97,7 +97,7 @@ async def run_and_evaluate(config: EvaluationRunConfig):
97
97
 
98
98
  @eval_command.result_callback(replace=True)
99
99
  def process_nat_eval(
100
- processors, # pylint: disable=unused-argument
100
+ processors,
101
101
  *,
102
102
  config_file: Path,
103
103
  dataset: Path,
@@ -26,14 +26,13 @@ from nat.registry_handlers.schemas.search import SearchFields
26
26
  logger = logging.getLogger(__name__)
27
27
 
28
28
 
29
- async def search_artifacts( # pylint: disable=R0917
30
- registry_handler_config: RegistryHandlerBaseConfig,
31
- component_types: list[ComponentEnum],
32
- visualize: bool,
33
- query: str,
34
- num_results: int,
35
- query_fields: list[SearchFields],
36
- save_path: str | None) -> None:
29
+ async def search_artifacts(registry_handler_config: RegistryHandlerBaseConfig,
30
+ component_types: list[ComponentEnum],
31
+ visualize: bool,
32
+ query: str,
33
+ num_results: int,
34
+ query_fields: list[SearchFields],
35
+ save_path: str | None) -> None:
37
36
 
38
37
  from nat.cli.type_registry import GlobalTypeRegistry
39
38
  from nat.registry_handlers.schemas.search import SearchQuery
@@ -297,8 +297,7 @@ def ping(url: str, timeout: int, json_output: bool) -> None:
297
297
 
298
298
  if json_output:
299
299
  click.echo(result.model_dump_json(indent=2))
300
+ elif result.status == "healthy":
301
+ click.echo(f"Server at {result.url} is healthy (response time: {result.response_time_ms}ms)")
300
302
  else:
301
- if result.status == "healthy":
302
- click.echo(f"Server at {result.url} is healthy (response time: {result.response_time_ms}ms)")
303
- else:
304
- click.echo(f"Server at {result.url} {result.status}: {result.error}")
303
+ click.echo(f"Server at {result.url} {result.status}: {result.error}")
@@ -29,14 +29,13 @@ from nat.utils.data_models.schema_validator import validate_yaml
29
29
  logger = logging.getLogger(__name__)
30
30
 
31
31
 
32
- async def search_artifacts( # pylint: disable=R0917
33
- registry_handler_config: RegistryHandlerBaseConfig,
34
- query: str,
35
- search_fields: list[SearchFields],
36
- visualize: bool,
37
- component_types: list[ComponentEnum],
38
- save_path: str | None = None,
39
- n_results: int = 10) -> None:
32
+ async def search_artifacts(registry_handler_config: RegistryHandlerBaseConfig,
33
+ query: str,
34
+ search_fields: list[SearchFields],
35
+ visualize: bool,
36
+ component_types: list[ComponentEnum],
37
+ save_path: str | None = None,
38
+ n_results: int = 10) -> None:
40
39
 
41
40
  from nat.cli.type_registry import GlobalTypeRegistry
42
41
  from nat.registry_handlers.schemas.search import SearchQuery
@@ -116,14 +115,13 @@ async def search_artifacts( # pylint: disable=R0917
116
115
  required=False,
117
116
  help=("The component types to include in search."),
118
117
  )
119
- def search( # pylint: disable=R0917
120
- config_file: str,
121
- channel: str,
122
- fields: list[str],
123
- query: str,
124
- component_types: list[ComponentEnum],
125
- n_results: int,
126
- output_path: str) -> None:
118
+ def search(config_file: str,
119
+ channel: str,
120
+ fields: list[str],
121
+ query: str,
122
+ component_types: list[ComponentEnum],
123
+ n_results: int,
124
+ output_path: str) -> None:
127
125
  """
128
126
  Search for NAT artifacts with the specified configuration.
129
127
  """
nat/cli/commands/start.py CHANGED
@@ -35,7 +35,6 @@ logger = logging.getLogger(__name__)
35
35
 
36
36
  class StartCommandGroup(click.Group):
37
37
 
38
- # pylint: disable=too-many-positional-arguments
39
38
  def __init__(
40
39
  self,
41
40
  name: str | None = None,
@@ -1,4 +1,3 @@
1
- # pylint: disable=unused-import
2
1
  # flake8: noqa
3
2
 
4
3
  # Import any tools which need to be automatically registered here
@@ -161,7 +161,6 @@ def get_workflow_path_from_name(workflow_name: str):
161
161
  default="NAT function template. Please update the description.",
162
162
  help="""A description of the component being created. Will be used to populate the docstring and will describe the
163
163
  component when inspecting installed components using 'nat info component'""")
164
- # pylint: disable=missing-param-doc
165
164
  def create_command(workflow_name: str, install: bool, workflow_dir: str, description: str):
166
165
  """
167
166
  Create a new NAT workflow using templates.
nat/cli/type_registry.py CHANGED
@@ -298,7 +298,7 @@ class RegisteredPackage(BaseModel):
298
298
  discovery_metadata: DiscoveryMetadata
299
299
 
300
300
 
301
- class TypeRegistry: # pylint: disable=too-many-public-methods
301
+ class TypeRegistry:
302
302
 
303
303
  def __init__(self) -> None:
304
304
  # Telemetry Exporters
@@ -779,7 +779,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
779
779
 
780
780
  self._registration_changed()
781
781
 
782
- def get_infos_by_type(self, component_type: ComponentEnum) -> dict: # pylint: disable=R0911
782
+ def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
783
783
 
784
784
  if component_type == ComponentEnum.FRONT_END:
785
785
  return self._registered_front_end_infos
@@ -849,8 +849,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
849
849
 
850
850
  raise ValueError(f"Supplied an unsupported component type {component_type}")
851
851
 
852
- def get_registered_types_by_component_type( # pylint: disable=R0911
853
- self, component_type: ComponentEnum) -> list[str]:
852
+ def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
854
853
 
855
854
  if component_type == ComponentEnum.FUNCTION:
856
855
  return [i.static_type() for i in self._registered_functions]
@@ -925,7 +924,6 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
925
924
  if (short_names[key.local_name] == 1):
926
925
  type_list.append((key.local_name, key.config_type))
927
926
 
928
- # pylint: disable=consider-alternative-union-syntax
929
927
  return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
930
928
 
931
929
  def compute_annotation(self, cls: type[TypedBaseModelT]):
nat/data_models/config.py CHANGED
@@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
47
47
 
48
48
 
49
49
  def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
50
- from nat.cli.type_registry import GlobalTypeRegistry # pylint: disable=cyclic-import
50
+ from nat.cli.type_registry import GlobalTypeRegistry
51
51
 
52
52
  new_errors = []
53
53
  logged_once = False
@@ -108,7 +108,7 @@ class EvalConfig(BaseModel):
108
108
  @classmethod
109
109
  def rebuild_annotations(cls):
110
110
 
111
- from nat.cli.type_registry import GlobalTypeRegistry # pylint: disable=cyclic-import
111
+ from nat.cli.type_registry import GlobalTypeRegistry
112
112
 
113
113
  type_registry = GlobalTypeRegistry.get()
114
114
 
@@ -54,19 +54,19 @@ class FunctionDependencies(BaseModel):
54
54
  return list(v)
55
55
 
56
56
  def add_function(self, function: str):
57
- self.functions.add(function) # pylint: disable=no-member
57
+ self.functions.add(function)
58
58
 
59
59
  def add_llm(self, llm: str):
60
- self.llms.add(llm) # pylint: disable=no-member
60
+ self.llms.add(llm)
61
61
 
62
62
  def add_embedder(self, embedder: str):
63
- self.embedders.add(embedder) # pylint: disable=no-member
63
+ self.embedders.add(embedder)
64
64
 
65
65
  def add_memory_client(self, memory_client: str):
66
- self.memory_clients.add(memory_client) # pylint: disable=no-member
66
+ self.memory_clients.add(memory_client)
67
67
 
68
68
  def add_object_store(self, object_store: str):
69
- self.object_stores.add(object_store) # pylint: disable=no-member
69
+ self.object_stores.add(object_store)
70
70
 
71
71
  def add_retriever(self, retriever: str):
72
- self.retrievers.add(retriever) # pylint: disable=no-member
72
+ self.retrievers.add(retriever)
@@ -142,7 +142,7 @@ class IntermediateStepPayload(BaseModel):
142
142
  UUID: str = Field(default_factory=lambda: str(uuid.uuid4()))
143
143
 
144
144
  @property
145
- def event_category(self) -> IntermediateStepCategory: # pylint: disable=too-many-return-statements
145
+ def event_category(self) -> IntermediateStepCategory:
146
146
  match self.event_type:
147
147
  case IntermediateStepType.LLM_START:
148
148
  return IntermediateStepCategory.LLM
@@ -180,7 +180,7 @@ class IntermediateStepPayload(BaseModel):
180
180
  raise ValueError(f"Unknown event type: {self.event_type}")
181
181
 
182
182
  @property
183
- def event_state(self) -> IntermediateStepState: # pylint: disable=too-many-return-statements
183
+ def event_state(self) -> IntermediateStepState:
184
184
  match self.event_type:
185
185
  case IntermediateStepType.LLM_START:
186
186
  return IntermediateStepState.START
@@ -290,7 +290,7 @@ class IntermediateStep(BaseModel):
290
290
  return self.payload.usage_info
291
291
 
292
292
  @property
293
- def UUID(self) -> str: # pylint: disable=invalid-name
293
+ def UUID(self) -> str:
294
294
  return self.payload.UUID
295
295
 
296
296
  @property
@@ -0,0 +1,125 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import Sequence
17
+ from re import Pattern
18
+ from typing import Generic
19
+ from typing import TypeVar
20
+
21
+ from pydantic import BaseModel
22
+ from pydantic import model_validator
23
+
24
+ T = TypeVar("T")
25
+
26
+
27
+ class ModelGatedFieldMixin(Generic[T]):
28
+ """
29
+ A mixin that gates a field based on model support.
30
+
31
+ This should be used to automatically validate a field based on a given model.
32
+
33
+ Parameters
34
+ ----------
35
+ field_name: `str`
36
+ The name of the field.
37
+ default_if_supported: `T`
38
+ The default value of the field if it is supported for the model.
39
+ unsupported_models: `Sequence[Pattern[str]] | None`
40
+ A sequence of regex patterns that match the model names NOT supported for the field.
41
+ Defaults to None.
42
+ supported_models: `Sequence[Pattern[str]] | None`
43
+ A sequence of regex patterns that match the model names supported for the field.
44
+ Defaults to None.
45
+ model_keys: `Sequence[str]`
46
+ A sequence of keys that are used to validate the field.
47
+ Defaults to ("model_name", "model", "azure_deployment",)
48
+ """
49
+
50
+ def __init_subclass__(
51
+ cls,
52
+ field_name: str | None = None,
53
+ default_if_supported: T | None = None,
54
+ unsupported_models: Sequence[Pattern[str]] | None = None,
55
+ supported_models: Sequence[Pattern[str]] | None = None,
56
+ model_keys: Sequence[str] = ("model_name", "model", "azure_deployment"),
57
+ ) -> None:
58
+ """
59
+ Store the class variables for the field and define the model validator.
60
+ """
61
+ super().__init_subclass__()
62
+ if ModelGatedFieldMixin in cls.__bases__:
63
+ if field_name is None:
64
+ raise ValueError("field_name must be provided when subclassing ModelGatedFieldMixin")
65
+ if default_if_supported is None:
66
+ raise ValueError("default_if_supported must be provided when subclassing ModelGatedFieldMixin")
67
+ if unsupported_models is None and supported_models is None:
68
+ raise ValueError("Either unsupported_models or supported_models must be provided")
69
+ if unsupported_models is not None and supported_models is not None:
70
+ raise ValueError("Only one of unsupported_models or supported_models must be provided")
71
+ if model_keys is not None and len(model_keys) == 0:
72
+ raise ValueError("model_keys must be provided and non-empty when subclassing ModelGatedFieldMixin")
73
+ cls.field_name = field_name
74
+ cls.default_if_supported = default_if_supported
75
+ cls.unsupported_models = unsupported_models
76
+ cls.supported_models = supported_models
77
+ if model_keys is not None:
78
+ cls.model_keys = model_keys
79
+
80
+ @classmethod
81
+ def check_model(cls, model_name: str) -> bool:
82
+ """
83
+ Check if a model is supported for a given field.
84
+
85
+ Args:
86
+ model_name: The name of the model to check.
87
+ """
88
+ unsupported = getattr(cls, "unsupported_models", None)
89
+ supported = getattr(cls, "supported_models", None)
90
+ if unsupported is not None:
91
+ return not any(p.search(model_name) for p in unsupported)
92
+ if supported is not None:
93
+ return any(p.search(model_name) for p in supported)
94
+ return False
95
+
96
+ cls._model_gated_field_check_model = check_model
97
+
98
+ @classmethod
99
+ def detect_support(cls, instance: BaseModel) -> str | None:
100
+ for key in getattr(cls, "model_keys"):
101
+ if hasattr(instance, key):
102
+ model_name_value = getattr(instance, key)
103
+ is_supported = getattr(cls, "_model_gated_field_check_model")(str(model_name_value))
104
+ return key if not is_supported else None
105
+ return None
106
+
107
+ cls._model_gated_field_detect_support = detect_support
108
+
109
+ @model_validator(mode="after")
110
+ def model_validate(self):
111
+ klass = self.__class__
112
+
113
+ field_name_local = getattr(klass, "field_name")
114
+ current_value = getattr(self, field_name_local, None)
115
+
116
+ found_key = klass._model_gated_field_detect_support(self)
117
+ if found_key is not None:
118
+ if current_value is not None:
119
+ raise ValueError(
120
+ f"{field_name_local} is not supported for {found_key}: {getattr(self, found_key)}")
121
+ elif current_value is None:
122
+ setattr(self, field_name_local, getattr(klass, "default_if_supported", None))
123
+ return self
124
+
125
+ cls._model_gated_field_model_validator = model_validate
@@ -39,7 +39,7 @@ class SWEBenchInput(BaseModel):
39
39
 
40
40
  # Handle improperly formatted JSON strings for list fields
41
41
  @field_validator("FAIL_TO_PASS", "PASS_TO_PASS", mode="before")
42
- def parse_list_fields(cls, value): # pylint: disable=no-self-argument
42
+ def parse_list_fields(cls, value):
43
43
  if isinstance(value, str):
44
44
  # Attempt to parse the string as a list
45
45
  return json.loads(value)