nvidia-nat 1.3.0a20250823__py3-none-any.whl → 1.3.0a20250826__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -57,6 +57,9 @@ class EvalOutputConfig(BaseModel):
57
57
  dir: Path = Path("./.tmp/nat/examples/default/")
58
58
  # S3 prefix for the workflow and evaluation results
59
59
  remote_dir: str | None = None
60
+ # Custom function to pre-evaluation process the eval input
61
+ # Format: "module.path.function_name"
62
+ custom_pre_eval_process_function: str | None = None
60
63
  # Custom scripts to run after the workflow and evaluation results are saved
61
64
  custom_scripts: dict[str, EvalCustomScriptConfig] = {}
62
65
  # S3 config for uploading the contents of the output directory
@@ -0,0 +1,244 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import Sequence
17
+ from dataclasses import dataclass
18
+ from re import Pattern
19
+ from typing import Generic
20
+ from typing import TypeVar
21
+
22
+ from pydantic import model_validator
23
+
24
+
25
+ @dataclass
26
+ class GatedFieldMixinConfig:
27
+ """Configuration for a gated field mixin."""
28
+
29
+ field_name: str
30
+ default_if_supported: object | None
31
+ unsupported: Sequence[Pattern[str]] | None
32
+ supported: Sequence[Pattern[str]] | None
33
+ keys: Sequence[str]
34
+
35
+
36
+ T = TypeVar("T")
37
+
38
+
39
+ class GatedFieldMixin(Generic[T]):
40
+ """
41
+ A mixin that gates a field based on specified keys.
42
+
43
+ This should be used to automatically validate a field based on a given key.
44
+
45
+ Parameters
46
+ ----------
47
+ field_name: `str`
48
+ The name of the field.
49
+ default_if_supported: `T | None`
50
+ The default value of the field if it is supported for the key.
51
+ keys: `Sequence[str]`
52
+ A sequence of keys that are used to validate the field.
53
+ unsupported: `Sequence[Pattern[str]] | None`
54
+ A sequence of regex patterns that match the key names NOT supported for the field.
55
+ Defaults to None.
56
+ supported: `Sequence[Pattern[str]] | None`
57
+ A sequence of regex patterns that match the key names supported for the field.
58
+ Defaults to None.
59
+ """
60
+
61
+ def __init_subclass__(
62
+ cls,
63
+ field_name: str | None = None,
64
+ default_if_supported: T | None = None,
65
+ keys: Sequence[str] | None = None,
66
+ unsupported: Sequence[Pattern[str]] | None = None,
67
+ supported: Sequence[Pattern[str]] | None = None,
68
+ ) -> None:
69
+ """Store the class variables for the field and define the gated field validator."""
70
+ super().__init_subclass__()
71
+
72
+ # Check if this class directly inherits from GatedFieldMixin
73
+ has_gated_field_mixin = GatedFieldMixin in cls.__bases__
74
+
75
+ if has_gated_field_mixin:
76
+ if keys is None:
77
+ raise ValueError("keys must be provided when subclassing GatedFieldMixin")
78
+ if field_name is None:
79
+ raise ValueError("field_name must be provided when subclassing GatedFieldMixin")
80
+
81
+ cls._setup_direct_mixin(field_name, default_if_supported, unsupported, supported, keys)
82
+
83
+ # Always try to collect mixins and create validators for multiple inheritance
84
+ # This handles both direct inheritance and deep inheritance chains
85
+ all_mixins = cls._collect_all_mixin_configs()
86
+ if all_mixins:
87
+ cls._create_combined_validator(all_mixins)
88
+
89
+ @classmethod
90
+ def _setup_direct_mixin(
91
+ cls,
92
+ field_name: str,
93
+ default_if_supported: T | None,
94
+ unsupported: Sequence[Pattern[str]] | None,
95
+ supported: Sequence[Pattern[str]] | None,
96
+ keys: Sequence[str],
97
+ ) -> None:
98
+ """Set up a class that directly inherits from GatedFieldMixin."""
99
+ cls._validate_mixin_parameters(unsupported, supported, keys)
100
+
101
+ # Create and store validator
102
+ validator = cls._create_gated_field_validator(field_name, default_if_supported, unsupported, supported, keys)
103
+ validator_name = f"_gated_field_validator_{field_name}"
104
+ setattr(cls, validator_name, validator)
105
+
106
+ # Store mixin info for multiple inheritance
107
+ if not hasattr(cls, "_gated_field_mixins"):
108
+ cls._gated_field_mixins = []
109
+
110
+ cls._gated_field_mixins.append(
111
+ GatedFieldMixinConfig(
112
+ field_name,
113
+ default_if_supported,
114
+ unsupported,
115
+ supported,
116
+ keys,
117
+ ))
118
+
119
+ @classmethod
120
+ def _validate_mixin_parameters(
121
+ cls,
122
+ unsupported: Sequence[Pattern[str]] | None,
123
+ supported: Sequence[Pattern[str]] | None,
124
+ keys: Sequence[str],
125
+ ) -> None:
126
+ """Validate that all required parameters are provided."""
127
+ if unsupported is None and supported is None:
128
+ raise ValueError("Either unsupported or supported must be provided")
129
+ if unsupported is not None and supported is not None:
130
+ raise ValueError("Only one of unsupported or supported must be provided")
131
+ if len(keys) == 0:
132
+ raise ValueError("keys must be provided and non-empty when subclassing GatedFieldMixin")
133
+
134
+ @classmethod
135
+ def _create_gated_field_validator(
136
+ cls,
137
+ field_name: str,
138
+ default_if_supported: T | None,
139
+ unsupported: Sequence[Pattern[str]] | None,
140
+ supported: Sequence[Pattern[str]] | None,
141
+ keys: Sequence[str],
142
+ ):
143
+ """Create the model validator function."""
144
+
145
+ @model_validator(mode="after")
146
+ def gated_field_validator(self):
147
+ """Validate the gated field."""
148
+ current_value = getattr(self, field_name, None)
149
+ is_supported = cls._check_field_support(self, unsupported, supported, keys)
150
+ if not is_supported:
151
+ if current_value is not None:
152
+ blocking_key = cls._find_blocking_key(self, unsupported, supported, keys)
153
+ value = getattr(self, blocking_key, "<unknown>")
154
+ raise ValueError(f"{field_name} is not supported for {blocking_key}: {value}")
155
+ elif current_value is None:
156
+ setattr(self, field_name, default_if_supported)
157
+ return self
158
+
159
+ return gated_field_validator
160
+
161
+ @classmethod
162
+ def _check_field_support(
163
+ cls,
164
+ instance: object,
165
+ unsupported: Sequence[Pattern[str]] | None,
166
+ supported: Sequence[Pattern[str]] | None,
167
+ keys: Sequence[str],
168
+ ) -> bool:
169
+ """Check if a specific field is supported based on its configuration and keys."""
170
+ for key in keys:
171
+ if not hasattr(instance, key):
172
+ continue
173
+ value = str(getattr(instance, key))
174
+ if supported is not None:
175
+ return any(p.search(value) for p in supported)
176
+ elif unsupported is not None:
177
+ return not any(p.search(value) for p in unsupported)
178
+ # Default to supported if no model keys found
179
+ return True
180
+
181
+ @classmethod
182
+ def _find_blocking_key(
183
+ cls,
184
+ instance: object,
185
+ unsupported: Sequence[Pattern[str]] | None,
186
+ supported: Sequence[Pattern[str]] | None,
187
+ keys: Sequence[str],
188
+ ) -> str:
189
+ """Find which key is blocking the field."""
190
+ for key in keys:
191
+ if not hasattr(instance, key):
192
+ continue
193
+ value = str(getattr(instance, key))
194
+ if supported is not None:
195
+ if not any(p.search(value) for p in supported):
196
+ return key
197
+ elif unsupported is not None:
198
+ if any(p.search(value) for p in unsupported):
199
+ return key
200
+
201
+ return "<unknown>"
202
+
203
+ @classmethod
204
+ def _collect_all_mixin_configs(cls) -> list[GatedFieldMixinConfig]:
205
+ """Collect all mixin configurations from base classes."""
206
+ all_mixins = []
207
+ for base in cls.__bases__:
208
+ if hasattr(base, "_gated_field_mixins"):
209
+ all_mixins.extend(base._gated_field_mixins)
210
+ return all_mixins
211
+
212
+ @classmethod
213
+ def _create_combined_validator(cls, all_mixins: list[GatedFieldMixinConfig]) -> None:
214
+ """Create a combined validator that handles all fields."""
215
+
216
+ @model_validator(mode="after")
217
+ def combined_gated_field_validator(self):
218
+ """Validate all gated fields."""
219
+ for mixin_config in all_mixins:
220
+ field_name_local = mixin_config.field_name
221
+ current_value = getattr(self, field_name_local, None)
222
+ if not self._check_field_support_instance(mixin_config):
223
+ if current_value is not None:
224
+ blocking_key = self._find_blocking_key_instance(mixin_config)
225
+ value = getattr(self, blocking_key, "<unknown>")
226
+ raise ValueError(f"{field_name_local} is not supported for {blocking_key}: {value}")
227
+ elif current_value is None:
228
+ setattr(self, field_name_local, mixin_config.default_if_supported)
229
+
230
+ return self
231
+
232
+ cls._combined_gated_field_validator = combined_gated_field_validator
233
+
234
+ # Add helper methods
235
+ def _check_field_support_instance(self, mixin_config: GatedFieldMixinConfig) -> bool:
236
+ """Check if a specific field is supported based on its configuration and keys."""
237
+ return cls._check_field_support(self, mixin_config.unsupported, mixin_config.supported, mixin_config.keys)
238
+
239
+ def _find_blocking_key_instance(self, mixin_config: GatedFieldMixinConfig) -> str:
240
+ """Find which key is blocking the field."""
241
+ return cls._find_blocking_key(self, mixin_config.unsupported, mixin_config.supported, mixin_config.keys)
242
+
243
+ cls._check_field_support_instance = _check_field_support_instance
244
+ cls._find_blocking_key_instance = _find_blocking_key_instance
@@ -18,19 +18,26 @@ import re
18
18
  from pydantic import BaseModel
