nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,384 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import json
18
+ import logging
19
+ import random
20
+ from collections.abc import Sequence
21
+ from dataclasses import dataclass
22
+ from typing import Any
23
+
24
+ from pydantic import BaseModel
25
+
26
+ from nat.builder.workflow_builder import WorkflowBuilder
27
+ from nat.data_models.config import Config
28
+ from nat.data_models.optimizable import SearchSpace
29
+ from nat.data_models.optimizer import OptimizerConfig
30
+ from nat.data_models.optimizer import OptimizerRunConfig
31
+ from nat.eval.evaluate import EvaluationRun
32
+ from nat.eval.evaluate import EvaluationRunConfig
33
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
34
+ from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class PromptOptimizerInputSchema(BaseModel):
40
+ original_prompt: str
41
+ objective: str
42
+ oracle_feedback: str | None = None
43
+
44
+
45
+ @experimental(feature_name="Optimizer")
46
+ async def optimize_prompts(
47
+ *,
48
+ base_cfg: Config,
49
+ full_space: dict[str, SearchSpace],
50
+ optimizer_config: OptimizerConfig,
51
+ opt_run_config: OptimizerRunConfig,
52
+ ) -> None:
53
+
54
+ # ------------- helpers ------------- #
55
+ @dataclass
56
+ class Individual:
57
+ prompts: dict[str, str] # param_name -> prompt text
58
+ metrics: dict[str, float] | None = None # evaluator_name -> average score
59
+ scalar_fitness: float | None = None
60
+
61
+ def _normalize_generation(
62
+ individuals: Sequence[Individual],
63
+ metric_names: Sequence[str],
64
+ directions: Sequence[str],
65
+ eps: float = 1e-12,
66
+ ) -> list[dict[str, float]]:
67
+ """Return per-individual dict of normalised scores in [0,1] where higher is better."""
68
+ # Extract arrays per metric
69
+ arrays = {m: [ind.metrics.get(m, 0.0) if ind.metrics else 0.0 for ind in individuals] for m in metric_names}
70
+ normed: list[dict[str, float]] = []
71
+ for i in range(len(individuals)):
72
+ entry: dict[str, float] = {}
73
+ for m, dirn in zip(metric_names, directions):
74
+ vals = arrays[m]
75
+ vmin = min(vals)
76
+ vmax = max(vals)
77
+ v = vals[i]
78
+ # Map to [0,1] with higher=better regardless of direction
79
+ if vmax - vmin < eps:
80
+ score01 = 0.5
81
+ else:
82
+ score01 = (v - vmin) / (vmax - vmin)
83
+ if dirn == "minimize":
84
+ score01 = 1.0 - score01
85
+ entry[m] = float(score01)
86
+ normed.append(entry)
87
+ return normed
88
+
89
+ def _scalarize(norm_scores: dict[str, float], *, mode: str, weights: Sequence[float] | None) -> float:
90
+ """Collapse normalised scores to a single scalar (higher is better)."""
91
+ vals = list(norm_scores.values())
92
+ if not vals:
93
+ return 0.0
94
+ if mode == "harmonic":
95
+ inv_sum = sum(1.0 / max(v, 1e-12) for v in vals)
96
+ return len(vals) / max(inv_sum, 1e-12)
97
+ if mode == "sum":
98
+ if weights is None:
99
+ return float(sum(vals))
100
+ if len(weights) != len(vals):
101
+ raise ValueError("weights length must equal number of objectives")
102
+ return float(sum(w * v for w, v in zip(weights, vals)))
103
+ if mode == "chebyshev":
104
+ return float(min(vals)) # maximise the worst-case score
105
+ raise ValueError(f"Unknown combination mode: {mode}")
106
+
107
+ def _apply_diversity_penalty(individuals: Sequence[Individual], diversity_lambda: float) -> list[float]:
108
+ if diversity_lambda <= 0.0:
109
+ return [0.0 for _ in individuals]
110
+ seen: dict[str, int] = {}
111
+ keys: list[str] = []
112
+ penalties: list[float] = []
113
+ for ind in individuals:
114
+ key = "\u241f".join(ind.prompts.get(k, "") for k in sorted(ind.prompts.keys()))
115
+ keys.append(key)
116
+ seen[key] = seen.get(key, 0) + 1
117
+ for key in keys:
118
+ duplicates = seen[key] - 1
119
+ penalties.append(diversity_lambda * float(duplicates))
120
+ return penalties
121
+
122
+ def _tournament_select(pop: Sequence[Individual], k: int) -> Individual:
123
+ contenders = random.sample(pop, k=min(k, len(pop)))
124
+ return max(contenders, key=lambda i: (i.scalar_fitness or 0.0))
125
+
126
+ # ------------- discover space ------------- #
127
+ prompt_space: dict[str, tuple[str, str]] = {
128
+ k: (v.prompt, v.prompt_purpose)
129
+ for k, v in full_space.items() if v.is_prompt
130
+ }
131
+
132
+ if not prompt_space:
133
+ logger.info("No prompts to optimize – skipping.")
134
+ return
135
+
136
+ metric_cfg = optimizer_config.eval_metrics
137
+ if metric_cfg is None or len(metric_cfg) == 0:
138
+ raise ValueError("optimizer_config.eval_metrics must be provided for GA prompt optimization")
139
+
140
+ directions = [v.direction for v in metric_cfg.values()] # "minimize" or "maximize"
141
+ eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
142
+ weights = [v.weight for v in metric_cfg.values()]
143
+
144
+ out_dir = optimizer_config.output_path
145
+ out_dir.mkdir(parents=True, exist_ok=True)
146
+
147
+ # ------------- builder & functions ------------- #
148
+ async with WorkflowBuilder(general_config=base_cfg.general, registry=None) as builder:
149
+ await builder.populate_builder(base_cfg)
150
+ init_fn_name = (optimizer_config.prompt.prompt_population_init_function)
151
+ if not init_fn_name:
152
+ raise ValueError(
153
+ "No prompt optimization function configured. Set optimizer.prompt_population_init_function")
154
+ init_fn = await builder.get_function(init_fn_name)
155
+
156
+ recombine_fn = None
157
+ if optimizer_config.prompt.prompt_recombination_function:
158
+ recombine_fn = await builder.get_function(optimizer_config.prompt.prompt_recombination_function)
159
+
160
+ logger.info(
161
+ "GA Prompt optimization ready: init_fn=%s, recombine_fn=%s",
162
+ init_fn_name,
163
+ optimizer_config.prompt.prompt_recombination_function,
164
+ )
165
+
166
+ # ------------- GA parameters ------------- #
167
+ pop_size = max(2, int(optimizer_config.prompt.ga_population_size))
168
+ generations = max(1, int(optimizer_config.prompt.ga_generations))
169
+ offspring_size = (optimizer_config.prompt.ga_offspring_size
170
+ or max(0, pop_size - optimizer_config.prompt.ga_elitism))
171
+ crossover_rate = float(optimizer_config.prompt.ga_crossover_rate)
172
+ mutation_rate = float(optimizer_config.prompt.ga_mutation_rate)
173
+ elitism = max(0, int(optimizer_config.prompt.ga_elitism))
174
+ selection_method = optimizer_config.prompt.ga_selection_method.lower()
175
+ tournament_size = max(2, int(optimizer_config.prompt.ga_tournament_size))
176
+ max_eval_concurrency = max(1, int(optimizer_config.prompt.ga_parallel_evaluations))
177
+ diversity_lambda = float(optimizer_config.prompt.ga_diversity_lambda)
178
+
179
+ # ------------- population init ------------- #
180
+ async def _mutate_prompt(original_prompt: str, purpose: str) -> str:
181
+ # Use LLM-based optimizer with no feedback
182
+ return await init_fn.acall_invoke(
183
+ PromptOptimizerInputSchema(
184
+ original_prompt=original_prompt,
185
+ objective=purpose,
186
+ oracle_feedback=None,
187
+ ))
188
+
189
+ async def _recombine_prompts(a: str, b: str, purpose: str) -> str:
190
+ if recombine_fn is None:
191
+ # Fallback: uniform choice per recombination
192
+ return random.choice([a, b])
193
+ payload = {"original_prompt": a, "objective": purpose, "oracle_feedback": None, "parent_b": b}
194
+ return await recombine_fn.acall_invoke(payload)
195
+
196
+ def _make_individual_from_prompts(prompts: dict[str, str]) -> Individual:
197
+ return Individual(prompts=dict(prompts))
198
+
199
+ async def _initial_population() -> list[Individual]:
200
+ individuals: list[Individual] = []
201
+ # Ensure first individual is the original prompts
202
+ originals = {k: prompt_space[k][0] for k in prompt_space}
203
+ individuals.append(_make_individual_from_prompts(originals))
204
+
205
+ init_sem = asyncio.Semaphore(max_eval_concurrency)
206
+
207
+ async def _create_random_individual() -> Individual:
208
+ async with init_sem:
209
+ mutated: dict[str, str] = {}
210
+ for param, (base_prompt, purpose) in prompt_space.items():
211
+ try:
212
+ new_p = await _mutate_prompt(base_prompt, purpose)
213
+ except Exception as e:
214
+ logger.warning("Mutation failed for %s: %s; using original.", param, e)
215
+ new_p = base_prompt
216
+ mutated[param] = new_p
217
+ return _make_individual_from_prompts(mutated)
218
+
219
+ needed = max(0, pop_size - 1)
220
+ tasks = [_create_random_individual() for _ in range(needed)]
221
+ individuals.extend(await asyncio.gather(*tasks))
222
+ return individuals
223
+
224
+ # ------------- evaluation ------------- #
225
+ reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
226
+
227
+ sem = asyncio.Semaphore(max_eval_concurrency)
228
+
229
+ async def _evaluate(ind: Individual) -> Individual:
230
+ async with sem:
231
+ cfg_trial = apply_suggestions(base_cfg, ind.prompts)
232
+ eval_cfg = EvaluationRunConfig(
233
+ config_file=cfg_trial,
234
+ dataset=opt_run_config.dataset,
235
+ result_json_path=opt_run_config.result_json_path,
236
+ endpoint=opt_run_config.endpoint,
237
+ endpoint_timeout=opt_run_config.endpoint_timeout,
238
+ override=opt_run_config.override,
239
+ )
240
+ # Run reps sequentially under the same semaphore to avoid overload
241
+ all_results: list[list[tuple[str, Any]]] = []
242
+ for _ in range(reps):
243
+ res = (await EvaluationRun(config=eval_cfg).run_and_evaluate()).evaluation_results
244
+ all_results.append(res)
245
+
246
+ metrics: dict[str, float] = {}
247
+ for metric_name in eval_metrics:
248
+ scores: list[float] = []
249
+ for run_results in all_results:
250
+ for name, result in run_results:
251
+ if name == metric_name:
252
+ scores.append(result.average_score)
253
+ break
254
+ metrics[metric_name] = float(sum(scores) / len(scores)) if scores else 0.0
255
+ ind.metrics = metrics
256
+ return ind
257
+
258
+ async def _evaluate_population(pop: list[Individual]) -> list[Individual]:
259
+ # Evaluate those missing metrics
260
+ unevaluated = [ind for ind in pop if not ind.metrics]
261
+ if unevaluated:
262
+ evaluated = await asyncio.gather(*[_evaluate(ind) for ind in unevaluated])
263
+ # in-place update
264
+ for ind, ev in zip(unevaluated, evaluated):
265
+ ind.metrics = ev.metrics
266
+ # Scalarize
267
+ norm_per_ind = _normalize_generation(pop, eval_metrics, directions)
268
+ penalties = _apply_diversity_penalty(pop, diversity_lambda)
269
+ for ind, norm_scores, penalty in zip(pop, norm_per_ind, penalties):
270
+ ind.scalar_fitness = _scalarize(
271
+ norm_scores, mode=optimizer_config.multi_objective_combination_mode, weights=weights) - penalty
272
+ return pop
273
+
274
+ # ------------- reproduction ops ------------- #
275
+ async def _make_child(parent_a: Individual, parent_b: Individual) -> Individual:
276
+ child_prompts: dict[str, str] = {}
277
+ for param, (base_prompt, purpose) in prompt_space.items():
278
+ pa = parent_a.prompts.get(param, base_prompt)
279
+ pb = parent_b.prompts.get(param, base_prompt)
280
+ child = pa
281
+ # crossover
282
+ if random.random() < crossover_rate:
283
+ try:
284
+ child = await _recombine_prompts(pa, pb, purpose)
285
+ except Exception as e:
286
+ logger.warning("Recombination failed for %s: %s; falling back to parent.", param, e)
287
+ child = random.choice([pa, pb])
288
+ # mutation
289
+ if random.random() < mutation_rate:
290
+ try:
291
+ child = await _mutate_prompt(child, purpose)
292
+ except Exception as e:
293
+ logger.warning("Mutation failed for %s: %s; keeping child as-is.", param, e)
294
+ child_prompts[param] = child
295
+ return _make_individual_from_prompts(child_prompts)
296
+
297
+ # ------------- GA loop ------------- #
298
+ population = await _initial_population()
299
+ history_rows: list[dict[str, Any]] = []
300
+
301
+ for gen in range(1, generations + 1):
302
+ logger.info("[GA] Generation %d/%d: evaluating population of %d", gen, generations, len(population))
303
+ population = await _evaluate_population(population)
304
+
305
+ # Log and save checkpoint
306
+ best = max(population, key=lambda i: (i.scalar_fitness or 0.0))
307
+ checkpoint = {k: (best.prompts[k], prompt_space[k][1]) for k in prompt_space}
308
+ checkpoint_path = out_dir / f"optimized_prompts_gen{gen}.json"
309
+ with checkpoint_path.open("w") as fh:
310
+ json.dump(checkpoint, fh, indent=2)
311
+ logger.info("[GA] Saved checkpoint: %s (fitness=%.4f)", checkpoint_path, best.scalar_fitness or 0.0)
312
+
313
+ # Append history
314
+ for idx, ind in enumerate(population):
315
+ row = {
316
+ "generation": gen,
317
+ "index": idx,
318
+ "scalar_fitness": ind.scalar_fitness,
319
+ }
320
+ if ind.metrics:
321
+ row.update({f"metric::{m}": ind.metrics[m] for m in eval_metrics})
322
+ history_rows.append(row)
323
+
324
+ # Next generation via elitism + reproduction
325
+ next_population: list[Individual] = []
326
+ if elitism > 0:
327
+ elites = sorted(population, key=lambda i: (i.scalar_fitness or 0.0), reverse=True)[:elitism]
328
+ next_population.extend([_make_individual_from_prompts(e.prompts) for e in elites])
329
+
330
+ def _select_parent(curr_pop: list[Individual]) -> Individual:
331
+ if selection_method == "tournament":
332
+ return _tournament_select(curr_pop, tournament_size)
333
+ # roulette wheel
334
+ total = sum(max(ind.scalar_fitness or 0.0, 0.0) for ind in curr_pop) or 1.0
335
+ r = random.random() * total
336
+ acc = 0.0
337
+ for ind in curr_pop:
338
+ acc += max(ind.scalar_fitness or 0.0, 0.0)
339
+ if acc >= r:
340
+ return ind
341
+ return curr_pop[-1]
342
+
343
+ # Produce offspring
344
+ needed = pop_size - len(next_population)
345
+ offspring: list[Individual] = []
346
+ for _ in range(max(0, offspring_size), needed):
347
+ pass # ensure bound correctness
348
+ while len(offspring) < needed:
349
+ p1 = _select_parent(population)
350
+ p2 = _select_parent(population)
351
+ if p2 is p1 and len(population) > 1:
352
+ p2 = random.choice([ind for ind in population if ind is not p1])
353
+ child = await _make_child(p1, p2)
354
+ offspring.append(child)
355
+
356
+ population = next_population + offspring
357
+
358
+ # Final evaluation to ensure metrics present
359
+ population = await _evaluate_population(population)
360
+ best = max(population, key=lambda i: (i.scalar_fitness or 0.0))
361
+ best_prompts = {k: (best.prompts[k], prompt_space[k][1]) for k in prompt_space}
362
+
363
+ # Save final
364
+ final_prompts_path = out_dir / "optimized_prompts.json"
365
+ with final_prompts_path.open("w") as fh:
366
+ json.dump(best_prompts, fh, indent=2)
367
+
368
+ trials_df_path = out_dir / "ga_history_prompts.csv"
369
+ try:
370
+ # Lazy import pandas if available; otherwise write CSV manually
371
+ import csv # pylint: disable=import-outside-toplevel
372
+
373
+ fieldnames: list[str] = sorted({k for row in history_rows for k in row.keys()})
374
+ with trials_df_path.open("w", newline="") as f:
375
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
376
+ writer.writeheader()
377
+ for row in history_rows:
378
+ writer.writerow(row)
379
+ except Exception as e: # pragma: no cover - best effort
380
+ logger.warning("Failed to write GA history CSV: %s", e)
381
+
382
+ logger.info("Prompt GA optimization finished successfully!")
383
+ logger.info("Final prompts saved to: %s", final_prompts_path)
384
+ logger.info("History saved to: %s", trials_df_path)
@@ -0,0 +1,66 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import defaultdict
17
+ from typing import Any
18
+
19
+ from pydantic import BaseModel
20
+
21
+
22
+ def _deep_merge_dict(target: dict[str, Any], updates: dict[str, Any]) -> None:
23
+ """In-place deep merge of nested dictionaries."""
24
+ for key, value in updates.items():
25
+ if key in target and isinstance(target[key], dict) and isinstance(value, dict):
26
+ _deep_merge_dict(target[key], value)
27
+ else:
28
+ target[key] = value
29
+
30
+
31
+ def nest_updates(flat: dict[str, Any]) -> dict[str, Any]:
32
+ """
33
+ Convert ``{'a.b.c': 1, 'd.x.y': 2}`` ➜
34
+ ``{'a': {'b': {'c': 1}}, 'd': {'x': {'y': 2}}}``.
35
+ Works even when the middle segment is a dict key.
36
+ """
37
+ root: dict[str, Any] = defaultdict(dict)
38
+
39
+ for dotted, value in flat.items():
40
+ head, *rest = dotted.split(".", 1)
41
+ if not rest: # leaf
42
+ root[head] = value
43
+ continue
44
+
45
+ tail = rest[0]
46
+ child_updates = nest_updates({tail: value})
47
+ if isinstance(root[head], dict):
48
+ _deep_merge_dict(root[head], child_updates)
49
+ else:
50
+ root[head] = child_updates
51
+ return dict(root)
52
+
53
+
54
+ def apply_suggestions(cfg: BaseModel, flat: dict[str, Any]) -> BaseModel:
55
+ """
56
+ Return a **new** config where only the dotted-path keys in *flat*
57
+ have been modified. Preserves all unrelated siblings.
58
+ """
59
+ cfg_dict = cfg.model_dump(mode="python")
60
+ for dotted, value in flat.items():
61
+ keys = dotted.split(".")
62
+ cursor = cfg_dict
63
+ for key in keys[:-1]:
64
+ cursor = cursor.setdefault(key, {})
65
+ cursor[keys[-1]] = value
66
+ return cfg.__class__.model_validate(cfg_dict)
nat/profiler/utils.py CHANGED
@@ -22,6 +22,7 @@ from typing import Any
22
22
  import pandas as pd
