nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +50 -22
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +54 -27
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +68 -17
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  53. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  54. nat/cli/commands/workflow/workflow_commands.py +62 -22
  55. nat/cli/entrypoint.py +8 -10
  56. nat/cli/main.py +3 -0
  57. nat/cli/register_workflow.py +38 -4
  58. nat/cli/type_registry.py +75 -6
  59. nat/control_flow/__init__.py +0 -0
  60. nat/control_flow/register.py +20 -0
  61. nat/control_flow/router_agent/__init__.py +0 -0
  62. nat/control_flow/router_agent/agent.py +329 -0
  63. nat/control_flow/router_agent/prompt.py +48 -0
  64. nat/control_flow/router_agent/register.py +91 -0
  65. nat/control_flow/sequential_executor.py +166 -0
  66. nat/data_models/agent.py +34 -0
  67. nat/data_models/api_server.py +74 -66
  68. nat/data_models/authentication.py +23 -9
  69. nat/data_models/common.py +1 -1
  70. nat/data_models/component.py +2 -0
  71. nat/data_models/component_ref.py +11 -0
  72. nat/data_models/config.py +41 -17
  73. nat/data_models/dataset_handler.py +1 -1
  74. nat/data_models/discovery_metadata.py +4 -4
  75. nat/data_models/evaluate.py +4 -1
  76. nat/data_models/function.py +34 -0
  77. nat/data_models/function_dependencies.py +14 -6
  78. nat/data_models/gated_field_mixin.py +242 -0
  79. nat/data_models/intermediate_step.py +3 -3
  80. nat/data_models/optimizable.py +119 -0
  81. nat/data_models/optimizer.py +149 -0
  82. nat/data_models/span.py +41 -3
  83. nat/data_models/swe_bench_model.py +1 -1
  84. nat/data_models/temperature_mixin.py +44 -0
  85. nat/data_models/thinking_mixin.py +86 -0
  86. nat/data_models/top_p_mixin.py +44 -0
  87. nat/embedder/nim_embedder.py +1 -1
  88. nat/embedder/openai_embedder.py +1 -1
  89. nat/embedder/register.py +0 -1
  90. nat/eval/config.py +3 -1
  91. nat/eval/dataset_handler/dataset_handler.py +71 -7
  92. nat/eval/evaluate.py +86 -31
  93. nat/eval/evaluator/base_evaluator.py +1 -1
  94. nat/eval/evaluator/evaluator_model.py +13 -0
  95. nat/eval/intermediate_step_adapter.py +1 -1
  96. nat/eval/rag_evaluator/evaluate.py +2 -2
  97. nat/eval/rag_evaluator/register.py +3 -3
  98. nat/eval/register.py +4 -1
  99. nat/eval/remote_workflow.py +3 -3
  100. nat/eval/runtime_evaluator/__init__.py +14 -0
  101. nat/eval/runtime_evaluator/evaluate.py +123 -0
  102. nat/eval/runtime_evaluator/register.py +100 -0
  103. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  104. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  105. nat/eval/trajectory_evaluator/register.py +1 -1
  106. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  107. nat/eval/utils/eval_trace_ctx.py +89 -0
  108. nat/eval/utils/weave_eval.py +18 -9
  109. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  110. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  111. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  112. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  113. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  114. nat/experimental/test_time_compute/register.py +0 -1
  115. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  116. nat/front_ends/console/authentication_flow_handler.py +82 -30
  117. nat/front_ends/console/console_front_end_plugin.py +8 -5
  118. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  119. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  120. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  121. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  122. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  123. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
  124. nat/front_ends/fastapi/job_store.py +518 -99
  125. nat/front_ends/fastapi/main.py +11 -19
  126. nat/front_ends/fastapi/message_handler.py +13 -14
  127. nat/front_ends/fastapi/message_validator.py +19 -19
  128. nat/front_ends/fastapi/response_helpers.py +4 -4
  129. nat/front_ends/fastapi/step_adaptor.py +2 -2
  130. nat/front_ends/fastapi/utils.py +57 -0
  131. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  132. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  133. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  134. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  135. nat/front_ends/mcp/tool_converter.py +44 -14
  136. nat/front_ends/register.py +0 -1
  137. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  138. nat/llm/aws_bedrock_llm.py +24 -12
  139. nat/llm/azure_openai_llm.py +13 -6
  140. nat/llm/litellm_llm.py +69 -0
  141. nat/llm/nim_llm.py +20 -8
  142. nat/llm/openai_llm.py +14 -6
  143. nat/llm/register.py +4 -1
  144. nat/llm/utils/env_config_value.py +2 -3
  145. nat/llm/utils/thinking.py +215 -0
  146. nat/meta/pypi.md +9 -9
  147. nat/object_store/register.py +0 -1
  148. nat/observability/exporter/base_exporter.py +3 -3
  149. nat/observability/exporter/file_exporter.py +1 -1
  150. nat/observability/exporter/processing_exporter.py +309 -81
  151. nat/observability/exporter/span_exporter.py +35 -15
  152. nat/observability/exporter_manager.py +7 -7
  153. nat/observability/mixin/file_mixin.py +7 -7
  154. nat/observability/mixin/redaction_config_mixin.py +42 -0
  155. nat/observability/mixin/tagging_config_mixin.py +62 -0
  156. nat/observability/mixin/type_introspection_mixin.py +420 -107
  157. nat/observability/processor/batching_processor.py +5 -7
  158. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  159. nat/observability/processor/processor.py +3 -0
  160. nat/observability/processor/processor_factory.py +70 -0
  161. nat/observability/processor/redaction/__init__.py +24 -0
  162. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  163. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  164. nat/observability/processor/redaction/redaction_processor.py +177 -0
  165. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  166. nat/observability/processor/span_tagging_processor.py +68 -0
  167. nat/observability/register.py +6 -4
  168. nat/profiler/calc/calc_runner.py +3 -4
  169. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  170. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  171. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  172. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  173. nat/profiler/data_frame_row.py +1 -1
  174. nat/profiler/decorators/framework_wrapper.py +62 -13
  175. nat/profiler/decorators/function_tracking.py +160 -3
  176. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  177. nat/profiler/forecasting/models/linear_model.py +1 -1
  178. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  179. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  180. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  181. nat/profiler/inference_optimization/data_models.py +3 -3
  182. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  183. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  184. nat/profiler/parameter_optimization/__init__.py +0 -0
  185. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  186. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  187. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  188. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  189. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  190. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  191. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  192. nat/profiler/profile_runner.py +14 -9
  193. nat/profiler/utils.py +4 -2
  194. nat/registry_handlers/local/local_handler.py +2 -2
  195. nat/registry_handlers/package_utils.py +1 -2
  196. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  197. nat/registry_handlers/register.py +3 -4
  198. nat/registry_handlers/rest/rest_handler.py +12 -13
  199. nat/retriever/milvus/retriever.py +2 -2
  200. nat/retriever/nemo_retriever/retriever.py +1 -1
  201. nat/retriever/register.py +0 -1
  202. nat/runtime/loader.py +2 -2
  203. nat/runtime/runner.py +106 -8
  204. nat/runtime/session.py +69 -8
  205. nat/settings/global_settings.py +16 -5
  206. nat/tool/chat_completion.py +5 -2
  207. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  208. nat/tool/datetime_tools.py +49 -9
  209. nat/tool/document_search.py +2 -2
  210. nat/tool/github_tools.py +450 -0
  211. nat/tool/memory_tools/get_memory_tool.py +1 -1
  212. nat/tool/nvidia_rag.py +1 -1
  213. nat/tool/register.py +2 -9
  214. nat/tool/retriever.py +3 -2
  215. nat/utils/callable_utils.py +70 -0
  216. nat/utils/data_models/schema_validator.py +3 -3
  217. nat/utils/decorators.py +210 -0
  218. nat/utils/exception_handlers/automatic_retries.py +104 -51
  219. nat/utils/exception_handlers/schemas.py +1 -1
  220. nat/utils/io/yaml_tools.py +2 -2
  221. nat/utils/log_levels.py +25 -0
  222. nat/utils/reactive/base/observable_base.py +2 -2
  223. nat/utils/reactive/base/observer_base.py +1 -1
  224. nat/utils/reactive/observable.py +2 -2
  225. nat/utils/reactive/observer.py +4 -4
  226. nat/utils/reactive/subscription.py +1 -1
  227. nat/utils/settings/global_settings.py +6 -8
  228. nat/utils/type_converter.py +4 -3
  229. nat/utils/type_utils.py +9 -5
  230. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
  231. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
  232. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
  233. nat/cli/commands/info/list_mcp.py +0 -304
  234. nat/tool/github_tools/create_github_commit.py +0 -133
  235. nat/tool/github_tools/create_github_issue.py +0 -87
  236. nat/tool/github_tools/create_github_pr.py +0 -106
  237. nat/tool/github_tools/get_github_file.py +0 -106
  238. nat/tool/github_tools/get_github_issue.py +0 -166
  239. nat/tool/github_tools/get_github_pr.py +0 -256
  240. nat/tool/github_tools/update_github_issue.py +0 -100
  241. nat/tool/mcp/exceptions.py +0 -142
  242. nat/tool/mcp/mcp_client.py +0 -255
  243. nat/tool/mcp/mcp_tool.py +0 -96
  244. nat/utils/exception_handlers/mcp.py +0 -211
  245. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  246. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  247. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
  248. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  249. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  250. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,384 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import json
