nvidia-nat 1.3.0a20250909__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. nat/agent/base.py +11 -6
  2. nat/agent/dual_node.py +2 -2
  3. nat/agent/prompt_optimizer/prompt.py +68 -0
  4. nat/agent/prompt_optimizer/register.py +149 -0
  5. nat/agent/react_agent/agent.py +1 -1
  6. nat/agent/react_agent/register.py +17 -7
  7. nat/agent/reasoning_agent/reasoning_agent.py +6 -1
  8. nat/agent/register.py +2 -0
  9. nat/agent/rewoo_agent/agent.py +6 -3
  10. nat/agent/rewoo_agent/register.py +16 -10
  11. nat/agent/router_agent/__init__.py +0 -0
  12. nat/agent/router_agent/agent.py +329 -0
  13. nat/agent/router_agent/prompt.py +48 -0
  14. nat/agent/router_agent/register.py +97 -0
  15. nat/agent/tool_calling_agent/agent.py +69 -7
  16. nat/agent/tool_calling_agent/register.py +17 -9
  17. nat/builder/builder.py +27 -4
  18. nat/builder/component_utils.py +7 -3
  19. nat/builder/function.py +167 -0
  20. nat/builder/function_info.py +1 -1
  21. nat/builder/workflow.py +5 -0
  22. nat/builder/workflow_builder.py +213 -16
  23. nat/cli/commands/optimize.py +90 -0
  24. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  25. nat/cli/commands/workflow/workflow_commands.py +5 -8
  26. nat/cli/entrypoint.py +2 -0
  27. nat/cli/register_workflow.py +38 -4
  28. nat/cli/type_registry.py +71 -0
  29. nat/data_models/api_server.py +1 -1
  30. nat/data_models/component.py +2 -0
  31. nat/data_models/component_ref.py +11 -0
  32. nat/data_models/config.py +40 -16
  33. nat/data_models/function.py +34 -0
  34. nat/data_models/function_dependencies.py +8 -0
  35. nat/data_models/optimizable.py +119 -0
  36. nat/data_models/optimizer.py +149 -0
  37. nat/data_models/temperature_mixin.py +4 -3
  38. nat/data_models/top_p_mixin.py +4 -3
  39. nat/embedder/nim_embedder.py +1 -1
  40. nat/embedder/openai_embedder.py +1 -1
  41. nat/eval/config.py +1 -1
  42. nat/eval/evaluate.py +5 -1
  43. nat/eval/register.py +4 -0
  44. nat/eval/runtime_evaluator/__init__.py +14 -0
  45. nat/eval/runtime_evaluator/evaluate.py +123 -0
  46. nat/eval/runtime_evaluator/register.py +100 -0
  47. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
  48. nat/front_ends/fastapi/dask_client_mixin.py +43 -0
  49. nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
  50. nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
  51. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
  52. nat/front_ends/fastapi/job_store.py +518 -99
  53. nat/front_ends/fastapi/main.py +11 -19
  54. nat/front_ends/fastapi/utils.py +57 -0
  55. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
  56. nat/llm/aws_bedrock_llm.py +15 -4
  57. nat/llm/nim_llm.py +14 -3
  58. nat/llm/openai_llm.py +8 -1
  59. nat/observability/exporter/processing_exporter.py +29 -55
  60. nat/observability/mixin/redaction_config_mixin.py +5 -4
  61. nat/observability/mixin/tagging_config_mixin.py +26 -14
  62. nat/observability/mixin/type_introspection_mixin.py +401 -107
  63. nat/observability/processor/processor.py +3 -0
  64. nat/observability/processor/redaction/__init__.py +24 -0
  65. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  66. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  67. nat/observability/processor/redaction/redaction_processor.py +177 -0
  68. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  69. nat/observability/processor/span_tagging_processor.py +21 -14
  70. nat/profiler/decorators/framework_wrapper.py +9 -6
  71. nat/profiler/parameter_optimization/__init__.py +0 -0
  72. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  73. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  74. nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
  75. nat/profiler/parameter_optimization/parameter_selection.py +108 -0
  76. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  77. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  78. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  79. nat/profiler/utils.py +3 -1
  80. nat/tool/chat_completion.py +5 -2
  81. nat/tool/document_search.py +1 -1
  82. nat/tool/github_tools.py +450 -0
  83. nat/tool/register.py +2 -7
  84. nat/utils/callable_utils.py +70 -0
  85. nat/utils/exception_handlers/automatic_retries.py +103 -48
  86. nat/utils/type_utils.py +4 -0
  87. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
  88. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +94 -74
  89. nat/observability/processor/header_redaction_processor.py +0 -123
  90. nat/observability/processor/redaction_processor.py +0 -77
  91. nat/tool/github_tools/create_github_commit.py +0 -133
  92. nat/tool/github_tools/create_github_issue.py +0 -87
  93. nat/tool/github_tools/create_github_pr.py +0 -106
  94. nat/tool/github_tools/get_github_file.py +0 -106
  95. nat/tool/github_tools/get_github_issue.py +0 -166
  96. nat/tool/github_tools/get_github_pr.py +0 -256
  97. nat/tool/github_tools/update_github_issue.py +0 -100
  98. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  99. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
  100. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
  101. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  102. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
  103. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,384 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import json
