mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,103 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+
5
+ from mantisdk.algorithm.gepa.lib.core.adapter import DataInst
6
+ from mantisdk.algorithm.gepa.lib.core.data_loader import DataId
7
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState, ValsetEvaluation
8
+ from mantisdk.algorithm.gepa.lib.strategies.eval_policy import EvaluationPolicy
9
+
10
+
11
+ def log_detailed_metrics_after_discovering_new_program(
12
+ logger,
13
+ gepa_state: GEPAState,
14
+ new_program_idx,
15
+ valset_evaluation: ValsetEvaluation,
16
+ objective_scores,
17
+ experiment_tracker,
18
+ linear_pareto_front_program_idx,
19
+ valset_size: int,
20
+ val_evaluation_policy: EvaluationPolicy[DataId, DataInst],
21
+ log_individual_valset_scores_and_programs: bool = False,
22
+ ):
23
+ # best_prog_per_agg_val_score = idxmax(gepa_state.program_full_scores_val_set)
24
+ best_prog_per_agg_val_score = val_evaluation_policy.get_best_program(gepa_state)
25
+ best_score_on_valset = val_evaluation_policy.get_valset_score(best_prog_per_agg_val_score, gepa_state)
26
+
27
+ # avg, coverage = gepa_state.get_program_average_val_subset(new_program_idx)
28
+ valset_score = val_evaluation_policy.get_valset_score(new_program_idx, gepa_state)
29
+ valset_scores = valset_evaluation.scores_by_val_id
30
+ coverage = len(valset_scores)
31
+ logger.log(
32
+ f"Iteration {gepa_state.i + 1}: Valset score for new program: {valset_score}"
33
+ f" (coverage {coverage} / {valset_size})"
34
+ )
35
+
36
+ agg_valset_score_new_program = val_evaluation_policy.get_valset_score(new_program_idx, gepa_state)
37
+
38
+ logger.log(f"Iteration {gepa_state.i + 1}: Val aggregate for new program: {agg_valset_score_new_program}")
39
+ logger.log(f"Iteration {gepa_state.i + 1}: Individual valset scores for new program: {valset_scores}")
40
+ if objective_scores:
41
+ logger.log(f"Iteration {gepa_state.i + 1}: Objective aggregate scores for new program: {objective_scores}")
42
+ logger.log(f"Iteration {gepa_state.i + 1}: New valset pareto front scores: {gepa_state.pareto_front_valset}")
43
+ if gepa_state.objective_pareto_front:
44
+ logger.log(f"Iteration {gepa_state.i + 1}: Objective pareto front scores: {gepa_state.objective_pareto_front}")
45
+
46
+ pareto_scores = list(gepa_state.pareto_front_valset.values())
47
+ assert all(score > float("-inf") for score in pareto_scores), (
48
+ "Should have at least one valid score per validation example"
49
+ )
50
+ assert len(pareto_scores) > 0
51
+ pareto_avg = sum(pareto_scores) / len(pareto_scores)
52
+
53
+ logger.log(f"Iteration {gepa_state.i + 1}: Valset pareto front aggregate score: {pareto_avg}")
54
+ logger.log(
55
+ f"Iteration {gepa_state.i + 1}: Updated valset pareto front programs: {gepa_state.program_at_pareto_front_valset}"
56
+ )
57
+ if gepa_state.program_at_pareto_front_objectives:
58
+ logger.log(
59
+ f"Iteration {gepa_state.i + 1}: Updated objective pareto front programs: {gepa_state.program_at_pareto_front_objectives}"
60
+ )
61
+ logger.log(
62
+ f"Iteration {gepa_state.i + 1}: Best valset aggregate score so far: {max(gepa_state.program_full_scores_val_set)}"
63
+ )
64
+ logger.log(
65
+ f"Iteration {gepa_state.i + 1}: Best program as per aggregate score on valset: {best_prog_per_agg_val_score}"
66
+ )
67
+ logger.log(f"Iteration {gepa_state.i + 1}: Best score on valset: {best_score_on_valset}")
68
+ logger.log(f"Iteration {gepa_state.i + 1}: Linear pareto front program index: {linear_pareto_front_program_idx}")
69
+ logger.log(f"Iteration {gepa_state.i + 1}: New program candidate index: {new_program_idx}")
70
+
71
+ metrics = {
72
+ "iteration": gepa_state.i + 1,
73
+ "new_program_idx": new_program_idx,
74
+ "valset_pareto_front_agg": pareto_avg,
75
+ "valset_pareto_front_programs": {k: list(v) for k, v in gepa_state.program_at_pareto_front_valset.items()},
76
+ "best_valset_agg_score": best_score_on_valset,
77
+ "linear_pareto_front_program_idx": linear_pareto_front_program_idx,
78
+ "best_program_as_per_agg_score_valset": best_prog_per_agg_val_score,
79
+ "best_score_on_valset": best_score_on_valset,
80
+ "val_evaluated_count_new_program": coverage,
81
+ "val_total_count": valset_size,
82
+ "val_program_average": valset_score,
83
+ }
84
+ if log_individual_valset_scores_and_programs:
85
+ metrics.update(
86
+ {
87
+ "valset_pareto_front_scores": dict(gepa_state.pareto_front_valset),
88
+ "individual_valset_score_new_program": dict(valset_scores),
89
+ }
90
+ )
91
+ if objective_scores:
92
+ metrics["objective_scores_new_program"] = dict(objective_scores)
93
+ if valset_evaluation.objective_scores_by_val_id:
94
+ metrics["objective_scores_by_val_new_program"] = {
95
+ val_id: dict(scores) for val_id, scores in valset_evaluation.objective_scores_by_val_id.items()
96
+ }
97
+ if gepa_state.objective_pareto_front:
98
+ metrics["objective_pareto_front_scores"] = dict(gepa_state.objective_pareto_front)
99
+ metrics["objective_pareto_front_programs"] = {
100
+ k: list(v) for k, v in gepa_state.program_at_pareto_front_objectives.items()
101
+ }
102
+
103
+ experiment_tracker.log_metrics(metrics, step=gepa_state.i + 1)
File without changes
@@ -0,0 +1,31 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Generic, Protocol
6
+
7
+ from mantisdk.algorithm.gepa.lib.core.data_loader import DataId
8
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState
9
+
10
+
11
+ @dataclass
12
+ class CandidateProposal(Generic[DataId]):
13
+ candidate: dict[str, str]
14
+ parent_program_ids: list[int]
15
+ # Optional mini-batch / subsample info
16
+ subsample_indices: list[DataId] | None = None
17
+ subsample_scores_before: list[float] | None = None
18
+ subsample_scores_after: list[float] | None = None
19
+ # Free-form metadata for logging/trace
20
+ tag: str = ""
21
+ metadata: dict[str, Any] = field(default_factory=dict)
22
+
23
+
24
+ class ProposeNewCandidate(Protocol[DataId]):
25
+ """
26
+ Strategy that receives the current optimizer state and proposes a new candidate or returns None.
27
+ It may compute subsample evaluations, set trace fields in state, etc.
28
+ The engine will handle acceptance and full eval unless the strategy already did those and encoded in metadata.
29
+ """
30
+
31
+ def propose(self, state: GEPAState[Any, DataId]) -> CandidateProposal | None: ...
@@ -0,0 +1,357 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ import math
5
+ import random
6
+ from collections.abc import Callable, Iterable, Sequence
7
+ from copy import deepcopy
8
+
9
+ from mantisdk.algorithm.gepa.lib.core.adapter import Candidate, DataInst, RolloutOutput
10
+ from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataLoader
11
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState, ObjectiveScores, ProgramIdx
12
+ from mantisdk.algorithm.gepa.lib.gepa_utils import find_dominator_programs
13
+ from mantisdk.algorithm.gepa.lib.logging.logger import LoggerProtocol
14
+ from mantisdk.algorithm.gepa.lib.proposer.base import CandidateProposal, ProposeNewCandidate
15
+
16
+ AncestorLog = tuple[int, int, int]
17
+ MergeDescription = tuple[int, int, tuple[int, ...]]
18
+ MergeAttempt = tuple[Candidate, ProgramIdx, ProgramIdx, ProgramIdx] | None
19
+
20
+
21
+ def does_triplet_have_desirable_predictors(
22
+ program_candidates: Sequence[Candidate],
23
+ ancestor: ProgramIdx,
24
+ id1: ProgramIdx,
25
+ id2: ProgramIdx,
26
+ ) -> bool:
27
+ found_predictors: list[tuple[int, int]] = []
28
+ pred_names = list(program_candidates[ancestor].keys())
29
+ for pred_idx, pred_name in enumerate(pred_names):
30
+ pred_anc = program_candidates[ancestor][pred_name]
31
+ pred_id1 = program_candidates[id1][pred_name]
32
+ pred_id2 = program_candidates[id2][pred_name]
33
+ if (pred_anc == pred_id1 or pred_anc == pred_id2) and pred_id1 != pred_id2:
34
+ same_as_ancestor_id = 1 if pred_anc == pred_id1 else 2
35
+ found_predictors.append((pred_idx, same_as_ancestor_id))
36
+
37
+ return len(found_predictors) > 0
38
+
39
+
40
+ def filter_ancestors(
41
+ i: ProgramIdx,
42
+ j: ProgramIdx,
43
+ common_ancestors: Iterable[ProgramIdx],
44
+ merges_performed: tuple[list[AncestorLog], list[MergeDescription]],
45
+ agg_scores: Sequence[float],
46
+ program_candidates: Sequence[Candidate],
47
+ ) -> list[ProgramIdx]:
48
+ filtered_ancestors: list[ProgramIdx] = []
49
+ for ancestor in common_ancestors:
50
+ if (i, j, ancestor) in merges_performed[0]:
51
+ continue
52
+
53
+ if agg_scores[ancestor] > agg_scores[i] or agg_scores[ancestor] > agg_scores[j]:
54
+ continue
55
+
56
+ if not does_triplet_have_desirable_predictors(program_candidates, ancestor, i, j):
57
+ continue
58
+
59
+ filtered_ancestors.append(ancestor)
60
+ return filtered_ancestors
61
+
62
+
63
+ def find_common_ancestor_pair(
64
+ rng: random.Random,
65
+ parent_list: Sequence[Sequence[int | None]],
66
+ program_indexes: Sequence[int],
67
+ merges_performed: tuple[list[AncestorLog], list[MergeDescription]],
68
+ agg_scores: Sequence[float],
69
+ program_candidates: Sequence[Candidate],
70
+ max_attempts: int = 10,
71
+ ) -> tuple[int, int, int] | None:
72
+ def get_ancestors(node: int, ancestors_found: set[int]) -> list[int]:
73
+ parents = parent_list[node]
74
+ for parent in parents:
75
+ if parent is not None and parent not in ancestors_found:
76
+ ancestors_found.add(parent)
77
+ get_ancestors(parent, ancestors_found)
78
+
79
+ return list(ancestors_found)
80
+
81
+ for _ in range(max_attempts):
82
+ if len(program_indexes) < 2:
83
+ return None
84
+ i, j = rng.sample(list(program_indexes), 2)
85
+ if i == j:
86
+ continue
87
+
88
+ if j < i:
89
+ i, j = j, i
90
+
91
+ ancestors_i = get_ancestors(i, set())
92
+ ancestors_j = get_ancestors(j, set())
93
+
94
+ if j in ancestors_i or i in ancestors_j:
95
+ # If one is an ancestor of the other, we cannot merge them
96
+ continue
97
+
98
+ common_ancestors = set(ancestors_i) & set(ancestors_j)
99
+ common_ancestors = filter_ancestors(i, j, common_ancestors, merges_performed, agg_scores, program_candidates)
100
+ if common_ancestors:
101
+ # Select a random common ancestor
102
+ common_ancestor = rng.choices(
103
+ list(common_ancestors),
104
+ k=1,
105
+ weights=[agg_scores[ancestor] for ancestor in common_ancestors],
106
+ )[0]
107
+ return (i, j, common_ancestor)
108
+
109
+ return None
110
+
111
+
112
+ def sample_and_attempt_merge_programs_by_common_predictors(
113
+ agg_scores: Sequence[float],
114
+ rng: random.Random,
115
+ merge_candidates: Sequence[int],
116
+ merges_performed: tuple[list[AncestorLog], list[MergeDescription]],
117
+ program_candidates: Sequence[Candidate],
118
+ parent_program_for_candidate: Sequence[Sequence[int | None]],
119
+ has_val_support_overlap: Callable[[ProgramIdx, ProgramIdx], bool] | None = None,
120
+ max_attempts: int = 10,
121
+ ) -> MergeAttempt:
122
+ if len(merge_candidates) < 2:
123
+ return None
124
+ if len(parent_program_for_candidate) < 3:
125
+ return None
126
+
127
+ for _ in range(max_attempts):
128
+ ids_to_merge = find_common_ancestor_pair(
129
+ rng,
130
+ parent_program_for_candidate,
131
+ list(merge_candidates),
132
+ merges_performed=merges_performed,
133
+ agg_scores=agg_scores,
134
+ program_candidates=program_candidates,
135
+ max_attempts=max_attempts,
136
+ )
137
+ if ids_to_merge is None:
138
+ continue
139
+ id1, id2, ancestor = ids_to_merge
140
+
141
+ if (id1, id2, ancestor) in merges_performed[0]:
142
+ continue
143
+ assert agg_scores[ancestor] <= agg_scores[id1], "Ancestor should not be better than its descendants"
144
+ assert agg_scores[ancestor] <= agg_scores[id2], "Ancestor should not be better than its descendants"
145
+ assert id1 != id2, "Cannot merge the same program"
146
+
147
+ # Now we have a common ancestor, which is outperformed by both its descendants
148
+
149
+ new_program: Candidate = deepcopy(program_candidates[ancestor])
150
+
151
+ new_prog_desc: tuple[ProgramIdx, ...] = ()
152
+
153
+ pred_names = set(program_candidates[ancestor].keys())
154
+ assert pred_names == set(program_candidates[id1].keys()) == set(program_candidates[id2].keys()), (
155
+ "Predictors should be the same across all programs"
156
+ )
157
+ for pred_name in pred_names:
158
+ pred_anc = program_candidates[ancestor][pred_name]
159
+ pred_id1 = program_candidates[id1][pred_name]
160
+ pred_id2 = program_candidates[id2][pred_name]
161
+ if (pred_anc == pred_id1 or pred_anc == pred_id2) and pred_id1 != pred_id2:
162
+ # We have a predictor that is the same as one of its ancestors, so we can update it with the other
163
+ same_as_ancestor_id = 1 if pred_anc == pred_id1 else 2
164
+ new_value_idx = id2 if same_as_ancestor_id == 1 else id1
165
+ new_program[pred_name] = program_candidates[new_value_idx][pred_name]
166
+ new_prog_desc = (*new_prog_desc, new_value_idx)
167
+ elif pred_anc != pred_id1 and pred_anc != pred_id2:
168
+ # Both predictors are different from the ancestor, and it is difficult to decide which one gives the benefits
169
+ # We randomly select one of the descendants to update the predictor
170
+ # The probability of selecting is proportional to the agg_scores of the descendants
171
+ # prog_to_get_instruction_from = id1 if (rng.random() < (agg_scores[id1] / (agg_scores[id1] + agg_scores[id2]))) else id2
172
+ prog_to_get_instruction_from = (
173
+ id1
174
+ if agg_scores[id1] > agg_scores[id2]
175
+ else (id2 if agg_scores[id2] > agg_scores[id1] else rng.choice([id1, id2]))
176
+ )
177
+ new_program[pred_name] = program_candidates[prog_to_get_instruction_from][pred_name]
178
+ new_prog_desc = (*new_prog_desc, prog_to_get_instruction_from)
179
+ elif pred_id1 == pred_id2:
180
+ # Either both predictors are the same, or both are different from the ancestor
181
+ # If both are different from the ancestor, we should use the new predictor, so selecting either one of the descendants is fine
182
+ # If both are same as the ancesor, again selecting any one of the descendants is fine
183
+ # So let's select id1
184
+ new_program[pred_name] = program_candidates[id1][pred_name]
185
+ new_prog_desc = (*new_prog_desc, id1)
186
+ else: # pragma: no cover - defensive
187
+ raise AssertionError("Unexpected case in predictor merging logic")
188
+
189
+ if (id1, id2, new_prog_desc) in merges_performed[1]:
190
+ # This triplet has already been merged, so we skip it
191
+ continue
192
+
193
+ if has_val_support_overlap and not has_val_support_overlap(id1, id2):
194
+ # Not enough overlapping validation support for candidates
195
+ continue
196
+
197
+ merges_performed[1].append((id1, id2, new_prog_desc))
198
+
199
+ return new_program, id1, id2, ancestor
200
+
201
+ return None
202
+
203
+
204
+ class MergeProposer(ProposeNewCandidate[DataId]):
205
+ """
206
+ Implements merge flow that combines compatible descendants of a common ancestor.
207
+
208
+ - Find merge candidates among Pareto front dominators
209
+ - Attempt a merge via sample_and_attempt_merge_programs_by_common_predictors
210
+ - Subsample eval on valset-driven selected indices
211
+ - Return proposal if merge's subsample score >= max(parents)
212
+ The engine handles full eval + adding to state.
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ logger: LoggerProtocol,
218
+ valset: DataLoader[DataId, DataInst],
219
+ evaluator: Callable[
220
+ [list[DataInst], dict[str, str]],
221
+ tuple[list[RolloutOutput], list[float], Sequence[ObjectiveScores] | None],
222
+ ],
223
+ use_merge: bool,
224
+ max_merge_invocations: int,
225
+ val_overlap_floor: int = 5,
226
+ rng: random.Random | None = None,
227
+ ):
228
+ self.logger = logger
229
+ self.valset = valset
230
+ self.evaluator = evaluator
231
+ self.use_merge = use_merge
232
+ self.max_merge_invocations = max_merge_invocations
233
+ self.rng = rng if rng is not None else random.Random(0)
234
+
235
+ if val_overlap_floor <= 0:
236
+ raise ValueError("val_overlap_floor should be a positive integer")
237
+ self.val_overlap_floor = val_overlap_floor
238
+ # Internal counters matching original behavior
239
+ self.merges_due = 0
240
+ self.total_merges_tested = 0
241
+ self.merges_performed: tuple[list[AncestorLog], list[MergeDescription]] = ([], [])
242
+
243
+ # Toggle controlled by engine: set True when last iter found new program
244
+ self.last_iter_found_new_program = False
245
+
246
+ def schedule_if_needed(self) -> None:
247
+ if self.use_merge and self.total_merges_tested < self.max_merge_invocations:
248
+ self.merges_due += 1
249
+
250
+ def select_eval_subsample_for_merged_program(
251
+ self,
252
+ scores1: dict[DataId, float],
253
+ scores2: dict[DataId, float],
254
+ num_subsample_ids: int = 5,
255
+ ) -> list[DataId]:
256
+ common_ids = list(set(scores1.keys()) & set(scores2.keys()))
257
+
258
+ p1 = [idx for idx in common_ids if scores1[idx] > scores2[idx]]
259
+ p2 = [idx for idx in common_ids if scores2[idx] > scores1[idx]]
260
+ p3 = [idx for idx in common_ids if idx not in p1 and idx not in p2]
261
+
262
+ n_each = max(1, math.ceil(num_subsample_ids / 3))
263
+ selected: list[DataId] = []
264
+ for bucket in (p1, p2, p3):
265
+ if len(selected) >= num_subsample_ids:
266
+ break
267
+ available = [idx for idx in bucket if idx not in selected]
268
+ take = min(len(available), n_each, num_subsample_ids - len(selected))
269
+ if take > 0:
270
+ selected += self.rng.sample(available, k=take)
271
+
272
+ remaining = num_subsample_ids - len(selected)
273
+ if remaining > 0:
274
+ unused = [idx for idx in common_ids if idx not in selected]
275
+ if len(unused) >= remaining:
276
+ selected += self.rng.sample(unused, k=remaining)
277
+ elif common_ids:
278
+ selected += self.rng.choices(common_ids, k=remaining)
279
+
280
+ return selected[:num_subsample_ids]
281
+
282
+ def propose(self, state: GEPAState[RolloutOutput, DataId]) -> CandidateProposal[DataId] | None:
283
+ i = state.i + 1
284
+ state.full_program_trace[-1]["invoked_merge"] = True
285
+
286
+ # Only attempt when scheduled by engine and after a new program in last iteration
287
+ if not (self.use_merge and self.last_iter_found_new_program and self.merges_due > 0):
288
+ self.logger.log(f"Iteration {i}: No merge candidates scheduled")
289
+ return None
290
+
291
+ pareto_front_programs = state.get_pareto_front_mapping()
292
+
293
+ tracked_scores: Sequence[float] = getattr(
294
+ state, "per_program_tracked_scores", state.program_full_scores_val_set
295
+ )
296
+ merge_candidates = find_dominator_programs(pareto_front_programs, list(tracked_scores))
297
+
298
+ def has_val_support_overlap(id1: ProgramIdx, id2: ProgramIdx) -> bool:
299
+ common_ids = set(state.prog_candidate_val_subscores[id1].keys()) & set(
300
+ state.prog_candidate_val_subscores[id2].keys()
301
+ )
302
+ return len(common_ids) >= self.val_overlap_floor
303
+
304
+ merge_output = sample_and_attempt_merge_programs_by_common_predictors(
305
+ agg_scores=list(tracked_scores),
306
+ rng=self.rng,
307
+ merge_candidates=merge_candidates,
308
+ merges_performed=self.merges_performed,
309
+ program_candidates=state.program_candidates,
310
+ parent_program_for_candidate=state.parent_program_for_candidate,
311
+ has_val_support_overlap=has_val_support_overlap,
312
+ )
313
+
314
+ if merge_output is None:
315
+ self.logger.log(f"Iteration {i}: No merge candidates found")
316
+ return None
317
+
318
+ new_program, id1, id2, ancestor = merge_output
319
+ state.full_program_trace[-1]["merged"] = True
320
+ state.full_program_trace[-1]["merged_entities"] = (id1, id2, ancestor)
321
+ self.merges_performed[0].append((id1, id2, ancestor))
322
+ self.logger.log(f"Iteration {i}: Merged programs {id1} and {id2} via ancestor {ancestor}")
323
+
324
+ subsample_ids = self.select_eval_subsample_for_merged_program(
325
+ state.prog_candidate_val_subscores[id1],
326
+ state.prog_candidate_val_subscores[id2],
327
+ )
328
+ if not subsample_ids:
329
+ self.logger.log(
330
+ f"Iteration {i}: Skipping merge of {id1} and {id2} due to insufficient overlapping val coverage"
331
+ )
332
+ return None
333
+
334
+ assert set(subsample_ids).issubset(state.prog_candidate_val_subscores[id1].keys())
335
+ assert set(subsample_ids).issubset(state.prog_candidate_val_subscores[id2].keys())
336
+ id1_sub_scores = [state.prog_candidate_val_subscores[id1][k] for k in subsample_ids]
337
+ id2_sub_scores = [state.prog_candidate_val_subscores[id2][k] for k in subsample_ids]
338
+ state.full_program_trace[-1]["subsample_ids"] = subsample_ids
339
+
340
+ new_sub_scores, actual_evals_count = state.cached_evaluate(
341
+ new_program, subsample_ids, self.valset.fetch, self.evaluator
342
+ )
343
+ state.full_program_trace[-1]["id1_subsample_scores"] = id1_sub_scores
344
+ state.full_program_trace[-1]["id2_subsample_scores"] = id2_sub_scores
345
+ state.full_program_trace[-1]["new_program_subsample_scores"] = new_sub_scores
346
+ state.total_num_evals += actual_evals_count
347
+
348
+ # Acceptance will be evaluated by engine (>= max(parents))
349
+ return CandidateProposal(
350
+ candidate=new_program,
351
+ parent_program_ids=[id1, id2],
352
+ subsample_indices=subsample_ids,
353
+ subsample_scores_before=[sum(id1_sub_scores), sum(id2_sub_scores)],
354
+ subsample_scores_after=new_sub_scores,
355
+ tag="merge",
356
+ metadata={"ancestor": ancestor},
357
+ )
@@ -0,0 +1,49 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Any, ClassVar, Mapping, Protocol, runtime_checkable
6
+
7
+ from mantisdk.algorithm.gepa.lib.core.adapter import Trajectory
8
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState
9
+
10
+
11
+ @runtime_checkable
12
+ class CandidateSelector(Protocol):
13
+ def select_candidate_idx(self, state: GEPAState) -> int: ...
14
+
15
+
16
+ class ReflectionComponentSelector(Protocol):
17
+ def __call__(
18
+ self,
19
+ state: GEPAState,
20
+ trajectories: list[Trajectory],
21
+ subsample_scores: list[float],
22
+ candidate_idx: int,
23
+ candidate: dict[str, str],
24
+ ) -> list[str]: ...
25
+
26
+
27
+ class LanguageModel(Protocol):
28
+ def __call__(self, prompt: str) -> str: ...
29
+
30
+
31
+ @dataclass
32
+ class Signature:
33
+ prompt_template: ClassVar[str]
34
+ input_keys: ClassVar[list[str]]
35
+ output_keys: ClassVar[list[str]]
36
+
37
+ @classmethod
38
+ def prompt_renderer(cls, input_dict: Mapping[str, Any]) -> str:
39
+ raise NotImplementedError
40
+
41
+ @classmethod
42
+ def output_extractor(cls, lm_out: str) -> dict[str, str]:
43
+ raise NotImplementedError
44
+
45
+ @classmethod
46
+ def run(cls, lm: LanguageModel, input_dict: Mapping[str, Any]) -> dict[str, str]:
47
+ full_prompt = cls.prompt_renderer(input_dict)
48
+ lm_out = lm(full_prompt).strip()
49
+ return cls.output_extractor(lm_out)