23
23
 
24
24
  from nat.builder.framework_enum import LLMFrameworkEnum
25
+ from nat.cli.type_registry import RegisteredFunctionGroupInfo
25
26
  from nat.cli.type_registry import RegisteredFunctionInfo
26
27
  from nat.data_models.intermediate_step import IntermediateStep
27
28
  from nat.profiler.data_frame_row import DataFrameRow
@@ -32,7 +33,8 @@ _FRAMEWORK_REGEX_MAP = {t: fr'\b{t._name_}\b' for t in LLMFrameworkEnum}
32
33
  logger = logging.getLogger(__name__)
33
34
 
34
35
 
35
- def detect_llm_frameworks_in_build_fn(registration: RegisteredFunctionInfo) -> list[LLMFrameworkEnum]:
36
+ def detect_llm_frameworks_in_build_fn(
37
+ registration: RegisteredFunctionInfo | RegisteredFunctionGroupInfo) -> list[LLMFrameworkEnum]:
36
38
  """
37
39
  Analyze a function's source (the build_fn) to see which LLM frameworks it uses. Also recurses
38
40
  into any additional Python functions that the build_fn calls while passing `builder`, so that
@@ -16,6 +16,8 @@
16
16
  from pydantic import Field
17
17
 
18
18
  from nat.cli.register_workflow import register_registry_handler
19
+ from nat.data_models.common import OptionalSecretStr
20
+ from nat.data_models.common import get_secret_value
19
21
  from nat.data_models.registry_handler import RegistryHandlerBaseConfig
20
22
 
21
23
 
@@ -23,8 +25,8 @@ class PypiRegistryHandlerConfig(RegistryHandlerBaseConfig, name="pypi"):
23
25
  """Registry handler for interacting with a remote PyPI registry index."""
24
26
 
25
27
  endpoint: str = Field(description="A string representing the remote endpoint.")
26
- token: str | None = Field(default=None,
27
- description="The authentication token to use when interacting with the registry.")
28
+ token: OptionalSecretStr = Field(default=None,
29
+ description="The authentication token to use when interacting with the registry.")
28
30
  publish_route: str = Field(description="The route to the NAT publish service.")
29
31
  pull_route: str = Field(description="The route to the NAT pull service.")
30
32
  search_route: str = Field(default="simple", description="The route to the NAT search service.")
@@ -35,6 +37,6 @@ async def pypi_publish_registry_handler(config: PypiRegistryHandlerConfig):
35
37
 
36
38
  from nat.registry_handlers.pypi.pypi_handler import PypiRegistryHandler
37
39
 
38
- registry_handler = PypiRegistryHandler(endpoint=config.endpoint, token=config.token)
40
+ registry_handler = PypiRegistryHandler(endpoint=config.endpoint, token=get_secret_value(config.token))
39
41
 
40
42
  yield registry_handler
@@ -18,6 +18,8 @@ import os
18
18
  from pydantic import Field
19
19
 
20
20
  from nat.cli.register_workflow import register_registry_handler
21
+ from nat.data_models.common import OptionalSecretStr
22
+ from nat.data_models.common import get_secret_value
21
23
  from nat.data_models.registry_handler import RegistryHandlerBaseConfig
22
24
 
23
25
 
@@ -25,8 +27,8 @@ class RestRegistryHandlerConfig(RegistryHandlerBaseConfig, name="rest"):
25
27
  """Registry handler for interacting with a remote REST registry."""
26
28
 
27
29
  endpoint: str = Field(description="A string representing the remote endpoint.")
28
- token: str | None = Field(default=None,
29
- description="The authentication token to use when interacting with the registry.")
30
+ token: OptionalSecretStr = Field(default=None,
31
+ description="The authentication token to use when interacting with the registry.")
30
32
  publish_route: str = Field(default="", description="The route to the NAT publish service.")
31
33
  pull_route: str = Field(default="", description="The route to the NAT pull service.")
32
34
  search_route: str = Field(default="", description="The route to the NAT search service")
@@ -44,7 +46,7 @@ async def rest_search_handler(config: RestRegistryHandlerConfig):
44
46
  if (registry_token is None):
45
47
  raise ValueError("Please supply registry token.")
46
48
  else:
47
- registry_token = config.token
49
+ registry_token = get_secret_value(config.token)
48
50
 
49
51
  registry_handler = RestRegistryHandler(token=registry_token,
50
52
  endpoint=config.endpoint,
@@ -214,7 +214,7 @@ def _wrap_milvus_results(res: list[Hit], content_field: str):
214
214
 
215
215
 
216
216
  def _wrap_milvus_single_results(res: Hit | dict, content_field: str) -> Document:
217
- if not isinstance(res, (Hit, dict)):
217
+ if not isinstance(res, Hit | dict):
218
218
  raise ValueError(f"Milvus search returned object of type {type(res)}. Expected 'Hit' or 'dict'.")
219
219
 
220
220
  if isinstance(res, Hit):
@@ -20,6 +20,7 @@ from nat.builder.builder import Builder
20
20
  from nat.builder.retriever import RetrieverProviderInfo
21
21
  from nat.cli.register_workflow import register_retriever_client
22
22
  from nat.cli.register_workflow import register_retriever_provider
23
+ from nat.data_models.common import OptionalSecretStr
23
24
  from nat.data_models.retriever import RetrieverBaseConfig
24
25
 
25
26
 
@@ -34,7 +35,7 @@ class NemoRetrieverConfig(RetrieverBaseConfig, name="nemo_retriever"):
34
35
  default=None,
35
36
  description="A list of fields to return from the datastore. If 'None', all fields but the vector are returned.")
36
37
  timeout: int = Field(default=60, description="Maximum time to wait for results to be returned from the service.")
37
- nvidia_api_key: str | None = Field(
38
+ nvidia_api_key: OptionalSecretStr = Field(
38
39
  description="API key used to authenticate with the service. If 'None', will use ENV Variable 'NVIDIA_API_KEY'",
39
40
  default=None,
40
41
  )
nat/runtime/loader.py CHANGED
@@ -114,7 +114,7 @@ async def load_workflow(config_file: StrPath, max_concurrency: int = -1):
114
114
  # Must yield the workflow function otherwise it cleans up
115
115
  async with WorkflowBuilder.from_config(config=config) as workflow:
116
116
 
117
- yield SessionManager(workflow.build(), max_concurrency=max_concurrency)
117
+ yield SessionManager(await workflow.build(), max_concurrency=max_concurrency)
118
118
 
119
119
 
120
120
  @lru_cache