mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,889 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
APO with textual gradients that read rollout spans and outputs to modify the prompt.
|
|
5
|
+
|
|
6
|
+
- algo: beam search with span-aware textual gradients -> apply_edit via LLM
|
|
7
|
+
- rollout: same pattern as your example, but task is a dict (T_task)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import logging
|
|
14
|
+
import random
|
|
15
|
+
import time
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import (
|
|
19
|
+
TYPE_CHECKING,
|
|
20
|
+
Any,
|
|
21
|
+
Counter,
|
|
22
|
+
Dict,
|
|
23
|
+
Generic,
|
|
24
|
+
Iterator,
|
|
25
|
+
List,
|
|
26
|
+
Optional,
|
|
27
|
+
Sequence,
|
|
28
|
+
Set,
|
|
29
|
+
Tuple,
|
|
30
|
+
TypedDict,
|
|
31
|
+
TypeVar,
|
|
32
|
+
cast,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
import poml
|
|
36
|
+
from openai import AsyncOpenAI
|
|
37
|
+
|
|
38
|
+
from mantisdk.adapter.messages import TraceToMessages
|
|
39
|
+
from mantisdk.algorithm.base import Algorithm
|
|
40
|
+
from mantisdk.algorithm.utils import batch_iter_over_dataset, with_llm_proxy, with_store
|
|
41
|
+
from mantisdk.reward import find_final_reward
|
|
42
|
+
from mantisdk.types import Dataset, NamedResources, PromptTemplate, Rollout, RolloutMode, RolloutStatus
|
|
43
|
+
|
|
44
|
+
if TYPE_CHECKING:
|
|
45
|
+
from mantisdk.llm_proxy import LLMProxy
|
|
46
|
+
from mantisdk.store.base import LightningStore
|
|
47
|
+
|
|
48
|
+
logger = logging.getLogger(__name__)
|
|
49
|
+
|
|
50
|
+
T_task = TypeVar("T_task")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class RolloutResultForAPO(TypedDict):
|
|
54
|
+
"""This must be all JSON serializable to be processable by POML."""
|
|
55
|
+
|
|
56
|
+
status: RolloutStatus
|
|
57
|
+
final_reward: Optional[float]
|
|
58
|
+
spans: List[Dict[str, Any]]
|
|
59
|
+
messages: List[Any]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class VersionedPromptTemplate:
|
|
64
|
+
version: str
|
|
65
|
+
prompt_template: PromptTemplate
|
|
66
|
+
score: Optional[float] = None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
GRADIENT_PROMPT_FILES = [
|
|
70
|
+
Path(__file__).parent / "prompts" / "text_gradient_variant01.poml",
|
|
71
|
+
Path(__file__).parent / "prompts" / "text_gradient_variant02.poml",
|
|
72
|
+
Path(__file__).parent / "prompts" / "text_gradient_variant03.poml",
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
APPLY_EDIT_PROMPT_FILES = [
|
|
76
|
+
Path(__file__).parent / "prompts" / "apply_edit_variant01.poml",
|
|
77
|
+
Path(__file__).parent / "prompts" / "apply_edit_variant02.poml",
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class APO(Algorithm, Generic[T_task]):
|
|
82
|
+
"""Automatic Prompt Optimization (APO) algorithm using textual gradients and beam search.
|
|
83
|
+
|
|
84
|
+
APO is an iterative prompt optimization algorithm that uses LLM-generated textual gradients
|
|
85
|
+
to improve prompts through a beam search process. It evaluates prompts on rollouts,
|
|
86
|
+
computes critiques based on the results, and applies edits to generate improved prompts.
|
|
87
|
+
|
|
88
|
+
The algorithm operates in rounds, where each round:
|
|
89
|
+
|
|
90
|
+
1. Samples parent prompts from the current beam
|
|
91
|
+
2. Generates new prompts by computing textual gradients and applying edits
|
|
92
|
+
3. Evaluates all candidates on a validation set
|
|
93
|
+
4. Selects the top-k prompts for the next round
|
|
94
|
+
|
|
95
|
+
Based on the ideas from:
|
|
96
|
+
|
|
97
|
+
- [ProTeGi](https://aclanthology.org/2023.emnlp-main.494.pdf)
|
|
98
|
+
- [TextGrad](https://github.com/zou-group/textgrad)
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
async_openai_client: AsyncOpenAI,
|
|
104
|
+
*,
|
|
105
|
+
gradient_model: str = "gpt-5-mini",
|
|
106
|
+
apply_edit_model: str = "gpt-4.1-mini",
|
|
107
|
+
diversity_temperature: float = 1.0,
|
|
108
|
+
gradient_batch_size: int = 4,
|
|
109
|
+
val_batch_size: int = 16,
|
|
110
|
+
beam_width: int = 4,
|
|
111
|
+
branch_factor: int = 4,
|
|
112
|
+
beam_rounds: int = 3,
|
|
113
|
+
rollout_batch_timeout: float = 3600.0,
|
|
114
|
+
run_initial_validation: bool = True,
|
|
115
|
+
# Internal flags for debugging
|
|
116
|
+
_poml_trace: bool = False,
|
|
117
|
+
):
|
|
118
|
+
"""
|
|
119
|
+
Initialize the APO algorithm with configuration parameters.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
async_openai_client: AsyncOpenAI client for making LLM API calls.
|
|
123
|
+
gradient_model: Model name for computing textual gradients (critiques).
|
|
124
|
+
apply_edit_model: Model name for applying edits based on critiques.
|
|
125
|
+
diversity_temperature: Temperature parameter for LLM calls to control diversity.
|
|
126
|
+
gradient_batch_size: Number of rollout results to sample for gradient computation.
|
|
127
|
+
val_batch_size: Number of validation examples to use for evaluation.
|
|
128
|
+
beam_width: Number of top-scoring prompts to keep in the beam at each round.
|
|
129
|
+
branch_factor: Number of new prompt candidates to generate from each parent prompt
|
|
130
|
+
by applying textual gradient edits. This controls the expansion of the search tree.
|
|
131
|
+
beam_rounds: Number of beam search rounds to perform.
|
|
132
|
+
rollout_batch_timeout: Maximum time in seconds to wait for rollout batch completion.
|
|
133
|
+
run_initial_validation: If True, runs validation on the seed prompt before starting
|
|
134
|
+
optimization to establish a baseline score. Defaults to True.
|
|
135
|
+
"""
|
|
136
|
+
self.async_openai_client = async_openai_client
|
|
137
|
+
self.gradient_model = gradient_model
|
|
138
|
+
self.apply_edit_model = apply_edit_model
|
|
139
|
+
self.diversity_temperature = diversity_temperature
|
|
140
|
+
self.gradient_batch_size = gradient_batch_size
|
|
141
|
+
self.val_batch_size = val_batch_size
|
|
142
|
+
self.beam_width = beam_width
|
|
143
|
+
self.branch_factor = branch_factor
|
|
144
|
+
self.beam_rounds = beam_rounds
|
|
145
|
+
self.rollout_batch_timeout = rollout_batch_timeout
|
|
146
|
+
self.run_initial_validation = run_initial_validation
|
|
147
|
+
|
|
148
|
+
self._history_best_prompt: Optional[PromptTemplate] = None
|
|
149
|
+
self._history_best_score: float = float("-inf")
|
|
150
|
+
self._history_best_version: Optional[str] = None
|
|
151
|
+
|
|
152
|
+
self._version_counter: int = 0
|
|
153
|
+
|
|
154
|
+
self._poml_trace = _poml_trace
|
|
155
|
+
|
|
156
|
+
def _create_versioned_prompt(
|
|
157
|
+
self,
|
|
158
|
+
prompt_template: PromptTemplate,
|
|
159
|
+
*,
|
|
160
|
+
score: Optional[float] = None,
|
|
161
|
+
) -> VersionedPromptTemplate:
|
|
162
|
+
"""
|
|
163
|
+
Wrap a prompt template with a new monotonically increasing version identifier.
|
|
164
|
+
"""
|
|
165
|
+
version = f"v{self._version_counter}"
|
|
166
|
+
self._version_counter += 1
|
|
167
|
+
return VersionedPromptTemplate(version=version, prompt_template=prompt_template, score=score)
|
|
168
|
+
|
|
169
|
+
def _format_log_prefix(
|
|
170
|
+
self,
|
|
171
|
+
*,
|
|
172
|
+
round_num: Optional[int] = None,
|
|
173
|
+
beam_idx: Optional[int] = None,
|
|
174
|
+
branch_idx: Optional[int] = None,
|
|
175
|
+
prompt_version: Optional[str] = None,
|
|
176
|
+
) -> str:
|
|
177
|
+
"""
|
|
178
|
+
Construct the standardized log prefix.
|
|
179
|
+
"""
|
|
180
|
+
parts: List[str] = []
|
|
181
|
+
if round_num is not None:
|
|
182
|
+
parts.append(f"Round {round_num:02d}")
|
|
183
|
+
if beam_idx is not None:
|
|
184
|
+
parts.append(f"Beam {beam_idx:02d}")
|
|
185
|
+
if branch_idx is not None:
|
|
186
|
+
parts.append(f"Branch {branch_idx:02d}")
|
|
187
|
+
if prompt_version is not None:
|
|
188
|
+
parts.append(f"Prompt {prompt_version}")
|
|
189
|
+
if not parts:
|
|
190
|
+
return ""
|
|
191
|
+
return f"[{' | '.join(parts)}]"
|
|
192
|
+
|
|
193
|
+
def _log(self, level: int, message: str, *, prefix: Optional[str] = None) -> None:
|
|
194
|
+
"""
|
|
195
|
+
Log a message with an optional standardized prefix.
|
|
196
|
+
"""
|
|
197
|
+
effective_prefix = prefix
|
|
198
|
+
if effective_prefix:
|
|
199
|
+
logger.log(level, f"{effective_prefix} {message}")
|
|
200
|
+
else:
|
|
201
|
+
logger.log(level, message)
|
|
202
|
+
|
|
203
|
+
def get_seed_prompt_template(self) -> Tuple[str, PromptTemplate]:
|
|
204
|
+
"""
|
|
205
|
+
Extract the initial prompt template from the algorithm's resources.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
A tuple of (resource_name, prompt_template) representing the seed prompt.
|
|
209
|
+
|
|
210
|
+
Raises:
|
|
211
|
+
ValueError: If initial_resources is not set or no PromptTemplate is found.
|
|
212
|
+
"""
|
|
213
|
+
initial_resources = self.get_initial_resources()
|
|
214
|
+
if initial_resources is None:
|
|
215
|
+
raise ValueError(
|
|
216
|
+
"initial_resources are not set for APO algorithm. "
|
|
217
|
+
"Use algorithm.set_initial_resources() to set initial resources or set it in Trainer()"
|
|
218
|
+
)
|
|
219
|
+
for name, resource in initial_resources.items():
|
|
220
|
+
if isinstance(resource, PromptTemplate):
|
|
221
|
+
return name, resource
|
|
222
|
+
raise ValueError("No prompt template resource found in initial_resources")
|
|
223
|
+
|
|
224
|
+
def get_adapter(self) -> TraceToMessages:
|
|
225
|
+
"""
|
|
226
|
+
Get the adapter for converting spans to messages.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
The TraceToMessages instance for this algorithm.
|
|
230
|
+
|
|
231
|
+
Raises:
|
|
232
|
+
ValueError: If the adapter is not a TraceToMessages.
|
|
233
|
+
"""
|
|
234
|
+
adapter = super().get_adapter()
|
|
235
|
+
if not isinstance(adapter, TraceToMessages):
|
|
236
|
+
raise ValueError("Adapter must be a TraceToMessages for APO algorithm")
|
|
237
|
+
return adapter
|
|
238
|
+
|
|
239
|
+
def get_best_prompt(self) -> PromptTemplate:
|
|
240
|
+
"""
|
|
241
|
+
Retrieve the best prompt discovered during optimization.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
The prompt template with the highest validation score found so far.
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
ValueError: If no best prompt has been found yet (run() not called).
|
|
248
|
+
"""
|
|
249
|
+
if self._history_best_prompt is None:
|
|
250
|
+
raise ValueError("No best prompt found")
|
|
251
|
+
return self._history_best_prompt
|
|
252
|
+
|
|
253
|
+
async def compute_textual_gradient(
|
|
254
|
+
self,
|
|
255
|
+
current_prompt: VersionedPromptTemplate,
|
|
256
|
+
rollout_results: List[RolloutResultForAPO],
|
|
257
|
+
*,
|
|
258
|
+
prefix: Optional[str] = None,
|
|
259
|
+
) -> Optional[str]:
|
|
260
|
+
"""
|
|
261
|
+
Compute a textual gradient (critique) for the current prompt based on rollout results.
|
|
262
|
+
|
|
263
|
+
This method samples rollout results, sends them to an LLM along with the current prompt,
|
|
264
|
+
and generates a critique describing how the prompt could be improved.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
current_prompt: The prompt template to critique.
|
|
268
|
+
rollout_results: List of rollout results containing spans, messages, and rewards.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
A textual critique generated by the LLM, or None if generation fails.
|
|
272
|
+
"""
|
|
273
|
+
tg_template = random.choice(GRADIENT_PROMPT_FILES)
|
|
274
|
+
|
|
275
|
+
if len(rollout_results) < self.gradient_batch_size:
|
|
276
|
+
self._log(
|
|
277
|
+
logging.WARNING,
|
|
278
|
+
f"Only {len(rollout_results)} rollouts available, but {self.gradient_batch_size} are needed. Using all rollouts.",
|
|
279
|
+
prefix=prefix,
|
|
280
|
+
)
|
|
281
|
+
sampled_rollout_results = rollout_results
|
|
282
|
+
else:
|
|
283
|
+
sampled_rollout_results = random.sample(rollout_results, self.gradient_batch_size)
|
|
284
|
+
|
|
285
|
+
self._log(
|
|
286
|
+
logging.INFO,
|
|
287
|
+
f"Gradient will be computed with {self.gradient_model} for {len(sampled_rollout_results)} rollouts with template: {tg_template.name}",
|
|
288
|
+
prefix=prefix,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
tg_msg = poml.poml( # type: ignore
|
|
292
|
+
tg_template,
|
|
293
|
+
context={
|
|
294
|
+
"experiments": sampled_rollout_results,
|
|
295
|
+
"prompt_template": current_prompt.prompt_template.template,
|
|
296
|
+
},
|
|
297
|
+
format="openai_chat",
|
|
298
|
+
)
|
|
299
|
+
self._log(
|
|
300
|
+
logging.DEBUG,
|
|
301
|
+
f"Gradient computed with {self.gradient_model} prompt: {tg_msg}",
|
|
302
|
+
prefix=prefix,
|
|
303
|
+
)
|
|
304
|
+
critique_response = await self.async_openai_client.chat.completions.create(
|
|
305
|
+
model=self.gradient_model,
|
|
306
|
+
messages=tg_msg["messages"], # type: ignore
|
|
307
|
+
temperature=self.diversity_temperature,
|
|
308
|
+
)
|
|
309
|
+
critique_text = critique_response.choices[0].message.content
|
|
310
|
+
self._log(
|
|
311
|
+
logging.INFO,
|
|
312
|
+
f"Gradient computed with {self.gradient_model} has result: {critique_text}",
|
|
313
|
+
prefix=prefix,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
return critique_text
|
|
317
|
+
|
|
318
|
+
async def textual_gradient_and_apply_edit(
|
|
319
|
+
self,
|
|
320
|
+
current_prompt: VersionedPromptTemplate,
|
|
321
|
+
rollout: List[RolloutResultForAPO],
|
|
322
|
+
*,
|
|
323
|
+
prefix: Optional[str] = None,
|
|
324
|
+
) -> Optional[str]:
|
|
325
|
+
"""
|
|
326
|
+
Generate an improved prompt by computing a textual gradient and applying an edit.
|
|
327
|
+
|
|
328
|
+
This is the main optimization step that:
|
|
329
|
+
|
|
330
|
+
1. Computes a critique (textual gradient) based on rollout performance
|
|
331
|
+
2. Uses another LLM to apply the critique and generate an improved prompt
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
current_prompt: The current prompt template to improve.
|
|
335
|
+
rollout: List of rollout results to base the critique on.
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
The improved prompt text, or the original prompt if gradient computation fails.
|
|
339
|
+
"""
|
|
340
|
+
# 1) Critique
|
|
341
|
+
critique_text = await self.compute_textual_gradient(
|
|
342
|
+
current_prompt,
|
|
343
|
+
rollout,
|
|
344
|
+
prefix=prefix,
|
|
345
|
+
)
|
|
346
|
+
if not critique_text:
|
|
347
|
+
self._log(
|
|
348
|
+
logging.ERROR,
|
|
349
|
+
"Failed to compute critique for prompt.",
|
|
350
|
+
prefix=prefix,
|
|
351
|
+
)
|
|
352
|
+
return current_prompt.prompt_template.template
|
|
353
|
+
|
|
354
|
+
# 2) Apply edit
|
|
355
|
+
ae_template = random.choice(APPLY_EDIT_PROMPT_FILES)
|
|
356
|
+
self._log(
|
|
357
|
+
logging.INFO,
|
|
358
|
+
f"Edit will be generated by {self.apply_edit_model} with template: {ae_template.name}",
|
|
359
|
+
prefix=prefix,
|
|
360
|
+
)
|
|
361
|
+
ae_msg = poml.poml( # type: ignore
|
|
362
|
+
ae_template,
|
|
363
|
+
context={
|
|
364
|
+
"prompt_template": current_prompt.prompt_template.template,
|
|
365
|
+
"critique": critique_text,
|
|
366
|
+
},
|
|
367
|
+
format="openai_chat",
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
ae_response = await self.async_openai_client.chat.completions.create(
|
|
371
|
+
model=self.apply_edit_model,
|
|
372
|
+
messages=ae_msg["messages"], # type: ignore
|
|
373
|
+
temperature=self.diversity_temperature,
|
|
374
|
+
)
|
|
375
|
+
new_prompt = ae_response.choices[0].message.content
|
|
376
|
+
if new_prompt:
|
|
377
|
+
self._log(
|
|
378
|
+
logging.INFO,
|
|
379
|
+
f"Edit generated by {self.apply_edit_model}: {new_prompt[:50]}...",
|
|
380
|
+
prefix=prefix,
|
|
381
|
+
)
|
|
382
|
+
return new_prompt
|
|
383
|
+
|
|
384
|
+
@with_store
|
|
385
|
+
async def get_rollout_results(
|
|
386
|
+
self,
|
|
387
|
+
store: LightningStore,
|
|
388
|
+
rollout: List[Rollout],
|
|
389
|
+
*,
|
|
390
|
+
prefix: Optional[str] = None,
|
|
391
|
+
) -> List[RolloutResultForAPO]:
|
|
392
|
+
"""
|
|
393
|
+
Convert completed rollouts to APO-compatible result format.
|
|
394
|
+
|
|
395
|
+
Fetches spans for each rollout, adapts them to messages, and packages them
|
|
396
|
+
with rewards and status information for gradient computation.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
rollout: List of completed rollout metadata.
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
List of rollout results formatted for APO processing.
|
|
403
|
+
"""
|
|
404
|
+
rollout_results: List[RolloutResultForAPO] = []
|
|
405
|
+
adapter = self.get_adapter()
|
|
406
|
+
for r in rollout:
|
|
407
|
+
spans = await store.query_spans(r.rollout_id)
|
|
408
|
+
messages = adapter.adapt(spans)
|
|
409
|
+
rollout_result = RolloutResultForAPO(
|
|
410
|
+
status=r.status,
|
|
411
|
+
final_reward=find_final_reward(spans),
|
|
412
|
+
spans=[span.model_dump() for span in spans],
|
|
413
|
+
messages=messages,
|
|
414
|
+
)
|
|
415
|
+
self._log(
|
|
416
|
+
logging.DEBUG,
|
|
417
|
+
f"Rollout result for {r.rollout_id}: status {rollout_result['status']} with final reward {rollout_result['final_reward']}. "
|
|
418
|
+
f"{len(rollout_result['spans'])} spans and {len(rollout_result['messages'])} messages.",
|
|
419
|
+
prefix=prefix,
|
|
420
|
+
)
|
|
421
|
+
rollout_results.append(rollout_result)
|
|
422
|
+
return rollout_results
|
|
423
|
+
|
|
424
|
+
async def evaluate_prompt_on_batch(
|
|
425
|
+
self,
|
|
426
|
+
prompt: VersionedPromptTemplate,
|
|
427
|
+
resource_name: str,
|
|
428
|
+
dataset: Sequence[T_task],
|
|
429
|
+
mode: RolloutMode,
|
|
430
|
+
*,
|
|
431
|
+
prefix: Optional[str] = None,
|
|
432
|
+
) -> Tuple[List[RolloutResultForAPO], float]:
|
|
433
|
+
"""
|
|
434
|
+
Evaluate a prompt on a batch of tasks by running rollouts and computing average reward.
|
|
435
|
+
|
|
436
|
+
This method:
|
|
437
|
+
|
|
438
|
+
1. Adds the prompt as a named resource to the store
|
|
439
|
+
2. Enqueues rollouts for each task in the dataset
|
|
440
|
+
3. Waits for rollouts to complete (with timeout)
|
|
441
|
+
4. Computes and returns the average reward
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
prompt: The prompt template string to evaluate.
|
|
445
|
+
resource_name: The name to register the prompt under in the store.
|
|
446
|
+
dataset: Sequence of tasks to evaluate the prompt on.
|
|
447
|
+
mode: Rollout mode ("train" or "val") for logging/tracking.
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
A tuple of (rollout_results, average_reward) where rollout_results contains
|
|
451
|
+
detailed information for each rollout and average_reward is the mean final reward.
|
|
452
|
+
"""
|
|
453
|
+
store = self.get_store()
|
|
454
|
+
preview = prompt.prompt_template.template[:50]
|
|
455
|
+
self._log(
|
|
456
|
+
logging.INFO,
|
|
457
|
+
f'Evaluating prompt "{preview}..." on {len(dataset)} tasks in {mode} mode',
|
|
458
|
+
prefix=prefix,
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Install prompt as named resource
|
|
462
|
+
resources: NamedResources = {resource_name: prompt.prompt_template}
|
|
463
|
+
resource_update = await store.update_resources(prompt.version, resources)
|
|
464
|
+
|
|
465
|
+
rollout_ids: List[str] = []
|
|
466
|
+
for t in dataset:
|
|
467
|
+
r = await store.enqueue_rollout(input=t, mode=mode, resources_id=resource_update.resources_id)
|
|
468
|
+
rollout_ids.append(r.rollout_id)
|
|
469
|
+
|
|
470
|
+
deadline = time.time() + self.rollout_batch_timeout
|
|
471
|
+
finished: List[Rollout] = []
|
|
472
|
+
while time.time() < deadline:
|
|
473
|
+
finished = await store.wait_for_rollouts(rollout_ids=rollout_ids, timeout=0.0)
|
|
474
|
+
if len(finished) >= len(rollout_ids):
|
|
475
|
+
self._log(
|
|
476
|
+
logging.INFO,
|
|
477
|
+
f"All {len(rollout_ids)} rollouts finished within timeout.",
|
|
478
|
+
prefix=prefix,
|
|
479
|
+
)
|
|
480
|
+
break
|
|
481
|
+
else:
|
|
482
|
+
self._log(
|
|
483
|
+
logging.DEBUG,
|
|
484
|
+
f"Only {len(finished)} rollouts finished within timeout. Waiting for remaining {len(rollout_ids) - len(finished)} rollouts.",
|
|
485
|
+
prefix=prefix,
|
|
486
|
+
)
|
|
487
|
+
# Sleep to avoid busy-waiting
|
|
488
|
+
await asyncio.sleep(2.0)
|
|
489
|
+
|
|
490
|
+
rollout_results = await self.get_rollout_results(
|
|
491
|
+
finished,
|
|
492
|
+
prefix=prefix,
|
|
493
|
+
)
|
|
494
|
+
final_rewards = [rr["final_reward"] for rr in rollout_results]
|
|
495
|
+
|
|
496
|
+
avg = float(sum([r or 0.0 for r in final_rewards]) / max(1, len(final_rewards)))
|
|
497
|
+
status_counter = Counter([rr["status"] for rr in rollout_results])
|
|
498
|
+
|
|
499
|
+
self._log(
|
|
500
|
+
logging.INFO,
|
|
501
|
+
f"Evaluated {len(rollout_results)} rollouts. Statuses: {status_counter}. Rewards: {final_rewards}, average is {avg}",
|
|
502
|
+
prefix=prefix,
|
|
503
|
+
)
|
|
504
|
+
return rollout_results, avg
|
|
505
|
+
|
|
506
|
+
def _initialize_beam(
|
|
507
|
+
self,
|
|
508
|
+
train_dataset: Optional[Dataset[T_task]],
|
|
509
|
+
val_dataset: Optional[Dataset[T_task]],
|
|
510
|
+
) -> Tuple[str, PromptTemplate, Iterator[Sequence[T_task]], Iterator[Sequence[T_task]]]:
|
|
511
|
+
"""
|
|
512
|
+
Initialize the beam search with seed prompt and dataset iterators.
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
train_dataset: Dataset for computing gradients.
|
|
516
|
+
val_dataset: Dataset for evaluating prompts.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
Tuple of (resource_name, seed_prompt, grad_iterator, val_iterator).
|
|
520
|
+
|
|
521
|
+
Raises:
|
|
522
|
+
ValueError: If either dataset is None.
|
|
523
|
+
"""
|
|
524
|
+
resource_name, seed_prompt = self.get_seed_prompt_template()
|
|
525
|
+
|
|
526
|
+
if train_dataset is None:
|
|
527
|
+
raise ValueError("train_dataset is required for APO algorithm")
|
|
528
|
+
if val_dataset is None:
|
|
529
|
+
raise ValueError("val_dataset is required for APO algorithm")
|
|
530
|
+
|
|
531
|
+
grad_dataset_iterator = batch_iter_over_dataset(train_dataset, self.gradient_batch_size)
|
|
532
|
+
val_dataset_iterator = batch_iter_over_dataset(val_dataset, self.val_batch_size)
|
|
533
|
+
|
|
534
|
+
# Initialize history tracking
|
|
535
|
+
self._history_best_prompt = seed_prompt
|
|
536
|
+
self._history_best_score = float("-inf")
|
|
537
|
+
|
|
538
|
+
return resource_name, seed_prompt, grad_dataset_iterator, val_dataset_iterator
|
|
539
|
+
|
|
540
|
+
def _sample_parent_prompts(
|
|
541
|
+
self,
|
|
542
|
+
beam: List[VersionedPromptTemplate],
|
|
543
|
+
round_num: int,
|
|
544
|
+
) -> List[Tuple[int, VersionedPromptTemplate]]:
|
|
545
|
+
"""
|
|
546
|
+
Sample parent prompts from the current beam for generating new candidates.
|
|
547
|
+
|
|
548
|
+
If the beam has fewer prompts than beam_width, replicates existing prompts.
|
|
549
|
+
Otherwise, randomly samples beam_width prompts.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
beam: Current list of prompt templates in the beam.
|
|
553
|
+
round_num: Current round number (for logging, 0-indexed).
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
List of parent prompts to generate children from.
|
|
557
|
+
"""
|
|
558
|
+
display_round = round_num + 1
|
|
559
|
+
if len(beam) < self.beam_width:
|
|
560
|
+
prefix = self._format_log_prefix(round_num=display_round)
|
|
561
|
+
self._log(
|
|
562
|
+
logging.WARNING,
|
|
563
|
+
f"Beam width is currently {self.beam_width}, but only {len(beam)} prompts in beam. Replicating all prompts.",
|
|
564
|
+
prefix=prefix,
|
|
565
|
+
)
|
|
566
|
+
return [(i % len(beam), beam[i % len(beam)]) for i in range(self.beam_width)]
|
|
567
|
+
|
|
568
|
+
selected_indices = random.sample(range(len(beam)), self.beam_width)
|
|
569
|
+
return [(idx, beam[idx]) for idx in selected_indices]
|
|
570
|
+
|
|
571
|
+
async def _generate_candidate_prompts(
|
|
572
|
+
self,
|
|
573
|
+
parent_prompts: List[Tuple[int, VersionedPromptTemplate]],
|
|
574
|
+
resource_name: str,
|
|
575
|
+
grad_dataset_iterator: Iterator[Sequence[T_task]],
|
|
576
|
+
round_num: int,
|
|
577
|
+
) -> List[VersionedPromptTemplate]:
|
|
578
|
+
"""
|
|
579
|
+
Generate new candidate prompts from parents using textual gradients.
|
|
580
|
+
|
|
581
|
+
For each parent prompt, generates branch_factor new candidates by:
|
|
582
|
+
|
|
583
|
+
1. Evaluating the parent on a training batch
|
|
584
|
+
2. Computing textual gradient
|
|
585
|
+
3. Applying edit to generate improved prompt
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
parent_prompts: List of parent prompts to generate children from.
|
|
589
|
+
resource_name: Name to register prompts under in the store.
|
|
590
|
+
grad_dataset_iterator: Iterator over training data batches.
|
|
591
|
+
round_num: Current round number (for logging, 0-indexed).
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
List of newly generated prompt templates.
|
|
595
|
+
"""
|
|
596
|
+
display_round = round_num + 1
|
|
597
|
+
round_prefix = self._format_log_prefix(round_num=display_round)
|
|
598
|
+
self._log(
|
|
599
|
+
logging.INFO,
|
|
600
|
+
f"Applying {self.branch_factor} edits to each of the {len(parent_prompts)} parents based on "
|
|
601
|
+
"gradients computed on training dataset",
|
|
602
|
+
prefix=round_prefix,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
parent_prompts_str = [
|
|
606
|
+
f"{p.version}:{p.score:.3f}" if p.score is not None else p.version for _, p in parent_prompts
|
|
607
|
+
]
|
|
608
|
+
self._log(
|
|
609
|
+
logging.INFO,
|
|
610
|
+
f"Parent prompts: {', '.join(parent_prompts_str)}",
|
|
611
|
+
prefix=round_prefix,
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
candidates: List[VersionedPromptTemplate] = []
|
|
615
|
+
used_beam_indices: Set[int] = set()
|
|
616
|
+
for real_beam_idx, (beam_idx, prompt) in enumerate(parent_prompts):
|
|
617
|
+
if beam_idx in used_beam_indices:
|
|
618
|
+
beam_prefix = self._format_log_prefix(
|
|
619
|
+
round_num=display_round,
|
|
620
|
+
beam_idx=beam_idx + 1,
|
|
621
|
+
prompt_version=prompt.version,
|
|
622
|
+
)
|
|
623
|
+
self._log(
|
|
624
|
+
logging.WARNING,
|
|
625
|
+
"Duplicated beam index found. Might be caused by beam_width too high. "
|
|
626
|
+
+ f"The real index of this beam is {real_beam_idx + 1}.",
|
|
627
|
+
prefix=beam_prefix,
|
|
628
|
+
)
|
|
629
|
+
else:
|
|
630
|
+
used_beam_indices.add(beam_idx)
|
|
631
|
+
for branch_idx in range(self.branch_factor):
|
|
632
|
+
parent_prefix = self._format_log_prefix(
|
|
633
|
+
round_num=display_round,
|
|
634
|
+
beam_idx=beam_idx + 1,
|
|
635
|
+
branch_idx=branch_idx + 1,
|
|
636
|
+
prompt_version=prompt.version,
|
|
637
|
+
)
|
|
638
|
+
baseline_score = f"{prompt.score:.3f}" if prompt.score is not None else "N/A"
|
|
639
|
+
self._log(
|
|
640
|
+
logging.INFO,
|
|
641
|
+
f"Use parent prompt {prompt.version} as a baseline to generate a new prompt. Baseline score: {baseline_score}",
|
|
642
|
+
prefix=parent_prefix,
|
|
643
|
+
)
|
|
644
|
+
grad_samples = next(grad_dataset_iterator)
|
|
645
|
+
rollout_results, _ = await self.evaluate_prompt_on_batch(
|
|
646
|
+
prompt,
|
|
647
|
+
resource_name,
|
|
648
|
+
grad_samples,
|
|
649
|
+
mode="train",
|
|
650
|
+
prefix=parent_prefix,
|
|
651
|
+
)
|
|
652
|
+
new_prompt = await self.textual_gradient_and_apply_edit(
|
|
653
|
+
prompt,
|
|
654
|
+
rollout_results,
|
|
655
|
+
prefix=parent_prefix,
|
|
656
|
+
)
|
|
657
|
+
if not new_prompt:
|
|
658
|
+
self._log(
|
|
659
|
+
logging.ERROR,
|
|
660
|
+
f"Failed to compute edit for prompt: {prompt.prompt_template.template}",
|
|
661
|
+
prefix=parent_prefix,
|
|
662
|
+
)
|
|
663
|
+
continue
|
|
664
|
+
new_prompt_template = PromptTemplate(template=new_prompt, engine="f-string")
|
|
665
|
+
versioned_candidate = self._create_versioned_prompt(new_prompt_template)
|
|
666
|
+
self._log(
|
|
667
|
+
logging.INFO,
|
|
668
|
+
f"New prompt template created from parent {prompt.version}: {versioned_candidate.version}",
|
|
669
|
+
prefix=parent_prefix,
|
|
670
|
+
)
|
|
671
|
+
candidate_prefix = self._format_log_prefix(
|
|
672
|
+
round_num=display_round, prompt_version=versioned_candidate.version
|
|
673
|
+
)
|
|
674
|
+
self._log(
|
|
675
|
+
logging.INFO,
|
|
676
|
+
f"New prompt template created from parent {prompt.version}:\n```\n{new_prompt}\n```",
|
|
677
|
+
prefix=candidate_prefix,
|
|
678
|
+
)
|
|
679
|
+
candidates.append(versioned_candidate)
|
|
680
|
+
|
|
681
|
+
return candidates
|
|
682
|
+
|
|
683
|
+
async def _evaluate_and_select_beam(
|
|
684
|
+
self,
|
|
685
|
+
candidates: List[VersionedPromptTemplate],
|
|
686
|
+
resource_name: str,
|
|
687
|
+
val_dataset_iterator: Iterator[Sequence[T_task]],
|
|
688
|
+
round_num: int,
|
|
689
|
+
) -> List[VersionedPromptTemplate]:
|
|
690
|
+
"""
|
|
691
|
+
Evaluate all candidate prompts on validation data and select top-k for the beam.
|
|
692
|
+
|
|
693
|
+
Args:
|
|
694
|
+
candidates: List of candidate prompts to evaluate.
|
|
695
|
+
resource_name: Name to register prompts under in the store.
|
|
696
|
+
val_dataset_iterator: Iterator over validation data batches.
|
|
697
|
+
round_num: Current round number (for logging, 0-indexed).
|
|
698
|
+
|
|
699
|
+
Returns:
|
|
700
|
+
List of top beam_width prompts sorted by validation score (best first).
|
|
701
|
+
|
|
702
|
+
Raises:
|
|
703
|
+
ValueError: If no candidates remain after evaluation.
|
|
704
|
+
"""
|
|
705
|
+
display_round = round_num + 1
|
|
706
|
+
round_prefix = self._format_log_prefix(round_num=display_round)
|
|
707
|
+
self._log(
|
|
708
|
+
logging.INFO,
|
|
709
|
+
f"Evaluating {len(candidates)} candidates on validation dataset",
|
|
710
|
+
prefix=round_prefix,
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
val_batch = next(val_dataset_iterator)
|
|
714
|
+
|
|
715
|
+
for prompt in candidates:
|
|
716
|
+
candidate_prefix = self._format_log_prefix(
|
|
717
|
+
round_num=display_round,
|
|
718
|
+
prompt_version=prompt.version,
|
|
719
|
+
)
|
|
720
|
+
_, score = await self.evaluate_prompt_on_batch(
|
|
721
|
+
prompt,
|
|
722
|
+
resource_name,
|
|
723
|
+
val_batch,
|
|
724
|
+
mode="val",
|
|
725
|
+
prefix=candidate_prefix,
|
|
726
|
+
)
|
|
727
|
+
prompt.score = score
|
|
728
|
+
self._log(
|
|
729
|
+
logging.INFO,
|
|
730
|
+
f"Candidate score: {score:.3f}",
|
|
731
|
+
prefix=candidate_prefix,
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
# Sort by score (descending) and select top beam_width
|
|
735
|
+
sorted_prompts = [p for p in sorted(candidates, key=lambda x: cast(float, x.score), reverse=True)]
|
|
736
|
+
selected_prompts = sorted_prompts[: self.beam_width]
|
|
737
|
+
selected_versions = [
|
|
738
|
+
f"{prompt.version}:{prompt.score:.3f}" if prompt.score is not None else prompt.version
|
|
739
|
+
for prompt in selected_prompts
|
|
740
|
+
]
|
|
741
|
+
self._log(
|
|
742
|
+
logging.INFO,
|
|
743
|
+
f"Top {len(selected_prompts)} candidates on validation dataset: {selected_versions}",
|
|
744
|
+
prefix=round_prefix,
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
if len(selected_prompts) == 0:
|
|
748
|
+
raise ValueError("No beam candidates any more")
|
|
749
|
+
|
|
750
|
+
return selected_prompts
|
|
751
|
+
|
|
752
|
+
async def _update_best_prompt(
|
|
753
|
+
self,
|
|
754
|
+
beam: List[VersionedPromptTemplate],
|
|
755
|
+
resource_name: str,
|
|
756
|
+
val_dataset: Dataset[T_task],
|
|
757
|
+
round_num: int,
|
|
758
|
+
) -> None:
|
|
759
|
+
"""
|
|
760
|
+
Evaluate the best prompt in the beam on the full validation set and update history.
|
|
761
|
+
|
|
762
|
+
Args:
|
|
763
|
+
beam: Current beam of prompts (sorted, best first).
|
|
764
|
+
resource_name: Name to register prompts under in the store.
|
|
765
|
+
val_dataset: Full validation dataset.
|
|
766
|
+
round_num: Current round number (for logging, 0-indexed).
|
|
767
|
+
"""
|
|
768
|
+
display_round = round_num + 1
|
|
769
|
+
best_prompt = beam[0]
|
|
770
|
+
prefix = self._format_log_prefix(round_num=display_round, prompt_version=best_prompt.version)
|
|
771
|
+
_, best_score = await self.evaluate_prompt_on_batch(
|
|
772
|
+
best_prompt,
|
|
773
|
+
resource_name,
|
|
774
|
+
cast(Sequence[T_task], val_dataset),
|
|
775
|
+
mode="val",
|
|
776
|
+
prefix=prefix,
|
|
777
|
+
)
|
|
778
|
+
self._log(
|
|
779
|
+
logging.INFO,
|
|
780
|
+
f"Beam leader score: {best_score:.3f}",
|
|
781
|
+
prefix=prefix,
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
if best_score > self._history_best_score:
|
|
785
|
+
prev = self._history_best_score
|
|
786
|
+
self._log(
|
|
787
|
+
logging.INFO,
|
|
788
|
+
f"Best prompt updated. New best score: {best_score:.3f} (prev: {prev:.3f})",
|
|
789
|
+
prefix=prefix,
|
|
790
|
+
)
|
|
791
|
+
self._history_best_prompt = best_prompt.prompt_template
|
|
792
|
+
self._history_best_score = best_score
|
|
793
|
+
self._history_best_version = best_prompt.version
|
|
794
|
+
else:
|
|
795
|
+
self._log(
|
|
796
|
+
logging.WARNING,
|
|
797
|
+
f"Best prompt not updated. Current score: {best_score:.3f} vs. history best: {self._history_best_score:.3f})",
|
|
798
|
+
prefix=prefix,
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
@with_llm_proxy()
|
|
802
|
+
@with_store
|
|
803
|
+
async def run(
|
|
804
|
+
self,
|
|
805
|
+
store: LightningStore, # Injected by decorator - callers should not provide this parameter
|
|
806
|
+
llm_proxy: Optional[LLMProxy], # Injected by decorator - callers should not provide this parameter
|
|
807
|
+
train_dataset: Optional[Dataset[T_task]] = None,
|
|
808
|
+
val_dataset: Optional[Dataset[T_task]] = None,
|
|
809
|
+
) -> None:
|
|
810
|
+
"""
|
|
811
|
+
Execute the APO algorithm to optimize prompts through beam search with textual gradients.
|
|
812
|
+
|
|
813
|
+
The algorithm performs iterative prompt optimization over multiple rounds:
|
|
814
|
+
|
|
815
|
+
- Each round: samples parent prompts, generates new candidates via textual gradients,
|
|
816
|
+
evaluates all candidates on validation data, and keeps the top performers
|
|
817
|
+
- Tracks the historically best prompt across all rounds
|
|
818
|
+
- Uses different training data samples for each gradient computation to ensure diversity
|
|
819
|
+
|
|
820
|
+
Args:
|
|
821
|
+
train_dataset: Dataset of tasks for computing textual gradients. Required.
|
|
822
|
+
val_dataset: Dataset of tasks for evaluating and selecting prompts. Required.
|
|
823
|
+
|
|
824
|
+
Raises:
|
|
825
|
+
ValueError: If train_dataset or val_dataset is None, or if resources are not set.
|
|
826
|
+
"""
|
|
827
|
+
# Initialize beam search
|
|
828
|
+
resource_name, seed_prompt, grad_iterator, val_iterator = self._initialize_beam(train_dataset, val_dataset)
|
|
829
|
+
|
|
830
|
+
if self._poml_trace:
|
|
831
|
+
poml.set_trace(trace_dir="pomltrace")
|
|
832
|
+
|
|
833
|
+
# Validation datasets are guaranteed to be non-None after initialization
|
|
834
|
+
assert val_dataset is not None
|
|
835
|
+
|
|
836
|
+
# Start with seed prompt in the beam
|
|
837
|
+
seed_versioned = self._create_versioned_prompt(seed_prompt)
|
|
838
|
+
beam: List[VersionedPromptTemplate] = [seed_versioned]
|
|
839
|
+
self._history_best_prompt = seed_prompt
|
|
840
|
+
self._history_best_version = seed_versioned.version
|
|
841
|
+
|
|
842
|
+
# Optionally evaluate seed prompt on validation set to establish baseline
|
|
843
|
+
if self.run_initial_validation:
|
|
844
|
+
seed_prefix = self._format_log_prefix(round_num=0, prompt_version=seed_versioned.version)
|
|
845
|
+
self._log(
|
|
846
|
+
logging.INFO,
|
|
847
|
+
"Evaluating seed prompt on validation dataset before optimization...",
|
|
848
|
+
prefix=seed_prefix,
|
|
849
|
+
)
|
|
850
|
+
_, seed_score = await self.evaluate_prompt_on_batch(
|
|
851
|
+
seed_versioned,
|
|
852
|
+
resource_name,
|
|
853
|
+
cast(Sequence[T_task], val_dataset),
|
|
854
|
+
mode="val",
|
|
855
|
+
prefix=seed_prefix,
|
|
856
|
+
)
|
|
857
|
+
self._log(
|
|
858
|
+
logging.INFO,
|
|
859
|
+
f"Seed prompt baseline score: {seed_score:.3f}",
|
|
860
|
+
prefix=seed_prefix,
|
|
861
|
+
)
|
|
862
|
+
self._history_best_prompt = seed_prompt
|
|
863
|
+
self._history_best_score = seed_score
|
|
864
|
+
self._history_best_version = seed_versioned.version
|
|
865
|
+
|
|
866
|
+
# Run beam search for specified number of rounds
|
|
867
|
+
for rnd in range(self.beam_rounds):
|
|
868
|
+
display_round = rnd + 1
|
|
869
|
+
round_prefix = self._format_log_prefix(round_num=display_round)
|
|
870
|
+
self._log(
|
|
871
|
+
logging.INFO,
|
|
872
|
+
f"Round {display_round}/{self.beam_rounds}...",
|
|
873
|
+
prefix=round_prefix,
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# Sample parent prompts from current beam
|
|
877
|
+
parent_prompts = self._sample_parent_prompts(beam, rnd)
|
|
878
|
+
|
|
879
|
+
# Generate new candidate prompts from parents
|
|
880
|
+
new_candidates = await self._generate_candidate_prompts(parent_prompts, resource_name, grad_iterator, rnd)
|
|
881
|
+
|
|
882
|
+
# Combine existing beam with new candidates
|
|
883
|
+
all_candidates = [*beam, *new_candidates]
|
|
884
|
+
|
|
885
|
+
# Evaluate and select top-k prompts for next beam
|
|
886
|
+
beam = await self._evaluate_and_select_beam(all_candidates, resource_name, val_iterator, rnd)
|
|
887
|
+
|
|
888
|
+
# Update historically best prompt if improved
|
|
889
|
+
await self._update_best_prompt(beam, resource_name, val_dataset, rnd)
|