nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +50 -22
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +54 -27
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +68 -17
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +62 -22
- nat/cli/entrypoint.py +8 -10
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +74 -66
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +19 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +106 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,384 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import random
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from pydantic import BaseModel
|
|
25
|
+
|
|
26
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
27
|
+
from nat.data_models.config import Config
|
|
28
|
+
from nat.data_models.optimizable import SearchSpace
|
|
29
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
30
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
31
|
+
from nat.eval.evaluate import EvaluationRun
|
|
32
|
+
from nat.eval.evaluate import EvaluationRunConfig
|
|
33
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
34
|
+
from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PromptOptimizerInputSchema(BaseModel):
|
|
40
|
+
original_prompt: str
|
|
41
|
+
objective: str
|
|
42
|
+
oracle_feedback: str | None = None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@experimental(feature_name="Optimizer")
|
|
46
|
+
async def optimize_prompts(
|
|
47
|
+
*,
|
|
48
|
+
base_cfg: Config,
|
|
49
|
+
full_space: dict[str, SearchSpace],
|
|
50
|
+
optimizer_config: OptimizerConfig,
|
|
51
|
+
opt_run_config: OptimizerRunConfig,
|
|
52
|
+
) -> None:
|
|
53
|
+
|
|
54
|
+
# ------------- helpers ------------- #
|
|
55
|
+
@dataclass
|
|
56
|
+
class Individual:
|
|
57
|
+
prompts: dict[str, str] # param_name -> prompt text
|
|
58
|
+
metrics: dict[str, float] | None = None # evaluator_name -> average score
|
|
59
|
+
scalar_fitness: float | None = None
|
|
60
|
+
|
|
61
|
+
def _normalize_generation(
|
|
62
|
+
individuals: Sequence[Individual],
|
|
63
|
+
metric_names: Sequence[str],
|
|
64
|
+
directions: Sequence[str],
|
|
65
|
+
eps: float = 1e-12,
|
|
66
|
+
) -> list[dict[str, float]]:
|
|
67
|
+
"""Return per-individual dict of normalised scores in [0,1] where higher is better."""
|
|
68
|
+
# Extract arrays per metric
|
|
69
|
+
arrays = {m: [ind.metrics.get(m, 0.0) if ind.metrics else 0.0 for ind in individuals] for m in metric_names}
|
|
70
|
+
normed: list[dict[str, float]] = []
|
|
71
|
+
for i in range(len(individuals)):
|
|
72
|
+
entry: dict[str, float] = {}
|
|
73
|
+
for m, dirn in zip(metric_names, directions):
|
|
74
|
+
vals = arrays[m]
|
|
75
|
+
vmin = min(vals)
|
|
76
|
+
vmax = max(vals)
|
|
77
|
+
v = vals[i]
|
|
78
|
+
# Map to [0,1] with higher=better regardless of direction
|
|
79
|
+
if vmax - vmin < eps:
|
|
80
|
+
score01 = 0.5
|
|
81
|
+
else:
|
|
82
|
+
score01 = (v - vmin) / (vmax - vmin)
|
|
83
|
+
if dirn == "minimize":
|
|
84
|
+
score01 = 1.0 - score01
|
|
85
|
+
entry[m] = float(score01)
|
|
86
|
+
normed.append(entry)
|
|
87
|
+
return normed
|
|
88
|
+
|
|
89
|
+
def _scalarize(norm_scores: dict[str, float], *, mode: str, weights: Sequence[float] | None) -> float:
|
|
90
|
+
"""Collapse normalised scores to a single scalar (higher is better)."""
|
|
91
|
+
vals = list(norm_scores.values())
|
|
92
|
+
if not vals:
|
|
93
|
+
return 0.0
|
|
94
|
+
if mode == "harmonic":
|
|
95
|
+
inv_sum = sum(1.0 / max(v, 1e-12) for v in vals)
|
|
96
|
+
return len(vals) / max(inv_sum, 1e-12)
|
|
97
|
+
if mode == "sum":
|
|
98
|
+
if weights is None:
|
|
99
|
+
return float(sum(vals))
|
|
100
|
+
if len(weights) != len(vals):
|
|
101
|
+
raise ValueError("weights length must equal number of objectives")
|
|
102
|
+
return float(sum(w * v for w, v in zip(weights, vals)))
|
|
103
|
+
if mode == "chebyshev":
|
|
104
|
+
return float(min(vals)) # maximise the worst-case score
|
|
105
|
+
raise ValueError(f"Unknown combination mode: {mode}")
|
|
106
|
+
|
|
107
|
+
def _apply_diversity_penalty(individuals: Sequence[Individual], diversity_lambda: float) -> list[float]:
|
|
108
|
+
if diversity_lambda <= 0.0:
|
|
109
|
+
return [0.0 for _ in individuals]
|
|
110
|
+
seen: dict[str, int] = {}
|
|
111
|
+
keys: list[str] = []
|
|
112
|
+
penalties: list[float] = []
|
|
113
|
+
for ind in individuals:
|
|
114
|
+
key = "\u241f".join(ind.prompts.get(k, "") for k in sorted(ind.prompts.keys()))
|
|
115
|
+
keys.append(key)
|
|
116
|
+
seen[key] = seen.get(key, 0) + 1
|
|
117
|
+
for key in keys:
|
|
118
|
+
duplicates = seen[key] - 1
|
|
119
|
+
penalties.append(diversity_lambda * float(duplicates))
|
|
120
|
+
return penalties
|
|
121
|
+
|
|
122
|
+
def _tournament_select(pop: Sequence[Individual], k: int) -> Individual:
|
|
123
|
+
contenders = random.sample(pop, k=min(k, len(pop)))
|
|
124
|
+
return max(contenders, key=lambda i: (i.scalar_fitness or 0.0))
|
|
125
|
+
|
|
126
|
+
# ------------- discover space ------------- #
|
|
127
|
+
prompt_space: dict[str, tuple[str, str]] = {
|
|
128
|
+
k: (v.prompt, v.prompt_purpose)
|
|
129
|
+
for k, v in full_space.items() if v.is_prompt
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
if not prompt_space:
|
|
133
|
+
logger.info("No prompts to optimize – skipping.")
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
metric_cfg = optimizer_config.eval_metrics
|
|
137
|
+
if metric_cfg is None or len(metric_cfg) == 0:
|
|
138
|
+
raise ValueError("optimizer_config.eval_metrics must be provided for GA prompt optimization")
|
|
139
|
+
|
|
140
|
+
directions = [v.direction for v in metric_cfg.values()] # "minimize" or "maximize"
|
|
141
|
+
eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
|
|
142
|
+
weights = [v.weight for v in metric_cfg.values()]
|
|
143
|
+
|
|
144
|
+
out_dir = optimizer_config.output_path
|
|
145
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
146
|
+
|
|
147
|
+
# ------------- builder & functions ------------- #
|
|
148
|
+
async with WorkflowBuilder(general_config=base_cfg.general, registry=None) as builder:
|
|
149
|
+
await builder.populate_builder(base_cfg)
|
|
150
|
+
init_fn_name = (optimizer_config.prompt.prompt_population_init_function)
|
|
151
|
+
if not init_fn_name:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
"No prompt optimization function configured. Set optimizer.prompt_population_init_function")
|
|
154
|
+
init_fn = await builder.get_function(init_fn_name)
|
|
155
|
+
|
|
156
|
+
recombine_fn = None
|
|
157
|
+
if optimizer_config.prompt.prompt_recombination_function:
|
|
158
|
+
recombine_fn = await builder.get_function(optimizer_config.prompt.prompt_recombination_function)
|
|
159
|
+
|
|
160
|
+
logger.info(
|
|
161
|
+
"GA Prompt optimization ready: init_fn=%s, recombine_fn=%s",
|
|
162
|
+
init_fn_name,
|
|
163
|
+
optimizer_config.prompt.prompt_recombination_function,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# ------------- GA parameters ------------- #
|
|
167
|
+
pop_size = max(2, int(optimizer_config.prompt.ga_population_size))
|
|
168
|
+
generations = max(1, int(optimizer_config.prompt.ga_generations))
|
|
169
|
+
offspring_size = (optimizer_config.prompt.ga_offspring_size
|
|
170
|
+
or max(0, pop_size - optimizer_config.prompt.ga_elitism))
|
|
171
|
+
crossover_rate = float(optimizer_config.prompt.ga_crossover_rate)
|
|
172
|
+
mutation_rate = float(optimizer_config.prompt.ga_mutation_rate)
|
|
173
|
+
elitism = max(0, int(optimizer_config.prompt.ga_elitism))
|
|
174
|
+
selection_method = optimizer_config.prompt.ga_selection_method.lower()
|
|
175
|
+
tournament_size = max(2, int(optimizer_config.prompt.ga_tournament_size))
|
|
176
|
+
max_eval_concurrency = max(1, int(optimizer_config.prompt.ga_parallel_evaluations))
|
|
177
|
+
diversity_lambda = float(optimizer_config.prompt.ga_diversity_lambda)
|
|
178
|
+
|
|
179
|
+
# ------------- population init ------------- #
|
|
180
|
+
async def _mutate_prompt(original_prompt: str, purpose: str) -> str:
|
|
181
|
+
# Use LLM-based optimizer with no feedback
|
|
182
|
+
return await init_fn.acall_invoke(
|
|
183
|
+
PromptOptimizerInputSchema(
|
|
184
|
+
original_prompt=original_prompt,
|
|
185
|
+
objective=purpose,
|
|
186
|
+
oracle_feedback=None,
|
|
187
|
+
))
|
|
188
|
+
|
|
189
|
+
async def _recombine_prompts(a: str, b: str, purpose: str) -> str:
|
|
190
|
+
if recombine_fn is None:
|
|
191
|
+
# Fallback: uniform choice per recombination
|
|
192
|
+
return random.choice([a, b])
|
|
193
|
+
payload = {"original_prompt": a, "objective": purpose, "oracle_feedback": None, "parent_b": b}
|
|
194
|
+
return await recombine_fn.acall_invoke(payload)
|
|
195
|
+
|
|
196
|
+
def _make_individual_from_prompts(prompts: dict[str, str]) -> Individual:
|
|
197
|
+
return Individual(prompts=dict(prompts))
|
|
198
|
+
|
|
199
|
+
async def _initial_population() -> list[Individual]:
|
|
200
|
+
individuals: list[Individual] = []
|
|
201
|
+
# Ensure first individual is the original prompts
|
|
202
|
+
originals = {k: prompt_space[k][0] for k in prompt_space}
|
|
203
|
+
individuals.append(_make_individual_from_prompts(originals))
|
|
204
|
+
|
|
205
|
+
init_sem = asyncio.Semaphore(max_eval_concurrency)
|
|
206
|
+
|
|
207
|
+
async def _create_random_individual() -> Individual:
|
|
208
|
+
async with init_sem:
|
|
209
|
+
mutated: dict[str, str] = {}
|
|
210
|
+
for param, (base_prompt, purpose) in prompt_space.items():
|
|
211
|
+
try:
|
|
212
|
+
new_p = await _mutate_prompt(base_prompt, purpose)
|
|
213
|
+
except Exception as e:
|
|
214
|
+
logger.warning("Mutation failed for %s: %s; using original.", param, e)
|
|
215
|
+
new_p = base_prompt
|
|
216
|
+
mutated[param] = new_p
|
|
217
|
+
return _make_individual_from_prompts(mutated)
|
|
218
|
+
|
|
219
|
+
needed = max(0, pop_size - 1)
|
|
220
|
+
tasks = [_create_random_individual() for _ in range(needed)]
|
|
221
|
+
individuals.extend(await asyncio.gather(*tasks))
|
|
222
|
+
return individuals
|
|
223
|
+
|
|
224
|
+
# ------------- evaluation ------------- #
|
|
225
|
+
reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
|
|
226
|
+
|
|
227
|
+
sem = asyncio.Semaphore(max_eval_concurrency)
|
|
228
|
+
|
|
229
|
+
async def _evaluate(ind: Individual) -> Individual:
|
|
230
|
+
async with sem:
|
|
231
|
+
cfg_trial = apply_suggestions(base_cfg, ind.prompts)
|
|
232
|
+
eval_cfg = EvaluationRunConfig(
|
|
233
|
+
config_file=cfg_trial,
|
|
234
|
+
dataset=opt_run_config.dataset,
|
|
235
|
+
result_json_path=opt_run_config.result_json_path,
|
|
236
|
+
endpoint=opt_run_config.endpoint,
|
|
237
|
+
endpoint_timeout=opt_run_config.endpoint_timeout,
|
|
238
|
+
override=opt_run_config.override,
|
|
239
|
+
)
|
|
240
|
+
# Run reps sequentially under the same semaphore to avoid overload
|
|
241
|
+
all_results: list[list[tuple[str, Any]]] = []
|
|
242
|
+
for _ in range(reps):
|
|
243
|
+
res = (await EvaluationRun(config=eval_cfg).run_and_evaluate()).evaluation_results
|
|
244
|
+
all_results.append(res)
|
|
245
|
+
|
|
246
|
+
metrics: dict[str, float] = {}
|
|
247
|
+
for metric_name in eval_metrics:
|
|
248
|
+
scores: list[float] = []
|
|
249
|
+
for run_results in all_results:
|
|
250
|
+
for name, result in run_results:
|
|
251
|
+
if name == metric_name:
|
|
252
|
+
scores.append(result.average_score)
|
|
253
|
+
break
|
|
254
|
+
metrics[metric_name] = float(sum(scores) / len(scores)) if scores else 0.0
|
|
255
|
+
ind.metrics = metrics
|
|
256
|
+
return ind
|
|
257
|
+
|
|
258
|
+
async def _evaluate_population(pop: list[Individual]) -> list[Individual]:
|
|
259
|
+
# Evaluate those missing metrics
|
|
260
|
+
unevaluated = [ind for ind in pop if not ind.metrics]
|
|
261
|
+
if unevaluated:
|
|
262
|
+
evaluated = await asyncio.gather(*[_evaluate(ind) for ind in unevaluated])
|
|
263
|
+
# in-place update
|
|
264
|
+
for ind, ev in zip(unevaluated, evaluated):
|
|
265
|
+
ind.metrics = ev.metrics
|
|
266
|
+
# Scalarize
|
|
267
|
+
norm_per_ind = _normalize_generation(pop, eval_metrics, directions)
|
|
268
|
+
penalties = _apply_diversity_penalty(pop, diversity_lambda)
|
|
269
|
+
for ind, norm_scores, penalty in zip(pop, norm_per_ind, penalties):
|
|
270
|
+
ind.scalar_fitness = _scalarize(
|
|
271
|
+
norm_scores, mode=optimizer_config.multi_objective_combination_mode, weights=weights) - penalty
|
|
272
|
+
return pop
|
|
273
|
+
|
|
274
|
+
# ------------- reproduction ops ------------- #
|
|
275
|
+
async def _make_child(parent_a: Individual, parent_b: Individual) -> Individual:
|
|
276
|
+
child_prompts: dict[str, str] = {}
|
|
277
|
+
for param, (base_prompt, purpose) in prompt_space.items():
|
|
278
|
+
pa = parent_a.prompts.get(param, base_prompt)
|
|
279
|
+
pb = parent_b.prompts.get(param, base_prompt)
|
|
280
|
+
child = pa
|
|
281
|
+
# crossover
|
|
282
|
+
if random.random() < crossover_rate:
|
|
283
|
+
try:
|
|
284
|
+
child = await _recombine_prompts(pa, pb, purpose)
|
|
285
|
+
except Exception as e:
|
|
286
|
+
logger.warning("Recombination failed for %s: %s; falling back to parent.", param, e)
|
|
287
|
+
child = random.choice([pa, pb])
|
|
288
|
+
# mutation
|
|
289
|
+
if random.random() < mutation_rate:
|
|
290
|
+
try:
|
|
291
|
+
child = await _mutate_prompt(child, purpose)
|
|
292
|
+
except Exception as e:
|
|
293
|
+
logger.warning("Mutation failed for %s: %s; keeping child as-is.", param, e)
|
|
294
|
+
child_prompts[param] = child
|
|
295
|
+
return _make_individual_from_prompts(child_prompts)
|
|
296
|
+
|
|
297
|
+
# ------------- GA loop ------------- #
|
|
298
|
+
population = await _initial_population()
|
|
299
|
+
history_rows: list[dict[str, Any]] = []
|
|
300
|
+
|
|
301
|
+
for gen in range(1, generations + 1):
|
|
302
|
+
logger.info("[GA] Generation %d/%d: evaluating population of %d", gen, generations, len(population))
|
|
303
|
+
population = await _evaluate_population(population)
|
|
304
|
+
|
|
305
|
+
# Log and save checkpoint
|
|
306
|
+
best = max(population, key=lambda i: (i.scalar_fitness or 0.0))
|
|
307
|
+
checkpoint = {k: (best.prompts[k], prompt_space[k][1]) for k in prompt_space}
|
|
308
|
+
checkpoint_path = out_dir / f"optimized_prompts_gen{gen}.json"
|
|
309
|
+
with checkpoint_path.open("w") as fh:
|
|
310
|
+
json.dump(checkpoint, fh, indent=2)
|
|
311
|
+
logger.info("[GA] Saved checkpoint: %s (fitness=%.4f)", checkpoint_path, best.scalar_fitness or 0.0)
|
|
312
|
+
|
|
313
|
+
# Append history
|
|
314
|
+
for idx, ind in enumerate(population):
|
|
315
|
+
row = {
|
|
316
|
+
"generation": gen,
|
|
317
|
+
"index": idx,
|
|
318
|
+
"scalar_fitness": ind.scalar_fitness,
|
|
319
|
+
}
|
|
320
|
+
if ind.metrics:
|
|
321
|
+
row.update({f"metric::{m}": ind.metrics[m] for m in eval_metrics})
|
|
322
|
+
history_rows.append(row)
|
|
323
|
+
|
|
324
|
+
# Next generation via elitism + reproduction
|
|
325
|
+
next_population: list[Individual] = []
|
|
326
|
+
if elitism > 0:
|
|
327
|
+
elites = sorted(population, key=lambda i: (i.scalar_fitness or 0.0), reverse=True)[:elitism]
|
|
328
|
+
next_population.extend([_make_individual_from_prompts(e.prompts) for e in elites])
|
|
329
|
+
|
|
330
|
+
def _select_parent(curr_pop: list[Individual]) -> Individual:
|
|
331
|
+
if selection_method == "tournament":
|
|
332
|
+
return _tournament_select(curr_pop, tournament_size)
|
|
333
|
+
# roulette wheel
|
|
334
|
+
total = sum(max(ind.scalar_fitness or 0.0, 0.0) for ind in curr_pop) or 1.0
|
|
335
|
+
r = random.random() * total
|
|
336
|
+
acc = 0.0
|
|
337
|
+
for ind in curr_pop:
|
|
338
|
+
acc += max(ind.scalar_fitness or 0.0, 0.0)
|
|
339
|
+
if acc >= r:
|
|
340
|
+
return ind
|
|
341
|
+
return curr_pop[-1]
|
|
342
|
+
|
|
343
|
+
# Produce offspring
|
|
344
|
+
needed = pop_size - len(next_population)
|
|
345
|
+
offspring: list[Individual] = []
|
|
346
|
+
for _ in range(max(0, offspring_size), needed):
|
|
347
|
+
pass # ensure bound correctness
|
|
348
|
+
while len(offspring) < needed:
|
|
349
|
+
p1 = _select_parent(population)
|
|
350
|
+
p2 = _select_parent(population)
|
|
351
|
+
if p2 is p1 and len(population) > 1:
|
|
352
|
+
p2 = random.choice([ind for ind in population if ind is not p1])
|
|
353
|
+
child = await _make_child(p1, p2)
|
|
354
|
+
offspring.append(child)
|
|
355
|
+
|
|
356
|
+
population = next_population + offspring
|
|
357
|
+
|
|
358
|
+
# Final evaluation to ensure metrics present
|
|
359
|
+
population = await _evaluate_population(population)
|
|
360
|
+
best = max(population, key=lambda i: (i.scalar_fitness or 0.0))
|
|
361
|
+
best_prompts = {k: (best.prompts[k], prompt_space[k][1]) for k in prompt_space}
|
|
362
|
+
|
|
363
|
+
# Save final
|
|
364
|
+
final_prompts_path = out_dir / "optimized_prompts.json"
|
|
365
|
+
with final_prompts_path.open("w") as fh:
|
|
366
|
+
json.dump(best_prompts, fh, indent=2)
|
|
367
|
+
|
|
368
|
+
trials_df_path = out_dir / "ga_history_prompts.csv"
|
|
369
|
+
try:
|
|
370
|
+
# Lazy import pandas if available; otherwise write CSV manually
|
|
371
|
+
import csv # pylint: disable=import-outside-toplevel
|
|
372
|
+
|
|
373
|
+
fieldnames: list[str] = sorted({k for row in history_rows for k in row.keys()})
|
|
374
|
+
with trials_df_path.open("w", newline="") as f:
|
|
375
|
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
376
|
+
writer.writeheader()
|
|
377
|
+
for row in history_rows:
|
|
378
|
+
writer.writerow(row)
|
|
379
|
+
except Exception as e: # pragma: no cover - best effort
|
|
380
|
+
logger.warning("Failed to write GA history CSV: %s", e)
|
|
381
|
+
|
|
382
|
+
logger.info("Prompt GA optimization finished successfully!")
|
|
383
|
+
logger.info("Final prompts saved to: %s", final_prompts_path)
|
|
384
|
+
logger.info("History saved to: %s", trials_df_path)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _deep_merge_dict(target: dict[str, Any], updates: dict[str, Any]) -> None:
|
|
23
|
+
"""In-place deep merge of nested dictionaries."""
|
|
24
|
+
for key, value in updates.items():
|
|
25
|
+
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
|
26
|
+
_deep_merge_dict(target[key], value)
|
|
27
|
+
else:
|
|
28
|
+
target[key] = value
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def nest_updates(flat: dict[str, Any]) -> dict[str, Any]:
|
|
32
|
+
"""
|
|
33
|
+
Convert ``{'a.b.c': 1, 'd.x.y': 2}`` ➜
|
|
34
|
+
``{'a': {'b': {'c': 1}}, 'd': {'x': {'y': 2}}}``.
|
|
35
|
+
Works even when the middle segment is a dict key.
|
|
36
|
+
"""
|
|
37
|
+
root: dict[str, Any] = defaultdict(dict)
|
|
38
|
+
|
|
39
|
+
for dotted, value in flat.items():
|
|
40
|
+
head, *rest = dotted.split(".", 1)
|
|
41
|
+
if not rest: # leaf
|
|
42
|
+
root[head] = value
|
|
43
|
+
continue
|
|
44
|
+
|
|
45
|
+
tail = rest[0]
|
|
46
|
+
child_updates = nest_updates({tail: value})
|
|
47
|
+
if isinstance(root[head], dict):
|
|
48
|
+
_deep_merge_dict(root[head], child_updates)
|
|
49
|
+
else:
|
|
50
|
+
root[head] = child_updates
|
|
51
|
+
return dict(root)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def apply_suggestions(cfg: BaseModel, flat: dict[str, Any]) -> BaseModel:
|
|
55
|
+
"""
|
|
56
|
+
Return a **new** config where only the dotted-path keys in *flat*
|
|
57
|
+
have been modified. Preserves all unrelated siblings.
|
|
58
|
+
"""
|
|
59
|
+
cfg_dict = cfg.model_dump(mode="python")
|
|
60
|
+
for dotted, value in flat.items():
|
|
61
|
+
keys = dotted.split(".")
|
|
62
|
+
cursor = cfg_dict
|
|
63
|
+
for key in keys[:-1]:
|
|
64
|
+
cursor = cursor.setdefault(key, {})
|
|
65
|
+
cursor[keys[-1]] = value
|
|
66
|
+
return cfg.__class__.model_validate(cfg_dict)
|
nat/profiler/profile_runner.py
CHANGED
|
@@ -88,14 +88,19 @@ class ProfilerRunner:
|
|
|
88
88
|
writes out combined requests JSON, then computes and saves additional metrics,
|
|
89
89
|
and optionally fits a forecasting model.
|
|
90
90
|
"""
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
from nat.profiler.inference_optimization.
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
91
|
+
# Yapf and ruff disagree on how to format long imports, disable yapf go with ruff
|
|
92
|
+
from nat.profiler.inference_optimization.bottleneck_analysis.nested_stack_analysis import (
|
|
93
|
+
multi_example_call_profiling,
|
|
94
|
+
) # yapf: disable
|
|
95
|
+
from nat.profiler.inference_optimization.bottleneck_analysis.simple_stack_analysis import (
|
|
96
|
+
profile_workflow_bottlenecks,
|
|
97
|
+
) # yapf: disable
|
|
98
|
+
from nat.profiler.inference_optimization.experimental.concurrency_spike_analysis import (
|
|
99
|
+
concurrency_spike_analysis,
|
|
100
|
+
) # yapf: disable
|
|
101
|
+
from nat.profiler.inference_optimization.experimental.prefix_span_analysis import (
|
|
102
|
+
prefixspan_subworkflow_with_text,
|
|
103
|
+
) # yapf: disable
|
|
99
104
|
from nat.profiler.inference_optimization.llm_metrics import LLMMetrics
|
|
100
105
|
from nat.profiler.inference_optimization.prompt_caching import get_common_prefixes
|
|
101
106
|
from nat.profiler.inference_optimization.token_uniqueness import compute_inter_query_token_uniqueness_by_llm
|
|
@@ -277,7 +282,7 @@ class ProfilerRunner:
|
|
|
277
282
|
fitted_model = model_trainer.train(all_steps)
|
|
278
283
|
logger.info("Fitted model for forecasting.")
|
|
279
284
|
except Exception as e:
|
|
280
|
-
logger.exception("Fitting model failed. %s", e
|
|
285
|
+
logger.exception("Fitting model failed. %s", e)
|
|
281
286
|
return ProfilerResults()
|
|
282
287
|
|
|
283
288
|
if self.write_output:
|
nat/profiler/utils.py
CHANGED
|
@@ -22,6 +22,7 @@ from typing import Any
|
|
|
22
22
|
import pandas as pd
|
|
23
23
|
|
|
24
24
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
25
|
+
from nat.cli.type_registry import RegisteredFunctionGroupInfo
|
|
25
26
|
from nat.cli.type_registry import RegisteredFunctionInfo
|
|
26
27
|
from nat.data_models.intermediate_step import IntermediateStep
|
|
27
28
|
from nat.profiler.data_frame_row import DataFrameRow
|
|
@@ -32,7 +33,8 @@ _FRAMEWORK_REGEX_MAP = {t: fr'\b{t._name_}\b' for t in LLMFrameworkEnum}
|
|
|
32
33
|
logger = logging.getLogger(__name__)
|
|
33
34
|
|
|
34
35
|
|
|
35
|
-
def detect_llm_frameworks_in_build_fn(
|
|
36
|
+
def detect_llm_frameworks_in_build_fn(
|
|
37
|
+
registration: RegisteredFunctionInfo | RegisteredFunctionGroupInfo) -> list[LLMFrameworkEnum]:
|
|
36
38
|
"""
|
|
37
39
|
Analyze a function's source (the build_fn) to see which LLM frameworks it uses. Also recurses
|
|
38
40
|
into any additional Python functions that the build_fn calls while passing `builder`, so that
|
|
@@ -175,7 +177,7 @@ def create_standardized_dataframe(requests_data: list[list[IntermediateStep]]) -
|
|
|
175
177
|
event_type=step.event_type).model_dump(), )
|
|
176
178
|
|
|
177
179
|
except Exception as e:
|
|
178
|
-
logger.exception("Error creating standardized DataFrame: %s", e
|
|
180
|
+
logger.exception("Error creating standardized DataFrame: %s", e)
|
|
179
181
|
return pd.DataFrame()
|
|
180
182
|
|
|
181
183
|
if not all_rows:
|
|
@@ -133,7 +133,7 @@ class LocalRegistryHandler(AbstractRegistryHandler):
|
|
|
133
133
|
"message": msg,
|
|
134
134
|
"action": ActionEnum.SEARCH
|
|
135
135
|
})
|
|
136
|
-
logger.exception(validated_search_response.status.message
|
|
136
|
+
logger.exception(validated_search_response.status.message)
|
|
137
137
|
|
|
138
138
|
yield validated_search_response
|
|
139
139
|
|
|
@@ -168,7 +168,7 @@ class LocalRegistryHandler(AbstractRegistryHandler):
|
|
|
168
168
|
validated_remove_response = RemoveResponse(status={
|
|
169
169
|
"status": StatusEnum.ERROR, "message": msg, "action": ActionEnum.REMOVE
|
|
170
170
|
}) # type: ignore
|
|
171
|
-
logger.exception(validated_remove_response.status.message
|
|
171
|
+
logger.exception(validated_remove_response.status.message)
|
|
172
172
|
|
|
173
173
|
yield validated_remove_response
|
|
174
174
|
|
|
@@ -29,7 +29,6 @@ from nat.registry_handlers.schemas.publish import Artifact
|
|
|
29
29
|
from nat.runtime.loader import PluginTypes
|
|
30
30
|
from nat.runtime.loader import discover_entrypoints
|
|
31
31
|
|
|
32
|
-
# pylint: disable=redefined-outer-name
|
|
33
32
|
logger = logging.getLogger(__name__)
|
|
34
33
|
|
|
35
34
|
|
|
@@ -397,7 +396,7 @@ def get_transitive_dependencies(distribution_names: list[str]) -> dict[str, set[
|
|
|
397
396
|
except importlib.metadata.PackageNotFoundError:
|
|
398
397
|
pass
|
|
399
398
|
|
|
400
|
-
logger.error("Distribution %s not found (tried common variations)", dist_name)
|
|
399
|
+
logger.error("Distribution %s not found (tried common variations)", dist_name, exc_info=True)
|
|
401
400
|
result[dist_name] = set()
|
|
402
401
|
|
|
403
402
|
return result
|
|
@@ -44,13 +44,12 @@ class PypiRegistryHandler(AbstractRegistryHandler):
|
|
|
44
44
|
https://github.com/pypiserver/pypiserver
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
|
-
def __init__(
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
search_route: str = ""):
|
|
47
|
+
def __init__(self,
|
|
48
|
+
endpoint: str,
|
|
49
|
+
token: str | None = None,
|
|
50
|
+
publish_route: str = "",
|
|
51
|
+
pull_route: str = "",
|
|
52
|
+
search_route: str = ""):
|
|
54
53
|
super().__init__()
|
|
55
54
|
self._endpoint = endpoint.rstrip("/")
|
|
56
55
|
self._token = token
|
|
@@ -86,7 +85,7 @@ class PypiRegistryHandler(AbstractRegistryHandler):
|
|
|
86
85
|
validated_publish_response = PublishResponse(status={
|
|
87
86
|
"status": StatusEnum.ERROR, "message": msg, "action": ActionEnum.PUBLISH
|
|
88
87
|
})
|
|
89
|
-
logger.exception(validated_publish_response.status.message
|
|
88
|
+
logger.exception(validated_publish_response.status.message)
|
|
90
89
|
|
|
91
90
|
yield validated_publish_response
|
|
92
91
|
|
|
@@ -126,17 +125,16 @@ class PypiRegistryHandler(AbstractRegistryHandler):
|
|
|
126
125
|
|
|
127
126
|
versioned_packages_str = " ".join(versioned_packages)
|
|
128
127
|
|
|
129
|
-
result = subprocess.run(
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
check=True)
|
|
128
|
+
result = subprocess.run([
|
|
129
|
+
"uv",
|
|
130
|
+
"pip",
|
|
131
|
+
"install",
|
|
132
|
+
"--prerelease=allow",
|
|
133
|
+
"--index-url",
|
|
134
|
+
f"{self._endpoint}/{self._pull_route}/",
|
|
135
|
+
versioned_packages_str
|
|
136
|
+
],
|
|
137
|
+
check=True)
|
|
140
138
|
|
|
141
139
|
result.check_returncode()
|
|
142
140
|
|
|
@@ -151,7 +149,7 @@ class PypiRegistryHandler(AbstractRegistryHandler):
|
|
|
151
149
|
validated_pull_response = PullResponse(status={
|
|
152
150
|
"status": StatusEnum.ERROR, "message": msg, "action": ActionEnum.PULL
|
|
153
151
|
})
|
|
154
|
-
logger.exception(validated_pull_response.status.message
|
|
152
|
+
logger.exception(validated_pull_response.status.message)
|
|
155
153
|
|
|
156
154
|
yield validated_pull_response
|
|
157
155
|
|
|
@@ -171,11 +169,10 @@ class PypiRegistryHandler(AbstractRegistryHandler):
|
|
|
171
169
|
"""
|
|
172
170
|
|
|
173
171
|
try:
|
|
174
|
-
completed_process = subprocess.run(
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
check=True)
|
|
172
|
+
completed_process = subprocess.run(["pip", "search", "--index", f"{self._endpoint}", query.query],
|
|
173
|
+
text=True,
|
|
174
|
+
capture_output=True,
|
|
175
|
+
check=True)
|
|
179
176
|
search_response_list = []
|
|
180
177
|
search_results = completed_process.stdout
|
|
181
178
|
package_results = search_results.split("\n")
|
|
@@ -215,7 +212,7 @@ class PypiRegistryHandler(AbstractRegistryHandler):
|
|
|
215
212
|
|
|
216
213
|
except Exception as e:
|
|
217
214
|
msg = f"Error searching for artifacts: {e}"
|
|
218
|
-
logger.exception(msg
|
|
215
|
+
logger.exception(msg)
|
|
219
216
|
validated_search_response = SearchResponse(params=query,
|
|
220
217
|
status={
|
|
221
218
|
"status": StatusEnum.ERROR,
|
|
@@ -13,9 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint: disable=unused-import
|
|
17
16
|
# flake8: noqa
|
|
18
17
|
|
|
19
|
-
from .local import register_local
|
|
20
|
-
from .pypi import register_pypi
|
|
21
|
-
from .rest import register_rest
|
|
18
|
+
from .local import register_local
|
|
19
|
+
from .pypi import register_pypi
|
|
20
|
+
from .rest import register_rest
|