19
19
  from pydantic import Field
20
20
 
21
- from nat.data_models.model_gated_field_mixin import ModelGatedFieldMixin
22
-
23
- _UNSUPPORTED_TEMPERATURE_MODELS = (re.compile(r"gpt-?5", re.IGNORECASE), )
21
+ from nat.data_models.gated_field_mixin import GatedFieldMixin
24
22
 
25
23
 
26
24
  class TemperatureMixin(
27
25
  BaseModel,
28
- ModelGatedFieldMixin[float],
26
+ GatedFieldMixin[float],
29
27
  field_name="temperature",
30
28
  default_if_supported=0.0,
31
- unsupported_models=_UNSUPPORTED_TEMPERATURE_MODELS,
29
+ keys=("model_name", "model", "azure_deployment"),
30
+ unsupported=(re.compile(r"gpt-?5", re.IGNORECASE), ),
32
31
  ):
33
32
  """
34
- Mixin class for temperature configuration.
33
+ Mixin class for temperature configuration. Unsupported on models like gpt-5.
34
+
35
+ Attributes:
36
+ temperature: Sampling temperature in [0, 1]. Defaults to 0.0 when supported on the model.
35
37
  """
36
- temperature: float | None = Field(default=None, ge=0.0, le=1.0, description="Sampling temperature in [0, 1].")
38
+ temperature: float | None = Field(
39
+ default=None,
40
+ ge=0.0,
41
+ le=1.0,
42
+ description="Sampling temperature in [0, 1]. Defaults to 0.0 when supported on the model.",
43
+ )
@@ -18,19 +18,26 @@ import re
18
18
  from pydantic import BaseModel