18
+ import logging
19
+ import random
20
+ from collections.abc import Sequence
21
+ from dataclasses import dataclass
22
+ from typing import Any
23
+
24
+ from pydantic import BaseModel
25
+
26
+ from nat.builder.workflow_builder import WorkflowBuilder
27
+ from nat.data_models.config import Config
28
+ from nat.data_models.optimizable import SearchSpace
29
+ from nat.data_models.optimizer import OptimizerConfig
30
+ from nat.data_models.optimizer import OptimizerRunConfig
31
+ from nat.eval.evaluate import EvaluationRun
32
+ from nat.eval.evaluate import EvaluationRunConfig
33
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
34
+ from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class PromptOptimizerInputSchema(BaseModel):
40
+ original_prompt: str
41
+ objective: str
42
+ oracle_feedback: str | None = None
43
+
44
+
45
+ @experimental(feature_name="Optimizer")
46
+ async def optimize_prompts(
47
+ *,
48
+ base_cfg: Config,
49
+ full_space: dict[str, SearchSpace],
50
+ optimizer_config: OptimizerConfig,
51
+ opt_run_config: OptimizerRunConfig,
52
+ ) -> None:
53
+
54
+ # ------------- helpers ------------- #
55
+ @dataclass
56
+ class Individual:
57
+ prompts: dict[str, str] # param_name -> prompt text
58
+ metrics: dict[str, float] | None = None # evaluator_name -> average score
59
+ scalar_fitness: float | None = None
60
+
61
+ def _normalize_generation(
62
+ individuals: Sequence[Individual],
63
+ metric_names: Sequence[str],
64
+ directions: Sequence[str],
65
+ eps: float = 1e-12,
66
+ ) -> list[dict[str, float]]:
67
+ """Return per-individual dict of normalised scores in [0,1] where higher is better."""
68
+ # Extract arrays per metric
69
+ arrays = {m: [ind.metrics.get(m, 0.0) if ind.metrics else 0.0 for ind in individuals] for m in metric_names}
70
+ normed: list[dict[str, float]] = []
71
+ for i in range(len(individuals)):
72
+ entry: dict[str, float] = {}
73
+ for m, dirn in zip(metric_names, directions):
74
+ vals = arrays[m]
75
+ vmin = min(vals)
76
+ vmax = max(vals)
77
+ v = vals[i]
78
+ # Map to [0,1] with higher=better regardless of direction
79
+ if vmax - vmin < eps:
80
+ score01 = 0.5
81
+ else:
82
+ score01 = (v - vmin) / (vmax - vmin)
83
+ if dirn == "minimize":
84
+ score01 = 1.0 - score01
85
+ entry[m] = float(score01)
86
+ normed.append(entry)
87
+ return normed
88
+
89
+ def _scalarize(norm_scores: dict[str, float], *, mode: str, weights: Sequence[float] | None) -> float:
90
+ """Collapse normalised scores to a single scalar (higher is better)."""
91
+ vals = list(norm_scores.values())
92
+ if not vals:
93
+ return 0.0
94
+ if mode == "harmonic":
95
+ inv_sum = sum(1.0 / max(v, 1e-12) for v in vals)
96
+ return len(vals) / max(inv_sum, 1e-12)
97
+ if mode == "sum":
98
+ if weights is None:
99
+ return float(sum(vals))
100
+ if len(weights) != len(vals):
101
+ raise ValueError("weights length must equal number of objectives")
102
+ return float(sum(w * v for w, v in zip(weights, vals)))
103
+ if mode == "chebyshev":
104
+ return float(min(vals)) # maximise the worst-case score
105
+ raise ValueError(f"Unknown combination mode: {mode}")
106
+
107
+ def _apply_diversity_penalty(individuals: Sequence[Individual], diversity_lambda: float) -> list[float]:
108
+ if diversity_lambda <= 0.0:
109
+ return [0.0 for _ in individuals]
110
+ seen: dict[str, int] = {}
111
+ keys: list[str] = []
112
+ penalties: list[float] = []
113
+ for ind in individuals:
114
+ key = "\u241f".join(ind.prompts.get(k, "") for k in sorted(ind.prompts.keys()))
115
+ keys.append(key)
116
+ seen[key] = seen.get(key, 0) + 1
117
+ for key in keys:
118
+ duplicates = seen[key] - 1
119
+ penalties.append(diversity_lambda * float(duplicates))
120
+ return penalties
121
+
122
+ def _tournament_select(pop: Sequence[Individual], k: int) -> Individual:
123
+ contenders = random.sample(pop, k=min(k, len(pop)))
124
+ return max(contenders, key=lambda i: (i.scalar_fitness or 0.0))
125
+
126
+ # ------------- discover space ------------- #
127
+ prompt_space: dict[str, tuple[str, str]] = {
128
+ k: (v.prompt, v.prompt_purpose)
129
+ for k, v in full_space.items() if v.is_prompt
130
+ }
131
+
132
+ if not prompt_space:
133
+ logger.info("No prompts to optimize – skipping.")
134
+ return
135
+
136
+ metric_cfg = optimizer_config.eval_metrics
137
+ if metric_cfg is None or len(metric_cfg) == 0:
138
+ raise ValueError("optimizer_config.eval_metrics must be provided for GA prompt optimization")
139
+
140
+ directions = [v.direction for v in metric_cfg.values()] # "minimize" or "maximize"
141
+ eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
142
+ weights = [v.weight for v in metric_cfg.values()]
143
+
144
+ out_dir = optimizer_config.output_path
145
+ out_dir.mkdir(parents=True, exist_ok=True)
146
+
147
+ # ------------- builder & functions ------------- #
148
+ async with WorkflowBuilder(general_config=base_cfg.general, registry=None) as builder:
149
+ await builder.populate_builder(base_cfg)
150
+ init_fn_name = (optimizer_config.prompt.prompt_population_init_function)
151
+ if not init_fn_name:
152
+ raise ValueError(
153
+ "No prompt optimization function configured. Set optimizer.prompt_population_init_function")
154
+ init_fn = await builder.get_function(init_fn_name)
155
+
156
+ recombine_fn = None
157
+ if optimizer_config.prompt.prompt_recombination_function:
158
+ recombine_fn = await builder.get_function(optimizer_config.prompt.prompt_recombination_function)
159
+
160
+ logger.info(
161
+ "GA Prompt optimization ready: init_fn=%s, recombine_fn=%s",
162
+ init_fn_name,
163
+ optimizer_config.prompt.prompt_recombination_function,
164
+ )
165
+
166
+ # ------------- GA parameters ------------- #
167
+ pop_size = max(2, int(optimizer_config.prompt.ga_population_size))
168
+ generations = max(1, int(optimizer_config.prompt.ga_generations))
169
+ offspring_size = (optimizer_config.prompt.ga_offspring_size
170
+ or max(0, pop_size - optimizer_config.prompt.ga_elitism))
171
+ crossover_rate = float(optimizer_config.prompt.ga_crossover_rate)
172
+ mutation_rate = float(optimizer_config.prompt.ga_mutation_rate)
173
+ elitism = max(0, int(optimizer_config.prompt.ga_elitism))
174
+ selection_method = optimizer_config.prompt.ga_selection_method.lower()
175
+ tournament_size = max(2, int(optimizer_config.prompt.ga_tournament_size))
176
+ max_eval_concurrency = max(1, int(optimizer_config.prompt.ga_parallel_evaluations))
177
+ diversity_lambda = float(optimizer_config.prompt.ga_diversity_lambda)
178
+
179
+ # ------------- population init ------------- #
180
+ async def _mutate_prompt(original_prompt: str, purpose: str) -> str:
181
+ # Use LLM-based optimizer with no feedback
182
+ return await init_fn.acall_invoke(
183
+ PromptOptimizerInputSchema(
184
+ original_prompt=original_prompt,
185
+ objective=purpose,
186
+ oracle_feedback=None,
187
+ ))
188
+
189
+ async def _recombine_prompts(a: str, b: str, purpose: str) -> str:
190
+ if recombine_fn is None:
191
+ # Fallback: uniform choice per recombination
192
+ return random.choice([a, b])
193
+ payload = {"original_prompt": a, "objective": purpose, "oracle_feedback": None, "parent_b": b}
194
+ return await recombine_fn.acall_invoke(payload)
195
+
196
+ def _make_individual_from_prompts(prompts: dict[str, str]) -> Individual:
197
+ return Individual(prompts=dict(prompts))
198
+
199
+ async def _initial_population() -> list[Individual]:
200
+ individuals: list[Individual] = []
201
+ # Ensure first individual is the original prompts
202
+ originals = {k: prompt_space[k][0] for k in prompt_space}
203
+ individuals.append(_make_individual_from_prompts(originals))
204
+
205
+ init_sem = asyncio.Semaphore(max_eval_concurrency)
206
+
207
+ async def _create_random_individual() -> Individual:
208
+ async with init_sem:
209
+ mutated: dict[str, str] = {}
210
+ for param, (base_prompt, purpose) in prompt_space.items():
211
+ try:
212
+ new_p = await _mutate_prompt(base_prompt, purpose)
213
+ except Exception as e:
214
+ logger.warning("Mutation failed for %s: %s; using original.", param, e)
215
+ new_p = base_prompt
216
+ mutated[param] = new_p
217
+ return _make_individual_from_prompts(mutated)
218
+
219
+ needed = max(0, pop_size - 1)
220
+ tasks = [_create_random_individual() for _ in range(needed)]
221
+ individuals.extend(await asyncio.gather(*tasks))
222
+ return individuals
223
+
224
+ # ------------- evaluation ------------- #
225
+ reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
226
+
227
+ sem = asyncio.Semaphore(max_eval_concurrency)
228
+
229
+ async def _evaluate(ind: Individual) -> Individual:
230
+ async with sem:
231
+ cfg_trial = apply_suggestions(base_cfg, ind.prompts)
232
+ eval_cfg = EvaluationRunConfig(
233
+ config_file=cfg_trial,
234
+ dataset=opt_run_config.dataset,
235
+ result_json_path=opt_run_config.result_json_path,
236
+ endpoint=opt_run_config.endpoint,
237
+ endpoint_timeout=opt_run_config.endpoint_timeout,
238
+ override=opt_run_config.override,
239
+ )
240
+ # Run reps sequentially under the same semaphore to avoid overload
241
+ all_results: list[list[tuple[str, Any]]] = []
242
+ for _ in range(reps):
243
+ res = (await EvaluationRun(config=eval_cfg).run_and_evaluate()).evaluation_results
244
+ all_results.append(res)
245
+
246
+ metrics: dict[str, float] = {}
247
+ for metric_name in eval_metrics:
248
+ scores: list[float] = []
249
+ for run_results in all_results:
250
+ for name, result in run_results:
251
+ if name == metric_name:
252
+ scores.append(result.average_score)
253
+ break
254
+ metrics[metric_name] = float(sum(scores) / len(scores)) if scores else 0.0
255
+ ind.metrics = metrics
256
+ return ind
257
+
258
+ async def _evaluate_population(pop: list[Individual]) -> list[Individual]:
259
+ # Evaluate those missing metrics
260
+ unevaluated = [ind for ind in pop if not ind.metrics]
261
+ if unevaluated:
262
+ evaluated = await asyncio.gather(*[_evaluate(ind) for ind in unevaluated])
263
+ # in-place update
264
+ for ind, ev in zip(unevaluated, evaluated):
265
+ ind.metrics = ev.metrics
266
+ # Scalarize
267
+ norm_per_ind = _normalize_generation(pop, eval_metrics, directions)
268
+ penalties = _apply_diversity_penalty(pop, diversity_lambda)
269
+ for ind, norm_scores, penalty in zip(pop, norm_per_ind, penalties):
270
+ ind.scalar_fitness = _scalarize(
271
+ norm_scores, mode=optimizer_config.multi_objective_combination_mode, weights=weights) - penalty
272
+ return pop
273
+
274
+ # ------------- reproduction ops ------------- #
275
+ async def _make_child(parent_a: Individual, parent_b: Individual) -> Individual:
276
+ child_prompts: dict[str, str] = {}
277
+ for param, (base_prompt, purpose) in prompt_space.items():
278
+ pa = parent_a.prompts.get(param, base_prompt)
279
+ pb = parent_b.prompts.get(param, base_prompt)
280
+ child = pa
281
+ # crossover
282
+ if random.random() < crossover_rate:
283
+ try:
284
+ child = await _recombine_prompts(pa, pb, purpose)
285
+ except Exception as e:
286
+ logger.warning("Recombination failed for %s: %s; falling back to parent.", param, e)
287
+ child = random.choice([pa, pb])
288
+ # mutation
289
+ if random.random() < mutation_rate:
290
+ try:
291
+ child = await _mutate_prompt(child, purpose)
292
+ except Exception as e:
293
+ logger.warning("Mutation failed for %s: %s; keeping child as-is.", param, e)
294
+ child_prompts[param] = child
295
+ return _make_individual_from_prompts(child_prompts)
296
+
297
+ # ------------- GA loop ------------- #
298
+ population = await _initial_population()
299
+ history_rows: list[dict[str, Any]] = []
300
+
301
+ for gen in range(1, generations + 1):
302
+ logger.info("[GA] Generation %d/%d: evaluating population of %d", gen, generations, len(population))
303
+ population = await _evaluate_population(population)
304
+
305
+ # Log and save checkpoint
306
+ best = max(population, key=lambda i: (i.scalar_fitness or 0.0))
307
+ checkpoint = {k: (best.prompts[k], prompt_space[k][1]) for k in prompt_space}
308
+ checkpoint_path = out_dir / f"optimized_prompts_gen{gen}.json"
309
+ with checkpoint_path.open("w") as fh:
310
+ json.dump(checkpoint, fh, indent=2)
311
+ logger.info("[GA] Saved checkpoint: %s (fitness=%.4f)", checkpoint_path, best.scalar_fitness or 0.0)
312
+
313
+ # Append history
314
+ for idx, ind in enumerate(population):
315
+ row = {
316
+ "generation": gen,
317
+ "index": idx,
318
+ "scalar_fitness": ind.scalar_fitness,
319
+ }
320
+ if ind.metrics:
321
+ row.update({f"metric::{m}": ind.metrics[m] for m in eval_metrics})
322
+ history_rows.append(row)
323
+
324
+ # Next generation via elitism + reproduction
325
+ next_population: list[Individual] = []
326
+ if elitism > 0:
327
+ elites = sorted(population, key=lambda i: (i.scalar_fitness or 0.0), reverse=True)[:elitism]
328
+ next_population.extend([_make_individual_from_prompts(e.prompts) for e in elites])
329
+
330
+ def _select_parent(curr_pop: list[Individual]) -> Individual:
331
+ if selection_method == "tournament":
332
+ return _tournament_select(curr_pop, tournament_size)
333
+ # roulette wheel
334
+ total = sum(max(ind.scalar_fitness or 0.0, 0.0) for ind in curr_pop) or 1.0
335
+ r = random.random() * total
336
+ acc = 0.0
337
+ for ind in curr_pop:
338
+ acc += max(ind.scalar_fitness or 0.0, 0.0)
339
+ if acc >= r:
340
+ return ind
341
+ return curr_pop[-1]
342
+
343
+ # Produce offspring
344
+ needed = pop_size - len(next_population)
345
+ offspring: list[Individual] = []
346
+ for _ in range(max(0, offspring_size), needed):
347
+ pass # ensure bound correctness
348
+ while len(offspring) < needed:
349
+ p1 = _select_parent(population)
350
+ p2 = _select_parent(population)
351
+ if p2 is p1 and len(population) > 1:
352
+ p2 = random.choice([ind for ind in population if ind is not p1])
353
+ child = await _make_child(p1, p2)
354
+ offspring.append(child)
355
+
356
+ population = next_population + offspring
357
+
358
+ # Final evaluation to ensure metrics present
359
+ population = await _evaluate_population(population)
360
+ best = max(population, key=lambda i: (i.scalar_fitness or 0.0))
361
+ best_prompts = {k: (best.prompts[k], prompt_space[k][1]) for k in prompt_space}
362
+
363
+ # Save final
364
+ final_prompts_path = out_dir / "optimized_prompts.json"
365
+ with final_prompts_path.open("w") as fh:
366
+ json.dump(best_prompts, fh, indent=2)
367
+
368
+ trials_df_path = out_dir / "ga_history_prompts.csv"
369
+ try:
370
+ # Lazy import pandas if available; otherwise write CSV manually
371
+ import csv # pylint: disable=import-outside-toplevel
372
+
373
+ fieldnames: list[str] = sorted({k for row in history_rows for k in row.keys()})
374
+ with trials_df_path.open("w", newline="") as f:
375
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
376
+ writer.writeheader()
377
+ for row in history_rows:
378
+ writer.writerow(row)
379
+ except Exception as e: # pragma: no cover - best effort
380
+ logger.warning("Failed to write GA history CSV: %s", e)
381
+
382
+ logger.info("Prompt GA optimization finished successfully!")
383
+ logger.info("Final prompts saved to: %s", final_prompts_path)
384
+ logger.info("History saved to: %s", trials_df_path)
@@ -0,0 +1,66 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import defaultdict
17
+ from typing import Any
18
+
19
+ from pydantic import BaseModel
20
+
21
+
22
+ def _deep_merge_dict(target: dict[str, Any], updates: dict[str, Any]) -> None:
23
+ """In-place deep merge of nested dictionaries."""
24
+ for key, value in updates.items():
25
+ if key in target and isinstance(target[key], dict) and isinstance(value, dict):
26
+ _deep_merge_dict(target[key], value)
27
+ else:
28
+ target[key] = value
29
+
30
+
31
+ def nest_updates(flat: dict[str, Any]) -> dict[str, Any]:
32
+ """
33
+ Convert ``{'a.b.c': 1, 'd.x.y': 2}`` ➜
34
+ ``{'a': {'b': {'c': 1}}, 'd': {'x': {'y': 2}}}``.
35
+ Works even when the middle segment is a dict key.
36
+ """
37
+ root: dict[str, Any] = defaultdict(dict)
38
+
39
+ for dotted, value in flat.items():
40
+ head, *rest = dotted.split(".", 1)
41
+ if not rest: # leaf
42
+ root[head] = value
43
+ continue
44
+
45
+ tail = rest[0]
46
+ child_updates = nest_updates({tail: value})
47
+ if isinstance(root[head], dict):
48
+ _deep_merge_dict(root[head], child_updates)
49
+ else:
50
+ root[head] = child_updates
51
+ return dict(root)
52
+
53
+
54
+ def apply_suggestions(cfg: BaseModel, flat: dict[str, Any]) -> BaseModel:
55
+ """
56
+ Return a **new** config where only the dotted-path keys in *flat*
57
+ have been modified. Preserves all unrelated siblings.
58
+ """
59
+ cfg_dict = cfg.model_dump(mode="python")
60
+ for dotted, value in flat.items():
61
+ keys = dotted.split(".")
62
+ cursor = cfg_dict
63
+ for key in keys[:-1]:
64
+ cursor = cursor.setdefault(key, {})
65
+ cursor[keys[-1]] = value
66
+ return cfg.__class__.model_validate(cfg_dict)
@@ -88,14 +88,19 @@ class ProfilerRunner:
88
88
  writes out combined requests JSON, then computes and saves additional metrics,
