mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,176 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ from collections.abc import Mapping, Sequence
5
+ from typing import Any
6
+
7
+ from mantisdk.algorithm.gepa.lib.core.adapter import DataInst, GEPAAdapter, RolloutOutput, Trajectory
8
+ from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataLoader, ensure_loader
9
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState
10
+ from mantisdk.algorithm.gepa.lib.proposer.base import CandidateProposal, ProposeNewCandidate
11
+ from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.base import (
12
+ CandidateSelector,
13
+ LanguageModel,
14
+ ReflectionComponentSelector,
15
+ )
16
+ from mantisdk.algorithm.gepa.lib.strategies.batch_sampler import BatchSampler
17
+ from mantisdk.algorithm.gepa.lib.strategies.instruction_proposal import InstructionProposalSignature
18
+
19
+
20
+ class ReflectiveMutationProposer(ProposeNewCandidate[DataId]):
21
+ """
22
+ Implements current reflective mutation flow:
23
+ - Select candidate via selector
24
+ - Select minibatch via sampler
25
+ - capture_traces_and_eval -> trajectories, subsample_scores
26
+ - skip if all scores==perfect and skip_perfect_score
27
+ - reflection + mutate -> new candidate
28
+ - evaluate new candidate on same minibatch -> new_subsample_scores
29
+ - Return proposal if improved; else None
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ logger: Any,
35
+ trainset: list[DataInst] | DataLoader[DataId, DataInst],
36
+ adapter: GEPAAdapter[DataInst, Trajectory, RolloutOutput],
37
+ candidate_selector: CandidateSelector,
38
+ module_selector: ReflectionComponentSelector,
39
+ batch_sampler: BatchSampler[DataId, DataInst],
40
+ perfect_score: float,
41
+ skip_perfect_score: bool,
42
+ experiment_tracker: Any,
43
+ reflection_lm: LanguageModel | None = None,
44
+ reflection_prompt_template: str | None = None,
45
+ ):
46
+ self.logger = logger
47
+ self.trainset = ensure_loader(trainset)
48
+ self.adapter = adapter
49
+ self.candidate_selector = candidate_selector
50
+ self.module_selector = module_selector
51
+ self.batch_sampler = batch_sampler
52
+ self.perfect_score = perfect_score
53
+ self.skip_perfect_score = skip_perfect_score
54
+ self.experiment_tracker = experiment_tracker
55
+ self.reflection_lm = reflection_lm
56
+
57
+ InstructionProposalSignature.validate_prompt_template(reflection_prompt_template)
58
+ self.reflection_prompt_template = reflection_prompt_template
59
+
60
+ def propose_new_texts(
61
+ self,
62
+ candidate: dict[str, str],
63
+ reflective_dataset: Mapping[str, Sequence[Mapping[str, Any]]],
64
+ components_to_update: list[str],
65
+ ) -> dict[str, str]:
66
+ if self.adapter.propose_new_texts is not None:
67
+ return self.adapter.propose_new_texts(candidate, reflective_dataset, components_to_update)
68
+
69
+ if self.reflection_lm is None:
70
+ raise ValueError("reflection_lm must be provided when adapter.propose_new_texts is None.")
71
+ new_texts: dict[str, str] = {}
72
+ for name in components_to_update:
73
+ # Gracefully handle cases where a selected component has no data in reflective_dataset
74
+ if name not in reflective_dataset or not reflective_dataset.get(name):
75
+ self.logger.log(f"Component '{name}' is not in reflective dataset. Skipping.")
76
+ continue
77
+
78
+ base_instruction = candidate[name]
79
+ dataset_with_feedback = reflective_dataset[name]
80
+ new_texts[name] = InstructionProposalSignature.run(
81
+ lm=self.reflection_lm,
82
+ input_dict={
83
+ "current_instruction_doc": base_instruction,
84
+ "dataset_with_feedback": dataset_with_feedback,
85
+ "prompt_template": self.reflection_prompt_template,
86
+ },
87
+ )["new_instruction"]
88
+ return new_texts
89
+
90
+ def propose(self, state: GEPAState) -> CandidateProposal | None:
91
+ i = state.i + 1
92
+
93
+ curr_prog_id = self.candidate_selector.select_candidate_idx(state)
94
+ curr_prog = state.program_candidates[curr_prog_id]
95
+ state.full_program_trace[-1]["selected_program_candidate"] = curr_prog_id
96
+ self.logger.log(
97
+ f"Iteration {i}: Selected program {curr_prog_id} score: {state.program_full_scores_val_set[curr_prog_id]}"
98
+ )
99
+
100
+ self.experiment_tracker.log_metrics({"iteration": i, "selected_program_candidate": curr_prog_id}, step=i)
101
+
102
+ subsample_ids = self.batch_sampler.next_minibatch_ids(self.trainset, state)
103
+ state.full_program_trace[-1]["subsample_ids"] = subsample_ids
104
+ minibatch = self.trainset.fetch(subsample_ids)
105
+
106
+ # 1) Evaluate current program with traces
107
+ # Note: We don't use cache for capture_traces=True evaluations since we need fresh traces for reflection
108
+ eval_curr = self.adapter.evaluate(minibatch, curr_prog, capture_traces=True)
109
+ state.total_num_evals += len(subsample_ids)
110
+ state.full_program_trace[-1]["subsample_scores"] = eval_curr.scores
111
+
112
+ # Update cache with current program evaluation results (for future reuse when capture_traces=False)
113
+ if state.evaluation_cache is not None:
114
+ objective_scores_list = list(eval_curr.objective_scores) if eval_curr.objective_scores else None
115
+ state.evaluation_cache.put_batch(
116
+ curr_prog, subsample_ids, eval_curr.outputs, eval_curr.scores, objective_scores_list
117
+ )
118
+
119
+ if not eval_curr.trajectories or len(eval_curr.trajectories) == 0:
120
+ self.logger.log(f"Iteration {i}: No trajectories captured. Skipping.")
121
+ return None
122
+
123
+ if self.skip_perfect_score and all(s >= self.perfect_score for s in eval_curr.scores):
124
+ self.logger.log(f"Iteration {i}: All subsample scores perfect. Skipping.")
125
+ return None
126
+
127
+ self.experiment_tracker.log_metrics({"subsample_score": sum(eval_curr.scores)}, step=i)
128
+
129
+ # 2) Decide which predictors to update
130
+ predictor_names_to_update = self.module_selector(
131
+ state, eval_curr.trajectories, eval_curr.scores, curr_prog_id, curr_prog
132
+ )
133
+
134
+ # 3) Build reflective dataset and propose texts
135
+ try:
136
+ reflective_dataset = self.adapter.make_reflective_dataset(curr_prog, eval_curr, predictor_names_to_update)
137
+ new_texts = self.propose_new_texts(curr_prog, reflective_dataset, predictor_names_to_update)
138
+ for pname, text in new_texts.items():
139
+ self.logger.log(f"Iteration {i}: Proposed new text for {pname}: {text}")
140
+ self.experiment_tracker.log_metrics(
141
+ {f"new_instruction_{pname}": text for pname, text in new_texts.items()}, step=i
142
+ )
143
+ except Exception as e:
144
+ self.logger.log(f"Iteration {i}: Exception during reflection/proposal: {e}")
145
+ import traceback
146
+
147
+ self.logger.log(traceback.format_exc())
148
+ return None
149
+
150
+ # 4) Create candidate, evaluate on same minibatch (no need to capture traces)
151
+ new_candidate = curr_prog.copy()
152
+ for pname, text in new_texts.items():
153
+ assert pname in new_candidate, f"{pname} missing in candidate"
154
+ new_candidate[pname] = text
155
+
156
+ def evaluator(b, c):
157
+ r = self.adapter.evaluate(b, c, capture_traces=False)
158
+ return r.outputs, r.scores, list(r.objective_scores) if r.objective_scores else None
159
+
160
+ new_scores, actual_evals_count = state.cached_evaluate(
161
+ new_candidate, subsample_ids, self.trainset.fetch, evaluator
162
+ )
163
+ state.total_num_evals += actual_evals_count
164
+ state.full_program_trace[-1]["new_subsample_scores"] = new_scores
165
+
166
+ new_sum = sum(new_scores)
167
+ self.experiment_tracker.log_metrics({"new_subsample_score": new_sum}, step=i)
168
+
169
+ return CandidateProposal(
170
+ candidate=new_candidate,
171
+ parent_program_ids=[curr_prog_id],
172
+ subsample_indices=subsample_ids,
173
+ subsample_scores_before=eval_curr.scores,
174
+ subsample_scores_after=new_scores,
175
+ tag="reflective_mutation",
176
+ )
File without changes
File without changes
@@ -0,0 +1,77 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ import random
5
+ from collections import Counter
6
+ from typing import Protocol
7
+
8
+ from mantisdk.algorithm.gepa.lib.core.adapter import DataInst
9
+ from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataLoader
10
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState
11
+
12
+
13
+ class BatchSampler(Protocol[DataId, DataInst]):
14
+ def next_minibatch_ids(self, loader: DataLoader[DataId, DataInst], state: GEPAState) -> list[DataId]: ...
15
+
16
+
17
+ class EpochShuffledBatchSampler(BatchSampler[DataId, DataInst]):
18
+ """
19
+ Mirrors the original batching logic:
20
+ - Shuffle ids each epoch
21
+ - Pad to minibatch size with least frequent ids
22
+ - Deterministic via state.rng1
23
+ """
24
+
25
+ def __init__(self, minibatch_size: int, rng: random.Random | None = None):
26
+ self.minibatch_size = minibatch_size
27
+ self.shuffled_ids: list[DataId] = []
28
+ self.epoch = -1
29
+ self.id_freqs = Counter()
30
+ self.last_trainset_size = 0
31
+ if rng is None:
32
+ self.rng = random.Random(0)
33
+ else:
34
+ self.rng = rng
35
+
36
+ def _update_shuffled(self, loader: DataLoader[DataId, DataInst]):
37
+ all_ids = list(loader.all_ids())
38
+ trainset_size = len(loader)
39
+ self.last_trainset_size = trainset_size
40
+
41
+ if trainset_size == 0:
42
+ self.shuffled_ids = []
43
+ self.id_freqs = Counter()
44
+ return
45
+
46
+ self.shuffled_ids = list(all_ids)
47
+ self.rng.shuffle(self.shuffled_ids)
48
+ self.id_freqs = Counter(self.shuffled_ids)
49
+
50
+ mod = trainset_size % self.minibatch_size
51
+ num_to_pad = (self.minibatch_size - mod) if mod != 0 else 0
52
+ if num_to_pad > 0:
53
+ for _ in range(num_to_pad):
54
+ selected_id = self.id_freqs.most_common()[::-1][0][0]
55
+ self.shuffled_ids.append(selected_id)
56
+ self.id_freqs[selected_id] += 1
57
+
58
+ def next_minibatch_ids(self, loader: DataLoader[DataId, DataInst], state: GEPAState) -> list[DataId]:
59
+ trainset_size = len(loader)
60
+ if trainset_size == 0:
61
+ raise ValueError("Cannot sample a minibatch from an empty loader.")
62
+
63
+ base_idx = state.i * self.minibatch_size
64
+ curr_epoch = 0 if self.epoch == -1 else base_idx // max(len(self.shuffled_ids), 1)
65
+
66
+ needs_refresh = not self.shuffled_ids or trainset_size != self.last_trainset_size or curr_epoch > self.epoch
67
+ if needs_refresh:
68
+ self.epoch = curr_epoch
69
+ self._update_shuffled(loader)
70
+
71
+ assert len(self.shuffled_ids) >= self.minibatch_size
72
+ assert len(self.shuffled_ids) % self.minibatch_size == 0
73
+
74
+ base_idx = base_idx % len(self.shuffled_ids)
75
+ end_idx = base_idx + self.minibatch_size
76
+ assert end_idx <= len(self.shuffled_ids)
77
+ return self.shuffled_ids[base_idx:end_idx]
@@ -0,0 +1,50 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ import random
5
+
6
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState
7
+ from mantisdk.algorithm.gepa.lib.gepa_utils import idxmax, select_program_candidate_from_pareto_front
8
+ from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.base import CandidateSelector
9
+
10
+
11
+ class ParetoCandidateSelector(CandidateSelector):
12
+ def __init__(self, rng: random.Random | None):
13
+ if rng is None:
14
+ self.rng = random.Random(0)
15
+ else:
16
+ self.rng = rng
17
+
18
+ def select_candidate_idx(self, state: GEPAState) -> int:
19
+ assert len(state.program_full_scores_val_set) == len(state.program_candidates)
20
+ return select_program_candidate_from_pareto_front(
21
+ state.get_pareto_front_mapping(),
22
+ state.per_program_tracked_scores,
23
+ self.rng,
24
+ )
25
+
26
+
27
+ class CurrentBestCandidateSelector(CandidateSelector):
28
+ def __init__(self):
29
+ pass
30
+
31
+ def select_candidate_idx(self, state: GEPAState) -> int:
32
+ assert len(state.program_full_scores_val_set) == len(state.program_candidates)
33
+ return idxmax(state.program_full_scores_val_set)
34
+
35
+
36
+ class EpsilonGreedyCandidateSelector(CandidateSelector):
37
+ def __init__(self, epsilon: float, rng: random.Random | None):
38
+ assert 0.0 <= epsilon <= 1.0
39
+ self.epsilon = epsilon
40
+ if rng is None:
41
+ self.rng = random.Random(0)
42
+ else:
43
+ self.rng = rng
44
+
45
+ def select_candidate_idx(self, state: GEPAState) -> int:
46
+ assert len(state.program_full_scores_val_set) == len(state.program_candidates)
47
+ if self.rng.random() < self.epsilon:
48
+ return self.rng.randint(0, len(state.program_candidates) - 1)
49
+ else:
50
+ return idxmax(state.program_full_scores_val_set)
@@ -0,0 +1,36 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+
5
+ from mantisdk.algorithm.gepa.lib.core.adapter import Trajectory
6
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState
7
+ from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.base import ReflectionComponentSelector
8
+
9
+
10
+ class RoundRobinReflectionComponentSelector(ReflectionComponentSelector):
11
+ def __call__(
12
+ self,
13
+ state: GEPAState,
14
+ trajectories: list[Trajectory],
15
+ subsample_scores: list[float],
16
+ candidate_idx: int,
17
+ candidate: dict[str, str],
18
+ ) -> list[str]:
19
+ pid = state.named_predictor_id_to_update_next_for_program_candidate[candidate_idx]
20
+ state.named_predictor_id_to_update_next_for_program_candidate[candidate_idx] = (pid + 1) % len(
21
+ state.list_of_named_predictors
22
+ )
23
+ name = state.list_of_named_predictors[pid]
24
+ return [name]
25
+
26
+
27
+ class AllReflectionComponentSelector(ReflectionComponentSelector):
28
+ def __call__(
29
+ self,
30
+ state: GEPAState,
31
+ trajectories: list[Trajectory],
32
+ subsample_scores: list[float],
33
+ candidate_idx: int,
34
+ candidate: dict[str, str],
35
+ ) -> list[str]:
36
+ return list(candidate.keys())
@@ -0,0 +1,64 @@
1
+ """Validation evaluation policy protocols and helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import abstractmethod
6
+ from typing import Protocol, runtime_checkable
7
+
8
+ from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataInst, DataLoader
9
+ from mantisdk.algorithm.gepa.lib.core.state import GEPAState, ProgramIdx
10
+
11
+
12
+ @runtime_checkable
13
+ class EvaluationPolicy(Protocol[DataId, DataInst]): # type: ignore
14
+ """Strategy for choosing validation ids to evaluate and identifying best programs for validation instances."""
15
+
16
+ @abstractmethod
17
+ def get_eval_batch(
18
+ self, loader: DataLoader[DataId, DataInst], state: GEPAState, target_program_idx: ProgramIdx | None = None
19
+ ) -> list[DataId]:
20
+ """Select examples for evaluation for a program"""
21
+ ...
22
+
23
+ @abstractmethod
24
+ def get_best_program(self, state: GEPAState) -> ProgramIdx:
25
+ """Return "best" program given all validation results so far across candidates"""
26
+ ...
27
+
28
+ @abstractmethod
29
+ def get_valset_score(self, program_idx: ProgramIdx, state: GEPAState) -> float:
30
+ """Return the score of the program on the valset"""
31
+ ...
32
+
33
+
34
+ class FullEvaluationPolicy(EvaluationPolicy[DataId, DataInst]):
35
+ """Policy that evaluates all validation instances every time."""
36
+
37
+ def get_eval_batch(
38
+ self, loader: DataLoader[DataId, DataInst], state: GEPAState, target_program_idx: ProgramIdx | None = None
39
+ ) -> list[DataId]:
40
+ """Always return the full ordered list of validation ids."""
41
+ return list(loader.all_ids())
42
+
43
+ def get_best_program(self, state: GEPAState) -> ProgramIdx:
44
+ """Pick the program whose evaluated validation scores achieve the highest average."""
45
+ best_idx, best_score, best_coverage = -1, float("-inf"), -1
46
+ for program_idx, scores in enumerate(state.prog_candidate_val_subscores):
47
+ coverage = len(scores)
48
+ avg = sum(scores.values()) / coverage if coverage else float("-inf")
49
+ if avg > best_score or (avg == best_score and coverage > best_coverage):
50
+ best_score = avg
51
+ best_idx = program_idx
52
+ best_coverage = coverage
53
+ return best_idx
54
+
55
+ def get_valset_score(self, program_idx: ProgramIdx, state: GEPAState) -> float:
56
+ """Return the score of the program on the valset"""
57
+ return state.get_program_average_val_subset(program_idx)[0]
58
+
59
+
60
+ __all__ = [
61
+ "DataLoader",
62
+ "EvaluationPolicy",
63
+ "FullEvaluationPolicy",
64
+ ]
@@ -0,0 +1,127 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ import re
5
+ from collections.abc import Mapping, Sequence
6
+ from typing import Any, ClassVar
7
+
8
+ from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.base import Signature
9
+
10
+
11
+ class InstructionProposalSignature(Signature):
12
+ default_prompt_template = """I provided an assistant with the following instructions to perform a task for me:
13
+ ```
14
+ <curr_instructions>
15
+ ```
16
+
17
+ The following are examples of different task inputs provided to the assistant along with the assistant's response for each of them, and some feedback on how the assistant's response could be better:
18
+ ```
19
+ <inputs_outputs_feedback>
20
+ ```
21
+
22
+ Your task is to write a new instruction for the assistant.
23
+
24
+ Read the inputs carefully and identify the input format and infer detailed task description about the task I wish to solve with the assistant.
25
+
26
+ Read all the assistant responses and the corresponding feedback. Identify all niche and domain specific factual information about the task and include it in the instruction, as a lot of it may not be available to the assistant in the future. The assistant may have utilized a generalizable strategy to solve the task, if so, include that in the instruction as well.
27
+
28
+ Provide the new instructions within ``` blocks."""
29
+
30
+ input_keys: ClassVar[list[str]] = ["current_instruction_doc", "dataset_with_feedback", "prompt_template"]
31
+ output_keys: ClassVar[list[str]] = ["new_instruction"]
32
+
33
+ @classmethod
34
+ def validate_prompt_template(cls, prompt_template: str | None) -> None:
35
+ if prompt_template is None:
36
+ return
37
+ missing_placeholders = [
38
+ placeholder
39
+ for placeholder in ("<curr_instructions>", "<inputs_outputs_feedback>")
40
+ if placeholder not in prompt_template
41
+ ]
42
+ if missing_placeholders:
43
+ raise ValueError(
44
+ f"Missing placeholder(s) in prompt template: {', '.join(missing_placeholders)}"
45
+ )
46
+
47
+ @classmethod
48
+ def prompt_renderer(cls, input_dict: Mapping[str, Any]) -> str:
49
+ current_instruction = input_dict.get("current_instruction_doc")
50
+ if not isinstance(current_instruction, str):
51
+ raise TypeError("current_instruction_doc must be a string")
52
+
53
+ dataset = input_dict.get("dataset_with_feedback")
54
+ if not isinstance(dataset, Sequence) or isinstance(dataset, (str, bytes)):
55
+ raise TypeError("dataset_with_feedback must be a sequence of records")
56
+ def format_samples(samples):
57
+ def render_value(value, level=3):
58
+ # level controls markdown header depth (###, ####, etc.)
59
+ if isinstance(value, dict):
60
+ s = ""
61
+ for k, v in value.items():
62
+ s += f"{'#' * level} {k}\n"
63
+ s += render_value(v, min(level + 1, 6))
64
+ if not value:
65
+ s += "\n"
66
+ return s
67
+ elif isinstance(value, list | tuple):
68
+ s = ""
69
+ for i, item in enumerate(value):
70
+ s += f"{'#' * level} Item {i + 1}\n"
71
+ s += render_value(item, min(level + 1, 6))
72
+ if not value:
73
+ s += "\n"
74
+ return s
75
+ else:
76
+ return f"{str(value).strip()}\n\n"
77
+
78
+ def convert_sample_to_markdown(sample, examplenum):
79
+ s = f"# Example {examplenum}\n"
80
+ for key, val in sample.items():
81
+ s += f"## {key}\n"
82
+ s += render_value(val, level=3)
83
+ return s
84
+
85
+ return "\n\n".join(convert_sample_to_markdown(sample, i + 1) for i, sample in enumerate(samples))
86
+
87
+ prompt_template = input_dict.get("prompt_template")
88
+ if prompt_template is None:
89
+ prompt_template = cls.default_prompt_template
90
+
91
+ cls.validate_prompt_template(prompt_template)
92
+
93
+ prompt = prompt_template.replace("<curr_instructions>", current_instruction)
94
+ prompt = prompt.replace("<inputs_outputs_feedback>", format_samples(dataset))
95
+
96
+ return prompt
97
+
98
+ @classmethod
99
+ def output_extractor(cls, lm_out: str) -> dict[str, str]:
100
+ def extract_instruction_text() -> str:
101
+ # Find the first and last backtick positions (if any)
102
+ start = lm_out.find("```") + 3
103
+ end = lm_out.rfind("```")
104
+
105
+ # Handle if the first and last backticks are the same or overlap
106
+ if start >= end:
107
+ # Handle incomplete blocks
108
+ stripped = lm_out.strip()
109
+ if stripped.startswith("```"):
110
+ # Remove opening ``` and optional language specifier
111
+ match = re.match(r"^```\S*\n?", lm_out)
112
+ if match:
113
+ return lm_out[match.end() :].strip()
114
+ elif stripped.endswith("```"):
115
+ # Remove closing ```
116
+ return stripped[:-3].strip()
117
+ return stripped
118
+
119
+ # Skip optional language specifier
120
+ content = lm_out[start:end]
121
+ match = re.match(r"^\S*\n", content)
122
+ if match:
123
+ content = content[match.end() :]
124
+
125
+ return content.strip()
126
+
127
+ return {"new_instruction": extract_instruction_text()}
@@ -0,0 +1,10 @@
1
+ from .stop_condition import (
2
+ CompositeStopper,
3
+ FileStopper,
4
+ MaxMetricCallsStopper,
5
+ NoImprovementStopper,
6
+ ScoreThresholdStopper,
7
+ SignalStopper,
8
+ StopperProtocol,
9
+ TimeoutStopCondition,
10
+ )