19
19
  from pydantic import Field
20
20
 
21
- from nat.data_models.model_gated_field_mixin import ModelGatedFieldMixin
22
-
23
- _UNSUPPORTED_TOP_P_MODELS = (re.compile(r"gpt-?5", re.IGNORECASE), )
21
+ from nat.data_models.gated_field_mixin import GatedFieldMixin
24
22
 
25
23
 
26
24
  class TopPMixin(
27
25
  BaseModel,
28
- ModelGatedFieldMixin[float],
26
+ GatedFieldMixin[float],
29
27
  field_name="top_p",
30
28
  default_if_supported=1.0,
31
- unsupported_models=_UNSUPPORTED_TOP_P_MODELS,
29
+ keys=("model_name", "model", "azure_deployment"),
30
+ unsupported=(re.compile(r"gpt-?5", re.IGNORECASE), ),
32
31
  ):
33
32
  """
34
- Mixin class for top-p configuration.
33
+ Mixin class for top-p configuration. Unsupported on models like gpt-5.
34
+
35
+ Attributes:
36
+ top_p: Top-p for distribution sampling. Defaults to 1.0 when supported on the model.
35
37
  """
36
- top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="Top-p for distribution sampling.")
38
+ top_p: float | None = Field(
39
+ default=None,
40
+ ge=0.0,
41
+ le=1.0,
42
+ description="Top-p for distribution sampling. Defaults to 1.0 when supported on the model.",
43
+ )
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import importlib
16
17
  import json
17
18
  import math
