nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +15 -5
- nat/agent/reasoning_agent/reasoning_agent.py +6 -1
- nat/agent/register.py +2 -0
- nat/agent/rewoo_agent/agent.py +4 -2
- nat/agent/rewoo_agent/register.py +8 -3
- nat/agent/router_agent/__init__.py +0 -0
- nat/agent/router_agent/agent.py +329 -0
- nat/agent/router_agent/prompt.py +48 -0
- nat/agent/router_agent/register.py +97 -0
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +11 -3
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/function.py +167 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +213 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -7
- nat/cli/entrypoint.py +2 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +43 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
- nat/llm/aws_bedrock_llm.py +14 -3
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +401 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +4 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +91 -71
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,384 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import random
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from pydantic import BaseModel
|
|
25
|
+
|
|
26
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
27
|
+
from nat.data_models.config import Config
|
|
28
|
+
from nat.data_models.optimizable import SearchSpace
|
|
29
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
30
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
31
|
+
from nat.eval.evaluate import EvaluationRun
|
|
32
|
+
from nat.eval.evaluate import EvaluationRunConfig
|
|
33
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
34
|
+
from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PromptOptimizerInputSchema(BaseModel):
|
|
40
|
+
original_prompt: str
|
|
41
|
+
objective: str
|
|
42
|
+
oracle_feedback: str | None = None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@experimental(feature_name="Optimizer")
|
|
46
|
+
async def optimize_prompts(
|
|
47
|
+
*,
|
|
48
|
+
base_cfg: Config,
|
|
49
|
+
full_space: dict[str, SearchSpace],
|
|
50
|
+
optimizer_config: OptimizerConfig,
|
|
51
|
+
opt_run_config: OptimizerRunConfig,
|
|
52
|
+
) -> None:
|
|
53
|
+
|
|
54
|
+
# ------------- helpers ------------- #
|
|
55
|
+
@dataclass
|
|
56
|
+
class Individual:
|
|
57
|
+
prompts: dict[str, str] # param_name -> prompt text
|
|
58
|
+
metrics: dict[str, float] | None = None # evaluator_name -> average score
|
|
59
|
+
scalar_fitness: float | None = None
|
|
60
|
+
|
|
61
|
+
def _normalize_generation(
|
|
62
|
+
individuals: Sequence[Individual],
|
|
63
|
+
metric_names: Sequence[str],
|
|
64
|
+
directions: Sequence[str],
|
|
65
|
+
eps: float = 1e-12,
|
|
66
|
+
) -> list[dict[str, float]]:
|
|
67
|
+
"""Return per-individual dict of normalised scores in [0,1] where higher is better."""
|
|
68
|
+
# Extract arrays per metric
|
|
69
|
+
arrays = {m: [ind.metrics.get(m, 0.0) if ind.metrics else 0.0 for ind in individuals] for m in metric_names}
|
|
70
|
+
normed: list[dict[str, float]] = []
|
|
71
|
+
for i in range(len(individuals)):
|
|
72
|
+
entry: dict[str, float] = {}
|
|
73
|
+
for m, dirn in zip(metric_names, directions):
|
|
74
|
+
vals = arrays[m]
|
|
75
|
+
vmin = min(vals)
|
|
76
|
+
vmax = max(vals)
|
|
77
|
+
v = vals[i]
|
|
78
|
+
# Map to [0,1] with higher=better regardless of direction
|
|
79
|
+
if vmax - vmin < eps:
|
|
80
|
+
score01 = 0.5
|
|
81
|
+
else:
|
|
82
|
+
score01 = (v - vmin) / (vmax - vmin)
|
|
83
|
+
if dirn == "minimize":
|
|
84
|
+
score01 = 1.0 - score01
|
|
85
|
+
entry[m] = float(score01)
|
|
86
|
+
normed.append(entry)
|
|
87
|
+
return normed
|
|
88
|
+
|
|
89
|
+
def _scalarize(norm_scores: dict[str, float], *, mode: str, weights: Sequence[float] | None) -> float:
|
|
90
|
+
"""Collapse normalised scores to a single scalar (higher is better)."""
|
|
91
|
+
vals = list(norm_scores.values())
|
|
92
|
+
if not vals:
|
|
93
|
+
return 0.0
|
|
94
|
+
if mode == "harmonic":
|
|
95
|
+
inv_sum = sum(1.0 / max(v, 1e-12) for v in vals)
|
|
96
|
+
return len(vals) / max(inv_sum, 1e-12)
|
|
97
|
+
if mode == "sum":
|
|
98
|
+
if weights is None:
|
|
99
|
+
return float(sum(vals))
|
|
100
|
+
if len(weights) != len(vals):
|
|
101
|
+
raise ValueError("weights length must equal number of objectives")
|
|
102
|
+
return float(sum(w * v for w, v in zip(weights, vals)))
|
|
103
|
+
if mode == "chebyshev":
|
|
104
|
+
return float(min(vals)) # maximise the worst-case score
|
|
105
|
+
raise ValueError(f"Unknown combination mode: {mode}")
|
|
106
|
+
|
|
107
|
+
def _apply_diversity_penalty(individuals: Sequence[Individual], diversity_lambda: float) -> list[float]:
|
|
108
|
+
if diversity_lambda <= 0.0:
|
|
109
|
+
return [0.0 for _ in individuals]
|
|
110
|
+
seen: dict[str, int] = {}
|
|
111
|
+
keys: list[str] = []
|
|
112
|
+
penalties: list[float] = []
|
|
113
|
+
for ind in individuals:
|
|
114
|
+
key = "\u241f".join(ind.prompts.get(k, "") for k in sorted(ind.prompts.keys()))
|
|
115
|
+
keys.append(key)
|
|
116
|
+
seen[key] = seen.get(key, 0) + 1
|
|
117
|
+
for key in keys:
|
|
118
|
+
duplicates = seen[key] - 1
|
|
119
|
+
penalties.append(diversity_lambda * float(duplicates))
|
|
120
|
+
return penalties
|
|
121
|
+
|
|
122
|
+
def _tournament_select(pop: Sequence[Individual], k: int) -> Individual:
|
|
123
|
+
contenders = random.sample(pop, k=min(k, len(pop)))
|
|
124
|
+
return max(contenders, key=lambda i: (i.scalar_fitness or 0.0))
|
|
125
|
+
|
|
126
|
+
# ------------- discover space ------------- #
|
|
127
|
+
prompt_space: dict[str, tuple[str, str]] = {
|
|
128
|
+
k: (v.prompt, v.prompt_purpose)
|
|
129
|
+
for k, v in full_space.items() if v.is_prompt
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
if not prompt_space:
|
|
133
|
+
logger.info("No prompts to optimize – skipping.")
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
metric_cfg = optimizer_config.eval_metrics
|
|
137
|
+
if metric_cfg is None or len(metric_cfg) == 0:
|
|
138
|
+
raise ValueError("optimizer_config.eval_metrics must be provided for GA prompt optimization")
|
|
139
|
+
|
|
140
|
+
directions = [v.direction for v in metric_cfg.values()] # "minimize" or "maximize"
|
|
141
|
+
eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
|
|
142
|
+
weights = [v.weight for v in metric_cfg.values()]
|
|
143
|
+
|
|
144
|
+
out_dir = optimizer_config.output_path
|
|
145
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
146
|
+
|
|
147
|
+
# ------------- builder & functions ------------- #
|
|
148
|
+
async with WorkflowBuilder(general_config=base_cfg.general, registry=None) as builder:
|
|
149
|
+
await builder.populate_builder(base_cfg)
|
|
150
|
+
init_fn_name = (optimizer_config.prompt.prompt_population_init_function)
|
|
151
|
+
if not init_fn_name:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
"No prompt optimization function configured. Set optimizer.prompt_population_init_function")
|
|
154
|
+
init_fn = builder.get_function(init_fn_name)
|
|
155
|
+
|
|
156
|
+
recombine_fn = None
|
|
157
|
+
if optimizer_config.prompt.prompt_recombination_function:
|
|
158
|
+
recombine_fn = builder.get_function(optimizer_config.prompt.prompt_recombination_function)
|
|
159
|
+
|
|
160
|
+
logger.info(
|
|
161
|
+
"GA Prompt optimization ready: init_fn=%s, recombine_fn=%s",
|
|
162
|
+
init_fn_name,
|
|
163
|
+
optimizer_config.prompt.prompt_recombination_function,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# ------------- GA parameters ------------- #
|
|
167
|
+
pop_size = max(2, int(optimizer_config.prompt.ga_population_size))
|
|
168
|
+
generations = max(1, int(optimizer_config.prompt.ga_generations))
|
|
169
|
+
offspring_size = (optimizer_config.prompt.ga_offspring_size
|
|
170
|
+
or max(0, pop_size - optimizer_config.prompt.ga_elitism))
|
|
171
|
+
crossover_rate = float(optimizer_config.prompt.ga_crossover_rate)
|
|
172
|
+
mutation_rate = float(optimizer_config.prompt.ga_mutation_rate)
|
|
173
|
+
elitism = max(0, int(optimizer_config.prompt.ga_elitism))
|
|
174
|
+
selection_method = optimizer_config.prompt.ga_selection_method.lower()
|
|
175
|
+
tournament_size = max(2, int(optimizer_config.prompt.ga_tournament_size))
|
|
176
|
+
max_eval_concurrency = max(1, int(optimizer_config.prompt.ga_parallel_evaluations))
|
|
177
|
+
diversity_lambda = float(optimizer_config.prompt.ga_diversity_lambda)
|
|
178
|
+
|
|
179
|
+
# ------------- population init ------------- #
|
|
180
|
+
async def _mutate_prompt(original_prompt: str, purpose: str) -> str:
|
|
181
|
+
# Use LLM-based optimizer with no feedback
|
|
182
|
+
return await init_fn.acall_invoke(
|
|
183
|
+
PromptOptimizerInputSchema(
|
|
184
|
+
original_prompt=original_prompt,
|
|
185
|
+
objective=purpose,
|
|
186
|
+
oracle_feedback=None,
|
|
187
|
+
))
|
|
188
|
+
|
|
189
|
+
async def _recombine_prompts(a: str, b: str, purpose: str) -> str:
|
|
190
|
+
if recombine_fn is None:
|
|
191
|
+
# Fallback: uniform choice per recombination
|
|
192
|
+
return random.choice([a, b])
|
|
193
|
+
payload = {"original_prompt": a, "objective": purpose, "oracle_feedback": None, "parent_b": b}
|
|
194
|
+
return await recombine_fn.acall_invoke(payload)
|
|
195
|
+
|
|
196
|
+
def _make_individual_from_prompts(prompts: dict[str, str]) -> Individual:
|
|
197
|
+
return Individual(prompts=dict(prompts))
|
|
198
|
+
|
|
199
|
+
async def _initial_population() -> list[Individual]:
|
|
200
|
+
individuals: list[Individual] = []
|
|
201
|
+
# Ensure first individual is the original prompts
|
|
202
|
+
originals = {k: prompt_space[k][0] for k in prompt_space}
|
|
203
|
+
individuals.append(_make_individual_from_prompts(originals))
|
|
204
|
+
|
|
205
|
+
init_sem = asyncio.Semaphore(max_eval_concurrency)
|
|
206
|
+
|
|
207
|
+
async def _create_random_individual() -> Individual:
|
|
208
|
+
async with init_sem:
|
|
209
|
+
mutated: dict[str, str] = {}
|
|
210
|
+
for param, (base_prompt, purpose) in prompt_space.items():
|
|
211
|
+
try:
|
|
212
|
+
new_p = await _mutate_prompt(base_prompt, purpose)
|
|
213
|
+
except Exception as e:
|
|
214
|
+
logger.warning("Mutation failed for %s: %s; using original.", param, e)
|
|
215
|
+
new_p = base_prompt
|
|
216
|
+
mutated[param] = new_p
|
|
217
|
+
return _make_individual_from_prompts(mutated)
|
|
218
|
+
|
|
219
|
+
needed = max(0, pop_size - 1)
|
|
220
|
+
tasks = [_create_random_individual() for _ in range(needed)]
|
|
221
|
+
individuals.extend(await asyncio.gather(*tasks))
|
|
222
|
+
return individuals
|
|
223
|
+
|
|
224
|
+
# ------------- evaluation ------------- #
|
|
225
|
+
reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
|
|
226
|
+
|
|
227
|
+
sem = asyncio.Semaphore(max_eval_concurrency)
|
|
228
|
+
|
|
229
|
+
async def _evaluate(ind: Individual) -> Individual:
|
|
230
|
+
async with sem:
|
|
231
|
+
cfg_trial = apply_suggestions(base_cfg, ind.prompts)
|
|
232
|
+
eval_cfg = EvaluationRunConfig(
|
|
233
|
+
config_file=cfg_trial,
|
|
234
|
+
dataset=opt_run_config.dataset,
|
|
235
|
+
result_json_path=opt_run_config.result_json_path,
|
|
236
|
+
endpoint=opt_run_config.endpoint,
|
|
237
|
+
endpoint_timeout=opt_run_config.endpoint_timeout,
|
|
238
|
+
override=opt_run_config.override,
|
|
239
|
+
)
|
|
240
|
+
# Run reps sequentially under the same semaphore to avoid overload
|
|
241
|
+
all_results: list[list[tuple[str, Any]]] = []
|
|
242
|
+
for _ in range(reps):
|
|
243
|
+
res = (await EvaluationRun(config=eval_cfg).run_and_evaluate()).evaluation_results
|
|
244
|
+
all_results.append(res)
|
|
245
|
+
|
|
246
|
+
metrics: dict[str, float] = {}
|
|
247
|
+
for metric_name in eval_metrics:
|
|
248
|
+
scores: list[float] = []
|
|
249
|
+
for run_results in all_results:
|
|
250
|
+
for name, result in run_results:
|
|
251
|
+
if name == metric_name:
|
|
252
|
+
scores.append(result.average_score)
|
|
253
|
+
break
|
|
254
|
+
metrics[metric_name] = float(sum(scores) / len(scores)) if scores else 0.0
|
|
255
|
+
ind.metrics = metrics
|
|
256
|
+
return ind
|
|
257
|
+
|
|
258
|
+
async def _evaluate_population(pop: list[Individual]) -> list[Individual]:
|
|
259
|
+
# Evaluate those missing metrics
|
|
260
|
+
unevaluated = [ind for ind in pop if not ind.metrics]
|
|
261
|
+
if unevaluated:
|
|
262
|
+
evaluated = await asyncio.gather(*[_evaluate(ind) for ind in unevaluated])
|
|
263
|
+
# in-place update
|
|
264
|
+
for ind, ev in zip(unevaluated, evaluated):
|
|
265
|
+
ind.metrics = ev.metrics
|
|
266
|
+
# Scalarize
|
|
267
|
+
norm_per_ind = _normalize_generation(pop, eval_metrics, directions)
|
|
268
|
+
penalties = _apply_diversity_penalty(pop, diversity_lambda)
|
|
269
|
+
for ind, norm_scores, penalty in zip(pop, norm_per_ind, penalties):
|
|
270
|
+
ind.scalar_fitness = _scalarize(
|
|
271
|
+
norm_scores, mode=optimizer_config.multi_objective_combination_mode, weights=weights) - penalty
|
|
272
|
+
return pop
|
|
273
|
+
|
|
274
|
+
# ------------- reproduction ops ------------- #
|
|
275
|
+
async def _make_child(parent_a: Individual, parent_b: Individual) -> Individual:
|
|
276
|
+
child_prompts: dict[str, str] = {}
|
|
277
|
+
for param, (base_prompt, purpose) in prompt_space.items():
|
|
278
|
+
pa = parent_a.prompts.get(param, base_prompt)
|
|
279
|
+
pb = parent_b.prompts.get(param, base_prompt)
|
|
280
|
+
child = pa
|
|
281
|
+
# crossover
|
|
282
|
+
if random.random() < crossover_rate:
|
|
283
|
+
try:
|
|
284
|
+
child = await _recombine_prompts(pa, pb, purpose)
|
|
285
|
+
except Exception as e:
|
|
286
|
+
logger.warning("Recombination failed for %s: %s; falling back to parent.", param, e)
|
|
287
|
+
child = random.choice([pa, pb])
|
|
288
|
+
# mutation
|
|
289
|
+
if random.random() < mutation_rate:
|
|
290
|
+
try:
|
|
291
|
+
child = await _mutate_prompt(child, purpose)
|
|
292
|
+
except Exception as e:
|
|
293
|
+
logger.warning("Mutation failed for %s: %s; keeping child as-is.", param, e)
|
|
294
|
+
child_prompts[param] = child
|
|
295
|
+
return _make_individual_from_prompts(child_prompts)
|
|
296
|
+
|
|
297
|
+
# ------------- GA loop ------------- #
|
|
298
|
+
population = await _initial_population()
|
|
299
|
+
history_rows: list[dict[str, Any]] = []
|
|
300
|
+
|
|
301
|
+
for gen in range(1, generations + 1):
|
|
302
|
+
logger.info("[GA] Generation %d/%d: evaluating population of %d", gen, generations, len(population))
|
|
303
|
+
population = await _evaluate_population(population)
|
|
304
|
+
|
|
305
|
+
# Log and save checkpoint
|
|
306
|
+
best = max(population, key=lambda i: (i.scalar_fitness or 0.0))
|
|
307
|
+
checkpoint = {k: (best.prompts[k], prompt_space[k][1]) for k in prompt_space}
|
|
308
|
+
checkpoint_path = out_dir / f"optimized_prompts_gen{gen}.json"
|
|
309
|
+
with checkpoint_path.open("w") as fh:
|
|
310
|
+
json.dump(checkpoint, fh, indent=2)
|
|
311
|
+
logger.info("[GA] Saved checkpoint: %s (fitness=%.4f)", checkpoint_path, best.scalar_fitness or 0.0)
|
|
312
|
+
|
|
313
|
+
# Append history
|
|
314
|
+
for idx, ind in enumerate(population):
|
|
315
|
+
row = {
|
|
316
|
+
"generation": gen,
|
|
317
|
+
"index": idx,
|
|
318
|
+
"scalar_fitness": ind.scalar_fitness,
|
|
319
|
+
}
|
|
320
|
+
if ind.metrics:
|
|
321
|
+
row.update({f"metric::{m}": ind.metrics[m] for m in eval_metrics})
|
|
322
|
+
history_rows.append(row)
|
|
323
|
+
|
|
324
|
+
# Next generation via elitism + reproduction
|
|
325
|
+
next_population: list[Individual] = []
|
|
326
|
+
if elitism > 0:
|
|
327
|
+
elites = sorted(population, key=lambda i: (i.scalar_fitness or 0.0), reverse=True)[:elitism]
|
|
328
|
+
next_population.extend([_make_individual_from_prompts(e.prompts) for e in elites])
|
|
329
|
+
|
|
330
|
+
def _select_parent(curr_pop: list[Individual]) -> Individual:
|
|
331
|
+
if selection_method == "tournament":
|
|
332
|
+
return _tournament_select(curr_pop, tournament_size)
|
|
333
|
+
# roulette wheel
|
|
334
|
+
total = sum(max(ind.scalar_fitness or 0.0, 0.0) for ind in curr_pop) or 1.0
|
|
335
|
+
r = random.random() * total
|
|
336
|
+
acc = 0.0
|
|
337
|
+
for ind in curr_pop:
|
|
338
|
+
acc += max(ind.scalar_fitness or 0.0, 0.0)
|
|
339
|
+
if acc >= r:
|
|
340
|
+
return ind
|
|
341
|
+
return curr_pop[-1]
|
|
342
|
+
|
|
343
|
+
# Produce offspring
|
|
344
|
+
needed = pop_size - len(next_population)
|
|
345
|
+
offspring: list[Individual] = []
|
|
346
|
+
for _ in range(max(0, offspring_size), needed):
|
|
347
|
+
pass # ensure bound correctness
|
|
348
|
+
while len(offspring) < needed:
|
|
349
|
+
p1 = _select_parent(population)
|
|
350
|
+
p2 = _select_parent(population)
|
|
351
|
+
if p2 is p1 and len(population) > 1:
|
|
352
|
+
p2 = random.choice([ind for ind in population if ind is not p1])
|
|
353
|
+
child = await _make_child(p1, p2)
|
|
354
|
+
offspring.append(child)
|
|
355
|
+
|
|
356
|
+
population = next_population + offspring
|
|
357
|
+
|
|
358
|
+
# Final evaluation to ensure metrics present
|
|
359
|
+
population = await _evaluate_population(population)
|
|
360
|
+
best = max(population, key=lambda i: (i.scalar_fitness or 0.0))
|
|
361
|
+
best_prompts = {k: (best.prompts[k], prompt_space[k][1]) for k in prompt_space}
|
|
362
|
+
|
|
363
|
+
# Save final
|
|
364
|
+
final_prompts_path = out_dir / "optimized_prompts.json"
|
|
365
|
+
with final_prompts_path.open("w") as fh:
|
|
366
|
+
json.dump(best_prompts, fh, indent=2)
|
|
367
|
+
|
|
368
|
+
trials_df_path = out_dir / "ga_history_prompts.csv"
|
|
369
|
+
try:
|
|
370
|
+
# Lazy import pandas if available; otherwise write CSV manually
|
|
371
|
+
import csv # pylint: disable=import-outside-toplevel
|
|
372
|
+
|
|
373
|
+
fieldnames: list[str] = sorted({k for row in history_rows for k in row.keys()})
|
|
374
|
+
with trials_df_path.open("w", newline="") as f:
|
|
375
|
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
376
|
+
writer.writeheader()
|
|
377
|
+
for row in history_rows:
|
|
378
|
+
writer.writerow(row)
|
|
379
|
+
except Exception as e: # pragma: no cover - best effort
|
|
380
|
+
logger.warning("Failed to write GA history CSV: %s", e)
|
|
381
|
+
|
|
382
|
+
logger.info("Prompt GA optimization finished successfully!")
|
|
383
|
+
logger.info("Final prompts saved to: %s", final_prompts_path)
|
|
384
|
+
logger.info("History saved to: %s", trials_df_path)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _deep_merge_dict(target: dict[str, Any], updates: dict[str, Any]) -> None:
|
|
23
|
+
"""In-place deep merge of nested dictionaries."""
|
|
24
|
+
for key, value in updates.items():
|
|
25
|
+
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
|
26
|
+
_deep_merge_dict(target[key], value)
|
|
27
|
+
else:
|
|
28
|
+
target[key] = value
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def nest_updates(flat: dict[str, Any]) -> dict[str, Any]:
|
|
32
|
+
"""
|
|
33
|
+
Convert ``{'a.b.c': 1, 'd.x.y': 2}`` ➜
|
|
34
|
+
``{'a': {'b': {'c': 1}}, 'd': {'x': {'y': 2}}}``.
|
|
35
|
+
Works even when the middle segment is a dict key.
|
|
36
|
+
"""
|
|
37
|
+
root: dict[str, Any] = defaultdict(dict)
|
|
38
|
+
|
|
39
|
+
for dotted, value in flat.items():
|
|
40
|
+
head, *rest = dotted.split(".", 1)
|
|
41
|
+
if not rest: # leaf
|
|
42
|
+
root[head] = value
|
|
43
|
+
continue
|
|
44
|
+
|
|
45
|
+
tail = rest[0]
|
|
46
|
+
child_updates = nest_updates({tail: value})
|
|
47
|
+
if isinstance(root[head], dict):
|
|
48
|
+
_deep_merge_dict(root[head], child_updates)
|
|
49
|
+
else:
|
|
50
|
+
root[head] = child_updates
|
|
51
|
+
return dict(root)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def apply_suggestions(cfg: BaseModel, flat: dict[str, Any]) -> BaseModel:
|
|
55
|
+
"""
|
|
56
|
+
Return a **new** config where only the dotted-path keys in *flat*
|
|
57
|
+
have been modified. Preserves all unrelated siblings.
|
|
58
|
+
"""
|
|
59
|
+
cfg_dict = cfg.model_dump(mode="python")
|
|
60
|
+
for dotted, value in flat.items():
|
|
61
|
+
keys = dotted.split(".")
|
|
62
|
+
cursor = cfg_dict
|
|
63
|
+
for key in keys[:-1]:
|
|
64
|
+
cursor = cursor.setdefault(key, {})
|
|
65
|
+
cursor[keys[-1]] = value
|
|
66
|
+
return cfg.__class__.model_validate(cfg_dict)
|
nat/profiler/utils.py
CHANGED
|
@@ -22,6 +22,7 @@ from typing import Any
|
|
|
22
22
|
import pandas as pd
|
|
23
23
|
|
|
24
24
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
25
|
+
from nat.cli.type_registry import RegisteredFunctionGroupInfo
|
|
25
26
|
from nat.cli.type_registry import RegisteredFunctionInfo
|
|
26
27
|
from nat.data_models.intermediate_step import IntermediateStep
|
|
27
28
|
from nat.profiler.data_frame_row import DataFrameRow
|
|
@@ -32,7 +33,8 @@ _FRAMEWORK_REGEX_MAP = {t: fr'\b{t._name_}\b' for t in LLMFrameworkEnum}
|
|
|
32
33
|
logger = logging.getLogger(__name__)
|
|
33
34
|
|
|
34
35
|
|
|
35
|
-
def detect_llm_frameworks_in_build_fn(
|
|
36
|
+
def detect_llm_frameworks_in_build_fn(
|
|
37
|
+
registration: RegisteredFunctionInfo | RegisteredFunctionGroupInfo) -> list[LLMFrameworkEnum]:
|
|
36
38
|
"""
|
|
37
39
|
Analyze a function's source (the build_fn) to see which LLM frameworks it uses. Also recurses
|
|
38
40
|
into any additional Python functions that the build_fn calls while passing `builder`, so that
|
nat/tool/chat_completion.py
CHANGED
|
@@ -63,7 +63,10 @@ async def register_chat_completion(config: ChatCompletionConfig, builder: Builde
|
|
|
63
63
|
# Generate response using the LLM
|
|
64
64
|
response = await llm.ainvoke(prompt)
|
|
65
65
|
|
|
66
|
-
|
|
66
|
+
if isinstance(response, str):
|
|
67
|
+
return response
|
|
68
|
+
|
|
69
|
+
return response.text()
|
|
67
70
|
|
|
68
71
|
except Exception as e:
|
|
69
72
|
# Fallback response if LLM call fails
|