18
+ import logging
19
+ import random
20
+ from collections.abc import Sequence
21
+ from dataclasses import dataclass
22
+ from typing import Any
23
+
24
+ from pydantic import BaseModel
25
+
26
+ from nat.builder.workflow_builder import WorkflowBuilder
27
+ from nat.data_models.config import Config
28
+ from nat.data_models.optimizable import SearchSpace
29
+ from nat.data_models.optimizer import OptimizerConfig
30
+ from nat.data_models.optimizer import OptimizerRunConfig
31
+ from nat.eval.evaluate import EvaluationRun
32
+ from nat.eval.evaluate import EvaluationRunConfig
33
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
34
+ from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class PromptOptimizerInputSchema(BaseModel):
40
+ original_prompt: str
41
+ objective: str
42
+ oracle_feedback: str | None = None
43
+
44
+
45
+ @experimental(feature_name="Optimizer")
46
+ async def optimize_prompts(
47
+ *,
48
+ base_cfg: Config,
49
+ full_space: dict[str, SearchSpace],
50
+ optimizer_config: OptimizerConfig,
51
+ opt_run_config: OptimizerRunConfig,
52
+ ) -> None:
53
+
54
+ # ------------- helpers ------------- #
55
+ @dataclass
56
+ class Individual:
57
+ prompts: dict[str, str] # param_name -> prompt text
58
+ metrics: dict[str, float] | None = None # evaluator_name -> average score
59
+ scalar_fitness: float | None = None
60
+
61
+ def _normalize_generation(
62
+ individuals: Sequence[Individual],
63
+ metric_names: Sequence[str],
64
+ directions: Sequence[str],
65
+ eps: float = 1e-12,
66
+ ) -> list[dict[str, float]]:
67
+ """Return per-individual dict of normalised scores in [0,1] where higher is better."""
68
+ # Extract arrays per metric
69
+ arrays = {m: [ind.metrics.get(m, 0.0) if ind.metrics else 0.0 for ind in individuals] for m in metric_names}
70
+ normed: list[dict[str, float]] = []
71
+ for i in range(len(individuals)):
72
+ entry: dict[str, float] = {}
73
+ for m, dirn in zip(metric_names, directions):
74
+ vals = arrays[m]
75
+ vmin = min(vals)
76
+ vmax = max(vals)
77
+ v = vals[i]
78
+ # Map to [0,1] with higher=better regardless of direction
79
+ if vmax - vmin < eps:
80
+ score01 = 0.5
81
+ else:
82
+ score01 = (v - vmin) / (vmax - vmin)
83
+ if dirn == "minimize":
84
+ score01 = 1.0 - score01
85
+ entry[m] = float(score01)
86
+ normed.append(entry)
87
+ return normed
88
+
89
+ def _scalarize(norm_scores: dict[str, float], *, mode: str, weights: Sequence[float] | None) -> float:
90
+ """Collapse normalised scores to a single scalar (higher is better)."""
91
+ vals = list(norm_scores.values())
92
+ if not vals:
93
+ return 0.0
94
+ if mode == "harmonic":
95
+ inv_sum = sum(1.0 / max(v, 1e-12) for v in vals)
96
+ return len(vals) / max(inv_sum, 1e-12)
97
+ if mode == "sum":
98
+ if weights is None:
99
+ return float(sum(vals))
100
+ if len(weights) != len(vals):
101
+ raise ValueError("weights length must equal number of objectives")
102
+ return float(sum(w * v for w, v in zip(weights, vals)))
103
+ if mode == "chebyshev":
104
+ return float(min(vals)) # maximise the worst-case score
105
+ raise ValueError(f"Unknown combination mode: {mode}")
106
+
107
+ def _apply_diversity_penalty(individuals: Sequence[Individual], diversity_lambda: float) -> list[float]:
108
+ if diversity_lambda <= 0.0:
109
+ return [0.0 for _ in individuals]
110
+ seen: dict[str, int] = {}
111
+ keys: list[str] = []
112
+ penalties: list[float] = []
113
+ for ind in individuals:
114
+ key = "\u241f".join(ind.prompts.get(k, "") for k in sorted(ind.prompts.keys()))
115
+ keys.append(key)
116
+ seen[key] = seen.get(key, 0) + 1
117
+ for key in keys:
118
+ duplicates = seen[key] - 1
119
+ penalties.append(diversity_lambda * float(duplicates))
120
+ return penalties
121
+
122
+ def _tournament_select(pop: Sequence[Individual], k: int) -> Individual:
123
+ contenders = random.sample(pop, k=min(k, len(pop)))
124
+ return max(contenders, key=lambda i: (i.scalar_fitness or 0.0))
125
+
126
+ # ------------- discover space ------------- #
127
+ prompt_space: dict[str, tuple[str, str]] = {
128
+ k: (v.prompt, v.prompt_purpose)
129
+ for k, v in full_space.items() if v.is_prompt
130
+ }
131
+
132
+ if not prompt_space:
133
+ logger.info("No prompts to optimize – skipping.")
134
+ return
135
+
136
+ metric_cfg = optimizer_config.eval_metrics
137
+ if metric_cfg is None or len(metric_cfg) == 0:
138
+ raise ValueError("optimizer_config.eval_metrics must be provided for GA prompt optimization")
139
+
140
+ directions = [v.direction for v in metric_cfg.values()] # "minimize" or "maximize"
141
+ eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
142
+ weights = [v.weight for v in metric_cfg.values()]
143
+
144
+ out_dir = optimizer_config.output_path
145
+ out_dir.mkdir(parents=True, exist_ok=True)
146
+
147
+ # ------------- builder & functions ------------- #
148
+ async with WorkflowBuilder(general_config=base_cfg.general, registry=None) as builder:
149
+ await builder.populate_builder(base_cfg)
150
+ init_fn_name = (optimizer_config.prompt.prompt_population_init_function)
151
+ if not init_fn_name:
152
+ raise ValueError(
153
+ "No prompt optimization function configured. Set optimizer.prompt_population_init_function")
154
+ init_fn = builder.get_function(init_fn_name)
155
+
156
+ recombine_fn = None
157
+ if optimizer_config.prompt.prompt_recombination_function:
158
+ recombine_fn = builder.get_function(optimizer_config.prompt.prompt_recombination_function)
159
+
160
+ logger.info(
161
+ "GA Prompt optimization ready: init_fn=%s, recombine_fn=%s",
162
+ init_fn_name,
163
+ optimizer_config.prompt.prompt_recombination_function,
164
+ )
165
+
166
+ # ------------- GA parameters ------------- #
167
+ pop_size = max(2, int(optimizer_config.prompt.ga_population_size))
168
+ generations = max(1, int(optimizer_config.prompt.ga_generations))
169
+ offspring_size = (optimizer_config.prompt.ga_offspring_size
170
+ or max(0, pop_size - optimizer_config.prompt.ga_elitism))
171
+ crossover_rate = float(optimizer_config.prompt.ga_crossover_rate)
172
+ mutation_rate = float(optimizer_config.prompt.ga_mutation_rate)
173
+ elitism = max(0, int(optimizer_config.prompt.ga_elitism))
174
+ selection_method = optimizer_config.prompt.ga_selection_method.lower()
175
+ tournament_size = max(2, int(optimizer_config.prompt.ga_tournament_size))
176
+ max_eval_concurrency = max(1, int(optimizer_config.prompt.ga_parallel_evaluations))
177
+ diversity_lambda = float(optimizer_config.prompt.ga_diversity_lambda)
178
+
179
+ # ------------- population init ------------- #
180
+ async def _mutate_prompt(original_prompt: str, purpose: str) -> str:
181
+ # Use LLM-based optimizer with no feedback
182
+ return await init_fn.acall_invoke(
183
+ PromptOptimizerInputSchema(
184
+ original_prompt=original_prompt,
185
+ objective=purpose,
186
+ oracle_feedback=None,
187
+ ))
188
+
189
+ async def _recombine_prompts(a: str, b: str, purpose: str) -> str:
190
+ if recombine_fn is None:
191
+ # Fallback: uniform choice per recombination
192
+ return random.choice([a, b])
193
+ payload = {"original_prompt": a, "objective": purpose, "oracle_feedback": None, "parent_b": b}
194
+ return await recombine_fn.acall_invoke(payload)
195
+
196
+ def _make_individual_from_prompts(prompts: dict[str, str]) -> Individual:
197
+ return Individual(prompts=dict(prompts))
198
+
199
+ async def _initial_population() -> list[Individual]:
200
+ individuals: list[Individual] = []
201
+ # Ensure first individual is the original prompts
202
+ originals = {k: prompt_space[k][0] for k in prompt_space}
203
+ individuals.append(_make_individual_from_prompts(originals))
204
+
205
+ init_sem = asyncio.Semaphore(max_eval_concurrency)
206
+
207
+ async def _create_random_individual() -> Individual:
208
+ async with init_sem:
209
+ mutated: dict[str, str] = {}
210
+ for param, (base_prompt, purpose) in prompt_space.items():
211
+ try:
212
+ new_p = await _mutate_prompt(base_prompt, purpose)
213
+ except Exception as e:
214
+ logger.warning("Mutation failed for %s: %s; using original.", param, e)
215
+ new_p = base_prompt
216
+ mutated[param] = new_p
217
+ return _make_individual_from_prompts(mutated)
218
+
219
+ needed = max(0, pop_size - 1)
220
+ tasks = [_create_random_individual() for _ in range(needed)]
221
+ individuals.extend(await asyncio.gather(*tasks))
222
+ return individuals
223
+
224
+ # ------------- evaluation ------------- #
225
+ reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
226
+
227
+ sem = asyncio.Semaphore(max_eval_concurrency)
228
+
229
+ async def _evaluate(ind: Individual) -> Individual:
230
+ async with sem:
231
+ cfg_trial = apply_suggestions(base_cfg, ind.prompts)
232
+ eval_cfg = EvaluationRunConfig(
233
+ config_file=cfg_trial,
234
+ dataset=opt_run_config.dataset,
235
+ result_json_path=opt_run_config.result_json_path,
236
+ endpoint=opt_run_config.endpoint,
237
+ endpoint_timeout=opt_run_config.endpoint_timeout,
238
+ override=opt_run_config.override,
239
+ )
240
+ # Run reps sequentially under the same semaphore to avoid overload
241
+ all_results: list[list[tuple[str, Any]]] = []
242
+ for _ in range(reps):
243
+ res = (await EvaluationRun(config=eval_cfg).run_and_evaluate()).evaluation_results
244
+ all_results.append(res)
245
+
246
+ metrics: dict[str, float] = {}
247
+ for metric_name in eval_metrics:
248
+ scores: list[float] = []
249
+ for run_results in all_results:
250
+ for name, result in run_results:
251
+ if name == metric_name:
252
+ scores.append(result.average_score)
253
+ break
254
+ metrics[metric_name] = float(sum(scores) / len(scores)) if scores else 0.0
255
+ ind.metrics = metrics
256
+ return ind
257
+
258
+ async def _evaluate_population(pop: list[Individual]) -> list[Individual]:
259
+ # Evaluate those missing metrics
260
+ unevaluated = [ind for ind in pop if not ind.metrics]
261
+ if unevaluated:
262
+ evaluated = await asyncio.gather(*[_evaluate(ind) for ind in unevaluated])
263
+ # in-place update
264
+ for ind, ev in zip(unevaluated, evaluated):
265
+ ind.metrics = ev.metrics
266
+ # Scalarize
267
+ norm_per_ind = _normalize_generation(pop, eval_metrics, directions)
268
+ penalties = _apply_diversity_penalty(pop, diversity_lambda)
269
+ for ind, norm_scores, penalty in zip(pop, norm_per_ind, penalties):
270
+ ind.scalar_fitness = _scalarize(
271
+ norm_scores, mode=optimizer_config.multi_objective_combination_mode, weights=weights) - penalty
272
+ return pop
273
+
274
+ # ------------- reproduction ops ------------- #
275
+ async def _make_child(parent_a: Individual, parent_b: Individual) -> Individual:
276
+ child_prompts: dict[str, str] = {}
277
+ for param, (base_prompt, purpose) in prompt_space.items():
278
+ pa = parent_a.prompts.get(param, base_prompt)
279
+ pb = parent_b.prompts.get(param, base_prompt)
280
+ child = pa
281
+ # crossover
282
+ if random.random() < crossover_rate:
283
+ try:
284
+ child = await _recombine_prompts(pa, pb, purpose)
285
+ except Exception as e:
286
+ logger.warning("Recombination failed for %s: %s; falling back to parent.", param, e)
287
+ child = random.choice([pa, pb])
288
+ # mutation
289
+ if random.random() < mutation_rate:
290
+ try:
291
+ child = await _mutate_prompt(child, purpose)
292
+ except Exception as e:
293
+ logger.warning("Mutation failed for %s: %s; keeping child as-is.", param, e)
294
+ child_prompts[param] = child
295
+ return _make_individual_from_prompts(child_prompts)
296
+
297
+ # ------------- GA loop ------------- #
298
+ population = await _initial_population()
299
+ history_rows: list[dict[str, Any]] = []
300
+
301
+ for gen in range(1, generations + 1):
302
+ logger.info("[GA] Generation %d/%d: evaluating population of %d", gen, generations, len(population))
303
+ population = await _evaluate_population(population)
304
+
305
+ # Log and save checkpoint
306
+ best = max(population, key=lambda i: (i.scalar_fitness or 0.0))
307
+ checkpoint = {k: (best.prompts[k], prompt_space[k][1]) for k in prompt_space}
308
+ checkpoint_path = out_dir / f"optimized_prompts_gen{gen}.json"
309
+ with checkpoint_path.open("w") as fh:
310
+ json.dump(checkpoint, fh, indent=2)
311
+ logger.info("[GA] Saved checkpoint: %s (fitness=%.4f)", checkpoint_path, best.scalar_fitness or 0.0)
312
+
313
+ # Append history
314
+ for idx, ind in enumerate(population):
315
+ row = {
316
+ "generation": gen,
317
+ "index": idx,
318
+ "scalar_fitness": ind.scalar_fitness,
319
+ }
320
+ if ind.metrics:
321
+ row.update({f"metric::{m}": ind.metrics[m] for m in eval_metrics})
322
+ history_rows.append(row)
323
+
324
+ # Next generation via elitism + reproduction
325
+ next_population: list[Individual] = []
326
+ if elitism > 0:
327
+ elites = sorted(population, key=lambda i: (i.scalar_fitness or 0.0), reverse=True)[:elitism]
328
+ next_population.extend([_make_individual_from_prompts(e.prompts) for e in elites])
329
+
330
+ def _select_parent(curr_pop: list[Individual]) -> Individual:
331
+ if selection_method == "tournament":
332
+ return _tournament_select(curr_pop, tournament_size)
333
+ # roulette wheel
334
+ total = sum(max(ind.scalar_fitness or 0.0, 0.0) for ind in curr_pop) or 1.0
335
+ r = random.random() * total
336
+ acc = 0.0
337
+ for ind in curr_pop:
338
+ acc += max(ind.scalar_fitness or 0.0, 0.0)
339
+ if acc >= r:
340
+ return ind
341
+ return curr_pop[-1]
342
+
343
+ # Produce offspring
344
+ needed = pop_size - len(next_population)
345
+ offspring: list[Individual] = []
346
+ for _ in range(max(0, offspring_size), needed):
347
+ pass # ensure bound correctness
348
+ while len(offspring) < needed:
349
+ p1 = _select_parent(population)
350
+ p2 = _select_parent(population)
351
+ if p2 is p1 and len(population) > 1:
352
+ p2 = random.choice([ind for ind in population if ind is not p1])
353
+ child = await _make_child(p1, p2)
354
+ offspring.append(child)
355
+
356
+ population = next_population + offspring
357
+
358
+ # Final evaluation to ensure metrics present
359
+ population = await _evaluate_population(population)
360
+ best = max(population, key=lambda i: (i.scalar_fitness or 0.0))
361
+ best_prompts = {k: (best.prompts[k], prompt_space[k][1]) for k in prompt_space}
362
+
363
+ # Save final
364
+ final_prompts_path = out_dir / "optimized_prompts.json"
365
+ with final_prompts_path.open("w") as fh:
366
+ json.dump(best_prompts, fh, indent=2)
367
+
368
+ trials_df_path = out_dir / "ga_history_prompts.csv"
369
+ try:
370
+ # Lazy import pandas if available; otherwise write CSV manually
371
+ import csv # pylint: disable=import-outside-toplevel
372
+
373
+ fieldnames: list[str] = sorted({k for row in history_rows for k in row.keys()})
374
+ with trials_df_path.open("w", newline="") as f:
375
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
376
+ writer.writeheader()
377
+ for row in history_rows:
378
+ writer.writerow(row)
379
+ except Exception as e: # pragma: no cover - best effort
380
+ logger.warning("Failed to write GA history CSV: %s", e)
381
+
382
+ logger.info("Prompt GA optimization finished successfully!")
383
+ logger.info("Final prompts saved to: %s", final_prompts_path)
384
+ logger.info("History saved to: %s", trials_df_path)
@@ -0,0 +1,66 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import defaultdict
17
+ from typing import Any
18
+
19
+ from pydantic import BaseModel
20
+
21
+
22
+ def _deep_merge_dict(target: dict[str, Any], updates: dict[str, Any]) -> None:
23
+ """In-place deep merge of nested dictionaries."""
24
+ for key, value in updates.items():
25
+ if key in target and isinstance(target[key], dict) and isinstance(value, dict):
26
+ _deep_merge_dict(target[key], value)
27
+ else:
28
+ target[key] = value
29
+
30
+
31
+ def nest_updates(flat: dict[str, Any]) -> dict[str, Any]:
32
+ """
33
+ Convert ``{'a.b.c': 1, 'd.x.y': 2}`` ➜
34
+ ``{'a': {'b': {'c': 1}}, 'd': {'x': {'y': 2}}}``.
35
+ Works even when the middle segment is a dict key.
36
+ """
37
+ root: dict[str, Any] = defaultdict(dict)
38
+
39
+ for dotted, value in flat.items():
40
+ head, *rest = dotted.split(".", 1)
41
+ if not rest: # leaf
42
+ root[head] = value
43
+ continue
44
+
45
+ tail = rest[0]
46
+ child_updates = nest_updates({tail: value})
47
+ if isinstance(root[head], dict):
48
+ _deep_merge_dict(root[head], child_updates)
49
+ else:
50
+ root[head] = child_updates
51
+ return dict(root)
52
+
53
+
54
+ def apply_suggestions(cfg: BaseModel, flat: dict[str, Any]) -> BaseModel:
55
+ """
56
+ Return a **new** config where only the dotted-path keys in *flat*
57
+ have been modified. Preserves all unrelated siblings.
58
+ """
59
+ cfg_dict = cfg.model_dump(mode="python")
60
+ for dotted, value in flat.items():
61
+ keys = dotted.split(".")
62
+ cursor = cfg_dict
63
+ for key in keys[:-1]:
64
+ cursor = cursor.setdefault(key, {})
65
+ cursor[keys[-1]] = value
66
+ return cfg.__class__.model_validate(cfg_dict)
nat/profiler/utils.py CHANGED
@@ -22,6 +22,7 @@ from typing import Any
22
22
  import pandas as pd