18
19
  from pathlib import Path
@@ -41,7 +42,8 @@ class DatasetHandler:
41
42
  reps: int,
42
43
  concurrency: int,
43
44
  num_passes: int = 1,
44
- adjust_dataset_size: bool = False):
45
+ adjust_dataset_size: bool = False,
46
+ custom_pre_eval_process_function: str | None = None):
45
47
  from nat.eval.intermediate_step_adapter import IntermediateStepAdapter
46
48
 
47
49
  self.dataset_config = dataset_config
@@ -53,6 +55,9 @@ class DatasetHandler:
53
55
  self.num_passes = num_passes
54
56
  self.adjust_dataset_size = adjust_dataset_size
55
57
 
58
+ # Custom pre-evaluation process function
59
+ self.custom_pre_eval_process_function = custom_pre_eval_process_function
60
+
56
61
  # Helpers
57
62
  self.intermediate_step_adapter = IntermediateStepAdapter()
58
63
 
@@ -330,6 +335,66 @@ class DatasetHandler:
330
335
  filtered_steps = self.intermediate_step_adapter.filter_intermediate_steps(intermediate_steps, event_filter)
331
336
  return self.intermediate_step_adapter.serialize_intermediate_steps(filtered_steps)
332
337
 
338
+ def pre_eval_process_eval_input(self, eval_input: EvalInput) -> EvalInput:
339
+ """
340
+ Pre-evaluation process the eval input using custom function if provided.
341
+
342
+ The custom pre-evaluation process function should have the signature:
343
+ def custom_pre_eval_process(item: EvalInputItem) -> EvalInputItem
344
+
345
+ The framework will iterate through all items and call this function on each one.
346
+
347
+ Args:
348
+ eval_input: The EvalInput object to pre-evaluation process
349
+
350
+ Returns:
351
+ The pre-evaluation processed EvalInput object
352
+ """
353
+ if self.custom_pre_eval_process_function:
354
+ try:
355
+ custom_function = self._load_custom_pre_eval_process_function()
356
+ processed_items = []
357
+
358
+ for item in eval_input.eval_input_items:
359
+ processed_item = custom_function(item)
360
+ if not isinstance(processed_item, EvalInputItem):
361
+ raise TypeError(f"Custom pre-evaluation '{self.custom_pre_eval_process_function}' must return "
362
+ f"EvalInputItem, got {type(processed_item)}")
363
+ processed_items.append(processed_item)
364
+
365
+ return EvalInput(eval_input_items=processed_items)
366
+ except Exception as e:
367
+ raise RuntimeError(f"Error calling custom pre-evaluation process function "
368
+ f"'{self.custom_pre_eval_process_function}': {e}") from e
369
+
370
+ return eval_input
371
+
372
+ def _load_custom_pre_eval_process_function(self):
373
+ """
374
+ Import and return the custom pre-evaluation process function using standard Python import path.
375
+
376
+ The function should process individual EvalInputItem objects.
377
+ """
378
+ # Split the function path to get module and function name
379
+ if "." not in self.custom_pre_eval_process_function:
380
+ raise ValueError(f"Invalid custom_pre_eval_process_function '{self.custom_pre_eval_process_function}'. "
381
+ "Expected format: '<module_path>.<function_name>'")
382
+ module_path, function_name = self.custom_pre_eval_process_function.rsplit(".", 1)
383
+
384
+ # Import the module
385
+ module = importlib.import_module(module_path)
386
+
387
+ # Get the function from the module
388
+ if not hasattr(module, function_name):
389
+ raise AttributeError(f"Function '{function_name}' not found in module '{module_path}'")
390
+
391
+ custom_function = getattr(module, function_name)
392
+
393
+ if not callable(custom_function):
394
+ raise ValueError(f"'{self.custom_pre_eval_process_function}' is not callable")
395
+
396
+ return custom_function
397
+
333
398
  def publish_eval_input(self,
334
399
  eval_input,
335
400
  workflow_output_step_filter: list[IntermediateStepType] | None = None) -> str:
nat/eval/evaluate.py CHANGED
@@ -486,11 +486,14 @@ class EvaluationRun:
486
486
  usage_stats=UsageStats(),
487
487
  profiler_results=ProfilerResults())
488
488
 