89
89
  and optionally fits a forecasting model.
90
90
  """
91
- from nat.profiler.inference_optimization.bottleneck_analysis.nested_stack_analysis import \
92
- multi_example_call_profiling
93
- from nat.profiler.inference_optimization.bottleneck_analysis.simple_stack_analysis import \
94
- profile_workflow_bottlenecks
95
- from nat.profiler.inference_optimization.experimental.concurrency_spike_analysis import \
96
- concurrency_spike_analysis
97
- from nat.profiler.inference_optimization.experimental.prefix_span_analysis import \
98
- prefixspan_subworkflow_with_text
91
+ # Yapf and ruff disagree on how to format long imports, disable yapf go with ruff
92
+ from nat.profiler.inference_optimization.bottleneck_analysis.nested_stack_analysis import (
93
+ multi_example_call_profiling,
94
+ ) # yapf: disable
95
+ from nat.profiler.inference_optimization.bottleneck_analysis.simple_stack_analysis import (
96
+ profile_workflow_bottlenecks,
97
+ ) # yapf: disable
98
+ from nat.profiler.inference_optimization.experimental.concurrency_spike_analysis import (
99
+ concurrency_spike_analysis,
100
+ ) # yapf: disable
101
+ from nat.profiler.inference_optimization.experimental.prefix_span_analysis import (
102
+ prefixspan_subworkflow_with_text,
103
+ ) # yapf: disable
99
104
  from nat.profiler.inference_optimization.llm_metrics import LLMMetrics
100
105
  from nat.profiler.inference_optimization.prompt_caching import get_common_prefixes
101
106
  from nat.profiler.inference_optimization.token_uniqueness import compute_inter_query_token_uniqueness_by_llm
@@ -277,7 +282,7 @@ class ProfilerRunner:
277
282
  fitted_model = model_trainer.train(all_steps)
278
283
  logger.info("Fitted model for forecasting.")
279
284
  except Exception as e:
280
- logger.exception("Fitting model failed. %s", e, exc_info=True)
285
+ logger.exception("Fitting model failed. %s", e)
281
286
  return ProfilerResults()
282
287
 
283
288
  if self.write_output:
nat/profiler/utils.py CHANGED
@@ -22,6 +22,7 @@ from typing import Any
22
22
  import pandas as pd
23
23
 
24
24
  from nat.builder.framework_enum import LLMFrameworkEnum
25
+ from nat.cli.type_registry import RegisteredFunctionGroupInfo
25
26
  from nat.cli.type_registry import RegisteredFunctionInfo
26
27
  from nat.data_models.intermediate_step import IntermediateStep
27
28
  from nat.profiler.data_frame_row import DataFrameRow
@@ -32,7 +33,8 @@ _FRAMEWORK_REGEX_MAP = {t: fr'\b{t._name_}\b' for t in LLMFrameworkEnum}
32
33
  logger = logging.getLogger(__name__)
33
34
 
34
35
 
35
- def detect_llm_frameworks_in_build_fn(registration: RegisteredFunctionInfo) -> list[LLMFrameworkEnum]:
36
+ def detect_llm_frameworks_in_build_fn(
37
+ registration: RegisteredFunctionInfo | RegisteredFunctionGroupInfo) -> list[LLMFrameworkEnum]:
36
38
  """