23
23
 
24
24
  from nat.builder.framework_enum import LLMFrameworkEnum
25
+ from nat.cli.type_registry import RegisteredFunctionGroupInfo
25
26
  from nat.cli.type_registry import RegisteredFunctionInfo
26
27
  from nat.data_models.intermediate_step import IntermediateStep
27
28
  from nat.profiler.data_frame_row import DataFrameRow
@@ -32,7 +33,8 @@ _FRAMEWORK_REGEX_MAP = {t: fr'\b{t._name_}\b' for t in LLMFrameworkEnum}
32
33
  logger = logging.getLogger(__name__)
33
34
 
34
35
 
35
- def detect_llm_frameworks_in_build_fn(registration: RegisteredFunctionInfo) -> list[LLMFrameworkEnum]:
36
+ def detect_llm_frameworks_in_build_fn(
37
+ registration: RegisteredFunctionInfo | RegisteredFunctionGroupInfo) -> list[LLMFrameworkEnum]:
36
38
  """
37
39
  Analyze a function's source (the build_fn) to see which LLM frameworks it uses. Also recurses
38
40
  into any additional Python functions that the build_fn calls while passing `builder`, so that
@@ -44,7 +44,7 @@ async def register_chat_completion(config: ChatCompletionConfig, builder: Builde
44
44
  """Registers a chat completion function that can handle natural language queries."""
45
45
 
46
46
  # Get the LLM from the builder context using the configured LLM reference
47
- # Use LangChain framework wrapper since we're using LangChain-based LLM
47
+ # Use LangChain/LangGraph framework wrapper since we're using LangChain/LangGraph-based LLM
48
48
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
49
49
 
50
50
  async def _chat_completion(query: str) -> str:
@@ -63,7 +63,10 @@ async def register_chat_completion(config: ChatCompletionConfig, builder: Builde
63
63
  # Generate response using the LLM
64
64
  response = await llm.ainvoke(prompt)
65
65
 
66
- return response
66
+ if isinstance(response, str):
67
+ return response
68
+
69
+ return response.text()
67
70
 
68
71
  except Exception as e:
69
72
  # Fallback response if LLM call fails
@@ -119,7 +119,7 @@ Return only the name of the predicted collection."""
119
119
  if len(results["chunks"]) == 0:
120
120
  return DocumentSearchOutput(collection_name=llm_pred.collection_name, documents="")
121
121
 
122
- # parse docs from Langchain Document object to string
122
+ # parse docs from LangChain/LangGraph Document object to string
123
123
  parsed_docs = []
124
124
 
125
125
  # iterate over results and store parsed content