489
+ custom_pre_eval_process_function = self.eval_config.general.output.custom_pre_eval_process_function \
490
+ if self.eval_config.general.output else None
489
491
  dataset_handler = DatasetHandler(dataset_config=dataset_config,
490
492
  reps=self.config.reps,
491
493
  concurrency=self.eval_config.general.max_concurrency,
492
494
  num_passes=self.config.num_passes,
493
- adjust_dataset_size=self.config.adjust_dataset_size)
495
+ adjust_dataset_size=self.config.adjust_dataset_size,
496
+ custom_pre_eval_process_function=custom_pre_eval_process_function)
494
497
  self.eval_input = dataset_handler.get_eval_input_from_dataset(self.config.dataset)
495
498
  if not self.eval_input.eval_input_items:
496
499
  logger.info("Dataset is empty. Nothing to evaluate.")
@@ -507,8 +510,8 @@ class EvaluationRun:
507
510
  # Initialize Weave integration
508
511
  self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config)
509
512
 
510
- # Run workflow
511
513
  with self.eval_trace_context.evaluation_context():
514
+ # Run workflow
512
515
  if self.config.endpoint:
513
516
  await self.run_workflow_remote()
514
517
  elif not self.config.skip_workflow:
@@ -517,6 +520,9 @@ class EvaluationRun:
517
520
  max_concurrency=self.eval_config.general.max_concurrency)
518
521
  await self.run_workflow_local(session_manager)
519
522
 
523
+ # Pre-evaluation process the workflow output
524
+ self.eval_input = dataset_handler.pre_eval_process_eval_input(self.eval_input)
525
+
520
526
  # Evaluate
521
527
  evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
522
528
  await self.run_evaluators(evaluators)
@@ -29,6 +29,19 @@ class EvalInputItem(BaseModel):
29
29
  trajectory: list[IntermediateStep] = [] # populated by the workflow
30
30
  full_dataset_entry: typing.Any
31
31
 
32
+ def copy_with_updates(self, **updates) -> "EvalInputItem":
33
+ """
34
+ Copy EvalInputItem with optional field updates.
35
+ """
36
+ # Get all current fields
37
+ item_data = self.model_dump()
38
+
39
+ # Apply any updates
40
+ item_data.update(updates)
41
+
42
+ # Create new item with all fields
43
+ return EvalInputItem(**item_data)
44
+
32
45
 
33
46
  class EvalInput(BaseModel):
34
47
  eval_input_items: list[EvalInputItem]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat
3
- Version: 1.3.0a20250823
3
+ Version: 1.3.0a20250826
4
4
  Summary: NVIDIA NeMo Agent toolkit
5
5
  Author: NVIDIA Corporation
6
6
  Maintainer: NVIDIA Corporation
@@ -103,18 +103,18 @@ nat/data_models/config.py,sha256=VBin3qeWxX8DUkw5lJ6RD7dStR7qPp52XIsy0cJkUP4,171
103
103
  nat/data_models/dataset_handler.py,sha256=UDQLpStUwLEMm0e6BvuOFEsHYSopWu5DhinxEv2U6og,5523
104
104
  nat/data_models/discovery_metadata.py,sha256=cIAheyVuWksB5Tnv8yB65k4kPaCSUXrbcO75O7puUTc,13492
105
105
  nat/data_models/embedder.py,sha256=nPhthEQDtzAMGd8gFRB1ZfJpN5M9DJvv0h28ohHnTmI,1002
106
- nat/data_models/evaluate.py,sha256=4d9tbLXdSqQ8YUIXB3xbZ0nYnGUC1n2ioFbR_RVGQoI,4675
106
+ nat/data_models/evaluate.py,sha256=L0GdNh_c8jii-MiK8oHW9sUUsGO3l1FMsprr-UazT5c,4836
107
107
  nat/data_models/evaluator.py,sha256=bd2njsyQB2t6ClJ66gJiCjYHsQpWZwPD7rsU0J109TI,939
108
108
  nat/data_models/front_end.py,sha256=z8k6lSWjt1vMOYFbjWQxodpwAqPeuGS0hRBjsriDW2s,932
109
109
  nat/data_models/function.py,sha256=M_duXVXL5MvYe0WVLvqEgEzXs0UAYNSMfy9ZTpxuKPA,1013
110
110
  nat/data_models/function_dependencies.py,sha256=pMTSZH2vFfW8DEmz60SAUZUeJtWsjRKtN5elx0ACtY0,2558
111
+ nat/data_models/gated_field_mixin.py,sha256=94jmbB6dYciBuqb2BODKh6F73LHsUN-aInOY_xNlcQ8,9793
111
112
  nat/data_models/interactive.py,sha256=qOkxyYPQYEBIBMDAA1rjfYcdvf6-iCM4qPV8Awc4VGw,8794
