mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,636 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ import hashlib
5
+ import json
6
+ import os
7
+ from collections import defaultdict
8
+ from collections.abc import Callable, Sequence
9
+ from dataclasses import dataclass, field
10
+ from typing import Any, ClassVar, Generic, Literal, TypeAlias
11
+
12
+ from mantisdk.algorithm.gepa.lib.core.adapter import RolloutOutput
13
+ from mantisdk.algorithm.gepa.lib.core.data_loader import DataId
14
+ from mantisdk.algorithm.gepa.lib.gepa_utils import json_default
15
+ from mantisdk.algorithm.gepa.lib.logging.logger import LoggerProtocol
16
+
17
+ # Types for GEPAState
18
+ ProgramIdx = int
19
+
20
+ # Type aliases
21
+ ObjectiveScores: TypeAlias = dict[str, float]
22
+ FrontierType: TypeAlias = Literal["instance", "objective", "hybrid", "cartesian"]
23
+ """Strategy for tracking Pareto frontiers: 'instance' (per validation example), 'objective' (per objective metric), 'hybrid' (both), or 'cartesian' (per example × objective)."""
24
+ FrontierKey: TypeAlias = DataId | str | tuple[str, DataId] | tuple[str, DataId, str]
25
+ """Key type for frontier mappings depending on frontier_type."""
26
+
27
+ CandidateHash: TypeAlias = str
28
+ CacheKey: TypeAlias = tuple[CandidateHash, DataId]
29
+
30
+
31
+ def _candidate_hash(candidate: dict[str, str]) -> CandidateHash:
32
+ """Compute a deterministic hash of a candidate dictionary."""
33
+ return hashlib.sha256(json.dumps(sorted(candidate.items())).encode()).hexdigest()
34
+
35
+
36
+ @dataclass
37
+ class CachedEvaluation(Generic[RolloutOutput]):
38
+ """Cached evaluation result for a (candidate, example) pair."""
39
+
40
+ output: RolloutOutput
41
+ score: float
42
+ objective_scores: ObjectiveScores | None
43
+
44
+
45
+ @dataclass
46
+ class EvaluationCache(Generic[RolloutOutput, DataId]):
47
+ """Cache for storing evaluation results of (candidate, example) pairs."""
48
+
49
+ _cache: dict[CacheKey, CachedEvaluation[RolloutOutput]] = field(default_factory=dict)
50
+
51
+ def get(self, candidate: dict[str, str], example_id: DataId) -> CachedEvaluation[RolloutOutput] | None:
52
+ """Retrieve cached evaluation result if it exists."""
53
+ return self._cache.get((_candidate_hash(candidate), example_id))
54
+
55
+ def put(
56
+ self,
57
+ candidate: dict[str, str],
58
+ example_id: DataId,
59
+ output: RolloutOutput,
60
+ score: float,
61
+ objective_scores: ObjectiveScores | None = None,
62
+ ) -> None:
63
+ """Store an evaluation result in the cache."""
64
+ self._cache[(_candidate_hash(candidate), example_id)] = CachedEvaluation(output, score, objective_scores)
65
+
66
+ def get_batch(
67
+ self, candidate: dict[str, str], example_ids: list[DataId]
68
+ ) -> tuple[dict[DataId, CachedEvaluation[RolloutOutput]], list[DataId]]:
69
+ """Look up cached results for a batch. Returns (cached_results, uncached_ids)."""
70
+ h = _candidate_hash(candidate)
71
+ cached, uncached = {}, []
72
+ for eid in example_ids:
73
+ if entry := self._cache.get((h, eid)):
74
+ cached[eid] = entry
75
+ else:
76
+ uncached.append(eid)
77
+ return cached, uncached
78
+
79
+ def put_batch(
80
+ self,
81
+ candidate: dict[str, str],
82
+ example_ids: list[DataId],
83
+ outputs: list[RolloutOutput],
84
+ scores: list[float],
85
+ objective_scores_list: Sequence[ObjectiveScores] | None = None,
86
+ ) -> None:
87
+ """Store evaluation results for a batch of examples."""
88
+ h = _candidate_hash(candidate)
89
+ for i, eid in enumerate(example_ids):
90
+ self._cache[(h, eid)] = CachedEvaluation(
91
+ outputs[i], scores[i], objective_scores_list[i] if objective_scores_list else None
92
+ )
93
+
94
+ def evaluate_with_cache_full(
95
+ self,
96
+ candidate: dict[str, str],
97
+ example_ids: list[DataId],
98
+ fetcher: Callable[[list[DataId]], Any],
99
+ evaluator: Callable[[Any, dict[str, str]], tuple[Any, list[float], Sequence[ObjectiveScores] | None]],
100
+ ) -> tuple[dict[DataId, RolloutOutput], dict[DataId, float], dict[DataId, ObjectiveScores] | None, int]:
101
+ """
102
+ Evaluate using cache, returning full results.
103
+
104
+ Returns (outputs_by_id, scores_by_id, objective_scores_by_id, num_actual_evals).
105
+ """
106
+ cached, uncached_ids = self.get_batch(candidate, example_ids)
107
+
108
+ outputs_by_id: dict[DataId, RolloutOutput] = {eid: c.output for eid, c in cached.items()}
109
+ scores_by_id: dict[DataId, float] = {eid: c.score for eid, c in cached.items()}
110
+ objective_by_id: dict[DataId, ObjectiveScores] | None = None
111
+
112
+ # Populate objective scores from cache
113
+ for eid, c in cached.items():
114
+ if c.objective_scores is not None:
115
+ objective_by_id = objective_by_id or {}
116
+ objective_by_id[eid] = c.objective_scores
117
+
118
+ # Evaluate uncached examples
119
+ if uncached_ids:
120
+ batch = fetcher(uncached_ids)
121
+ outputs, scores, obj_scores = evaluator(batch, candidate)
122
+ for idx, eid in enumerate(uncached_ids):
123
+ outputs_by_id[eid] = outputs[idx]
124
+ scores_by_id[eid] = scores[idx]
125
+ if obj_scores is not None:
126
+ objective_by_id = objective_by_id or {}
127
+ objective_by_id[eid] = obj_scores[idx]
128
+ self.put_batch(candidate, uncached_ids, outputs, scores, obj_scores)
129
+
130
+ return outputs_by_id, scores_by_id, objective_by_id, len(uncached_ids)
131
+
132
+
133
+ @dataclass(slots=True)
134
+ class ValsetEvaluation(Generic[RolloutOutput, DataId]):
135
+ """Container for evaluation results on a validation set batch."""
136
+
137
+ outputs_by_val_id: dict[DataId, RolloutOutput]
138
+ scores_by_val_id: dict[DataId, float]
139
+ objective_scores_by_val_id: dict[DataId, ObjectiveScores] | None = None
140
+
141
+
142
+ class GEPAState(Generic[RolloutOutput, DataId]):
143
+ """Persistent optimizer state tracking candidates, sparse validation coverage, and objective frontiers."""
144
+
145
+ _VALIDATION_SCHEMA_VERSION: ClassVar[int] = 4
146
+
147
+ program_candidates: list[dict[str, str]]
148
+ parent_program_for_candidate: list[list[ProgramIdx | None]]
149
+ prog_candidate_val_subscores: list[dict[DataId, float]]
150
+ prog_candidate_objective_scores: list[ObjectiveScores]
151
+
152
+ pareto_front_valset: dict[DataId, float]
153
+ program_at_pareto_front_valset: dict[DataId, set[ProgramIdx]]
154
+ objective_pareto_front: ObjectiveScores
155
+ program_at_pareto_front_objectives: dict[str, set[ProgramIdx]]
156
+ pareto_front_cartesian: dict[tuple[DataId, str], float]
157
+ program_at_pareto_front_cartesian: dict[tuple[DataId, str], set[ProgramIdx]]
158
+
159
+ list_of_named_predictors: list[str]
160
+ named_predictor_id_to_update_next_for_program_candidate: list[int]
161
+
162
+ i: int
163
+ num_full_ds_evals: int
164
+
165
+ total_num_evals: int
166
+
167
+ num_metric_calls_by_discovery: list[int]
168
+
169
+ full_program_trace: list[dict[str, Any]]
170
+ best_outputs_valset: dict[DataId, list[tuple[ProgramIdx, RolloutOutput]]] | None
171
+
172
+ validation_schema_version: int
173
+
174
+ # Optional evaluation cache for (candidate, example) pairs
175
+ evaluation_cache: "EvaluationCache[RolloutOutput, DataId] | None"
176
+
177
+ def __init__(
178
+ self,
179
+ seed_candidate: dict[str, str],
180
+ base_evaluation: ValsetEvaluation[RolloutOutput, DataId],
181
+ track_best_outputs: bool = False,
182
+ frontier_type: FrontierType = "instance",
183
+ evaluation_cache: "EvaluationCache[RolloutOutput, DataId] | None" = None,
184
+ ):
185
+ self.program_candidates = [dict(seed_candidate)]
186
+ self.prog_candidate_val_subscores = [dict(base_evaluation.scores_by_val_id)]
187
+
188
+ base_objective_aggregates = self._aggregate_objective_scores(base_evaluation.objective_scores_by_val_id)
189
+ self.prog_candidate_objective_scores = [base_objective_aggregates]
190
+
191
+ self.parent_program_for_candidate = [[None]]
192
+
193
+ self.frontier_type: FrontierType = frontier_type
194
+ self.pareto_front_valset = {val_id: score for val_id, score in base_evaluation.scores_by_val_id.items()}
195
+ self.program_at_pareto_front_valset = {val_id: {0} for val_id in base_evaluation.scores_by_val_id.keys()}
196
+ self.objective_pareto_front = dict(base_objective_aggregates)
197
+ self.program_at_pareto_front_objectives = {objective: {0} for objective in base_objective_aggregates.keys()}
198
+
199
+ # Validate that objective scores are provided for frontier types that require them
200
+ if frontier_type in ("objective", "hybrid", "cartesian"):
201
+ if not base_evaluation.objective_scores_by_val_id:
202
+ raise ValueError(
203
+ f"frontier_type='{frontier_type}' requires objective_scores to be provided by the evaluator, "
204
+ f"but none were found. Use an evaluator that returns objective_scores or use frontier_type='instance'."
205
+ )
206
+
207
+ # Cartesian frontier will be base_evaluation.objective_scores_by_val_id
208
+ if frontier_type == "cartesian":
209
+ assert base_evaluation.objective_scores_by_val_id is not None # Already validated above
210
+ self.pareto_front_cartesian = {
211
+ (val_id, objective): objective_score
212
+ for val_id, objective_scores in base_evaluation.objective_scores_by_val_id.items()
213
+ for objective, objective_score in objective_scores.items()
214
+ }
215
+ self.program_at_pareto_front_cartesian = {
216
+ (val_id, objective): {0}
217
+ for val_id, objective_scores in base_evaluation.objective_scores_by_val_id.items()
218
+ for objective in objective_scores.keys()
219
+ }
220
+ else:
221
+ self.pareto_front_cartesian = {}
222
+ self.program_at_pareto_front_cartesian = {}
223
+
224
+ self.list_of_named_predictors = list(seed_candidate.keys())
225
+ self.named_predictor_id_to_update_next_for_program_candidate = [0]
226
+ self.i = -1
227
+
228
+ self.num_metric_calls_by_discovery = [0]
229
+
230
+ if track_best_outputs:
231
+ self.best_outputs_valset = {
232
+ val_id: [(0, output)] for val_id, output in base_evaluation.outputs_by_val_id.items()
233
+ }
234
+ else:
235
+ self.best_outputs_valset = None
236
+
237
+ self.full_program_trace = []
238
+ self.validation_schema_version = self._VALIDATION_SCHEMA_VERSION
239
+ self.evaluation_cache = evaluation_cache
240
+
241
+ def is_consistent(self) -> bool:
242
+ assert len(self.program_candidates) == len(self.parent_program_for_candidate)
243
+ assert len(self.program_candidates) == len(self.named_predictor_id_to_update_next_for_program_candidate)
244
+ assert len(self.program_candidates) == len(self.prog_candidate_val_subscores)
245
+ assert len(self.program_candidates) == len(self.prog_candidate_objective_scores)
246
+ assert len(self.program_candidates) == len(self.num_metric_calls_by_discovery)
247
+
248
+ assert len(self.pareto_front_valset) == len(self.program_at_pareto_front_valset)
249
+ assert set(self.pareto_front_valset.keys()) == set(self.program_at_pareto_front_valset.keys())
250
+ assert set(self.objective_pareto_front.keys()) == set(self.program_at_pareto_front_objectives.keys())
251
+
252
+ for front in self.program_at_pareto_front_valset.values():
253
+ for prog_idx in front:
254
+ assert prog_idx < len(self.program_candidates), (
255
+ "Program index in valset pareto front exceeds number of program candidates"
256
+ )
257
+
258
+ return True
259
+
260
+ def save(self, run_dir: str | None, *, use_cloudpickle: bool = False) -> None:
261
+ if run_dir is None:
262
+ return
263
+ with open(os.path.join(run_dir, "gepa_state.bin"), "wb") as f:
264
+ if use_cloudpickle:
265
+ import cloudpickle as pickle # type: ignore[import-not-found]
266
+ else:
267
+ import pickle
268
+ serialized = dict(self.__dict__.items())
269
+ serialized["validation_schema_version"] = GEPAState._VALIDATION_SCHEMA_VERSION
270
+ pickle.dump(serialized, f)
271
+
272
+ @staticmethod
273
+ def load(run_dir: str) -> "GEPAState[RolloutOutput, DataId]":
274
+ with open(os.path.join(run_dir, "gepa_state.bin"), "rb") as f:
275
+ import pickle
276
+
277
+ data = pickle.load(f)
278
+
279
+ # handle schema migration
280
+ version = data.get("validation_schema_version")
281
+ if version is None or version < 2:
282
+ GEPAState._migrate_from_legacy_state_v0(data)
283
+ version = data.get("validation_schema_version")
284
+ if version is None or version < GEPAState._VALIDATION_SCHEMA_VERSION:
285
+ GEPAState._upgrade_state_dict(data)
286
+
287
+ state = GEPAState.__new__(GEPAState)
288
+ state.__dict__.update(data)
289
+
290
+ state.validation_schema_version = GEPAState._VALIDATION_SCHEMA_VERSION
291
+ assert len(state.program_candidates) == len(state.prog_candidate_val_subscores)
292
+ assert len(state.program_candidates) == len(state.prog_candidate_objective_scores)
293
+ assert len(state.program_candidates) == len(state.num_metric_calls_by_discovery)
294
+ assert len(state.program_candidates) == len(state.parent_program_for_candidate)
295
+ assert len(state.program_candidates) == len(state.named_predictor_id_to_update_next_for_program_candidate)
296
+ assert len(state.pareto_front_valset) == len(state.program_at_pareto_front_valset)
297
+ assert set(state.pareto_front_valset.keys()) == set(state.program_at_pareto_front_valset.keys())
298
+ assert set(state.objective_pareto_front.keys()) == set(state.program_at_pareto_front_objectives.keys())
299
+ return state
300
+
301
+ @staticmethod
302
+ def _migrate_from_legacy_state_v0(d: dict[str, Any]) -> None:
303
+ assert isinstance(d, dict)
304
+ assert "prog_candidate_val_subscores" in d
305
+ assert isinstance(d["prog_candidate_val_subscores"], list)
306
+ assert all(isinstance(scores, list) for scores in d["prog_candidate_val_subscores"])
307
+ legacy_scores: list[list[float]] = d.pop("prog_candidate_val_subscores", [])
308
+ d["prog_candidate_val_subscores"] = [
309
+ {idx: score for idx, score in enumerate(scores)} for scores in legacy_scores
310
+ ]
311
+
312
+ pareto_front = d.get("pareto_front_valset")
313
+ if isinstance(pareto_front, list):
314
+ d["pareto_front_valset"] = {idx: score for idx, score in enumerate(pareto_front)}
315
+
316
+ program_at_front = d.get("program_at_pareto_front_valset")
317
+ if isinstance(program_at_front, list):
318
+ d["program_at_pareto_front_valset"] = {idx: set(front) for idx, front in enumerate(program_at_front)}
319
+
320
+ best_outputs = d.get("best_outputs_valset")
321
+ if isinstance(best_outputs, list):
322
+ d["best_outputs_valset"] = {idx: list(outputs) for idx, outputs in enumerate(best_outputs)}
323
+
324
+ d["validation_schema_version"] = 2
325
+
326
+ @staticmethod
327
+ def _upgrade_state_dict(d: dict[str, Any]) -> None:
328
+ num_candidates = len(d.get("program_candidates", []))
329
+ if "prog_candidate_objective_scores" not in d:
330
+ d["prog_candidate_objective_scores"] = [{} for _ in range(num_candidates)]
331
+ if "objective_pareto_front" not in d:
332
+ d["objective_pareto_front"] = {}
333
+ if "program_at_pareto_front_objectives" not in d:
334
+ d["program_at_pareto_front_objectives"] = {}
335
+ if "frontier_type" not in d:
336
+ d["frontier_type"] = "instance"
337
+ # Since frontier_type instance does not require "pareto_front_cartesian" and "program_at_pareto_front_cartesian", we can safely set them to empty dicts.
338
+ d["pareto_front_cartesian"] = {}
339
+ d["program_at_pareto_front_cartesian"] = {}
340
+ # evaluation_cache is not persisted across runs by default; initialize to None if missing
341
+ if "evaluation_cache" not in d:
342
+ d["evaluation_cache"] = None
343
+ d["validation_schema_version"] = GEPAState._VALIDATION_SCHEMA_VERSION
344
+
345
+ @staticmethod
346
+ def _aggregate_objective_scores(
347
+ val_objective_scores: dict[DataId, ObjectiveScores] | None,
348
+ ) -> ObjectiveScores:
349
+ if not val_objective_scores:
350
+ return {}
351
+ totals: dict[str, float] = {}
352
+ counts: dict[str, int] = {}
353
+ for objective_dict in val_objective_scores.values():
354
+ for objective, score in objective_dict.items():
355
+ totals[objective] = totals.get(objective, 0.0) + score
356
+ counts[objective] = counts.get(objective, 0) + 1
357
+ return {
358
+ objective: totals[objective] / counts[objective] for objective in totals.keys() if counts[objective] > 0
359
+ }
360
+
361
+ def get_program_average_val_subset(self, program_idx: int) -> tuple[float, int]:
362
+ # TODO: This should be only used/handled by the val_evaluation_policy, and never used directly.
363
+ scores = self.prog_candidate_val_subscores[program_idx]
364
+ if not scores:
365
+ return float("-inf"), 0
366
+ num_samples = len(scores)
367
+ avg = sum(scores.values()) / num_samples
368
+ return avg, num_samples
369
+
370
+ @property
371
+ def valset_evaluations(self) -> dict[DataId, list[ProgramIdx]]:
372
+ """
373
+ Valset examples by id and programs that have evaluated them. Keys include only validation
374
+ ids that have been scored at least once.
375
+ """
376
+ result: dict[DataId, list[ProgramIdx]] = defaultdict(list)
377
+ for program_idx, val_scores in enumerate(self.prog_candidate_val_subscores):
378
+ for val_id in val_scores.keys():
379
+ result[val_id].append(program_idx)
380
+ return result
381
+
382
+ @property
383
+ def program_full_scores_val_set(self) -> list[float]:
384
+ # TODO: This should be using the val_evaluation_policy instead of the get_program_average_val_subset method to calculate the scores.
385
+ return [
386
+ self.get_program_average_val_subset(program_idx)[0]
387
+ for program_idx in range(len(self.prog_candidate_val_subscores))
388
+ ]
389
+
390
+ @property
391
+ def per_program_tracked_scores(self) -> list[float]:
392
+ return [
393
+ self.get_program_average_val_subset(program_idx)[0]
394
+ for program_idx in range(len(self.prog_candidate_val_subscores))
395
+ ]
396
+
397
+ def _update_objective_pareto_front(self, objective_scores: ObjectiveScores, program_idx: ProgramIdx) -> None:
398
+ if not objective_scores:
399
+ return
400
+ for objective, score in objective_scores.items():
401
+ prev_score = self.objective_pareto_front.get(objective, float("-inf"))
402
+ if score > prev_score:
403
+ self.objective_pareto_front[objective] = score
404
+ self.program_at_pareto_front_objectives[objective] = {program_idx}
405
+ elif score == prev_score:
406
+ front = self.program_at_pareto_front_objectives.setdefault(objective, set())
407
+ front.add(program_idx)
408
+
409
+ def _update_pareto_front_for_val_id(
410
+ self,
411
+ val_id: DataId,
412
+ score: float,
413
+ program_idx: ProgramIdx,
414
+ output: RolloutOutput | None,
415
+ run_dir: str | None,
416
+ iteration: int,
417
+ ) -> None:
418
+ prev_score = self.pareto_front_valset.get(val_id, float("-inf"))
419
+ if score > prev_score:
420
+ self.pareto_front_valset[val_id] = score
421
+ self.program_at_pareto_front_valset[val_id] = {program_idx}
422
+ if self.best_outputs_valset is not None and output is not None:
423
+ self.best_outputs_valset[val_id] = [(program_idx, output)]
424
+ if run_dir is not None:
425
+ task_dir = os.path.join(run_dir, "generated_best_outputs_valset", f"task_{val_id}")
426
+ os.makedirs(task_dir, exist_ok=True)
427
+ with open(os.path.join(task_dir, f"iter_{iteration}_prog_{program_idx}.json"), "w") as fout:
428
+ json.dump(output, fout, indent=4, default=json_default)
429
+ elif score == prev_score:
430
+ pareto_front = self.program_at_pareto_front_valset.setdefault(val_id, set())
431
+ pareto_front.add(program_idx)
432
+ if self.best_outputs_valset is not None and output is not None:
433
+ self.best_outputs_valset[val_id].append((program_idx, output))
434
+
435
+ def _update_pareto_front_for_cartesian(
436
+ self,
437
+ val_id: DataId,
438
+ objective: str,
439
+ objective_score: float,
440
+ program_idx: ProgramIdx,
441
+ ) -> None:
442
+ prev_score = self.pareto_front_cartesian.get((val_id, objective), float("-inf"))
443
+ if objective_score > prev_score:
444
+ self.pareto_front_cartesian[(val_id, objective)] = objective_score
445
+ self.program_at_pareto_front_cartesian[(val_id, objective)] = {program_idx}
446
+ elif objective_score == prev_score:
447
+ front = self.program_at_pareto_front_cartesian.setdefault((val_id, objective), set())
448
+ front.add(program_idx)
449
+
450
+ def update_state_with_new_program(
451
+ self,
452
+ parent_program_idx: list[ProgramIdx],
453
+ new_program: dict[str, str],
454
+ valset_evaluation: ValsetEvaluation,
455
+ run_dir: str | None,
456
+ num_metric_calls_by_discovery_of_new_program: int,
457
+ ) -> ProgramIdx:
458
+ new_program_idx = len(self.program_candidates)
459
+ self.program_candidates.append(dict(new_program))
460
+ self.num_metric_calls_by_discovery.append(num_metric_calls_by_discovery_of_new_program)
461
+
462
+ max_predictor_id = max(
463
+ [self.named_predictor_id_to_update_next_for_program_candidate[p] for p in parent_program_idx],
464
+ default=0,
465
+ )
466
+ self.named_predictor_id_to_update_next_for_program_candidate.append(max_predictor_id)
467
+ self.parent_program_for_candidate.append(list(parent_program_idx))
468
+
469
+ valset_scores = dict(valset_evaluation.scores_by_val_id)
470
+ self.prog_candidate_val_subscores.append(valset_scores)
471
+ objective_scores = self._aggregate_objective_scores(valset_evaluation.objective_scores_by_val_id)
472
+ self.prog_candidate_objective_scores.append(objective_scores)
473
+
474
+ for val_id, score in valset_scores.items():
475
+ output = valset_evaluation.outputs_by_val_id.get(val_id) if valset_evaluation.outputs_by_val_id else None
476
+ self._update_pareto_front_for_val_id(
477
+ val_id,
478
+ score,
479
+ new_program_idx,
480
+ output,
481
+ run_dir,
482
+ self.i + 1,
483
+ )
484
+
485
+ self._update_objective_pareto_front(objective_scores, new_program_idx)
486
+
487
+ if self.frontier_type in ("objective", "hybrid", "cartesian"):
488
+ if not valset_evaluation.objective_scores_by_val_id:
489
+ raise ValueError(
490
+ f"frontier_type='{self.frontier_type}' requires objective_scores to be provided by the evaluator, "
491
+ f"but none were found in the evaluation result."
492
+ )
493
+
494
+ if self.frontier_type == "cartesian":
495
+ assert valset_evaluation.objective_scores_by_val_id is not None # Validated above
496
+ for val_id, objective_scores in valset_evaluation.objective_scores_by_val_id.items():
497
+ for objective, objective_score in objective_scores.items():
498
+ self._update_pareto_front_for_cartesian(
499
+ val_id,
500
+ objective,
501
+ objective_score,
502
+ new_program_idx,
503
+ )
504
+
505
+ return new_program_idx
506
+
507
+ def _get_pareto_front_mapping(self, frontier_type: FrontierType) -> dict[FrontierKey, set[ProgramIdx]]:
508
+ if frontier_type == "instance":
509
+ return {val_id: set(front) for val_id, front in self.program_at_pareto_front_valset.items()}
510
+ if frontier_type == "objective":
511
+ return {objective: set(front) for objective, front in self.program_at_pareto_front_objectives.items()}
512
+ if frontier_type == "hybrid":
513
+ combined: dict[FrontierKey, set[ProgramIdx]] = {
514
+ ("val_id", val_id): set(front) for val_id, front in self.program_at_pareto_front_valset.items()
515
+ }
516
+ for objective, front in self.program_at_pareto_front_objectives.items():
517
+ combined[("objective", objective)] = set(front)
518
+ return combined
519
+ if frontier_type == "cartesian":
520
+ return {
521
+ ("cartesian", val_id, objective): set(front)
522
+ for (val_id, objective), front in self.program_at_pareto_front_cartesian.items()
523
+ }
524
+ raise ValueError(f"Unknown frontier_type: {frontier_type}")
525
+
526
+ def get_pareto_front_mapping(self) -> dict[FrontierKey, set[ProgramIdx]]:
527
+ """Return frontier key to best-program-indices mapping based on configured frontier_type."""
528
+ return self._get_pareto_front_mapping(self.frontier_type)
529
+
530
+ def cached_evaluate(
531
+ self,
532
+ candidate: dict[str, str],
533
+ example_ids: list[DataId],
534
+ fetcher: Callable[[list[DataId]], Any],
535
+ evaluator: Callable[[Any, dict[str, str]], tuple[Any, list[float], Sequence[ObjectiveScores] | None]],
536
+ ) -> tuple[list[float], int]:
537
+ """Evaluate with optional caching. Returns (scores, num_actual_evals)."""
538
+ _, scores_by_id, _, num_actual_evals = self.cached_evaluate_full(candidate, example_ids, fetcher, evaluator)
539
+ return [scores_by_id[eid] for eid in example_ids], num_actual_evals
540
+
541
+ def cached_evaluate_full(
542
+ self,
543
+ candidate: dict[str, str],
544
+ example_ids: list[DataId],
545
+ fetcher: Callable[[list[DataId]], Any],
546
+ evaluator: Callable[[Any, dict[str, str]], tuple[Any, list[float], Sequence[ObjectiveScores] | None]],
547
+ ) -> tuple[dict[DataId, RolloutOutput], dict[DataId, float], dict[DataId, ObjectiveScores] | None, int]:
548
+ """Evaluate with optional caching, returning full results."""
549
+ if self.evaluation_cache is not None:
550
+ return self.evaluation_cache.evaluate_with_cache_full(candidate, example_ids, fetcher, evaluator)
551
+ batch = fetcher(example_ids)
552
+ outputs, scores, objective_scores = evaluator(batch, candidate)
553
+ outputs_by_id = dict(zip(example_ids, outputs, strict=False))
554
+ scores_by_id = dict(zip(example_ids, scores, strict=False))
555
+ objective_by_id = dict(zip(example_ids, objective_scores, strict=False)) if objective_scores else None
556
+ return outputs_by_id, scores_by_id, objective_by_id, len(example_ids)
557
+
558
+
559
+ def write_eval_scores_to_directory(scores: dict[DataId, float], output_dir: str) -> None:
560
+ for val_id, score in scores.items():
561
+ task_dir = os.path.join(output_dir, f"task_{val_id}")
562
+ os.makedirs(task_dir, exist_ok=True)
563
+ with open(os.path.join(task_dir, f"iter_{0}_prog_0.json"), "w") as f:
564
+ json.dump(score, f, indent=4, default=json_default)
565
+
566
+
567
+ def write_eval_outputs_to_directory(outputs, output_dir: str) -> None:
568
+ """
569
+ Write generated rollout outputs (not scalar scores) to disk.
570
+
571
+ Structure:
572
+ {output_dir}/task_{val_id}/iter_0_prog_0.json
573
+
574
+ This directory is used to store best outputs for inspection/reuse.
575
+ """
576
+ for val_id, output in outputs.items():
577
+ task_dir = os.path.join(output_dir, f"task_{val_id}")
578
+ os.makedirs(task_dir, exist_ok=True)
579
+ with open(os.path.join(task_dir, "iter_0_prog_0.json"), "w") as f:
580
+ json.dump(output, f, indent=4, default=json_default)
581
+
582
+
583
+ def initialize_gepa_state(
584
+ run_dir: str | None,
585
+ logger: LoggerProtocol,
586
+ seed_candidate: dict[str, str],
587
+ valset_evaluator: Callable[
588
+ [dict[str, str]],
589
+ ValsetEvaluation[RolloutOutput, DataId],
590
+ ],
591
+ track_best_outputs: bool = False,
592
+ frontier_type: FrontierType = "instance",
593
+ evaluation_cache: "EvaluationCache[RolloutOutput, DataId] | None" = None,
594
+ ) -> GEPAState[RolloutOutput, DataId]:
595
+ if run_dir is not None and os.path.exists(os.path.join(run_dir, "gepa_state.bin")):
596
+ logger.log("Loading gepa state from run dir")
597
+ gepa_state = GEPAState.load(run_dir)
598
+ if gepa_state.frontier_type != frontier_type:
599
+ raise ValueError(
600
+ f"Frontier type mismatch: requested '{frontier_type}' but loaded state has '{gepa_state.frontier_type}'. "
601
+ f"Use a different run_dir or match the frontier_type parameter."
602
+ )
603
+ # Sync cache with current run's cache_evaluation setting:
604
+ # - If caching is disabled (evaluation_cache is None), clear any loaded cache
605
+ # to respect the current run's cache_evaluation=False setting
606
+ # - If caching is enabled and the loaded state has a cache, preserve it
607
+ # (allows resuming with cached results from previous run)
608
+ # - If caching is enabled but no cache exists in loaded state, use the new empty cache
609
+ if evaluation_cache is None:
610
+ gepa_state.evaluation_cache = None
611
+ elif gepa_state.evaluation_cache is None:
612
+ gepa_state.evaluation_cache = evaluation_cache
613
+ # else: keep the loaded cache (gepa_state.evaluation_cache is already set)
614
+ else:
615
+ num_evals_run = 0
616
+
617
+ eval_result = valset_evaluator(seed_candidate)
618
+ if run_dir is not None:
619
+ write_eval_outputs_to_directory(
620
+ eval_result.outputs_by_val_id, os.path.join(run_dir, "generated_best_outputs_valset")
621
+ )
622
+
623
+ num_evals_run += len(eval_result.scores_by_val_id)
624
+
625
+ gepa_state = GEPAState(
626
+ seed_candidate,
627
+ eval_result,
628
+ track_best_outputs=track_best_outputs,
629
+ frontier_type=frontier_type,
630
+ evaluation_cache=evaluation_cache,
631
+ )
632
+
633
+ gepa_state.num_full_ds_evals = 1
634
+ gepa_state.total_num_evals = num_evals_run
635
+
636
+ return gepa_state
File without changes
@@ -0,0 +1,24 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+
5
+ def init_dataset():
6
+ import random
7
+
8
+ from datasets import load_dataset
9
+
10
+ train_split = [
11
+ {"input": x["problem"], "additional_context": {"solution": x["solution"]}, "answer": "### " + str(x["answer"])}
12
+ for x in load_dataset("AI-MO/aimo-validation-aime")["train"]
13
+ ]
14
+ random.Random(0).shuffle(train_split)
15
+ test_split = [
16
+ {"input": x["problem"], "answer": "### " + str(x["answer"])}
17
+ for x in load_dataset("MathArena/aime_2025")["train"]
18
+ ]
19
+
20
+ trainset = train_split[: len(train_split) // 2]
21
+ valset = train_split[len(train_split) // 2 :]
22
+ testset = test_split * 5
23
+
24
+ return trainset, valset, testset