mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
from collections.abc import Mapping, Sequence
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import DataInst, GEPAAdapter, RolloutOutput, Trajectory
|
|
8
|
+
from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataLoader, ensure_loader
|
|
9
|
+
from mantisdk.algorithm.gepa.lib.core.state import GEPAState
|
|
10
|
+
from mantisdk.algorithm.gepa.lib.proposer.base import CandidateProposal, ProposeNewCandidate
|
|
11
|
+
from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.base import (
|
|
12
|
+
CandidateSelector,
|
|
13
|
+
LanguageModel,
|
|
14
|
+
ReflectionComponentSelector,
|
|
15
|
+
)
|
|
16
|
+
from mantisdk.algorithm.gepa.lib.strategies.batch_sampler import BatchSampler
|
|
17
|
+
from mantisdk.algorithm.gepa.lib.strategies.instruction_proposal import InstructionProposalSignature
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ReflectiveMutationProposer(ProposeNewCandidate[DataId]):
|
|
21
|
+
"""
|
|
22
|
+
Implements current reflective mutation flow:
|
|
23
|
+
- Select candidate via selector
|
|
24
|
+
- Select minibatch via sampler
|
|
25
|
+
- capture_traces_and_eval -> trajectories, subsample_scores
|
|
26
|
+
- skip if all scores==perfect and skip_perfect_score
|
|
27
|
+
- reflection + mutate -> new candidate
|
|
28
|
+
- evaluate new candidate on same minibatch -> new_subsample_scores
|
|
29
|
+
- Return proposal if improved; else None
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
logger: Any,
|
|
35
|
+
trainset: list[DataInst] | DataLoader[DataId, DataInst],
|
|
36
|
+
adapter: GEPAAdapter[DataInst, Trajectory, RolloutOutput],
|
|
37
|
+
candidate_selector: CandidateSelector,
|
|
38
|
+
module_selector: ReflectionComponentSelector,
|
|
39
|
+
batch_sampler: BatchSampler[DataId, DataInst],
|
|
40
|
+
perfect_score: float,
|
|
41
|
+
skip_perfect_score: bool,
|
|
42
|
+
experiment_tracker: Any,
|
|
43
|
+
reflection_lm: LanguageModel | None = None,
|
|
44
|
+
reflection_prompt_template: str | None = None,
|
|
45
|
+
):
|
|
46
|
+
self.logger = logger
|
|
47
|
+
self.trainset = ensure_loader(trainset)
|
|
48
|
+
self.adapter = adapter
|
|
49
|
+
self.candidate_selector = candidate_selector
|
|
50
|
+
self.module_selector = module_selector
|
|
51
|
+
self.batch_sampler = batch_sampler
|
|
52
|
+
self.perfect_score = perfect_score
|
|
53
|
+
self.skip_perfect_score = skip_perfect_score
|
|
54
|
+
self.experiment_tracker = experiment_tracker
|
|
55
|
+
self.reflection_lm = reflection_lm
|
|
56
|
+
|
|
57
|
+
InstructionProposalSignature.validate_prompt_template(reflection_prompt_template)
|
|
58
|
+
self.reflection_prompt_template = reflection_prompt_template
|
|
59
|
+
|
|
60
|
+
def propose_new_texts(
|
|
61
|
+
self,
|
|
62
|
+
candidate: dict[str, str],
|
|
63
|
+
reflective_dataset: Mapping[str, Sequence[Mapping[str, Any]]],
|
|
64
|
+
components_to_update: list[str],
|
|
65
|
+
) -> dict[str, str]:
|
|
66
|
+
if self.adapter.propose_new_texts is not None:
|
|
67
|
+
return self.adapter.propose_new_texts(candidate, reflective_dataset, components_to_update)
|
|
68
|
+
|
|
69
|
+
if self.reflection_lm is None:
|
|
70
|
+
raise ValueError("reflection_lm must be provided when adapter.propose_new_texts is None.")
|
|
71
|
+
new_texts: dict[str, str] = {}
|
|
72
|
+
for name in components_to_update:
|
|
73
|
+
# Gracefully handle cases where a selected component has no data in reflective_dataset
|
|
74
|
+
if name not in reflective_dataset or not reflective_dataset.get(name):
|
|
75
|
+
self.logger.log(f"Component '{name}' is not in reflective dataset. Skipping.")
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
base_instruction = candidate[name]
|
|
79
|
+
dataset_with_feedback = reflective_dataset[name]
|
|
80
|
+
new_texts[name] = InstructionProposalSignature.run(
|
|
81
|
+
lm=self.reflection_lm,
|
|
82
|
+
input_dict={
|
|
83
|
+
"current_instruction_doc": base_instruction,
|
|
84
|
+
"dataset_with_feedback": dataset_with_feedback,
|
|
85
|
+
"prompt_template": self.reflection_prompt_template,
|
|
86
|
+
},
|
|
87
|
+
)["new_instruction"]
|
|
88
|
+
return new_texts
|
|
89
|
+
|
|
90
|
+
def propose(self, state: GEPAState) -> CandidateProposal | None:
|
|
91
|
+
i = state.i + 1
|
|
92
|
+
|
|
93
|
+
curr_prog_id = self.candidate_selector.select_candidate_idx(state)
|
|
94
|
+
curr_prog = state.program_candidates[curr_prog_id]
|
|
95
|
+
state.full_program_trace[-1]["selected_program_candidate"] = curr_prog_id
|
|
96
|
+
self.logger.log(
|
|
97
|
+
f"Iteration {i}: Selected program {curr_prog_id} score: {state.program_full_scores_val_set[curr_prog_id]}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self.experiment_tracker.log_metrics({"iteration": i, "selected_program_candidate": curr_prog_id}, step=i)
|
|
101
|
+
|
|
102
|
+
subsample_ids = self.batch_sampler.next_minibatch_ids(self.trainset, state)
|
|
103
|
+
state.full_program_trace[-1]["subsample_ids"] = subsample_ids
|
|
104
|
+
minibatch = self.trainset.fetch(subsample_ids)
|
|
105
|
+
|
|
106
|
+
# 1) Evaluate current program with traces
|
|
107
|
+
# Note: We don't use cache for capture_traces=True evaluations since we need fresh traces for reflection
|
|
108
|
+
eval_curr = self.adapter.evaluate(minibatch, curr_prog, capture_traces=True)
|
|
109
|
+
state.total_num_evals += len(subsample_ids)
|
|
110
|
+
state.full_program_trace[-1]["subsample_scores"] = eval_curr.scores
|
|
111
|
+
|
|
112
|
+
# Update cache with current program evaluation results (for future reuse when capture_traces=False)
|
|
113
|
+
if state.evaluation_cache is not None:
|
|
114
|
+
objective_scores_list = list(eval_curr.objective_scores) if eval_curr.objective_scores else None
|
|
115
|
+
state.evaluation_cache.put_batch(
|
|
116
|
+
curr_prog, subsample_ids, eval_curr.outputs, eval_curr.scores, objective_scores_list
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if not eval_curr.trajectories or len(eval_curr.trajectories) == 0:
|
|
120
|
+
self.logger.log(f"Iteration {i}: No trajectories captured. Skipping.")
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
if self.skip_perfect_score and all(s >= self.perfect_score for s in eval_curr.scores):
|
|
124
|
+
self.logger.log(f"Iteration {i}: All subsample scores perfect. Skipping.")
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
self.experiment_tracker.log_metrics({"subsample_score": sum(eval_curr.scores)}, step=i)
|
|
128
|
+
|
|
129
|
+
# 2) Decide which predictors to update
|
|
130
|
+
predictor_names_to_update = self.module_selector(
|
|
131
|
+
state, eval_curr.trajectories, eval_curr.scores, curr_prog_id, curr_prog
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# 3) Build reflective dataset and propose texts
|
|
135
|
+
try:
|
|
136
|
+
reflective_dataset = self.adapter.make_reflective_dataset(curr_prog, eval_curr, predictor_names_to_update)
|
|
137
|
+
new_texts = self.propose_new_texts(curr_prog, reflective_dataset, predictor_names_to_update)
|
|
138
|
+
for pname, text in new_texts.items():
|
|
139
|
+
self.logger.log(f"Iteration {i}: Proposed new text for {pname}: {text}")
|
|
140
|
+
self.experiment_tracker.log_metrics(
|
|
141
|
+
{f"new_instruction_{pname}": text for pname, text in new_texts.items()}, step=i
|
|
142
|
+
)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
self.logger.log(f"Iteration {i}: Exception during reflection/proposal: {e}")
|
|
145
|
+
import traceback
|
|
146
|
+
|
|
147
|
+
self.logger.log(traceback.format_exc())
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
# 4) Create candidate, evaluate on same minibatch (no need to capture traces)
|
|
151
|
+
new_candidate = curr_prog.copy()
|
|
152
|
+
for pname, text in new_texts.items():
|
|
153
|
+
assert pname in new_candidate, f"{pname} missing in candidate"
|
|
154
|
+
new_candidate[pname] = text
|
|
155
|
+
|
|
156
|
+
def evaluator(b, c):
|
|
157
|
+
r = self.adapter.evaluate(b, c, capture_traces=False)
|
|
158
|
+
return r.outputs, r.scores, list(r.objective_scores) if r.objective_scores else None
|
|
159
|
+
|
|
160
|
+
new_scores, actual_evals_count = state.cached_evaluate(
|
|
161
|
+
new_candidate, subsample_ids, self.trainset.fetch, evaluator
|
|
162
|
+
)
|
|
163
|
+
state.total_num_evals += actual_evals_count
|
|
164
|
+
state.full_program_trace[-1]["new_subsample_scores"] = new_scores
|
|
165
|
+
|
|
166
|
+
new_sum = sum(new_scores)
|
|
167
|
+
self.experiment_tracker.log_metrics({"new_subsample_score": new_sum}, step=i)
|
|
168
|
+
|
|
169
|
+
return CandidateProposal(
|
|
170
|
+
candidate=new_candidate,
|
|
171
|
+
parent_program_ids=[curr_prog_id],
|
|
172
|
+
subsample_indices=subsample_ids,
|
|
173
|
+
subsample_scores_before=eval_curr.scores,
|
|
174
|
+
subsample_scores_after=new_scores,
|
|
175
|
+
tag="reflective_mutation",
|
|
176
|
+
)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
import random
|
|
5
|
+
from collections import Counter
|
|
6
|
+
from typing import Protocol
|
|
7
|
+
|
|
8
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import DataInst
|
|
9
|
+
from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataLoader
|
|
10
|
+
from mantisdk.algorithm.gepa.lib.core.state import GEPAState
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BatchSampler(Protocol[DataId, DataInst]):
|
|
14
|
+
def next_minibatch_ids(self, loader: DataLoader[DataId, DataInst], state: GEPAState) -> list[DataId]: ...
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EpochShuffledBatchSampler(BatchSampler[DataId, DataInst]):
|
|
18
|
+
"""
|
|
19
|
+
Mirrors the original batching logic:
|
|
20
|
+
- Shuffle ids each epoch
|
|
21
|
+
- Pad to minibatch size with least frequent ids
|
|
22
|
+
- Deterministic via state.rng1
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, minibatch_size: int, rng: random.Random | None = None):
|
|
26
|
+
self.minibatch_size = minibatch_size
|
|
27
|
+
self.shuffled_ids: list[DataId] = []
|
|
28
|
+
self.epoch = -1
|
|
29
|
+
self.id_freqs = Counter()
|
|
30
|
+
self.last_trainset_size = 0
|
|
31
|
+
if rng is None:
|
|
32
|
+
self.rng = random.Random(0)
|
|
33
|
+
else:
|
|
34
|
+
self.rng = rng
|
|
35
|
+
|
|
36
|
+
def _update_shuffled(self, loader: DataLoader[DataId, DataInst]):
|
|
37
|
+
all_ids = list(loader.all_ids())
|
|
38
|
+
trainset_size = len(loader)
|
|
39
|
+
self.last_trainset_size = trainset_size
|
|
40
|
+
|
|
41
|
+
if trainset_size == 0:
|
|
42
|
+
self.shuffled_ids = []
|
|
43
|
+
self.id_freqs = Counter()
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
self.shuffled_ids = list(all_ids)
|
|
47
|
+
self.rng.shuffle(self.shuffled_ids)
|
|
48
|
+
self.id_freqs = Counter(self.shuffled_ids)
|
|
49
|
+
|
|
50
|
+
mod = trainset_size % self.minibatch_size
|
|
51
|
+
num_to_pad = (self.minibatch_size - mod) if mod != 0 else 0
|
|
52
|
+
if num_to_pad > 0:
|
|
53
|
+
for _ in range(num_to_pad):
|
|
54
|
+
selected_id = self.id_freqs.most_common()[::-1][0][0]
|
|
55
|
+
self.shuffled_ids.append(selected_id)
|
|
56
|
+
self.id_freqs[selected_id] += 1
|
|
57
|
+
|
|
58
|
+
def next_minibatch_ids(self, loader: DataLoader[DataId, DataInst], state: GEPAState) -> list[DataId]:
|
|
59
|
+
trainset_size = len(loader)
|
|
60
|
+
if trainset_size == 0:
|
|
61
|
+
raise ValueError("Cannot sample a minibatch from an empty loader.")
|
|
62
|
+
|
|
63
|
+
base_idx = state.i * self.minibatch_size
|
|
64
|
+
curr_epoch = 0 if self.epoch == -1 else base_idx // max(len(self.shuffled_ids), 1)
|
|
65
|
+
|
|
66
|
+
needs_refresh = not self.shuffled_ids or trainset_size != self.last_trainset_size or curr_epoch > self.epoch
|
|
67
|
+
if needs_refresh:
|
|
68
|
+
self.epoch = curr_epoch
|
|
69
|
+
self._update_shuffled(loader)
|
|
70
|
+
|
|
71
|
+
assert len(self.shuffled_ids) >= self.minibatch_size
|
|
72
|
+
assert len(self.shuffled_ids) % self.minibatch_size == 0
|
|
73
|
+
|
|
74
|
+
base_idx = base_idx % len(self.shuffled_ids)
|
|
75
|
+
end_idx = base_idx + self.minibatch_size
|
|
76
|
+
assert end_idx <= len(self.shuffled_ids)
|
|
77
|
+
return self.shuffled_ids[base_idx:end_idx]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
import random
|
|
5
|
+
|
|
6
|
+
from mantisdk.algorithm.gepa.lib.core.state import GEPAState
|
|
7
|
+
from mantisdk.algorithm.gepa.lib.gepa_utils import idxmax, select_program_candidate_from_pareto_front
|
|
8
|
+
from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.base import CandidateSelector
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ParetoCandidateSelector(CandidateSelector):
|
|
12
|
+
def __init__(self, rng: random.Random | None):
|
|
13
|
+
if rng is None:
|
|
14
|
+
self.rng = random.Random(0)
|
|
15
|
+
else:
|
|
16
|
+
self.rng = rng
|
|
17
|
+
|
|
18
|
+
def select_candidate_idx(self, state: GEPAState) -> int:
|
|
19
|
+
assert len(state.program_full_scores_val_set) == len(state.program_candidates)
|
|
20
|
+
return select_program_candidate_from_pareto_front(
|
|
21
|
+
state.get_pareto_front_mapping(),
|
|
22
|
+
state.per_program_tracked_scores,
|
|
23
|
+
self.rng,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CurrentBestCandidateSelector(CandidateSelector):
|
|
28
|
+
def __init__(self):
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
def select_candidate_idx(self, state: GEPAState) -> int:
|
|
32
|
+
assert len(state.program_full_scores_val_set) == len(state.program_candidates)
|
|
33
|
+
return idxmax(state.program_full_scores_val_set)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class EpsilonGreedyCandidateSelector(CandidateSelector):
|
|
37
|
+
def __init__(self, epsilon: float, rng: random.Random | None):
|
|
38
|
+
assert 0.0 <= epsilon <= 1.0
|
|
39
|
+
self.epsilon = epsilon
|
|
40
|
+
if rng is None:
|
|
41
|
+
self.rng = random.Random(0)
|
|
42
|
+
else:
|
|
43
|
+
self.rng = rng
|
|
44
|
+
|
|
45
|
+
def select_candidate_idx(self, state: GEPAState) -> int:
|
|
46
|
+
assert len(state.program_full_scores_val_set) == len(state.program_candidates)
|
|
47
|
+
if self.rng.random() < self.epsilon:
|
|
48
|
+
return self.rng.randint(0, len(state.program_candidates) - 1)
|
|
49
|
+
else:
|
|
50
|
+
return idxmax(state.program_full_scores_val_set)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import Trajectory
|
|
6
|
+
from mantisdk.algorithm.gepa.lib.core.state import GEPAState
|
|
7
|
+
from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.base import ReflectionComponentSelector
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RoundRobinReflectionComponentSelector(ReflectionComponentSelector):
|
|
11
|
+
def __call__(
|
|
12
|
+
self,
|
|
13
|
+
state: GEPAState,
|
|
14
|
+
trajectories: list[Trajectory],
|
|
15
|
+
subsample_scores: list[float],
|
|
16
|
+
candidate_idx: int,
|
|
17
|
+
candidate: dict[str, str],
|
|
18
|
+
) -> list[str]:
|
|
19
|
+
pid = state.named_predictor_id_to_update_next_for_program_candidate[candidate_idx]
|
|
20
|
+
state.named_predictor_id_to_update_next_for_program_candidate[candidate_idx] = (pid + 1) % len(
|
|
21
|
+
state.list_of_named_predictors
|
|
22
|
+
)
|
|
23
|
+
name = state.list_of_named_predictors[pid]
|
|
24
|
+
return [name]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AllReflectionComponentSelector(ReflectionComponentSelector):
|
|
28
|
+
def __call__(
|
|
29
|
+
self,
|
|
30
|
+
state: GEPAState,
|
|
31
|
+
trajectories: list[Trajectory],
|
|
32
|
+
subsample_scores: list[float],
|
|
33
|
+
candidate_idx: int,
|
|
34
|
+
candidate: dict[str, str],
|
|
35
|
+
) -> list[str]:
|
|
36
|
+
return list(candidate.keys())
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Validation evaluation policy protocols and helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import abstractmethod
|
|
6
|
+
from typing import Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataInst, DataLoader
|
|
9
|
+
from mantisdk.algorithm.gepa.lib.core.state import GEPAState, ProgramIdx
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@runtime_checkable
|
|
13
|
+
class EvaluationPolicy(Protocol[DataId, DataInst]): # type: ignore
|
|
14
|
+
"""Strategy for choosing validation ids to evaluate and identifying best programs for validation instances."""
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def get_eval_batch(
|
|
18
|
+
self, loader: DataLoader[DataId, DataInst], state: GEPAState, target_program_idx: ProgramIdx | None = None
|
|
19
|
+
) -> list[DataId]:
|
|
20
|
+
"""Select examples for evaluation for a program"""
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def get_best_program(self, state: GEPAState) -> ProgramIdx:
|
|
25
|
+
"""Return "best" program given all validation results so far across candidates"""
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def get_valset_score(self, program_idx: ProgramIdx, state: GEPAState) -> float:
|
|
30
|
+
"""Return the score of the program on the valset"""
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class FullEvaluationPolicy(EvaluationPolicy[DataId, DataInst]):
|
|
35
|
+
"""Policy that evaluates all validation instances every time."""
|
|
36
|
+
|
|
37
|
+
def get_eval_batch(
|
|
38
|
+
self, loader: DataLoader[DataId, DataInst], state: GEPAState, target_program_idx: ProgramIdx | None = None
|
|
39
|
+
) -> list[DataId]:
|
|
40
|
+
"""Always return the full ordered list of validation ids."""
|
|
41
|
+
return list(loader.all_ids())
|
|
42
|
+
|
|
43
|
+
def get_best_program(self, state: GEPAState) -> ProgramIdx:
|
|
44
|
+
"""Pick the program whose evaluated validation scores achieve the highest average."""
|
|
45
|
+
best_idx, best_score, best_coverage = -1, float("-inf"), -1
|
|
46
|
+
for program_idx, scores in enumerate(state.prog_candidate_val_subscores):
|
|
47
|
+
coverage = len(scores)
|
|
48
|
+
avg = sum(scores.values()) / coverage if coverage else float("-inf")
|
|
49
|
+
if avg > best_score or (avg == best_score and coverage > best_coverage):
|
|
50
|
+
best_score = avg
|
|
51
|
+
best_idx = program_idx
|
|
52
|
+
best_coverage = coverage
|
|
53
|
+
return best_idx
|
|
54
|
+
|
|
55
|
+
def get_valset_score(self, program_idx: ProgramIdx, state: GEPAState) -> float:
|
|
56
|
+
"""Return the score of the program on the valset"""
|
|
57
|
+
return state.get_program_average_val_subset(program_idx)[0]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
__all__ = [
|
|
61
|
+
"DataLoader",
|
|
62
|
+
"EvaluationPolicy",
|
|
63
|
+
"FullEvaluationPolicy",
|
|
64
|
+
]
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
import re
|
|
5
|
+
from collections.abc import Mapping, Sequence
|
|
6
|
+
from typing import Any, ClassVar
|
|
7
|
+
|
|
8
|
+
from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.base import Signature
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class InstructionProposalSignature(Signature):
|
|
12
|
+
default_prompt_template = """I provided an assistant with the following instructions to perform a task for me:
|
|
13
|
+
```
|
|
14
|
+
<curr_instructions>
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
The following are examples of different task inputs provided to the assistant along with the assistant's response for each of them, and some feedback on how the assistant's response could be better:
|
|
18
|
+
```
|
|
19
|
+
<inputs_outputs_feedback>
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
Your task is to write a new instruction for the assistant.
|
|
23
|
+
|
|
24
|
+
Read the inputs carefully and identify the input format and infer detailed task description about the task I wish to solve with the assistant.
|
|
25
|
+
|
|
26
|
+
Read all the assistant responses and the corresponding feedback. Identify all niche and domain specific factual information about the task and include it in the instruction, as a lot of it may not be available to the assistant in the future. The assistant may have utilized a generalizable strategy to solve the task, if so, include that in the instruction as well.
|
|
27
|
+
|
|
28
|
+
Provide the new instructions within ``` blocks."""
|
|
29
|
+
|
|
30
|
+
input_keys: ClassVar[list[str]] = ["current_instruction_doc", "dataset_with_feedback", "prompt_template"]
|
|
31
|
+
output_keys: ClassVar[list[str]] = ["new_instruction"]
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def validate_prompt_template(cls, prompt_template: str | None) -> None:
|
|
35
|
+
if prompt_template is None:
|
|
36
|
+
return
|
|
37
|
+
missing_placeholders = [
|
|
38
|
+
placeholder
|
|
39
|
+
for placeholder in ("<curr_instructions>", "<inputs_outputs_feedback>")
|
|
40
|
+
if placeholder not in prompt_template
|
|
41
|
+
]
|
|
42
|
+
if missing_placeholders:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
f"Missing placeholder(s) in prompt template: {', '.join(missing_placeholders)}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def prompt_renderer(cls, input_dict: Mapping[str, Any]) -> str:
|
|
49
|
+
current_instruction = input_dict.get("current_instruction_doc")
|
|
50
|
+
if not isinstance(current_instruction, str):
|
|
51
|
+
raise TypeError("current_instruction_doc must be a string")
|
|
52
|
+
|
|
53
|
+
dataset = input_dict.get("dataset_with_feedback")
|
|
54
|
+
if not isinstance(dataset, Sequence) or isinstance(dataset, (str, bytes)):
|
|
55
|
+
raise TypeError("dataset_with_feedback must be a sequence of records")
|
|
56
|
+
def format_samples(samples):
|
|
57
|
+
def render_value(value, level=3):
|
|
58
|
+
# level controls markdown header depth (###, ####, etc.)
|
|
59
|
+
if isinstance(value, dict):
|
|
60
|
+
s = ""
|
|
61
|
+
for k, v in value.items():
|
|
62
|
+
s += f"{'#' * level} {k}\n"
|
|
63
|
+
s += render_value(v, min(level + 1, 6))
|
|
64
|
+
if not value:
|
|
65
|
+
s += "\n"
|
|
66
|
+
return s
|
|
67
|
+
elif isinstance(value, list | tuple):
|
|
68
|
+
s = ""
|
|
69
|
+
for i, item in enumerate(value):
|
|
70
|
+
s += f"{'#' * level} Item {i + 1}\n"
|
|
71
|
+
s += render_value(item, min(level + 1, 6))
|
|
72
|
+
if not value:
|
|
73
|
+
s += "\n"
|
|
74
|
+
return s
|
|
75
|
+
else:
|
|
76
|
+
return f"{str(value).strip()}\n\n"
|
|
77
|
+
|
|
78
|
+
def convert_sample_to_markdown(sample, examplenum):
|
|
79
|
+
s = f"# Example {examplenum}\n"
|
|
80
|
+
for key, val in sample.items():
|
|
81
|
+
s += f"## {key}\n"
|
|
82
|
+
s += render_value(val, level=3)
|
|
83
|
+
return s
|
|
84
|
+
|
|
85
|
+
return "\n\n".join(convert_sample_to_markdown(sample, i + 1) for i, sample in enumerate(samples))
|
|
86
|
+
|
|
87
|
+
prompt_template = input_dict.get("prompt_template")
|
|
88
|
+
if prompt_template is None:
|
|
89
|
+
prompt_template = cls.default_prompt_template
|
|
90
|
+
|
|
91
|
+
cls.validate_prompt_template(prompt_template)
|
|
92
|
+
|
|
93
|
+
prompt = prompt_template.replace("<curr_instructions>", current_instruction)
|
|
94
|
+
prompt = prompt.replace("<inputs_outputs_feedback>", format_samples(dataset))
|
|
95
|
+
|
|
96
|
+
return prompt
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def output_extractor(cls, lm_out: str) -> dict[str, str]:
|
|
100
|
+
def extract_instruction_text() -> str:
|
|
101
|
+
# Find the first and last backtick positions (if any)
|
|
102
|
+
start = lm_out.find("```") + 3
|
|
103
|
+
end = lm_out.rfind("```")
|
|
104
|
+
|
|
105
|
+
# Handle if the first and last backticks are the same or overlap
|
|
106
|
+
if start >= end:
|
|
107
|
+
# Handle incomplete blocks
|
|
108
|
+
stripped = lm_out.strip()
|
|
109
|
+
if stripped.startswith("```"):
|
|
110
|
+
# Remove opening ``` and optional language specifier
|
|
111
|
+
match = re.match(r"^```\S*\n?", lm_out)
|
|
112
|
+
if match:
|
|
113
|
+
return lm_out[match.end() :].strip()
|
|
114
|
+
elif stripped.endswith("```"):
|
|
115
|
+
# Remove closing ```
|
|
116
|
+
return stripped[:-3].strip()
|
|
117
|
+
return stripped
|
|
118
|
+
|
|
119
|
+
# Skip optional language specifier
|
|
120
|
+
content = lm_out[start:end]
|
|
121
|
+
match = re.match(r"^\S*\n", content)
|
|
122
|
+
if match:
|
|
123
|
+
content = content[match.end() :]
|
|
124
|
+
|
|
125
|
+
return content.strip()
|
|
126
|
+
|
|
127
|
+
return {"new_instruction": extract_instruction_text()}
|