112
113
  nat/data_models/intermediate_step.py,sha256=OhdGAtSBiMiCfT1R-BjihUvcVoQsRsxHDQcJ1gMDVXA,11585
113
114
  nat/data_models/invocation_node.py,sha256=nDRylgzBfJduGA-lme9xN4P6BdOYj0L6ytLHnTuetMI,1618
114
115
  nat/data_models/llm.py,sha256=PCTeCPg2YnYk5tvSu7XOEVr2Cg_wSkjAdVsoiTg8Joo,968
115
116
  nat/data_models/logging.py,sha256=1QtVjIQ99PgMYUuzw4h1FAoPRteZY7uf3oFTqV3ONgA,940
116
117
  nat/data_models/memory.py,sha256=IKwe7CflCto30j4yI5yQtq8DXfMilAJ17S5NcsSDrOQ,1052
117
- nat/data_models/model_gated_field_mixin.py,sha256=AvhYO7GNciRmGLfWu9QOhYLipwRcJheCxHBxcDmsp_k,5653
118
118
  nat/data_models/object_store.py,sha256=S8YY6i8ALgRPuggUI1FCG-xbvwPWuaCg1lJnZOx5scM,1515
119
119
  nat/data_models/profiler.py,sha256=z3IlEhj-veB4Yz85271bTkScSUkVwK50tR3dwlDRgcE,1781
120
120
  nat/data_models/registry_handler.py,sha256=g1rFaz4uSydMJn7qpdX-DNHJd_rNf8tXYN49dLDYHPo,968
@@ -125,8 +125,8 @@ nat/data_models/step_adaptor.py,sha256=1qn56wB_nIYBM5IjN4ftsltCAkqaJd3Sqe5v0TRR4
125
125
  nat/data_models/streaming.py,sha256=sSqJqLqb70qyw69_4R9QC2RMbRw7UjTLPdo3FYBUGxE,1159
126
126
  nat/data_models/swe_bench_model.py,sha256=uZs-hLFuT1B5CiPFwFg1PHinDW8PHne8TBzu7tHFv_k,1718
127
127
  nat/data_models/telemetry_exporter.py,sha256=P7kqxIQnFVuvo_UFpH9QSB8fACy_0U2Uzkw_IfWXagE,998
128
- nat/data_models/temperature_mixin.py,sha256=qra4YH6xmrwqMPo27zylbEFJz6z8XAL9b1U1-NyJ3_s,1291
129
- nat/data_models/top_p_mixin.py,sha256=4A3xa9gp2qv41jh6OqRL75q_iCpGa-GaKEExEW91X-s,1255
128
+ nat/data_models/temperature_mixin.py,sha256=HI6xJscmAv10JUmK8gk909mMifN1GW-qkswk9stC5ks,1492
129
+ nat/data_models/top_p_mixin.py,sha256=ld25YD4Xf70Odbl9qSHuaf391bmB5qgvON2LsuOfhlM,1463
130
130
  nat/data_models/ttc_strategy.py,sha256=tAkKWcyEBmBOOYtHMtQTgeCbHxFTk5SEkmFunNVnfyE,1114
131
131
  nat/embedder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
132
132
  nat/embedder/azure_openai_embedder.py,sha256=8OM54DCVCmWuQ8eZ5lXxsKYQktRmHMoIelRHZmKurUk,2381
@@ -135,7 +135,7 @@ nat/embedder/openai_embedder.py,sha256=uxyEe_2R02RPvEDGdzO9HUNaekm0MVRxJPCPB5-CS
135
135
  nat/embedder/register.py,sha256=TM_LKuSlJr3tEceNVuHfAx_yrCzf1sryD5Ycep5rNGo,883
136
136
  nat/eval/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
137
137
  nat/eval/config.py,sha256=de_DrLalzobILLJDTevGhbY_zjaoMyzMqVhDjbtpjk8,2365
138
- nat/eval/evaluate.py,sha256=vF1LmZM7ixQXqb4fj2kF6qI_8dQaknZpccfu-5OdXtY,26115
138
+ nat/eval/evaluate.py,sha256=a8rH1YEzS3Z7yzPWqWPDZpvR9C4ZHxoTXbc2lxPDmso,26551
139
139
  nat/eval/intermediate_step_adapter.py,sha256=E1AqH0hO_NNUT7nRGkIllitYJAszjjubul0IUwq-cyc,4498
