mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,889 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ """
4
+ APO with textual gradients that read rollout spans and outputs to modify the prompt.
5
+
6
+ - algo: beam search with span-aware textual gradients -> apply_edit via LLM
7
+ - rollout: same pattern as your example, but task is a dict (T_task)
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import logging
14
+ import random
15
+ import time
16
+ from dataclasses import dataclass
17
+ from pathlib import Path
18
+ from typing import (
19
+ TYPE_CHECKING,
20
+ Any,
21
+ Counter,
22
+ Dict,
23
+ Generic,
24
+ Iterator,
25
+ List,
26
+ Optional,
27
+ Sequence,
28
+ Set,
29
+ Tuple,
30
+ TypedDict,
31
+ TypeVar,
32
+ cast,
33
+ )
34
+
35
+ import poml
36
+ from openai import AsyncOpenAI
37
+
38
+ from mantisdk.adapter.messages import TraceToMessages
39
+ from mantisdk.algorithm.base import Algorithm
40
+ from mantisdk.algorithm.utils import batch_iter_over_dataset, with_llm_proxy, with_store
41
+ from mantisdk.reward import find_final_reward
42
+ from mantisdk.types import Dataset, NamedResources, PromptTemplate, Rollout, RolloutMode, RolloutStatus
43
+
44
+ if TYPE_CHECKING:
45
+ from mantisdk.llm_proxy import LLMProxy
46
+ from mantisdk.store.base import LightningStore
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+ T_task = TypeVar("T_task")
51
+
52
+
53
+ class RolloutResultForAPO(TypedDict):
54
+ """This must be all JSON serializable to be processable by POML."""
55
+
56
+ status: RolloutStatus
57
+ final_reward: Optional[float]
58
+ spans: List[Dict[str, Any]]
59
+ messages: List[Any]
60
+
61
+
62
+ @dataclass
63
+ class VersionedPromptTemplate:
64
+ version: str
65
+ prompt_template: PromptTemplate
66
+ score: Optional[float] = None
67
+
68
+
69
+ GRADIENT_PROMPT_FILES = [
70
+ Path(__file__).parent / "prompts" / "text_gradient_variant01.poml",
71
+ Path(__file__).parent / "prompts" / "text_gradient_variant02.poml",
72
+ Path(__file__).parent / "prompts" / "text_gradient_variant03.poml",
73
+ ]
74
+
75
+ APPLY_EDIT_PROMPT_FILES = [
76
+ Path(__file__).parent / "prompts" / "apply_edit_variant01.poml",
77
+ Path(__file__).parent / "prompts" / "apply_edit_variant02.poml",
78
+ ]
79
+
80
+
81
+ class APO(Algorithm, Generic[T_task]):
82
+ """Automatic Prompt Optimization (APO) algorithm using textual gradients and beam search.
83
+
84
+ APO is an iterative prompt optimization algorithm that uses LLM-generated textual gradients
85
+ to improve prompts through a beam search process. It evaluates prompts on rollouts,
86
+ computes critiques based on the results, and applies edits to generate improved prompts.
87
+
88
+ The algorithm operates in rounds, where each round:
89
+
90
+ 1. Samples parent prompts from the current beam
91
+ 2. Generates new prompts by computing textual gradients and applying edits
92
+ 3. Evaluates all candidates on a validation set
93
+ 4. Selects the top-k prompts for the next round
94
+
95
+ Based on the ideas from:
96
+
97
+ - [ProTeGi](https://aclanthology.org/2023.emnlp-main.494.pdf)
98
+ - [TextGrad](https://github.com/zou-group/textgrad)
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ async_openai_client: AsyncOpenAI,
104
+ *,
105
+ gradient_model: str = "gpt-5-mini",
106
+ apply_edit_model: str = "gpt-4.1-mini",
107
+ diversity_temperature: float = 1.0,
108
+ gradient_batch_size: int = 4,
109
+ val_batch_size: int = 16,
110
+ beam_width: int = 4,
111
+ branch_factor: int = 4,
112
+ beam_rounds: int = 3,
113
+ rollout_batch_timeout: float = 3600.0,
114
+ run_initial_validation: bool = True,
115
+ # Internal flags for debugging
116
+ _poml_trace: bool = False,
117
+ ):
118
+ """
119
+ Initialize the APO algorithm with configuration parameters.
120
+
121
+ Args:
122
+ async_openai_client: AsyncOpenAI client for making LLM API calls.
123
+ gradient_model: Model name for computing textual gradients (critiques).
124
+ apply_edit_model: Model name for applying edits based on critiques.
125
+ diversity_temperature: Temperature parameter for LLM calls to control diversity.
126
+ gradient_batch_size: Number of rollout results to sample for gradient computation.
127
+ val_batch_size: Number of validation examples to use for evaluation.
128
+ beam_width: Number of top-scoring prompts to keep in the beam at each round.
129
+ branch_factor: Number of new prompt candidates to generate from each parent prompt
130
+ by applying textual gradient edits. This controls the expansion of the search tree.
131
+ beam_rounds: Number of beam search rounds to perform.
132
+ rollout_batch_timeout: Maximum time in seconds to wait for rollout batch completion.
133
+ run_initial_validation: If True, runs validation on the seed prompt before starting
134
+ optimization to establish a baseline score. Defaults to True.
135
+ """
136
+ self.async_openai_client = async_openai_client
137
+ self.gradient_model = gradient_model
138
+ self.apply_edit_model = apply_edit_model
139
+ self.diversity_temperature = diversity_temperature
140
+ self.gradient_batch_size = gradient_batch_size
141
+ self.val_batch_size = val_batch_size
142
+ self.beam_width = beam_width
143
+ self.branch_factor = branch_factor
144
+ self.beam_rounds = beam_rounds
145
+ self.rollout_batch_timeout = rollout_batch_timeout
146
+ self.run_initial_validation = run_initial_validation
147
+
148
+ self._history_best_prompt: Optional[PromptTemplate] = None
149
+ self._history_best_score: float = float("-inf")
150
+ self._history_best_version: Optional[str] = None
151
+
152
+ self._version_counter: int = 0
153
+
154
+ self._poml_trace = _poml_trace
155
+
156
+ def _create_versioned_prompt(
157
+ self,
158
+ prompt_template: PromptTemplate,
159
+ *,
160
+ score: Optional[float] = None,
161
+ ) -> VersionedPromptTemplate:
162
+ """
163
+ Wrap a prompt template with a new monotonically increasing version identifier.
164
+ """
165
+ version = f"v{self._version_counter}"
166
+ self._version_counter += 1
167
+ return VersionedPromptTemplate(version=version, prompt_template=prompt_template, score=score)
168
+
169
+ def _format_log_prefix(
170
+ self,
171
+ *,
172
+ round_num: Optional[int] = None,
173
+ beam_idx: Optional[int] = None,
174
+ branch_idx: Optional[int] = None,
175
+ prompt_version: Optional[str] = None,
176
+ ) -> str:
177
+ """
178
+ Construct the standardized log prefix.
179
+ """
180
+ parts: List[str] = []
181
+ if round_num is not None:
182
+ parts.append(f"Round {round_num:02d}")
183
+ if beam_idx is not None:
184
+ parts.append(f"Beam {beam_idx:02d}")
185
+ if branch_idx is not None:
186
+ parts.append(f"Branch {branch_idx:02d}")
187
+ if prompt_version is not None:
188
+ parts.append(f"Prompt {prompt_version}")
189
+ if not parts:
190
+ return ""
191
+ return f"[{' | '.join(parts)}]"
192
+
193
+ def _log(self, level: int, message: str, *, prefix: Optional[str] = None) -> None:
194
+ """
195
+ Log a message with an optional standardized prefix.
196
+ """
197
+ effective_prefix = prefix
198
+ if effective_prefix:
199
+ logger.log(level, f"{effective_prefix} {message}")
200
+ else:
201
+ logger.log(level, message)
202
+
203
+ def get_seed_prompt_template(self) -> Tuple[str, PromptTemplate]:
204
+ """
205
+ Extract the initial prompt template from the algorithm's resources.
206
+
207
+ Returns:
208
+ A tuple of (resource_name, prompt_template) representing the seed prompt.
209
+
210
+ Raises:
211
+ ValueError: If initial_resources is not set or no PromptTemplate is found.
212
+ """
213
+ initial_resources = self.get_initial_resources()
214
+ if initial_resources is None:
215
+ raise ValueError(
216
+ "initial_resources are not set for APO algorithm. "
217
+ "Use algorithm.set_initial_resources() to set initial resources or set it in Trainer()"
218
+ )
219
+ for name, resource in initial_resources.items():
220
+ if isinstance(resource, PromptTemplate):
221
+ return name, resource
222
+ raise ValueError("No prompt template resource found in initial_resources")
223
+
224
+ def get_adapter(self) -> TraceToMessages:
225
+ """
226
+ Get the adapter for converting spans to messages.
227
+
228
+ Returns:
229
+ The TraceToMessages instance for this algorithm.
230
+
231
+ Raises:
232
+ ValueError: If the adapter is not a TraceToMessages.
233
+ """
234
+ adapter = super().get_adapter()
235
+ if not isinstance(adapter, TraceToMessages):
236
+ raise ValueError("Adapter must be a TraceToMessages for APO algorithm")
237
+ return adapter
238
+
239
+ def get_best_prompt(self) -> PromptTemplate:
240
+ """
241
+ Retrieve the best prompt discovered during optimization.
242
+
243
+ Returns:
244
+ The prompt template with the highest validation score found so far.
245
+
246
+ Raises:
247
+ ValueError: If no best prompt has been found yet (run() not called).
248
+ """
249
+ if self._history_best_prompt is None:
250
+ raise ValueError("No best prompt found")
251
+ return self._history_best_prompt
252
+
253
+ async def compute_textual_gradient(
254
+ self,
255
+ current_prompt: VersionedPromptTemplate,
256
+ rollout_results: List[RolloutResultForAPO],
257
+ *,
258
+ prefix: Optional[str] = None,
259
+ ) -> Optional[str]:
260
+ """
261
+ Compute a textual gradient (critique) for the current prompt based on rollout results.
262
+
263
+ This method samples rollout results, sends them to an LLM along with the current prompt,
264
+ and generates a critique describing how the prompt could be improved.
265
+
266
+ Args:
267
+ current_prompt: The prompt template to critique.
268
+ rollout_results: List of rollout results containing spans, messages, and rewards.
269
+
270
+ Returns:
271
+ A textual critique generated by the LLM, or None if generation fails.
272
+ """
273
+ tg_template = random.choice(GRADIENT_PROMPT_FILES)
274
+
275
+ if len(rollout_results) < self.gradient_batch_size:
276
+ self._log(
277
+ logging.WARNING,
278
+ f"Only {len(rollout_results)} rollouts available, but {self.gradient_batch_size} are needed. Using all rollouts.",
279
+ prefix=prefix,
280
+ )
281
+ sampled_rollout_results = rollout_results
282
+ else:
283
+ sampled_rollout_results = random.sample(rollout_results, self.gradient_batch_size)
284
+
285
+ self._log(
286
+ logging.INFO,
287
+ f"Gradient will be computed with {self.gradient_model} for {len(sampled_rollout_results)} rollouts with template: {tg_template.name}",
288
+ prefix=prefix,
289
+ )
290
+
291
+ tg_msg = poml.poml( # type: ignore
292
+ tg_template,
293
+ context={
294
+ "experiments": sampled_rollout_results,
295
+ "prompt_template": current_prompt.prompt_template.template,
296
+ },
297
+ format="openai_chat",
298
+ )
299
+ self._log(
300
+ logging.DEBUG,
301
+ f"Gradient computed with {self.gradient_model} prompt: {tg_msg}",
302
+ prefix=prefix,
303
+ )
304
+ critique_response = await self.async_openai_client.chat.completions.create(
305
+ model=self.gradient_model,
306
+ messages=tg_msg["messages"], # type: ignore
307
+ temperature=self.diversity_temperature,
308
+ )
309
+ critique_text = critique_response.choices[0].message.content
310
+ self._log(
311
+ logging.INFO,
312
+ f"Gradient computed with {self.gradient_model} has result: {critique_text}",
313
+ prefix=prefix,
314
+ )
315
+
316
+ return critique_text
317
+
318
+ async def textual_gradient_and_apply_edit(
319
+ self,
320
+ current_prompt: VersionedPromptTemplate,
321
+ rollout: List[RolloutResultForAPO],
322
+ *,
323
+ prefix: Optional[str] = None,
324
+ ) -> Optional[str]:
325
+ """
326
+ Generate an improved prompt by computing a textual gradient and applying an edit.
327
+
328
+ This is the main optimization step that:
329
+
330
+ 1. Computes a critique (textual gradient) based on rollout performance
331
+ 2. Uses another LLM to apply the critique and generate an improved prompt
332
+
333
+ Args:
334
+ current_prompt: The current prompt template to improve.
335
+ rollout: List of rollout results to base the critique on.
336
+
337
+ Returns:
338
+ The improved prompt text, or the original prompt if gradient computation fails.
339
+ """
340
+ # 1) Critique
341
+ critique_text = await self.compute_textual_gradient(
342
+ current_prompt,
343
+ rollout,
344
+ prefix=prefix,
345
+ )
346
+ if not critique_text:
347
+ self._log(
348
+ logging.ERROR,
349
+ "Failed to compute critique for prompt.",
350
+ prefix=prefix,
351
+ )
352
+ return current_prompt.prompt_template.template
353
+
354
+ # 2) Apply edit
355
+ ae_template = random.choice(APPLY_EDIT_PROMPT_FILES)
356
+ self._log(
357
+ logging.INFO,
358
+ f"Edit will be generated by {self.apply_edit_model} with template: {ae_template.name}",
359
+ prefix=prefix,
360
+ )
361
+ ae_msg = poml.poml( # type: ignore
362
+ ae_template,
363
+ context={
364
+ "prompt_template": current_prompt.prompt_template.template,
365
+ "critique": critique_text,
366
+ },
367
+ format="openai_chat",
368
+ )
369
+
370
+ ae_response = await self.async_openai_client.chat.completions.create(
371
+ model=self.apply_edit_model,
372
+ messages=ae_msg["messages"], # type: ignore
373
+ temperature=self.diversity_temperature,
374
+ )
375
+ new_prompt = ae_response.choices[0].message.content
376
+ if new_prompt:
377
+ self._log(
378
+ logging.INFO,
379
+ f"Edit generated by {self.apply_edit_model}: {new_prompt[:50]}...",
380
+ prefix=prefix,
381
+ )
382
+ return new_prompt
383
+
384
+ @with_store
385
+ async def get_rollout_results(
386
+ self,
387
+ store: LightningStore,
388
+ rollout: List[Rollout],
389
+ *,
390
+ prefix: Optional[str] = None,
391
+ ) -> List[RolloutResultForAPO]:
392
+ """
393
+ Convert completed rollouts to APO-compatible result format.
394
+
395
+ Fetches spans for each rollout, adapts them to messages, and packages them
396
+ with rewards and status information for gradient computation.
397
+
398
+ Args:
399
+ rollout: List of completed rollout metadata.
400
+
401
+ Returns:
402
+ List of rollout results formatted for APO processing.
403
+ """
404
+ rollout_results: List[RolloutResultForAPO] = []
405
+ adapter = self.get_adapter()
406
+ for r in rollout:
407
+ spans = await store.query_spans(r.rollout_id)
408
+ messages = adapter.adapt(spans)
409
+ rollout_result = RolloutResultForAPO(
410
+ status=r.status,
411
+ final_reward=find_final_reward(spans),
412
+ spans=[span.model_dump() for span in spans],
413
+ messages=messages,
414
+ )
415
+ self._log(
416
+ logging.DEBUG,
417
+ f"Rollout result for {r.rollout_id}: status {rollout_result['status']} with final reward {rollout_result['final_reward']}. "
418
+ f"{len(rollout_result['spans'])} spans and {len(rollout_result['messages'])} messages.",
419
+ prefix=prefix,
420
+ )
421
+ rollout_results.append(rollout_result)
422
+ return rollout_results
423
+
424
+ async def evaluate_prompt_on_batch(
425
+ self,
426
+ prompt: VersionedPromptTemplate,
427
+ resource_name: str,
428
+ dataset: Sequence[T_task],
429
+ mode: RolloutMode,
430
+ *,
431
+ prefix: Optional[str] = None,
432
+ ) -> Tuple[List[RolloutResultForAPO], float]:
433
+ """
434
+ Evaluate a prompt on a batch of tasks by running rollouts and computing average reward.
435
+
436
+ This method:
437
+
438
+ 1. Adds the prompt as a named resource to the store
439
+ 2. Enqueues rollouts for each task in the dataset
440
+ 3. Waits for rollouts to complete (with timeout)
441
+ 4. Computes and returns the average reward
442
+
443
+ Args:
444
+ prompt: The prompt template string to evaluate.
445
+ resource_name: The name to register the prompt under in the store.
446
+ dataset: Sequence of tasks to evaluate the prompt on.
447
+ mode: Rollout mode ("train" or "val") for logging/tracking.
448
+
449
+ Returns:
450
+ A tuple of (rollout_results, average_reward) where rollout_results contains
451
+ detailed information for each rollout and average_reward is the mean final reward.
452
+ """
453
+ store = self.get_store()
454
+ preview = prompt.prompt_template.template[:50]
455
+ self._log(
456
+ logging.INFO,
457
+ f'Evaluating prompt "{preview}..." on {len(dataset)} tasks in {mode} mode',
458
+ prefix=prefix,
459
+ )
460
+
461
+ # Install prompt as named resource
462
+ resources: NamedResources = {resource_name: prompt.prompt_template}
463
+ resource_update = await store.update_resources(prompt.version, resources)
464
+
465
+ rollout_ids: List[str] = []
466
+ for t in dataset:
467
+ r = await store.enqueue_rollout(input=t, mode=mode, resources_id=resource_update.resources_id)
468
+ rollout_ids.append(r.rollout_id)
469
+
470
+ deadline = time.time() + self.rollout_batch_timeout
471
+ finished: List[Rollout] = []
472
+ while time.time() < deadline:
473
+ finished = await store.wait_for_rollouts(rollout_ids=rollout_ids, timeout=0.0)
474
+ if len(finished) >= len(rollout_ids):
475
+ self._log(
476
+ logging.INFO,
477
+ f"All {len(rollout_ids)} rollouts finished within timeout.",
478
+ prefix=prefix,
479
+ )
480
+ break
481
+ else:
482
+ self._log(
483
+ logging.DEBUG,
484
+ f"Only {len(finished)} rollouts finished within timeout. Waiting for remaining {len(rollout_ids) - len(finished)} rollouts.",
485
+ prefix=prefix,
486
+ )
487
+ # Sleep to avoid busy-waiting
488
+ await asyncio.sleep(2.0)
489
+
490
+ rollout_results = await self.get_rollout_results(
491
+ finished,
492
+ prefix=prefix,
493
+ )
494
+ final_rewards = [rr["final_reward"] for rr in rollout_results]
495
+
496
+ avg = float(sum([r or 0.0 for r in final_rewards]) / max(1, len(final_rewards)))
497
+ status_counter = Counter([rr["status"] for rr in rollout_results])
498
+
499
+ self._log(
500
+ logging.INFO,
501
+ f"Evaluated {len(rollout_results)} rollouts. Statuses: {status_counter}. Rewards: {final_rewards}, average is {avg}",
502
+ prefix=prefix,
503
+ )
504
+ return rollout_results, avg
505
+
506
+ def _initialize_beam(
507
+ self,
508
+ train_dataset: Optional[Dataset[T_task]],
509
+ val_dataset: Optional[Dataset[T_task]],
510
+ ) -> Tuple[str, PromptTemplate, Iterator[Sequence[T_task]], Iterator[Sequence[T_task]]]:
511
+ """
512
+ Initialize the beam search with seed prompt and dataset iterators.
513
+
514
+ Args:
515
+ train_dataset: Dataset for computing gradients.
516
+ val_dataset: Dataset for evaluating prompts.
517
+
518
+ Returns:
519
+ Tuple of (resource_name, seed_prompt, grad_iterator, val_iterator).
520
+
521
+ Raises:
522
+ ValueError: If either dataset is None.
523
+ """
524
+ resource_name, seed_prompt = self.get_seed_prompt_template()
525
+
526
+ if train_dataset is None:
527
+ raise ValueError("train_dataset is required for APO algorithm")
528
+ if val_dataset is None:
529
+ raise ValueError("val_dataset is required for APO algorithm")
530
+
531
+ grad_dataset_iterator = batch_iter_over_dataset(train_dataset, self.gradient_batch_size)
532
+ val_dataset_iterator = batch_iter_over_dataset(val_dataset, self.val_batch_size)
533
+
534
+ # Initialize history tracking
535
+ self._history_best_prompt = seed_prompt
536
+ self._history_best_score = float("-inf")
537
+
538
+ return resource_name, seed_prompt, grad_dataset_iterator, val_dataset_iterator
539
+
540
+ def _sample_parent_prompts(
541
+ self,
542
+ beam: List[VersionedPromptTemplate],
543
+ round_num: int,
544
+ ) -> List[Tuple[int, VersionedPromptTemplate]]:
545
+ """
546
+ Sample parent prompts from the current beam for generating new candidates.
547
+
548
+ If the beam has fewer prompts than beam_width, replicates existing prompts.
549
+ Otherwise, randomly samples beam_width prompts.
550
+
551
+ Args:
552
+ beam: Current list of prompt templates in the beam.
553
+ round_num: Current round number (for logging, 0-indexed).
554
+
555
+ Returns:
556
+ List of parent prompts to generate children from.
557
+ """
558
+ display_round = round_num + 1
559
+ if len(beam) < self.beam_width:
560
+ prefix = self._format_log_prefix(round_num=display_round)
561
+ self._log(
562
+ logging.WARNING,
563
+ f"Beam width is currently {self.beam_width}, but only {len(beam)} prompts in beam. Replicating all prompts.",
564
+ prefix=prefix,
565
+ )
566
+ return [(i % len(beam), beam[i % len(beam)]) for i in range(self.beam_width)]
567
+
568
+ selected_indices = random.sample(range(len(beam)), self.beam_width)
569
+ return [(idx, beam[idx]) for idx in selected_indices]
570
+
571
+ async def _generate_candidate_prompts(
572
+ self,
573
+ parent_prompts: List[Tuple[int, VersionedPromptTemplate]],
574
+ resource_name: str,
575
+ grad_dataset_iterator: Iterator[Sequence[T_task]],
576
+ round_num: int,
577
+ ) -> List[VersionedPromptTemplate]:
578
+ """
579
+ Generate new candidate prompts from parents using textual gradients.
580
+
581
+ For each parent prompt, generates branch_factor new candidates by:
582
+
583
+ 1. Evaluating the parent on a training batch
584
+ 2. Computing textual gradient
585
+ 3. Applying edit to generate improved prompt
586
+
587
+ Args:
588
+ parent_prompts: List of parent prompts to generate children from.
589
+ resource_name: Name to register prompts under in the store.
590
+ grad_dataset_iterator: Iterator over training data batches.
591
+ round_num: Current round number (for logging, 0-indexed).
592
+
593
+ Returns:
594
+ List of newly generated prompt templates.
595
+ """
596
+ display_round = round_num + 1
597
+ round_prefix = self._format_log_prefix(round_num=display_round)
598
+ self._log(
599
+ logging.INFO,
600
+ f"Applying {self.branch_factor} edits to each of the {len(parent_prompts)} parents based on "
601
+ "gradients computed on training dataset",
602
+ prefix=round_prefix,
603
+ )
604
+
605
+ parent_prompts_str = [
606
+ f"{p.version}:{p.score:.3f}" if p.score is not None else p.version for _, p in parent_prompts
607
+ ]
608
+ self._log(
609
+ logging.INFO,
610
+ f"Parent prompts: {', '.join(parent_prompts_str)}",
611
+ prefix=round_prefix,
612
+ )
613
+
614
+ candidates: List[VersionedPromptTemplate] = []
615
+ used_beam_indices: Set[int] = set()
616
+ for real_beam_idx, (beam_idx, prompt) in enumerate(parent_prompts):
617
+ if beam_idx in used_beam_indices:
618
+ beam_prefix = self._format_log_prefix(
619
+ round_num=display_round,
620
+ beam_idx=beam_idx + 1,
621
+ prompt_version=prompt.version,
622
+ )
623
+ self._log(
624
+ logging.WARNING,
625
+ "Duplicated beam index found. Might be caused by beam_width too high. "
626
+ + f"The real index of this beam is {real_beam_idx + 1}.",
627
+ prefix=beam_prefix,
628
+ )
629
+ else:
630
+ used_beam_indices.add(beam_idx)
631
+ for branch_idx in range(self.branch_factor):
632
+ parent_prefix = self._format_log_prefix(
633
+ round_num=display_round,
634
+ beam_idx=beam_idx + 1,
635
+ branch_idx=branch_idx + 1,
636
+ prompt_version=prompt.version,
637
+ )
638
+ baseline_score = f"{prompt.score:.3f}" if prompt.score is not None else "N/A"
639
+ self._log(
640
+ logging.INFO,
641
+ f"Use parent prompt {prompt.version} as a baseline to generate a new prompt. Baseline score: {baseline_score}",
642
+ prefix=parent_prefix,
643
+ )
644
+ grad_samples = next(grad_dataset_iterator)
645
+ rollout_results, _ = await self.evaluate_prompt_on_batch(
646
+ prompt,
647
+ resource_name,
648
+ grad_samples,
649
+ mode="train",
650
+ prefix=parent_prefix,
651
+ )
652
+ new_prompt = await self.textual_gradient_and_apply_edit(
653
+ prompt,
654
+ rollout_results,
655
+ prefix=parent_prefix,
656
+ )
657
+ if not new_prompt:
658
+ self._log(
659
+ logging.ERROR,
660
+ f"Failed to compute edit for prompt: {prompt.prompt_template.template}",
661
+ prefix=parent_prefix,
662
+ )
663
+ continue
664
+ new_prompt_template = PromptTemplate(template=new_prompt, engine="f-string")
665
+ versioned_candidate = self._create_versioned_prompt(new_prompt_template)
666
+ self._log(
667
+ logging.INFO,
668
+ f"New prompt template created from parent {prompt.version}: {versioned_candidate.version}",
669
+ prefix=parent_prefix,
670
+ )
671
+ candidate_prefix = self._format_log_prefix(
672
+ round_num=display_round, prompt_version=versioned_candidate.version
673
+ )
674
+ self._log(
675
+ logging.INFO,
676
+ f"New prompt template created from parent {prompt.version}:\n```\n{new_prompt}\n```",
677
+ prefix=candidate_prefix,
678
+ )
679
+ candidates.append(versioned_candidate)
680
+
681
+ return candidates
682
+
683
+ async def _evaluate_and_select_beam(
684
+ self,
685
+ candidates: List[VersionedPromptTemplate],
686
+ resource_name: str,
687
+ val_dataset_iterator: Iterator[Sequence[T_task]],
688
+ round_num: int,
689
+ ) -> List[VersionedPromptTemplate]:
690
+ """
691
+ Evaluate all candidate prompts on validation data and select top-k for the beam.
692
+
693
+ Args:
694
+ candidates: List of candidate prompts to evaluate.
695
+ resource_name: Name to register prompts under in the store.
696
+ val_dataset_iterator: Iterator over validation data batches.
697
+ round_num: Current round number (for logging, 0-indexed).
698
+
699
+ Returns:
700
+ List of top beam_width prompts sorted by validation score (best first).
701
+
702
+ Raises:
703
+ ValueError: If no candidates remain after evaluation.
704
+ """
705
+ display_round = round_num + 1
706
+ round_prefix = self._format_log_prefix(round_num=display_round)
707
+ self._log(
708
+ logging.INFO,
709
+ f"Evaluating {len(candidates)} candidates on validation dataset",
710
+ prefix=round_prefix,
711
+ )
712
+
713
+ val_batch = next(val_dataset_iterator)
714
+
715
+ for prompt in candidates:
716
+ candidate_prefix = self._format_log_prefix(
717
+ round_num=display_round,
718
+ prompt_version=prompt.version,
719
+ )
720
+ _, score = await self.evaluate_prompt_on_batch(
721
+ prompt,
722
+ resource_name,
723
+ val_batch,
724
+ mode="val",
725
+ prefix=candidate_prefix,
726
+ )
727
+ prompt.score = score
728
+ self._log(
729
+ logging.INFO,
730
+ f"Candidate score: {score:.3f}",
731
+ prefix=candidate_prefix,
732
+ )
733
+
734
+ # Sort by score (descending) and select top beam_width
735
+ sorted_prompts = [p for p in sorted(candidates, key=lambda x: cast(float, x.score), reverse=True)]
736
+ selected_prompts = sorted_prompts[: self.beam_width]
737
+ selected_versions = [
738
+ f"{prompt.version}:{prompt.score:.3f}" if prompt.score is not None else prompt.version
739
+ for prompt in selected_prompts
740
+ ]
741
+ self._log(
742
+ logging.INFO,
743
+ f"Top {len(selected_prompts)} candidates on validation dataset: {selected_versions}",
744
+ prefix=round_prefix,
745
+ )
746
+
747
+ if len(selected_prompts) == 0:
748
+ raise ValueError("No beam candidates any more")
749
+
750
+ return selected_prompts
751
+
752
+ async def _update_best_prompt(
753
+ self,
754
+ beam: List[VersionedPromptTemplate],
755
+ resource_name: str,
756
+ val_dataset: Dataset[T_task],
757
+ round_num: int,
758
+ ) -> None:
759
+ """
760
+ Evaluate the best prompt in the beam on the full validation set and update history.
761
+
762
+ Args:
763
+ beam: Current beam of prompts (sorted, best first).
764
+ resource_name: Name to register prompts under in the store.
765
+ val_dataset: Full validation dataset.
766
+ round_num: Current round number (for logging, 0-indexed).
767
+ """
768
+ display_round = round_num + 1
769
+ best_prompt = beam[0]
770
+ prefix = self._format_log_prefix(round_num=display_round, prompt_version=best_prompt.version)
771
+ _, best_score = await self.evaluate_prompt_on_batch(
772
+ best_prompt,
773
+ resource_name,
774
+ cast(Sequence[T_task], val_dataset),
775
+ mode="val",
776
+ prefix=prefix,
777
+ )
778
+ self._log(
779
+ logging.INFO,
780
+ f"Beam leader score: {best_score:.3f}",
781
+ prefix=prefix,
782
+ )
783
+
784
+ if best_score > self._history_best_score:
785
+ prev = self._history_best_score
786
+ self._log(
787
+ logging.INFO,
788
+ f"Best prompt updated. New best score: {best_score:.3f} (prev: {prev:.3f})",
789
+ prefix=prefix,
790
+ )
791
+ self._history_best_prompt = best_prompt.prompt_template
792
+ self._history_best_score = best_score
793
+ self._history_best_version = best_prompt.version
794
+ else:
795
+ self._log(
796
+ logging.WARNING,
797
+ f"Best prompt not updated. Current score: {best_score:.3f} vs. history best: {self._history_best_score:.3f})",
798
+ prefix=prefix,
799
+ )
800
+
801
+ @with_llm_proxy()
802
+ @with_store
803
+ async def run(
804
+ self,
805
+ store: LightningStore, # Injected by decorator - callers should not provide this parameter
806
+ llm_proxy: Optional[LLMProxy], # Injected by decorator - callers should not provide this parameter
807
+ train_dataset: Optional[Dataset[T_task]] = None,
808
+ val_dataset: Optional[Dataset[T_task]] = None,
809
+ ) -> None:
810
+ """
811
+ Execute the APO algorithm to optimize prompts through beam search with textual gradients.
812
+
813
+ The algorithm performs iterative prompt optimization over multiple rounds:
814
+
815
+ - Each round: samples parent prompts, generates new candidates via textual gradients,
816
+ evaluates all candidates on validation data, and keeps the top performers
817
+ - Tracks the historically best prompt across all rounds
818
+ - Uses different training data samples for each gradient computation to ensure diversity
819
+
820
+ Args:
821
+ train_dataset: Dataset of tasks for computing textual gradients. Required.
822
+ val_dataset: Dataset of tasks for evaluating and selecting prompts. Required.
823
+
824
+ Raises:
825
+ ValueError: If train_dataset or val_dataset is None, or if resources are not set.
826
+ """
827
+ # Initialize beam search
828
+ resource_name, seed_prompt, grad_iterator, val_iterator = self._initialize_beam(train_dataset, val_dataset)
829
+
830
+ if self._poml_trace:
831
+ poml.set_trace(trace_dir="pomltrace")
832
+
833
+ # Validation datasets are guaranteed to be non-None after initialization
834
+ assert val_dataset is not None
835
+
836
+ # Start with seed prompt in the beam
837
+ seed_versioned = self._create_versioned_prompt(seed_prompt)
838
+ beam: List[VersionedPromptTemplate] = [seed_versioned]
839
+ self._history_best_prompt = seed_prompt
840
+ self._history_best_version = seed_versioned.version
841
+
842
+ # Optionally evaluate seed prompt on validation set to establish baseline
843
+ if self.run_initial_validation:
844
+ seed_prefix = self._format_log_prefix(round_num=0, prompt_version=seed_versioned.version)
845
+ self._log(
846
+ logging.INFO,
847
+ "Evaluating seed prompt on validation dataset before optimization...",
848
+ prefix=seed_prefix,
849
+ )
850
+ _, seed_score = await self.evaluate_prompt_on_batch(
851
+ seed_versioned,
852
+ resource_name,
853
+ cast(Sequence[T_task], val_dataset),
854
+ mode="val",
855
+ prefix=seed_prefix,
856
+ )
857
+ self._log(
858
+ logging.INFO,
859
+ f"Seed prompt baseline score: {seed_score:.3f}",
860
+ prefix=seed_prefix,
861
+ )
862
+ self._history_best_prompt = seed_prompt
863
+ self._history_best_score = seed_score
864
+ self._history_best_version = seed_versioned.version
865
+
866
+ # Run beam search for specified number of rounds
867
+ for rnd in range(self.beam_rounds):
868
+ display_round = rnd + 1
869
+ round_prefix = self._format_log_prefix(round_num=display_round)
870
+ self._log(
871
+ logging.INFO,
872
+ f"Round {display_round}/{self.beam_rounds}...",
873
+ prefix=round_prefix,
874
+ )
875
+
876
+ # Sample parent prompts from current beam
877
+ parent_prompts = self._sample_parent_prompts(beam, rnd)
878
+
879
+ # Generate new candidate prompts from parents
880
+ new_candidates = await self._generate_candidate_prompts(parent_prompts, resource_name, grad_iterator, rnd)
881
+
882
+ # Combine existing beam with new candidates
883
+ all_candidates = [*beam, *new_candidates]
884
+
885
+ # Evaluate and select top-k prompts for next beam
886
+ beam = await self._evaluate_and_select_beam(all_candidates, resource_name, val_iterator, rnd)
887
+
888
+ # Update historically best prompt if improved
889
+ await self._update_best_prompt(beam, resource_name, val_dataset, rnd)