mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,356 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ import traceback
5
+ from collections.abc import Sequence
6
+ from typing import Generic
7
+
8
+ from mantisdk.algorithm.gepa.lib.core.adapter import DataInst, GEPAAdapter, RolloutOutput, Trajectory
9
+ from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataLoader, ensure_loader
10
+ from mantisdk.algorithm.gepa.lib.core.state import EvaluationCache, FrontierType, GEPAState, ValsetEvaluation, initialize_gepa_state
11
+ from mantisdk.algorithm.gepa.lib.logging.experiment_tracker import ExperimentTracker
12
+ from mantisdk.algorithm.gepa.lib.logging.logger import LoggerProtocol
13
+ from mantisdk.algorithm.gepa.lib.logging.utils import log_detailed_metrics_after_discovering_new_program
14
+ from mantisdk.algorithm.gepa.lib.proposer.merge import MergeProposer
15
+ from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.reflective_mutation import (
16
+ ReflectiveMutationProposer,
17
+ )
18
+ from mantisdk.algorithm.gepa.lib.strategies.eval_policy import EvaluationPolicy, FullEvaluationPolicy
19
+ from mantisdk.algorithm.gepa.lib.utils import StopperProtocol
20
+
21
+ # Import tqdm for progress bar functionality
22
+ try:
23
+ from tqdm import tqdm
24
+ except ImportError:
25
+ tqdm = None
26
+
27
+
28
+ class GEPAEngine(Generic[DataId, DataInst, Trajectory, RolloutOutput]):
29
+ """Orchestrates the optimization loop using pluggable candidate proposers."""
30
+
31
+ def __init__(
32
+ self,
33
+ adapter: GEPAAdapter[DataInst, Trajectory, RolloutOutput],
34
+ run_dir: str | None,
35
+ valset: list[DataInst] | DataLoader[DataId, DataInst] | None,
36
+ seed_candidate: dict[str, str],
37
+ # Controls
38
+ perfect_score: float,
39
+ seed: int,
40
+ # Strategies and helpers
41
+ reflective_proposer: ReflectiveMutationProposer,
42
+ merge_proposer: MergeProposer | None,
43
+ frontier_type: FrontierType,
44
+ # Logging
45
+ logger: LoggerProtocol,
46
+ experiment_tracker: ExperimentTracker,
47
+ # Optional parameters
48
+ track_best_outputs: bool = False,
49
+ display_progress_bar: bool = False,
50
+ raise_on_exception: bool = True,
51
+ use_cloudpickle: bool = False,
52
+ # Budget and Stop Condition
53
+ stop_callback: StopperProtocol | None = None,
54
+ val_evaluation_policy: EvaluationPolicy[DataId, DataInst] | None = None,
55
+ # Evaluation caching (stored in state, passed here for initialization)
56
+ evaluation_cache: EvaluationCache[RolloutOutput, DataId] | None = None,
57
+ ):
58
+ self.logger = logger
59
+ self.run_dir = run_dir
60
+
61
+ # Graceful stopping mechanism
62
+ self._stop_requested = False
63
+
64
+ # Set up stopping mechanism
65
+ self.stop_callback = stop_callback
66
+ self.adapter = adapter
67
+
68
+ # Store cache reference for state initialization (actual cache lives in GEPAState)
69
+ self._initial_evaluation_cache = evaluation_cache
70
+
71
+ def evaluator(
72
+ batch: list[DataInst], program: dict[str, str]
73
+ ) -> tuple[list[RolloutOutput], list[float], Sequence[dict[str, float]] | None]:
74
+ eval_result = adapter.evaluate(batch, program, capture_traces=False)
75
+ return eval_result.outputs, eval_result.scores, eval_result.objective_scores
76
+
77
+ self.evaluator = evaluator
78
+
79
+ self.valset = ensure_loader(valset) if valset is not None else None
80
+ self.seed_candidate = seed_candidate
81
+
82
+ self.perfect_score = perfect_score
83
+ self.seed = seed
84
+ self.experiment_tracker = experiment_tracker
85
+
86
+ self.reflective_proposer = reflective_proposer
87
+ self.merge_proposer = merge_proposer
88
+ self.frontier_type: FrontierType = frontier_type
89
+
90
+ # Merge scheduling flags (mirroring previous behavior)
91
+ if self.merge_proposer is not None:
92
+ self.merge_proposer.last_iter_found_new_program = False
93
+
94
+ self.track_best_outputs = track_best_outputs
95
+ self.display_progress_bar = display_progress_bar
96
+ self.use_cloudpickle = use_cloudpickle
97
+
98
+ self.raise_on_exception = raise_on_exception
99
+ self.val_evaluation_policy: EvaluationPolicy[DataId, DataInst] = (
100
+ val_evaluation_policy if val_evaluation_policy is not None else FullEvaluationPolicy()
101
+ )
102
+
103
+ def _evaluate_on_valset(
104
+ self,
105
+ program: dict[str, str],
106
+ state: GEPAState[RolloutOutput, DataId],
107
+ ) -> ValsetEvaluation[RolloutOutput, DataId]:
108
+ valset = self.valset
109
+ assert valset is not None
110
+
111
+ val_ids = self.val_evaluation_policy.get_eval_batch(valset, state)
112
+
113
+ outputs_by_val_idx, scores_by_val_idx, objective_by_val_idx, num_actual_evals = state.cached_evaluate_full(
114
+ program, list(val_ids), valset.fetch, self.evaluator
115
+ )
116
+ state.total_num_evals += num_actual_evals
117
+
118
+ return ValsetEvaluation(
119
+ outputs_by_val_id=outputs_by_val_idx,
120
+ scores_by_val_id=scores_by_val_idx,
121
+ objective_scores_by_val_id=objective_by_val_idx,
122
+ )
123
+
124
+ def _run_full_eval_and_add(
125
+ self,
126
+ new_program: dict[str, str],
127
+ state: GEPAState[RolloutOutput, DataId],
128
+ parent_program_idx: list[int],
129
+ ) -> tuple[int, int]:
130
+ num_metric_calls_by_discovery = state.total_num_evals
131
+ valset_evaluation = self._evaluate_on_valset(new_program, state)
132
+ state.num_full_ds_evals += 1
133
+
134
+ new_program_idx = state.update_state_with_new_program(
135
+ parent_program_idx=parent_program_idx,
136
+ new_program=new_program,
137
+ valset_evaluation=valset_evaluation,
138
+ run_dir=self.run_dir,
139
+ num_metric_calls_by_discovery_of_new_program=num_metric_calls_by_discovery,
140
+ )
141
+ state.full_program_trace[-1]["new_program_idx"] = new_program_idx
142
+ state.full_program_trace[-1]["evaluated_val_indices"] = sorted(valset_evaluation.scores_by_val_id.keys())
143
+
144
+ valset_score = self.val_evaluation_policy.get_valset_score(new_program_idx, state)
145
+
146
+ linear_pareto_front_program_idx = self.val_evaluation_policy.get_best_program(state)
147
+ if new_program_idx == linear_pareto_front_program_idx:
148
+ self.logger.log(f"Iteration {state.i + 1}: Found a better program on the valset with score {valset_score}.")
149
+
150
+ valset = self.valset
151
+ assert valset is not None
152
+
153
+ log_detailed_metrics_after_discovering_new_program(
154
+ logger=self.logger,
155
+ gepa_state=state,
156
+ new_program_idx=new_program_idx,
157
+ valset_evaluation=valset_evaluation,
158
+ objective_scores=state.prog_candidate_objective_scores[new_program_idx],
159
+ experiment_tracker=self.experiment_tracker,
160
+ linear_pareto_front_program_idx=linear_pareto_front_program_idx,
161
+ valset_size=len(valset),
162
+ val_evaluation_policy=self.val_evaluation_policy,
163
+ )
164
+ return new_program_idx, linear_pareto_front_program_idx
165
+
166
+ def run(self) -> GEPAState[RolloutOutput, DataId]:
167
+ # Check tqdm availability if progress bar is enabled
168
+ progress_bar = None
169
+ if self.display_progress_bar:
170
+ if tqdm is None:
171
+ raise ImportError("tqdm must be installed when display_progress_bar is enabled")
172
+
173
+ # Check if stop_callback contains MaxMetricCallsStopper
174
+ total_calls: int | None = None
175
+ stop_cb = self.stop_callback
176
+ if stop_cb is not None:
177
+ max_calls_attr = getattr(stop_cb, "max_metric_calls", None)
178
+ if isinstance(max_calls_attr, int):
179
+ # Direct MaxMetricCallsStopper
180
+ total_calls = max_calls_attr
181
+ else:
182
+ stoppers = getattr(stop_cb, "stoppers", None)
183
+ if stoppers is not None:
184
+ # CompositeStopper - iterate to find MaxMetricCallsStopper
185
+ for stopper in stoppers:
186
+ stopper_max = getattr(stopper, "max_metric_calls", None)
187
+ if isinstance(stopper_max, int):
188
+ total_calls = stopper_max
189
+ break
190
+
191
+ if total_calls is not None:
192
+ progress_bar = tqdm(total=total_calls, desc="GEPA Optimization", unit="rollouts")
193
+ else:
194
+ progress_bar = tqdm(desc="GEPA Optimization", unit="rollouts")
195
+ progress_bar.update(0)
196
+
197
+ # Prepare valset
198
+ valset = self.valset
199
+ if valset is None:
200
+ raise ValueError("valset must be provided to GEPAEngine.run()")
201
+
202
+ def valset_evaluator(
203
+ program: dict[str, str],
204
+ ) -> ValsetEvaluation[RolloutOutput, DataId]:
205
+ all_ids = list(valset.all_ids())
206
+ outputs, scores, objective_scores = self.evaluator(valset.fetch(all_ids), program)
207
+ outputs_dict = dict(zip(all_ids, outputs, strict=False))
208
+ scores_dict = dict(zip(all_ids, scores, strict=False))
209
+ objective_scores_dict = (
210
+ dict(zip(all_ids, objective_scores, strict=False)) if objective_scores is not None else None
211
+ )
212
+ return ValsetEvaluation(
213
+ outputs_by_val_id=outputs_dict,
214
+ scores_by_val_id=scores_dict,
215
+ objective_scores_by_val_id=objective_scores_dict,
216
+ )
217
+
218
+ # Initialize state
219
+ state = initialize_gepa_state(
220
+ run_dir=self.run_dir,
221
+ logger=self.logger,
222
+ seed_candidate=self.seed_candidate,
223
+ valset_evaluator=valset_evaluator,
224
+ track_best_outputs=self.track_best_outputs,
225
+ frontier_type=self.frontier_type,
226
+ evaluation_cache=self._initial_evaluation_cache,
227
+ )
228
+
229
+ # Log base program score
230
+ base_val_avg, base_val_coverage = state.get_program_average_val_subset(0)
231
+ self.experiment_tracker.log_metrics(
232
+ {
233
+ "base_program_full_valset_score": base_val_avg,
234
+ "base_program_val_coverage": base_val_coverage,
235
+ "iteration": state.i + 1,
236
+ },
237
+ step=state.i + 1,
238
+ )
239
+
240
+ self.logger.log(
241
+ f"Iteration {state.i + 1}: Base program full valset score: {base_val_avg} "
242
+ f"over {base_val_coverage} / {len(valset)} examples"
243
+ )
244
+
245
+ # Merge scheduling
246
+ if self.merge_proposer is not None:
247
+ self.merge_proposer.last_iter_found_new_program = False
248
+
249
+ # Main loop
250
+ last_pbar_val = 0
251
+ while not self._should_stop(state):
252
+ if self.display_progress_bar and progress_bar is not None:
253
+ delta = state.total_num_evals - last_pbar_val
254
+ progress_bar.update(delta)
255
+ last_pbar_val = state.total_num_evals
256
+
257
+ assert state.is_consistent()
258
+ try:
259
+ state.save(self.run_dir, use_cloudpickle=self.use_cloudpickle)
260
+ state.i += 1
261
+ state.full_program_trace.append({"i": state.i})
262
+
263
+ # 1) Attempt merge first if scheduled and last iter found new program
264
+ if self.merge_proposer is not None and self.merge_proposer.use_merge:
265
+ if self.merge_proposer.merges_due > 0 and self.merge_proposer.last_iter_found_new_program:
266
+ proposal = self.merge_proposer.propose(state)
267
+ self.merge_proposer.last_iter_found_new_program = False # old behavior
268
+
269
+ if proposal is not None and proposal.tag == "merge":
270
+ parent_sums = proposal.subsample_scores_before or [
271
+ float("-inf"),
272
+ float("-inf"),
273
+ ]
274
+ new_sum = sum(proposal.subsample_scores_after or [])
275
+
276
+ if new_sum >= max(parent_sums):
277
+ # ACCEPTED: consume one merge attempt and record it
278
+ self._run_full_eval_and_add(
279
+ new_program=proposal.candidate,
280
+ state=state,
281
+ parent_program_idx=proposal.parent_program_ids,
282
+ )
283
+ self.merge_proposer.merges_due -= 1
284
+ self.merge_proposer.total_merges_tested += 1
285
+ continue # skip reflective this iteration
286
+ else:
287
+ # REJECTED: do NOT consume merges_due or total_merges_tested
288
+ self.logger.log(
289
+ f"Iteration {state.i + 1}: New program subsample score {new_sum} "
290
+ f"is worse than both parents {parent_sums}, skipping merge"
291
+ )
292
+ # Skip reflective this iteration (old behavior)
293
+ continue
294
+
295
+ # Old behavior: regardless of whether we attempted, clear the flag before reflective
296
+ self.merge_proposer.last_iter_found_new_program = False
297
+
298
+ # 2) Reflective mutation proposer
299
+ proposal = self.reflective_proposer.propose(state)
300
+ if proposal is None:
301
+ self.logger.log(f"Iteration {state.i + 1}: Reflective mutation did not propose a new candidate")
302
+ continue
303
+
304
+ # Acceptance: require strict improvement on subsample
305
+ old_sum = sum(proposal.subsample_scores_before or [])
306
+ new_sum = sum(proposal.subsample_scores_after or [])
307
+ if new_sum <= old_sum:
308
+ self.logger.log(
309
+ f"Iteration {state.i + 1}: New subsample score {new_sum} is not better than old score {old_sum}, skipping"
310
+ )
311
+ continue
312
+ else:
313
+ self.logger.log(
314
+ f"Iteration {state.i + 1}: New subsample score {new_sum} is better than old score {old_sum}. Continue to full eval and add to candidate pool."
315
+ )
316
+
317
+ # Accept: full eval + add
318
+ self._run_full_eval_and_add(
319
+ new_program=proposal.candidate,
320
+ state=state,
321
+ parent_program_idx=proposal.parent_program_ids,
322
+ )
323
+
324
+ # Schedule merge attempts like original behavior
325
+ if self.merge_proposer is not None:
326
+ self.merge_proposer.last_iter_found_new_program = True
327
+ if self.merge_proposer.total_merges_tested < self.merge_proposer.max_merge_invocations:
328
+ self.merge_proposer.merges_due += 1
329
+
330
+ except Exception as e:
331
+ self.logger.log(f"Iteration {state.i + 1}: Exception during optimization: {e}")
332
+ self.logger.log(traceback.format_exc())
333
+ if self.raise_on_exception:
334
+ raise e
335
+ else:
336
+ continue
337
+
338
+ # Close progress bar if it exists
339
+ if self.display_progress_bar and progress_bar is not None:
340
+ progress_bar.close()
341
+
342
+ state.save(self.run_dir)
343
+ return state
344
+
345
+ def _should_stop(self, state: GEPAState[RolloutOutput, DataId]) -> bool:
346
+ """Check if the optimization should stop."""
347
+ if self._stop_requested:
348
+ return True
349
+ if self.stop_callback and self.stop_callback(state):
350
+ return True
351
+ return False
352
+
353
+ def request_stop(self) -> None:
354
+ """Manually request the optimization to stop gracefully."""
355
+ self.logger.log("Stop requested manually. Initiating graceful shutdown...")
356
+ self._stop_requested = True
@@ -0,0 +1,233 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic
6
+
7
+ from mantisdk.algorithm.gepa.lib.core.adapter import RolloutOutput
8
+ from mantisdk.algorithm.gepa.lib.core.data_loader import DataId
9
+ from mantisdk.algorithm.gepa.lib.core.state import ProgramIdx
10
+
11
+ if TYPE_CHECKING:
12
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class GEPAResult(Generic[RolloutOutput, DataId]):
17
+ """
18
+ Immutable snapshot of a GEPA run with convenience accessors.
19
+
20
+ - candidates: list of proposed candidates (component_name -> component_text)
21
+ - parents: lineage info; for each candidate i, parents[i] is a list of parent indices or None
22
+ - val_aggregate_scores: per-candidate aggregate score on the validation set (higher is better)
23
+ - val_subscores: per-candidate mapping from validation id to score on the validation set (sparse dict)
24
+ - val_aggregate_subscores: optional per-candidate aggregate subscores across objectives
25
+ - per_val_instance_best_candidates: for each val instance t, a set of candidate indices achieving the current best score on t
26
+ - per_objective_best_candidates: optional per-objective set of candidate indices achieving best aggregate subscore
27
+ - discovery_eval_counts: number of metric calls accumulated up to the discovery of each candidate
28
+
29
+ Optional fields:
30
+ - best_outputs_valset: per-task best outputs on the validation set. [task_idx -> [(program_idx_1, output_1), (program_idx_2, output_2), ...]]
31
+
32
+ Run-level metadata:
33
+ - total_metric_calls: total number of metric calls made across the run
34
+ - num_full_val_evals: number of full validation evaluations performed
35
+ - run_dir: where artifacts were written (if any)
36
+ - seed: RNG seed for reproducibility (if known)
37
+ - tracked_scores: optional tracked aggregate scores (if different from val_aggregate_scores)
38
+
39
+ Convenience:
40
+ - best_idx: candidate index with the highest val_aggregate_scores
41
+ - best_candidate: the program text mapping for best_idx
42
+ - non_dominated_indices(): candidate indices that are not dominated across per-instance pareto fronts
43
+ - lineage(idx): parent chain from base to idx
44
+ - diff(parent_idx, child_idx, only_changed=True): component-wise diff between two candidates
45
+ - best_k(k): top-k candidates by aggregate val score
46
+ - instance_winners(t): set of candidates on the pareto front for val instance t
47
+ - to_dict(...), save_json(...): serialization helpers
48
+ """
49
+
50
+ # Core data
51
+ candidates: list[dict[str, str]]
52
+ parents: list[list[ProgramIdx | None]]
53
+ val_aggregate_scores: list[float]
54
+ val_subscores: list[dict[DataId, float]]
55
+ per_val_instance_best_candidates: dict[DataId, set[ProgramIdx]]
56
+ discovery_eval_counts: list[int]
57
+ val_aggregate_subscores: list[dict[str, float]] | None = None
58
+ per_objective_best_candidates: dict[str, set[ProgramIdx]] | None = None
59
+ objective_pareto_front: dict[str, float] | None = None
60
+
61
+ # Optional data
62
+ best_outputs_valset: dict[DataId, list[tuple[ProgramIdx, RolloutOutput]]] | None = None
63
+
64
+ # Run metadata (optional)
65
+ total_metric_calls: int | None = None
66
+ num_full_val_evals: int | None = None
67
+ run_dir: str | None = None
68
+ seed: int | None = None
69
+
70
+ _VALIDATION_SCHEMA_VERSION: ClassVar[int] = 2
71
+
72
+ # -------- Convenience properties --------
73
+ @property
74
+ def num_candidates(self) -> int:
75
+ return len(self.candidates)
76
+
77
+ @property
78
+ def num_val_instances(self) -> int:
79
+ return len(self.per_val_instance_best_candidates)
80
+
81
+ @property
82
+ def best_idx(self) -> int:
83
+ scores = self.val_aggregate_scores
84
+ return max(range(len(scores)), key=lambda i: scores[i])
85
+
86
+ @property
87
+ def best_candidate(self) -> dict[str, str]:
88
+ return self.candidates[self.best_idx]
89
+
90
+ def to_dict(self) -> dict[str, Any]:
91
+ cands = [dict(cand.items()) for cand in self.candidates]
92
+
93
+ return {
94
+ "candidates": cands,
95
+ "parents": self.parents,
96
+ "val_aggregate_scores": self.val_aggregate_scores,
97
+ "val_subscores": self.val_subscores,
98
+ "best_outputs_valset": self.best_outputs_valset,
99
+ "per_val_instance_best_candidates": {
100
+ val_id: list(front) for val_id, front in self.per_val_instance_best_candidates.items()
101
+ },
102
+ "val_aggregate_subscores": self.val_aggregate_subscores,
103
+ "per_objective_best_candidates": (
104
+ {k: list(v) for k, v in self.per_objective_best_candidates.items()}
105
+ if self.per_objective_best_candidates is not None
106
+ else None
107
+ ),
108
+ "objective_pareto_front": self.objective_pareto_front,
109
+ "discovery_eval_counts": self.discovery_eval_counts,
110
+ "total_metric_calls": self.total_metric_calls,
111
+ "num_full_val_evals": self.num_full_val_evals,
112
+ "run_dir": self.run_dir,
113
+ "seed": self.seed,
114
+ "best_idx": self.best_idx,
115
+ "validation_schema_version": GEPAResult._VALIDATION_SCHEMA_VERSION,
116
+ }
117
+
118
+ @staticmethod
119
+ def from_dict(d: dict[str, Any]) -> "GEPAResult[RolloutOutput, DataId]":
120
+ version = d.get("validation_schema_version") or 0
121
+ if version > GEPAResult._VALIDATION_SCHEMA_VERSION:
122
+ raise ValueError(
123
+ f"Unsupported GEPAResult validation schema version {version}; "
124
+ f"max supported is {GEPAResult._VALIDATION_SCHEMA_VERSION}"
125
+ )
126
+
127
+ if version <= 1:
128
+ return GEPAResult._migrate_from_dict_v0(d)
129
+
130
+ return GEPAResult._from_dict_v2(d)
131
+
132
+ @staticmethod
133
+ def _common_kwargs_from_dict(d: dict[str, Any]) -> dict[str, Any]:
134
+ return {
135
+ "candidates": [dict(candidate) for candidate in d.get("candidates", [])],
136
+ "parents": [list(parent_row) for parent_row in d.get("parents", [])],
137
+ "val_aggregate_scores": list(d.get("val_aggregate_scores", [])),
138
+ "discovery_eval_counts": list(d.get("discovery_eval_counts", [])),
139
+ "total_metric_calls": d.get("total_metric_calls"),
140
+ "num_full_val_evals": d.get("num_full_val_evals"),
141
+ "run_dir": d.get("run_dir"),
142
+ "seed": d.get("seed"),
143
+ }
144
+
145
+ @staticmethod
146
+ def _migrate_from_dict_v0(d: dict[str, Any]) -> "GEPAResult[RolloutOutput, DataId]":
147
+ kwargs = GEPAResult._common_kwargs_from_dict(d)
148
+ kwargs["val_subscores"] = [
149
+ {idx: score for idx, score in enumerate(scores)} for scores in d.get("val_subscores", [])
150
+ ]
151
+ kwargs["per_val_instance_best_candidates"] = {
152
+ idx: set(front) for idx, front in enumerate(d.get("per_val_instance_best_candidates", []))
153
+ }
154
+
155
+ best_outputs_valset = d.get("best_outputs_valset")
156
+ if best_outputs_valset is not None:
157
+ kwargs["best_outputs_valset"] = {
158
+ idx: [(program_idx, output) for program_idx, output in outputs]
159
+ for idx, outputs in enumerate(best_outputs_valset)
160
+ }
161
+ else:
162
+ kwargs["best_outputs_valset"] = None
163
+ return GEPAResult(**kwargs)
164
+
165
+ @staticmethod
166
+ def _from_dict_v2(d: dict[str, Any]) -> "GEPAResult[RolloutOutput, DataId]":
167
+ kwargs = GEPAResult._common_kwargs_from_dict(d)
168
+ kwargs["val_subscores"] = [dict(scores) for scores in d.get("val_subscores", [])]
169
+ per_val_instance_best_candidates_data = d.get("per_val_instance_best_candidates", {})
170
+ kwargs["per_val_instance_best_candidates"] = {
171
+ val_id: set(candidates_on_front)
172
+ for val_id, candidates_on_front in per_val_instance_best_candidates_data.items()
173
+ }
174
+
175
+ best_outputs_valset = d.get("best_outputs_valset")
176
+ if best_outputs_valset is not None:
177
+ kwargs["best_outputs_valset"] = {
178
+ val_id: [(program_idx, output) for program_idx, output in outputs]
179
+ for val_id, outputs in best_outputs_valset.items()
180
+ }
181
+ else:
182
+ kwargs["best_outputs_valset"] = None
183
+
184
+ val_aggregate_subscores = d.get("val_aggregate_subscores")
185
+ kwargs["val_aggregate_subscores"] = (
186
+ [dict(scores) for scores in val_aggregate_subscores] if val_aggregate_subscores is not None else None
187
+ )
188
+
189
+ per_objective_best_candidates = d.get("per_objective_best_candidates")
190
+ if per_objective_best_candidates is not None:
191
+ kwargs["per_objective_best_candidates"] = {
192
+ objective: set(program_indices) for objective, program_indices in per_objective_best_candidates.items()
193
+ }
194
+ else:
195
+ kwargs["per_objective_best_candidates"] = None
196
+
197
+ objective_pareto_front = d.get("objective_pareto_front")
198
+ kwargs["objective_pareto_front"] = dict(objective_pareto_front) if objective_pareto_front is not None else None
199
+
200
+ return GEPAResult(**kwargs)
201
+
202
+ @staticmethod
203
+ def from_state(
204
+ state: "GEPAState[RolloutOutput, DataId]",
205
+ run_dir: str | None = None,
206
+ seed: int | None = None,
207
+ ) -> "GEPAResult[RolloutOutput, DataId]":
208
+ """Build a GEPAResult from a GEPAState."""
209
+ objective_scores_list = [dict(scores) for scores in state.prog_candidate_objective_scores]
210
+ has_objective_scores = any(obj for obj in objective_scores_list)
211
+ per_objective_best = {
212
+ objective: set(front) for objective, front in state.program_at_pareto_front_objectives.items()
213
+ }
214
+ objective_front = dict(state.objective_pareto_front)
215
+
216
+ return GEPAResult(
217
+ candidates=list(state.program_candidates),
218
+ parents=list(state.parent_program_for_candidate),
219
+ val_aggregate_scores=list(state.program_full_scores_val_set),
220
+ best_outputs_valset=getattr(state, "best_outputs_valset", None),
221
+ val_subscores=[dict(scores) for scores in state.prog_candidate_val_subscores],
222
+ per_val_instance_best_candidates={
223
+ val_id: set(front) for val_id, front in state.program_at_pareto_front_valset.items()
224
+ },
225
+ val_aggregate_subscores=(objective_scores_list if has_objective_scores else None),
226
+ per_objective_best_candidates=(per_objective_best if per_objective_best else None),
227
+ objective_pareto_front=objective_front if objective_front else None,
228
+ discovery_eval_counts=list(state.num_metric_calls_by_discovery),
229
+ total_metric_calls=getattr(state, "total_num_evals", None),
230
+ num_full_val_evals=getattr(state, "num_full_ds_evals", None),
231
+ run_dir=run_dir,
232
+ seed=seed,
233
+ )