140
140
  nat/eval/register.py,sha256=p8nM_P9yLWrglwZtTKAVjonAXfqzgBEwvYk1zlg4uw0,1044
141
141
  nat/eval/remote_workflow.py,sha256=IUsWHbAShBY5GgTsG9qTwK-SqRwrSNZM9CUfzBhlWo8,6012
@@ -144,10 +144,10 @@ nat/eval/usage_stats.py,sha256=_d_ErIvSKDN8wuvOyPn01kqM7zlFpwVxqOonaCV5XtY,1319
144
144
  nat/eval/dataset_handler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
145
145
  nat/eval/dataset_handler/dataset_downloader.py,sha256=w5Kjj7Dn2gQ14ekgt3rtGT8EPGJtXXXlckbtby9rJ2k,4553
146
146
  nat/eval/dataset_handler/dataset_filter.py,sha256=HrK0m3SQnPX6zOcKwJwYMrD9O-f6aOw8Vo3j9RKszqE,2115
147
- nat/eval/dataset_handler/dataset_handler.py,sha256=ii-y3RtHO7EZVajerSkW68UTH38nJbdou1Cuv1C9o4I,15233
147
+ nat/eval/dataset_handler/dataset_handler.py,sha256=my28rKvQiiRN_h2TJz6fdKeMOjP3LC3_e2aJNnPPYhE,18159
148
148
  nat/eval/evaluator/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
149
149
  nat/eval/evaluator/base_evaluator.py,sha256=pubxUm2DXkruSiYPkzezPj3-gzFEDAFGTYYTw7iXglU,3233
150
- nat/eval/evaluator/evaluator_model.py,sha256=Mjf6kVpL__zCoL2B7vRkXxdNXRdFC3bJ6giRPZPLEzI,1515
150
+ nat/eval/evaluator/evaluator_model.py,sha256=riGCcDW8YwC3Kd1yoVmbMdJE1Yf2kVmO8uhsGsKKJA4,1878
151
151
  nat/eval/rag_evaluator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
152
152
  nat/eval/rag_evaluator/evaluate.py,sha256=qvpa6FdQNQmChU6op1aELw9_Dabd81cSQ-HHQVOBjwc,8399
153
153
  nat/eval/rag_evaluator/register.py,sha256=3OZcQ0Q7K_2zxNTD7p6EY29Zv--qO4ExQkGdJszDKf0,5831
@@ -432,10 +432,10 @@ nat/utils/reactive/base/observer_base.py,sha256=6BiQfx26EMumotJ3KoVcdmFBYR_fnAss
432
432
  nat/utils/reactive/base/subject_base.py,sha256=UQOxlkZTIeeyYmG5qLtDpNf_63Y7p-doEeUA08_R8ME,2521
433
433
  nat/utils/settings/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
434
434
  nat/utils/settings/global_settings.py,sha256=4tVWrZMDAuH_D6pdyiWrE1yAs61PBY99ssCHkDcbs7g,7324