37
39
  Analyze a function's source (the build_fn) to see which LLM frameworks it uses. Also recurses
38
40
  into any additional Python functions that the build_fn calls while passing `builder`, so that
@@ -175,7 +177,7 @@ def create_standardized_dataframe(requests_data: list[list[IntermediateStep]]) -
175
177
  event_type=step.event_type).model_dump(), )
176
178
 
177
179
  except Exception as e:
178
- logger.exception("Error creating standardized DataFrame: %s", e, exc_info=True)
180
+ logger.exception("Error creating standardized DataFrame: %s", e)
179
181
  return pd.DataFrame()
180
182
 
181
183
  if not all_rows:
@@ -133,7 +133,7 @@ class LocalRegistryHandler(AbstractRegistryHandler):
133
133
  "message": msg,
134
134
  "action": ActionEnum.SEARCH
135
135
  })
136
- logger.exception(validated_search_response.status.message, exc_info=True)
136
+ logger.exception(validated_search_response.status.message)
137
137
 
138
138
  yield validated_search_response
139
139
 
@@ -168,7 +168,7 @@ class LocalRegistryHandler(AbstractRegistryHandler):
168
168
  validated_remove_response = RemoveResponse(status={
169
169
  "status": StatusEnum.ERROR, "message": msg, "action": ActionEnum.REMOVE
170
170
  }) # type: ignore
