nvidia-nat 1.4.0a20251022__py3-none-any.whl → 1.4.0a20251024__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/data_models/api_server.py +2 -0
- nat/data_models/optimizable.py +89 -1
- nat/data_models/optimizer.py +12 -0
- nat/eval/rag_evaluator/evaluate.py +7 -4
- nat/front_ends/mcp/mcp_front_end_config.py +5 -2
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +4 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +27 -1
- nat/profiler/parameter_optimization/pareto_visualizer.py +31 -16
- nat/utils/__init__.py +72 -0
- {nvidia_nat-1.4.0a20251022.dist-info → nvidia_nat-1.4.0a20251024.dist-info}/METADATA +1 -1
- {nvidia_nat-1.4.0a20251022.dist-info → nvidia_nat-1.4.0a20251024.dist-info}/RECORD +16 -16
- {nvidia_nat-1.4.0a20251022.dist-info → nvidia_nat-1.4.0a20251024.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251022.dist-info → nvidia_nat-1.4.0a20251024.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.4.0a20251022.dist-info → nvidia_nat-1.4.0a20251024.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251022.dist-info → nvidia_nat-1.4.0a20251024.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251022.dist-info → nvidia_nat-1.4.0a20251024.dist-info}/top_level.txt +0 -0
nat/data_models/api_server.py
CHANGED
|
@@ -608,6 +608,8 @@ class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
|
608
608
|
type: typing.Literal[WebSocketMessageType.USER_INTERACTION_MESSAGE]
|
|
609
609
|
id: str = "default"
|
|
610
610
|
thread_id: str = "default"
|
|
611
|
+
parent_id: str = "default"
|
|
612
|
+
conversation_id: str | None = None
|
|
611
613
|
content: UserMessageContent
|
|
612
614
|
user: User = User()
|
|
613
615
|
security: Security = Security()
|
nat/data_models/optimizable.py
CHANGED
|
@@ -18,6 +18,7 @@ from typing import Any
|
|
|
18
18
|
from typing import Generic
|
|
19
19
|
from typing import TypeVar
|
|
20
20
|
|
|
21
|
+
import numpy as np
|
|
21
22
|
from optuna import Trial
|
|
22
23
|
from pydantic import BaseModel
|
|
23
24
|
from pydantic import ConfigDict
|
|
@@ -45,13 +46,34 @@ class SearchSpace(BaseModel, Generic[T]):
|
|
|
45
46
|
|
|
46
47
|
@model_validator(mode="after")
|
|
47
48
|
def validate_search_space_parameters(self):
|
|
48
|
-
"""Validate
|
|
49
|
+
"""Validate SearchSpace configuration."""
|
|
50
|
+
# 1. Prompt-specific validation
|
|
51
|
+
if self.is_prompt:
|
|
52
|
+
# When optimizing prompts, numeric parameters don't make sense
|
|
53
|
+
if self.low is not None or self.high is not None:
|
|
54
|
+
raise ValueError("SearchSpace with 'is_prompt=True' cannot have 'low' or 'high' parameters")
|
|
55
|
+
if self.log:
|
|
56
|
+
raise ValueError("SearchSpace with 'is_prompt=True' cannot have 'log=True'")
|
|
57
|
+
if self.step is not None:
|
|
58
|
+
raise ValueError("SearchSpace with 'is_prompt=True' cannot have 'step' parameter")
|
|
59
|
+
return self
|
|
60
|
+
|
|
61
|
+
# 2. Values-based validation
|
|
49
62
|
if self.values is not None:
|
|
50
63
|
# If values is provided, we don't need high/low
|
|
51
64
|
if self.high is not None or self.low is not None:
|
|
52
65
|
raise ValueError("SearchSpace 'values' is mutually exclusive with 'high' and 'low'")
|
|
66
|
+
# Ensure values is not empty
|
|
67
|
+
if len(self.values) == 0:
|
|
68
|
+
raise ValueError("SearchSpace 'values' must not be empty")
|
|
53
69
|
return self
|
|
54
70
|
|
|
71
|
+
# 3. Range-based validation
|
|
72
|
+
if (self.low is None) != (self.high is None): # XOR using !=
|
|
73
|
+
raise ValueError(f"SearchSpace range requires both 'low' and 'high'; got low={self.low}, high={self.high}")
|
|
74
|
+
if self.low is not None and self.high is not None and self.low >= self.high:
|
|
75
|
+
raise ValueError(f"SearchSpace 'low' must be less than 'high'; got low={self.low}, high={self.high}")
|
|
76
|
+
|
|
55
77
|
return self
|
|
56
78
|
|
|
57
79
|
# Helper for Optuna Trials
|
|
@@ -65,6 +87,72 @@ class SearchSpace(BaseModel, Generic[T]):
|
|
|
65
87
|
return trial.suggest_int(name, self.low, self.high, log=self.log, step=self.step)
|
|
66
88
|
return trial.suggest_float(name, self.low, self.high, log=self.log, step=self.step)
|
|
67
89
|
|
|
90
|
+
def to_grid_values(self) -> list[Any]:
|
|
91
|
+
"""
|
|
92
|
+
Convert SearchSpace to a list of values for GridSampler.
|
|
93
|
+
|
|
94
|
+
Grid search requires explicit values. This can be provided in two ways:
|
|
95
|
+
1. Explicit values: SearchSpace(values=[0.1, 0.5, 0.9])
|
|
96
|
+
2. Range with step: SearchSpace(low=0.1, high=0.9, step=0.2)
|
|
97
|
+
|
|
98
|
+
For ranges, step is required (no default will be applied) to avoid
|
|
99
|
+
unintentional combinatorial explosion.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
if self.is_prompt:
|
|
103
|
+
raise ValueError("Prompt optimization not currently supported using Optuna. "
|
|
104
|
+
"Use the genetic algorithm implementation instead.")
|
|
105
|
+
|
|
106
|
+
# Option 1: Explicit values provided
|
|
107
|
+
if self.values is not None:
|
|
108
|
+
return list(self.values)
|
|
109
|
+
|
|
110
|
+
# Option 2: Range with required step
|
|
111
|
+
if self.low is None or self.high is None:
|
|
112
|
+
raise ValueError("Grid search requires either 'values' or both 'low' and 'high' to be defined")
|
|
113
|
+
|
|
114
|
+
if self.step is None:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"Grid search with range (low={self.low}, high={self.high}) requires 'step' to be specified. "
|
|
117
|
+
"Please define the step size to discretize the range, for example: step=0.1")
|
|
118
|
+
|
|
119
|
+
# Validate step is positive
|
|
120
|
+
step_float = float(self.step)
|
|
121
|
+
if step_float <= 0:
|
|
122
|
+
raise ValueError(f"Grid search step must be positive; got step={self.step}")
|
|
123
|
+
|
|
124
|
+
# Generate grid values from range with step
|
|
125
|
+
# Use integer range only if low, high, and step are all integral
|
|
126
|
+
if (isinstance(self.low, int) and isinstance(self.high, int) and step_float.is_integer()):
|
|
127
|
+
step = int(step_float)
|
|
128
|
+
|
|
129
|
+
if self.log:
|
|
130
|
+
raise ValueError("Log scale is not supported for integer ranges in grid search. "
|
|
131
|
+
"Please use linear scale or provide explicit values.")
|
|
132
|
+
values = list(range(self.low, self.high + 1, step))
|
|
133
|
+
if values and values[-1] != self.high:
|
|
134
|
+
values.append(self.high)
|
|
135
|
+
return values
|
|
136
|
+
|
|
137
|
+
# Float range (including integer low/high with float step)
|
|
138
|
+
low_val = float(self.low)
|
|
139
|
+
high_val = float(self.high)
|
|
140
|
+
step_val = step_float
|
|
141
|
+
|
|
142
|
+
if self.log:
|
|
143
|
+
raise ValueError("Log scale is not yet supported for grid search with ranges. "
|
|
144
|
+
"Please provide explicit values using the 'values' field.")
|
|
145
|
+
|
|
146
|
+
# Use arange to respect step size
|
|
147
|
+
values = np.arange(low_val, high_val, step_val).tolist()
|
|
148
|
+
|
|
149
|
+
# Always include the high endpoint if not already present (within tolerance)
|
|
150
|
+
# This ensures the full range is explored in grid search
|
|
151
|
+
if not values or abs(values[-1] - high_val) > 1e-9:
|
|
152
|
+
values.append(high_val)
|
|
153
|
+
|
|
154
|
+
return values
|
|
155
|
+
|
|
68
156
|
|
|
69
157
|
def OptimizableField(
|
|
70
158
|
default: Any = PydanticUndefined,
|
nat/data_models/optimizer.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
from enum import Enum
|
|
16
17
|
from pathlib import Path
|
|
17
18
|
|
|
18
19
|
from pydantic import BaseModel
|
|
@@ -28,12 +29,23 @@ class OptimizerMetric(BaseModel):
|
|
|
28
29
|
weight: float = Field(description="Weight of the metric in the optimization process.", default=1.0)
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
class SamplerType(str, Enum):
|
|
33
|
+
BAYESIAN = "bayesian"
|
|
34
|
+
GRID = "grid"
|
|
35
|
+
|
|
36
|
+
|
|
31
37
|
class NumericOptimizationConfig(BaseModel):
|
|
32
38
|
"""
|
|
33
39
|
Configuration for numeric/enum optimization (Optuna).
|
|
34
40
|
"""
|
|
35
41
|
enabled: bool = Field(default=True, description="Enable numeric optimization")
|
|
36
42
|
n_trials: int = Field(description="Number of trials for numeric optimization.", default=20)
|
|
43
|
+
sampler: SamplerType | None = Field(
|
|
44
|
+
default=None,
|
|
45
|
+
description="Sampling strategy for numeric optimization. Options: None or 'bayesian' uses \
|
|
46
|
+
the Optuna default (TPE for single-objective, NSGA-II for multi-objective) or 'grid' performs \
|
|
47
|
+
exhaustive grid search over parameter combinations. Defaults to None.",
|
|
48
|
+
)
|
|
37
49
|
|
|
38
50
|
|
|
39
51
|
class PromptGAOptimizationConfig(BaseModel):
|
|
@@ -116,11 +116,14 @@ class RAGEvaluator:
|
|
|
116
116
|
"""Convert NaN or None to 0.0 for safe arithmetic/serialization."""
|
|
117
117
|
return 0.0 if v is None or (isinstance(v, float) and math.isnan(v)) else v
|
|
118
118
|
|
|
119
|
-
#
|
|
119
|
+
# Keep original scores (preserving NaN/None) for output
|
|
120
|
+
original_scores_dict = {metric: [score.get(metric) for score in scores] for metric in scores[0]}
|
|
121
|
+
|
|
122
|
+
# Convert from list of dicts to dict of lists, coercing NaN/None to 0.0 for average calculation
|
|
120
123
|
scores_dict = {metric: [_nan_to_zero(score.get(metric)) for score in scores] for metric in scores[0]}
|
|
121
124
|
first_metric_name = list(scores_dict.keys())[0] if scores_dict else None
|
|
122
125
|
|
|
123
|
-
# Compute the average of each metric
|
|
126
|
+
# Compute the average of each metric using cleaned scores (NaN/None -> 0.0)
|
|
124
127
|
average_scores = {
|
|
125
128
|
metric: (sum(values) / len(values) if values else 0.0)
|
|
126
129
|
for metric, values in scores_dict.items()
|
|
@@ -137,11 +140,11 @@ class RAGEvaluator:
|
|
|
137
140
|
else:
|
|
138
141
|
ids = df["user_input"].tolist() # Use "user_input" as ID fallback
|
|
139
142
|
|
|
140
|
-
# Construct EvalOutputItem list
|
|
143
|
+
# Construct EvalOutputItem list using original scores (preserving NaN/None)
|
|
141
144
|
eval_output_items = [
|
|
142
145
|
EvalOutputItem(
|
|
143
146
|
id=ids[i],
|
|
144
|
-
score=
|
|
147
|
+
score=original_scores_dict[first_metric_name][i] if first_metric_name else None,
|
|
145
148
|
reasoning={
|
|
146
149
|
key:
|
|
147
150
|
getattr(row, key, None) # Use getattr to safely access attributes
|
|
@@ -37,8 +37,11 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
|
37
37
|
port: int = Field(default=9901, description="Port to bind the server to (default: 9901)", ge=0, le=65535)
|
|
38
38
|
debug: bool = Field(default=False, description="Enable debug mode (default: False)")
|
|
39
39
|
log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
|
|
40
|
-
tool_names: list[str] = Field(
|
|
41
|
-
|
|
40
|
+
tool_names: list[str] = Field(
|
|
41
|
+
default_factory=list,
|
|
42
|
+
description="The list of tools MCP server will expose (default: all tools)."
|
|
43
|
+
"Tool names can be functions or function groups",
|
|
44
|
+
)
|
|
42
45
|
transport: Literal["sse", "streamable-http"] = Field(
|
|
43
46
|
default="streamable-http",
|
|
44
47
|
description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
|
|
@@ -251,6 +251,10 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
|
251
251
|
filtered_functions: dict[str, Function] = {}
|
|
252
252
|
for function_name, function in functions.items():
|
|
253
253
|
if function_name in self.front_end_config.tool_names:
|
|
254
|
+
# Treat current tool_names as function names, so check if the function name is in the list
|
|
255
|
+
filtered_functions[function_name] = function
|
|
256
|
+
elif any(function_name.startswith(f"{group_name}.") for group_name in self.front_end_config.tool_names):
|
|
257
|
+
# Treat tool_names as function group names, so check if the function name starts with the group name
|
|
254
258
|
filtered_functions[function_name] = function
|
|
255
259
|
else:
|
|
256
260
|
logger.debug("Skipping function %s as it's not in tool_names", function_name)
|
|
@@ -24,6 +24,7 @@ from nat.data_models.config import Config
|
|
|
24
24
|
from nat.data_models.optimizable import SearchSpace
|
|
25
25
|
from nat.data_models.optimizer import OptimizerConfig
|
|
26
26
|
from nat.data_models.optimizer import OptimizerRunConfig
|
|
27
|
+
from nat.data_models.optimizer import SamplerType
|
|
27
28
|
from nat.eval.evaluate import EvaluationRun
|
|
28
29
|
from nat.eval.evaluate import EvaluationRunConfig
|
|
29
30
|
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
@@ -59,7 +60,21 @@ def optimize_parameters(
|
|
|
59
60
|
eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
|
|
60
61
|
weights = [v.weight for v in metric_cfg.values()]
|
|
61
62
|
|
|
62
|
-
|
|
63
|
+
# Create appropriate sampler based on configuration
|
|
64
|
+
sampler_type = optimizer_config.numeric.sampler
|
|
65
|
+
|
|
66
|
+
if sampler_type == SamplerType.GRID:
|
|
67
|
+
# For grid search, convert the existing space to value sequences
|
|
68
|
+
grid_search_space = {param_name: search_space.to_grid_values() for param_name, search_space in space.items()}
|
|
69
|
+
sampler = optuna.samplers.GridSampler(grid_search_space)
|
|
70
|
+
logger.info("Using Grid sampler for numeric optimization")
|
|
71
|
+
else:
|
|
72
|
+
# None or BAYESIAN: let Optuna choose defaults
|
|
73
|
+
sampler = None
|
|
74
|
+
logger.info(
|
|
75
|
+
"Using Optuna default sampler types: TPESampler for single-objective, NSGAIISampler for multi-objective")
|
|
76
|
+
|
|
77
|
+
study = optuna.create_study(directions=directions, sampler=sampler)
|
|
63
78
|
|
|
64
79
|
# Create output directory for intermediate files
|
|
65
80
|
out_dir = optimizer_config.output_path
|
|
@@ -121,6 +136,17 @@ def optimize_parameters(
|
|
|
121
136
|
with (out_dir / "trials_dataframe_params.csv").open("w") as fh:
|
|
122
137
|
# Export full trials DataFrame (values, params, timings, etc.).
|
|
123
138
|
df = study.trials_dataframe()
|
|
139
|
+
|
|
140
|
+
# Rename values_X columns to actual metric names
|
|
141
|
+
metric_names = list(metric_cfg.keys())
|
|
142
|
+
rename_mapping = {}
|
|
143
|
+
for i, metric_name in enumerate(metric_names):
|
|
144
|
+
old_col = f"values_{i}"
|
|
145
|
+
if old_col in df.columns:
|
|
146
|
+
rename_mapping[old_col] = f"values_{metric_name}"
|
|
147
|
+
if rename_mapping:
|
|
148
|
+
df = df.rename(columns=rename_mapping)
|
|
149
|
+
|
|
124
150
|
# Normalise rep_scores column naming for convenience.
|
|
125
151
|
if "user_attrs_rep_scores" in df.columns and "rep_scores" not in df.columns:
|
|
126
152
|
df = df.rename(columns={"user_attrs_rep_scores": "rep_scores"})
|
|
@@ -46,9 +46,13 @@ class ParetoVisualizer:
|
|
|
46
46
|
|
|
47
47
|
fig, ax = plt.subplots(figsize=figsize)
|
|
48
48
|
|
|
49
|
-
# Extract metric values
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
# Extract metric values - support both old (values_0) and new (values_metricname) formats
|
|
50
|
+
x_col = f"values_{self.metric_names[0]}" \
|
|
51
|
+
if f"values_{self.metric_names[0]}" in trials_df.columns else f"values_{0}"
|
|
52
|
+
y_col = f"values_{self.metric_names[1]}"\
|
|
53
|
+
if f"values_{self.metric_names[1]}" in trials_df.columns else f"values_{1}"
|
|
54
|
+
x_vals = trials_df[x_col].values
|
|
55
|
+
y_vals = trials_df[y_col].values
|
|
52
56
|
|
|
53
57
|
# Plot all trials
|
|
54
58
|
ax.scatter(x_vals,
|
|
@@ -62,8 +66,8 @@ class ParetoVisualizer:
|
|
|
62
66
|
|
|
63
67
|
# Plot Pareto optimal trials if provided
|
|
64
68
|
if pareto_trials_df is not None and not pareto_trials_df.empty:
|
|
65
|
-
pareto_x = pareto_trials_df[
|
|
66
|
-
pareto_y = pareto_trials_df[
|
|
69
|
+
pareto_x = pareto_trials_df[x_col].values
|
|
70
|
+
pareto_y = pareto_trials_df[y_col].values
|
|
67
71
|
|
|
68
72
|
ax.scatter(pareto_x,
|
|
69
73
|
pareto_y,
|
|
@@ -98,8 +102,8 @@ class ParetoVisualizer:
|
|
|
98
102
|
ax.grid(True, alpha=0.3)
|
|
99
103
|
|
|
100
104
|
# Add direction annotations
|
|
101
|
-
x_annotation = (f"Better {self.metric_names[0]}
|
|
102
|
-
if self.directions[0] == "minimize" else f"
|
|
105
|
+
x_annotation = (f"Better {self.metric_names[0]} ←"
|
|
106
|
+
if self.directions[0] == "minimize" else f"→ Better {self.metric_names[0]}")
|
|
103
107
|
ax.annotate(x_annotation,
|
|
104
108
|
xy=(0.02, 0.98),
|
|
105
109
|
xycoords='axes fraction',
|
|
@@ -109,8 +113,8 @@ class ParetoVisualizer:
|
|
|
109
113
|
style='italic',
|
|
110
114
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.7))
|
|
111
115
|
|
|
112
|
-
y_annotation = (f"Better {self.metric_names[1]}
|
|
113
|
-
if self.directions[1] == "minimize" else f"Better {self.metric_names[1]}
|
|
116
|
+
y_annotation = (f"Better {self.metric_names[1]} ↓"
|
|
117
|
+
if self.directions[1] == "minimize" else f"Better {self.metric_names[1]} ↑")
|
|
114
118
|
ax.annotate(y_annotation,
|
|
115
119
|
xy=(0.02, 0.02),
|
|
116
120
|
xycoords='axes fraction',
|
|
@@ -145,7 +149,10 @@ class ParetoVisualizer:
|
|
|
145
149
|
# Normalize values for better visualization
|
|
146
150
|
all_values = []
|
|
147
151
|
for i in range(n_metrics):
|
|
148
|
-
|
|
152
|
+
# Support both old (values_0) and new (values_metricname) formats
|
|
153
|
+
col_name = f"values_{self.metric_names[i]}"\
|
|
154
|
+
if f"values_{self.metric_names[i]}" in trials_df.columns else f"values_{i}"
|
|
155
|
+
all_values.append(trials_df[col_name].values)
|
|
149
156
|
|
|
150
157
|
# Normalize each metric to [0, 1] for parallel coordinates
|
|
151
158
|
normalized_values = []
|
|
@@ -221,23 +228,31 @@ class ParetoVisualizer:
|
|
|
221
228
|
|
|
222
229
|
if i == j:
|
|
223
230
|
# Diagonal: histograms
|
|
224
|
-
|
|
231
|
+
# Support both old (values_0) and new (values_metricname) formats
|
|
232
|
+
col_name = f"values_{self.metric_names[i]}"\
|
|
233
|
+
if f"values_{self.metric_names[i]}" in trials_df.columns else f"values_{i}"
|
|
234
|
+
values = trials_df[col_name].values
|
|
225
235
|
ax.hist(values, bins=20, alpha=0.7, color='lightblue', edgecolor='navy')
|
|
226
236
|
if pareto_trials_df is not None and not pareto_trials_df.empty:
|
|
227
|
-
pareto_values = pareto_trials_df[
|
|
237
|
+
pareto_values = pareto_trials_df[col_name].values
|
|
228
238
|
ax.hist(pareto_values, bins=20, alpha=0.8, color='red', edgecolor='darkred')
|
|
229
239
|
ax.set_xlabel(f"{self.metric_names[i]}")
|
|
230
240
|
ax.set_ylabel("Frequency")
|
|
231
241
|
else:
|
|
232
242
|
# Off-diagonal: scatter plots
|
|
233
|
-
|
|
234
|
-
|
|
243
|
+
# Support both old (values_0) and new (values_metricname) formats
|
|
244
|
+
x_col = f"values_{self.metric_names[j]}"\
|
|
245
|
+
if f"values_{self.metric_names[j]}" in trials_df.columns else f"values_{j}"
|
|
246
|
+
y_col = f"values_{self.metric_names[i]}"\
|
|
247
|
+
if f"values_{self.metric_names[i]}" in trials_df.columns else f"values_{i}"
|
|
248
|
+
x_vals = trials_df[x_col].values
|
|
249
|
+
y_vals = trials_df[y_col].values
|
|
235
250
|
|
|
236
251
|
ax.scatter(x_vals, y_vals, alpha=0.6, s=30, c='lightblue', edgecolors='navy', linewidths=0.5)
|
|
237
252
|
|
|
238
253
|
if pareto_trials_df is not None and not pareto_trials_df.empty:
|
|
239
|
-
pareto_x = pareto_trials_df[
|
|
240
|
-
pareto_y = pareto_trials_df[
|
|
254
|
+
pareto_x = pareto_trials_df[x_col].values
|
|
255
|
+
pareto_y = pareto_trials_df[y_col].values
|
|
241
256
|
ax.scatter(pareto_x,
|
|
242
257
|
pareto_y,
|
|
243
258
|
alpha=0.9,
|
nat/utils/__init__.py
CHANGED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
if typing.TYPE_CHECKING:
|
|
20
|
+
|
|
21
|
+
from nat.data_models.config import Config
|
|
22
|
+
|
|
23
|
+
from .type_utils import StrPath
|
|
24
|
+
|
|
25
|
+
_T = typing.TypeVar("_T")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def run_workflow(*,
|
|
29
|
+
config: "Config | None" = None,
|
|
30
|
+
config_file: "StrPath | None" = None,
|
|
31
|
+
prompt: str,
|
|
32
|
+
to_type: type[_T] = str) -> _T:
|
|
33
|
+
"""
|
|
34
|
+
Wrapper to run a workflow given either a config or a config file path and a prompt, returning the result in the
|
|
35
|
+
type specified by the `to_type`.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
config : Config | None
|
|
40
|
+
The configuration object to use for the workflow. If None, config_file must be provided.
|
|
41
|
+
config_file : StrPath | None
|
|
42
|
+
The path to the configuration file. If None, config must be provided. Can be either a str or a Path object.
|
|
43
|
+
prompt : str
|
|
44
|
+
The prompt to run the workflow with.
|
|
45
|
+
to_type : type[_T]
|
|
46
|
+
The type to convert the result to. Default is str.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
_T
|
|
51
|
+
The result of the workflow converted to the specified type.
|
|
52
|
+
"""
|
|
53
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
54
|
+
from nat.runtime.loader import load_config
|
|
55
|
+
from nat.runtime.session import SessionManager
|
|
56
|
+
|
|
57
|
+
if config is not None and config_file is not None:
|
|
58
|
+
raise ValueError("Only one of config or config_file should be provided")
|
|
59
|
+
|
|
60
|
+
if config is None:
|
|
61
|
+
if config_file is None:
|
|
62
|
+
raise ValueError("Either config_file or config must be provided")
|
|
63
|
+
|
|
64
|
+
if not Path(config_file).exists():
|
|
65
|
+
raise ValueError(f"Config file {config_file} does not exist")
|
|
66
|
+
|
|
67
|
+
config = load_config(config_file)
|
|
68
|
+
|
|
69
|
+
async with WorkflowBuilder.from_config(config=config) as workflow_builder:
|
|
70
|
+
workflow = SessionManager(await workflow_builder.build())
|
|
71
|
+
async with workflow.run(prompt) as runner:
|
|
72
|
+
return await runner.result(to_type=to_type)
|
|
@@ -114,7 +114,7 @@ nat/control_flow/router_agent/prompt.py,sha256=fIAiNsAs1zXRAatButR76zSpHJNxSkXXK
|
|
|
114
114
|
nat/control_flow/router_agent/register.py,sha256=4RGmS9sy-QtIMmvh8mfMcR1VqxFPLpG4RckWCIExh40,4144
|
|
115
115
|
nat/data_models/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
|
|
116
116
|
nat/data_models/agent.py,sha256=IwDyb9Zc3R4Zd5rFeqt7q0EQswczAl5focxV9KozIzs,1625
|
|
117
|
-
nat/data_models/api_server.py,sha256=
|
|
117
|
+
nat/data_models/api_server.py,sha256=oQtSiP7jpkHIZ75g21A_lTiidNsQo54pq3qy2StIJcs,30652
|
|
118
118
|
nat/data_models/authentication.py,sha256=XPu9W8nh4XRSuxPv3HxO-FMQ_JtTEoK6Y02JwnzDwTg,8457
|
|
119
119
|
nat/data_models/common.py,sha256=nXXfGrjpxebzBUa55mLdmzePLt7VFHvTAc6Znj3yEv0,5875
|
|
120
120
|
nat/data_models/component.py,sha256=b_hXOA8Gm5UNvlFkAhsR6kEvf33ST50MKtr5kWf75Ao,1894
|
|
@@ -137,8 +137,8 @@ nat/data_models/logging.py,sha256=1QtVjIQ99PgMYUuzw4h1FAoPRteZY7uf3oFTqV3ONgA,94
|
|
|
137
137
|
nat/data_models/memory.py,sha256=IKwe7CflCto30j4yI5yQtq8DXfMilAJ17S5NcsSDrOQ,1052
|
|
138
138
|
nat/data_models/object_store.py,sha256=S8YY6i8ALgRPuggUI1FCG-xbvwPWuaCg1lJnZOx5scM,1515
|
|
139
139
|
nat/data_models/openai_mcp.py,sha256=UkAalZE0my8a_sq-GynjsfDoSOw2NWLNZM9hcV23TzY,1911
|
|
140
|
-
nat/data_models/optimizable.py,sha256=
|
|
141
|
-
nat/data_models/optimizer.py,sha256=
|
|
140
|
+
nat/data_models/optimizable.py,sha256=dG9YGM6MwAReLXimk31CzzOlbknGwsk0znfAiDuOeuI,8981
|
|
141
|
+
nat/data_models/optimizer.py,sha256=kKj74zYBjoTV51owkIHqfPefzXFY2OwB93OCORwJCOQ,5750
|
|
142
142
|
nat/data_models/profiler.py,sha256=z3IlEhj-veB4Yz85271bTkScSUkVwK50tR3dwlDRgcE,1781
|
|
143
143
|
nat/data_models/registry_handler.py,sha256=g1rFaz4uSydMJn7qpdX-DNHJd_rNf8tXYN49dLDYHPo,968
|
|
144
144
|
nat/data_models/retriever.py,sha256=IJAIaeEXM8zj_towrvZ30Uoxt8e4WvOXrQwqGloS1qI,1202
|
|
@@ -173,7 +173,7 @@ nat/eval/evaluator/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQ
|
|
|
173
173
|
nat/eval/evaluator/base_evaluator.py,sha256=5WaVGhCGzkynCJyQdxRv7CtqLoUpr6B4O8tilP_gb3g,3232
|
|
174
174
|
nat/eval/evaluator/evaluator_model.py,sha256=riGCcDW8YwC3Kd1yoVmbMdJE1Yf2kVmO8uhsGsKKJA4,1878
|
|
175
175
|
nat/eval/rag_evaluator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
176
|
-
nat/eval/rag_evaluator/evaluate.py,sha256=
|
|
176
|
+
nat/eval/rag_evaluator/evaluate.py,sha256=IfCpfCKBTYhReRkPPbOqyr-9H6gsPGaeFWBIcGDUynw,8639
|
|
177
177
|
nat/eval/rag_evaluator/register.py,sha256=AzT5uICDU5dEo7scvStmOWC7ac-S0Tx4UY87idGtXIs,5835
|
|
178
178
|
nat/eval/runners/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
|
179
179
|
nat/eval/runners/config.py,sha256=bRPai_th02OJrFepbbY6w-t7A18TBXozQUnnnH9iWIU,1403
|
|
@@ -262,9 +262,9 @@ nat/front_ends/fastapi/html_snippets/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv
|
|
|
262
262
|
nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py,sha256=BNpWwzmA58UM0GK4kZXG4PHJy_5K9ihaVHu8SgCs5JA,1131
|
|
263
263
|
nat/front_ends/mcp/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
|
|
264
264
|
nat/front_ends/mcp/introspection_token_verifier.py,sha256=s7Q4Q6rWZJ0ZVujSxxpvVI6Bnhkg1LJQ3RLkvhzFIGE,2836
|
|
265
|
-
nat/front_ends/mcp/mcp_front_end_config.py,sha256=
|
|
265
|
+
nat/front_ends/mcp/mcp_front_end_config.py,sha256=dnNsf487XZtoipU3DcmCAZ9eqtyF5a2p_1huQ_4uwPI,4919
|
|
266
266
|
nat/front_ends/mcp/mcp_front_end_plugin.py,sha256=4u_kpen_T-_Uh62V5M7dfW9KyzbqXI7tGBG4AxJXWm0,5231
|
|
267
|
-
nat/front_ends/mcp/mcp_front_end_plugin_worker.py,sha256=
|
|
267
|
+
nat/front_ends/mcp/mcp_front_end_plugin_worker.py,sha256=IuwCcrBYmN6hyVra5vnwVVEjCuj2YO9ZUs-mhJnNNSQ,11626
|
|
268
268
|
nat/front_ends/mcp/memory_profiler.py,sha256=OpcpLBAGCdQwYSFZbtAqdfncrnGYVjDcMpWydB71hjY,12811
|
|
269
269
|
nat/front_ends/mcp/register.py,sha256=3aJtgG5VaiqujoeU1-Eq7Hl5pWslIlIwGFU2ASLTXgM,1173
|
|
270
270
|
nat/front_ends/mcp/tool_converter.py,sha256=14NweQN3cPFBw7ZNiGyUHO4VhMGHrtfLGgvu4_H38oU,12426
|
|
@@ -371,9 +371,9 @@ nat/profiler/inference_optimization/experimental/prefix_span_analysis.py,sha256=
|
|
|
371
371
|
nat/profiler/parameter_optimization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
372
372
|
nat/profiler/parameter_optimization/optimizable_utils.py,sha256=93Pl8A14Zq_f3XsxSH-yFnEJ6B7W5hp7doPnPoLlRB4,3714
|
|
373
373
|
nat/profiler/parameter_optimization/optimizer_runtime.py,sha256=rXmCOq81o7ZorQOUYociVjuO3NO9CIjFBbwql2u_4H4,2715
|
|
374
|
-
nat/profiler/parameter_optimization/parameter_optimizer.py,sha256=
|
|
374
|
+
nat/profiler/parameter_optimization/parameter_optimizer.py,sha256=vxUvso4RnSwoUF5rJkJJGaIOJojebrvYcA79WA0ZP7c,7719
|
|
375
375
|
nat/profiler/parameter_optimization/parameter_selection.py,sha256=pfnNQIx1evNICgChsOJXIFQHoL1R_kmh_vNDsVMC9kg,3982
|
|
376
|
-
nat/profiler/parameter_optimization/pareto_visualizer.py,sha256=
|
|
376
|
+
nat/profiler/parameter_optimization/pareto_visualizer.py,sha256=QclLZmmsWINIAh4n0XAKmnIZOqGHTMr-iggZS0kxj-Y,17055
|
|
377
377
|
nat/profiler/parameter_optimization/prompt_optimizer.py,sha256=_AmdeB1jRamd93qR5UqRy5LweYR3bjnD7zoLxzXYE0k,17658
|
|
378
378
|
nat/profiler/parameter_optimization/update_helpers.py,sha256=NxWhrGVchbjws85pPd-jS-C14_l70QvVSvEfENndVcY,2339
|
|
379
379
|
nat/registry_handlers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -441,7 +441,7 @@ nat/tool/memory_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3h
|
|
|
441
441
|
nat/tool/memory_tools/add_memory_tool.py,sha256=N400PPvI37NUCMh5KcuoAL8khK8ecUQyfenahfjzbHQ,3368
|
|
442
442
|
nat/tool/memory_tools/delete_memory_tool.py,sha256=zMllkpC0of9qFPNuG9vkVOoydRblOViCQf0uSbqz0sE,2461
|
|
443
443
|
nat/tool/memory_tools/get_memory_tool.py,sha256=fcW6QE7bMZrpNK62et3sTw_QZ8cV9lXfEuDsm1-05bE,2768
|
|
444
|
-
nat/utils/__init__.py,sha256=
|
|
444
|
+
nat/utils/__init__.py,sha256=WRO5RDryn6mChA2gwDVERclaei7LGPwrEMCUSliZI7k,2653
|
|
445
445
|
nat/utils/callable_utils.py,sha256=EIao6NhHRFEoBqYRC7aWoFqhlr2LeFT0XK-ac0coF9E,2475
|
|
446
446
|
nat/utils/debugging_utils.py,sha256=6M4JhbHDNDnfmSRGmHvT5IgEeWSHBore3VngdE_PMqc,1332
|
|
447
447
|
nat/utils/decorators.py,sha256=AoMip9zmqrZm5wovZQytNvzFfIlS3PQxSYcgYeoLhxA,8240
|
|
@@ -475,10 +475,10 @@ nat/utils/reactive/base/observer_base.py,sha256=6BiQfx26EMumotJ3KoVcdmFBYR_fnAss
|
|
|
475
475
|
nat/utils/reactive/base/subject_base.py,sha256=UQOxlkZTIeeyYmG5qLtDpNf_63Y7p-doEeUA08_R8ME,2521
|
|
476
476
|
nat/utils/settings/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
477
477
|
nat/utils/settings/global_settings.py,sha256=9JaO6pxKT_Pjw6rxJRsRlFCXdVKCl_xUKU2QHZQWWNM,7294
|
|
478
|
-
nvidia_nat-1.4.
|
|
479
|
-
nvidia_nat-1.4.
|
|
480
|
-
nvidia_nat-1.4.
|
|
481
|
-
nvidia_nat-1.4.
|
|
482
|
-
nvidia_nat-1.4.
|
|
483
|
-
nvidia_nat-1.4.
|
|
484
|
-
nvidia_nat-1.4.
|
|
478
|
+
nvidia_nat-1.4.0a20251024.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
|
|
479
|
+
nvidia_nat-1.4.0a20251024.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
480
|
+
nvidia_nat-1.4.0a20251024.dist-info/METADATA,sha256=GpP4M0PgVcHAhPCR9BR5TN7UpPUHZ9KqZbLv3TiVuMk,10248
|
|
481
|
+
nvidia_nat-1.4.0a20251024.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
482
|
+
nvidia_nat-1.4.0a20251024.dist-info/entry_points.txt,sha256=4jCqjyETMpyoWbCBf4GalZU8I_wbstpzwQNezdAVbbo,698
|
|
483
|
+
nvidia_nat-1.4.0a20251024.dist-info/top_level.txt,sha256=lgJWLkigiVZuZ_O1nxVnD_ziYBwgpE2OStdaCduMEGc,8
|
|
484
|
+
nvidia_nat-1.4.0a20251024.dist-info/RECORD,,
|
|
File without changes
|
{nvidia_nat-1.4.0a20251022.dist-info → nvidia_nat-1.4.0a20251024.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
|
File without changes
|
{nvidia_nat-1.4.0a20251022.dist-info → nvidia_nat-1.4.0a20251024.dist-info}/licenses/LICENSE.md
RENAMED
|
File without changes
|
|
File without changes
|