435
- nvidia_nat-1.3.0a20250823.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
436
- nvidia_nat-1.3.0a20250823.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
437
- nvidia_nat-1.3.0a20250823.dist-info/METADATA,sha256=jdTUL524mA1IjR55AW48aydzoct5x5oMepviLo32VhQ,21762
438
- nvidia_nat-1.3.0a20250823.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
439
- nvidia_nat-1.3.0a20250823.dist-info/entry_points.txt,sha256=FNh4pZVSe_61s29zdks66lmXBPtsnko8KSZ4ffv7WVE,653
440
- nvidia_nat-1.3.0a20250823.dist-info/top_level.txt,sha256=lgJWLkigiVZuZ_O1nxVnD_ziYBwgpE2OStdaCduMEGc,8
441
- nvidia_nat-1.3.0a20250823.dist-info/RECORD,,
435
+ nvidia_nat-1.3.0a20250826.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
436
+ nvidia_nat-1.3.0a20250826.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
437
+ nvidia_nat-1.3.0a20250826.dist-info/METADATA,sha256=bGIcgBdPnY_qUO-k11S-YYBNG_k5wIQ1b3HX-FeZVQo,21762
438
+ nvidia_nat-1.3.0a20250826.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
439
+ nvidia_nat-1.3.0a20250826.dist-info/entry_points.txt,sha256=FNh4pZVSe_61s29zdks66lmXBPtsnko8KSZ4ffv7WVE,653
440
+ nvidia_nat-1.3.0a20250826.dist-info/top_level.txt,sha256=lgJWLkigiVZuZ_O1nxVnD_ziYBwgpE2OStdaCduMEGc,8
441
+ nvidia_nat-1.3.0a20250826.dist-info/RECORD,,
@@ -1,125 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from collections.abc import Sequence
17
- from re import Pattern
18
- from typing import Generic
19
- from typing import TypeVar
20
-
21
- from pydantic import BaseModel
22
- from pydantic import model_validator
23
-
24
- T = TypeVar("T")
25
-
26
-
27
- class ModelGatedFieldMixin(Generic[T]):
28
- """
29
- A mixin that gates a field based on model support.
30
-
31
- This should be used to automatically validate a field based on a given model.
32
-
33
- Parameters
34
- ----------
35
- field_name: `str`
36
- The name of the field.
37
- default_if_supported: `T`
38
- The default value of the field if it is supported for the model.
39
- unsupported_models: `Sequence[Pattern[str]] | None`
40
- A sequence of regex patterns that match the model names NOT supported for the field.
41
- Defaults to None.
42
- supported_models: `Sequence[Pattern[str]] | None`
43
- A sequence of regex patterns that match the model names supported for the field.
44
- Defaults to None.
45
- model_keys: `Sequence[str]`
46
- A sequence of keys that are used to validate the field.
47
- Defaults to ("model_name", "model", "azure_deployment",)
48
- """
49
-
50
- def __init_subclass__(
51
- cls,
52
- field_name: str | None = None,
53
- default_if_supported: T | None = None,
54
- unsupported_models: Sequence[Pattern[str]] | None = None,
55
- supported_models: Sequence[Pattern[str]] | None = None,
56
- model_keys: Sequence[str] = ("model_name", "model", "azure_deployment"),
57
- ) -> None:
58
- """
59
- Store the class variables for the field and define the model validator.
60
- """
61
- super().__init_subclass__()
62
- if ModelGatedFieldMixin in cls.__bases__:
63
- if field_name is None:
64
- raise ValueError("field_name must be provided when subclassing ModelGatedFieldMixin")
65
- if default_if_supported is None:
66
- raise ValueError("default_if_supported must be provided when subclassing ModelGatedFieldMixin")
67
- if unsupported_models is None and supported_models is None:
68
- raise ValueError("Either unsupported_models or supported_models must be provided")
69
- if unsupported_models is not None and supported_models is not None:
70
- raise ValueError("Only one of unsupported_models or supported_models must be provided")
71
- if model_keys is not None and len(model_keys) == 0:
72
- raise ValueError("model_keys must be provided and non-empty when subclassing ModelGatedFieldMixin")
73
- cls.field_name = field_name
74
- cls.default_if_supported = default_if_supported
75
- cls.unsupported_models = unsupported_models
76
- cls.supported_models = supported_models
77
- if model_keys is not None:
78
- cls.model_keys = model_keys
79
-
80
- @classmethod
81
- def check_model(cls, model_name: str) -> bool:
82
- """
83
- Check if a model is supported for a given field.
84
-
85
- Args:
86
- model_name: The name of the model to check.
87
- """
88
- unsupported = getattr(cls, "unsupported_models", None)
89
- supported = getattr(cls, "supported_models", None)
90
- if unsupported is not None:
91
- return not any(p.search(model_name) for p in unsupported)
92
- if supported is not None:
93
- return any(p.search(model_name) for p in supported)
94
- return False
95
-
96
- cls._model_gated_field_check_model = check_model
97
-
98
- @classmethod
99
- def detect_support(cls, instance: BaseModel) -> str | None:
100
- for key in getattr(cls, "model_keys"):
101
- if hasattr(instance, key):
102
- model_name_value = getattr(instance, key)
103
- is_supported = getattr(cls, "_model_gated_field_check_model")(str(model_name_value))
104
- return key if not is_supported else None
105
- return None
106
-
107
- cls._model_gated_field_detect_support = detect_support
108
-
109
- @model_validator(mode="after")
110
- def model_validate(self):
111
- klass = self.__class__
112
-
113
- field_name_local = getattr(klass, "field_name")
114
- current_value = getattr(self, field_name_local, None)
115
-
116
- found_key = klass._model_gated_field_detect_support(self)
117
- if found_key is not None:
118
- if current_value is not None:
119
- raise ValueError(
120
- f"{field_name_local} is not supported for {found_key}: {getattr(self, found_key)}")
121
- elif current_value is None:
122
- setattr(self, field_name_local, getattr(klass, "default_if_supported", None))
123
- return self
124
-
125
- cls._model_gated_field_model_validator = model_validate