171
- logger.exception(validated_remove_response.status.message, exc_info=True)
171
+ logger.exception(validated_remove_response.status.message)
172
172
 
173
173
  yield validated_remove_response
174
174
 
@@ -29,7 +29,6 @@ from nat.registry_handlers.schemas.publish import Artifact
29
29
  from nat.runtime.loader import PluginTypes
30
30
  from nat.runtime.loader import discover_entrypoints
31
31
 
32
- # pylint: disable=redefined-outer-name
33
32
  logger = logging.getLogger(__name__)
34
33
 
35
34
 
@@ -397,7 +396,7 @@ def get_transitive_dependencies(distribution_names: list[str]) -> dict[str, set[
397
396
  except importlib.metadata.PackageNotFoundError:
398
397
  pass
399
398
 
400
- logger.error("Distribution %s not found (tried common variations)", dist_name)
399
+ logger.error("Distribution %s not found (tried common variations)", dist_name, exc_info=True)
401
400
  result[dist_name] = set()
402
401
 
403
402
  return result
@@ -44,13 +44,12 @@ class PypiRegistryHandler(AbstractRegistryHandler):
44
44
  https://github.com/pypiserver/pypiserver
45
45
  """
46
46
 
47
- def __init__( # pylint: disable=R0917
48
- self,
49
- endpoint: str,
50
- token: str | None = None,
51
- publish_route: str = "",
52
- pull_route: str = "",
53
- search_route: str = ""):
47
+ def __init__(self,
48
+ endpoint: str,
49
+ token: str | None = None,
50
+ publish_route: str = "",
51
+ pull_route: str = "",
52
+ search_route: str = ""):
54
53
  super().__init__()
55
54
  self._endpoint = endpoint.rstrip("/")
56
55
  self._token = token
@@ -86,7 +85,7 @@ class PypiRegistryHandler(AbstractRegistryHandler):
86
85
  validated_publish_response = PublishResponse(status={
87
86
  "status": StatusEnum.ERROR, "message": msg, "action": ActionEnum.PUBLISH
88
87
  })
89
- logger.exception(validated_publish_response.status.message, exc_info=True)
88
+ logger.exception(validated_publish_response.status.message)
90
89
 
91
90
  yield validated_publish_response
92
91
 
@@ -126,17 +125,16 @@ class PypiRegistryHandler(AbstractRegistryHandler):
126
125
 
127
126
  versioned_packages_str = " ".join(versioned_packages)
128
127
 
129
- result = subprocess.run(
130
- [
131
- "uv",
132
- "pip",
133
- "install",
134
- "--prerelease=allow",
135
- "--index-url",
136
- f"{self._endpoint}/{self._pull_route}/",
137
- versioned_packages_str
138
- ], # pylint: disable=W0631
139
- check=True)
128
+ result = subprocess.run([
129
+ "uv",
130
+ "pip",
131
+ "install",
132
+ "--prerelease=allow",
133
+ "--index-url",
134
+ f"{self._endpoint}/{self._pull_route}/",
135
+ versioned_packages_str
136
+ ],
137
+ check=True)
140
138
 
141
139
  result.check_returncode()
142
140
 
@@ -151,7 +149,7 @@ class PypiRegistryHandler(AbstractRegistryHandler):
151
149
  validated_pull_response = PullResponse(status={
152
150
  "status": StatusEnum.ERROR, "message": msg, "action": ActionEnum.PULL
153
151
  })
154
- logger.exception(validated_pull_response.status.message, exc_info=True)
152
+ logger.exception(validated_pull_response.status.message)
155
153
 
156
154
  yield validated_pull_response
157
155
 
@@ -171,11 +169,10 @@ class PypiRegistryHandler(AbstractRegistryHandler):
171
169
  """
172
170
 
173
171
  try:
174
- completed_process = subprocess.run(
175
- ["pip", "search", "--index", f"{self._endpoint}", query.query], # pylint: disable=W0631
176
- text=True,
177
- capture_output=True,
178
- check=True)
172
+ completed_process = subprocess.run(["pip", "search", "--index", f"{self._endpoint}", query.query],
173
+ text=True,
174
+ capture_output=True,
175
+ check=True)
179
176
  search_response_list = []
180
177
  search_results = completed_process.stdout
181
178
  package_results = search_results.split("\n")
@@ -215,7 +212,7 @@ class PypiRegistryHandler(AbstractRegistryHandler):
215
212
 
216
213
  except Exception as e:
217
214
  msg = f"Error searching for artifacts: {e}"
218
- logger.exception(msg, exc_info=True)
215
+ logger.exception(msg)
219
216
  validated_search_response = SearchResponse(params=query,
220
217
  status={
221
218
  "status": StatusEnum.ERROR,
@@ -13,9 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
- from .local import register_local # pylint: disable=E0611
20
- from .pypi import register_pypi # pylint: disable=E0611
21
- from .rest import register_rest # pylint: disable=E0611
18
+ from .local import register_local
19
+ from .pypi import register_pypi
20
+ from .rest import register_rest