mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
from collections.abc import Mapping, Sequence
|
|
5
|
+
from typing import Any, NamedTuple, Protocol, TypedDict, cast
|
|
6
|
+
|
|
7
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import EvaluationBatch, GEPAAdapter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# DataInst, Trajectory, RolloutOutput
|
|
11
|
+
class DefaultDataInst(TypedDict):
|
|
12
|
+
input: str
|
|
13
|
+
additional_context: dict[str, str]
|
|
14
|
+
answer: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EvaluationResult(NamedTuple):
|
|
18
|
+
score: float
|
|
19
|
+
feedback: str
|
|
20
|
+
objective_scores: dict[str, float] | None = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DefaultTrajectory(TypedDict):
|
|
24
|
+
data: DefaultDataInst
|
|
25
|
+
full_assistant_response: str
|
|
26
|
+
feedback: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DefaultRolloutOutput(TypedDict):
|
|
30
|
+
full_assistant_response: str
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
DefaultReflectiveRecord = TypedDict(
|
|
34
|
+
"DefaultReflectiveRecord",
|
|
35
|
+
{
|
|
36
|
+
"Inputs": str,
|
|
37
|
+
"Generated Outputs": str,
|
|
38
|
+
"Feedback": str,
|
|
39
|
+
},
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ChatMessage(TypedDict):
|
|
44
|
+
role: str
|
|
45
|
+
content: str
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ChatCompletionCallable(Protocol):
|
|
49
|
+
"""Protocol for chat completion callables (duck typing for custom model wrappers)."""
|
|
50
|
+
|
|
51
|
+
def __call__(self, messages: Sequence[ChatMessage]) -> str: ...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# Callable that evaluates a response and returns (score, feedback, optional objective_scores)
|
|
55
|
+
class Evaluator(Protocol):
|
|
56
|
+
def __call__(self, data: DefaultDataInst, response: str) -> EvaluationResult:
|
|
57
|
+
"""
|
|
58
|
+
Evaluates a response and returns a score, feedback, and optional objective scores.
|
|
59
|
+
"""
|
|
60
|
+
...
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ContainsAnswerEvaluator:
|
|
64
|
+
"""Default evaluator that checks if the expected answer is contained in the response."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, failure_score: float = 0.0):
|
|
67
|
+
self.failure_score = failure_score
|
|
68
|
+
|
|
69
|
+
def __call__(self, data: DefaultDataInst, response: str) -> EvaluationResult:
|
|
70
|
+
is_correct = data["answer"] in response
|
|
71
|
+
score = 1.0 if is_correct else self.failure_score
|
|
72
|
+
|
|
73
|
+
if is_correct:
|
|
74
|
+
feedback = f"The generated response is correct. The response include the correct answer '{data['answer']}'"
|
|
75
|
+
else:
|
|
76
|
+
additional_context_str = "\n".join(f"{k}: {v}" for k, v in data["additional_context"].items())
|
|
77
|
+
feedback = (
|
|
78
|
+
f"The generated response is incorrect. The correct answer is '{data['answer']}'. "
|
|
79
|
+
"Ensure that the correct answer is included in the response exactly as it is."
|
|
80
|
+
)
|
|
81
|
+
if additional_context_str:
|
|
82
|
+
feedback += f" Here is some additional context that might be helpful:\n{additional_context_str}"
|
|
83
|
+
|
|
84
|
+
return EvaluationResult(score=score, feedback=feedback, objective_scores=None)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class DefaultAdapter(GEPAAdapter[DefaultDataInst, DefaultTrajectory, DefaultRolloutOutput]):
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
model: str | ChatCompletionCallable,
|
|
91
|
+
evaluator: Evaluator | None = None,
|
|
92
|
+
max_litellm_workers: int = 10,
|
|
93
|
+
litellm_batch_completion_kwargs: dict[str, Any] | None = None,
|
|
94
|
+
):
|
|
95
|
+
if isinstance(model, str):
|
|
96
|
+
import litellm
|
|
97
|
+
|
|
98
|
+
self.litellm = litellm
|
|
99
|
+
self.model = model
|
|
100
|
+
self.evaluator = evaluator or ContainsAnswerEvaluator()
|
|
101
|
+
self.max_litellm_workers = max_litellm_workers
|
|
102
|
+
self.litellm_batch_completion_kwargs = litellm_batch_completion_kwargs or {}
|
|
103
|
+
|
|
104
|
+
def evaluate(
|
|
105
|
+
self,
|
|
106
|
+
batch: list[DefaultDataInst],
|
|
107
|
+
candidate: dict[str, str],
|
|
108
|
+
capture_traces: bool = False,
|
|
109
|
+
) -> EvaluationBatch[DefaultTrajectory, DefaultRolloutOutput]:
|
|
110
|
+
outputs: list[DefaultRolloutOutput] = []
|
|
111
|
+
scores: list[float] = []
|
|
112
|
+
objective_scores: list[dict[str, float] | None] = []
|
|
113
|
+
trajectories: list[DefaultTrajectory] | None = [] if capture_traces else None
|
|
114
|
+
|
|
115
|
+
system_content = next(iter(candidate.values()))
|
|
116
|
+
|
|
117
|
+
litellm_requests = []
|
|
118
|
+
|
|
119
|
+
for data in batch:
|
|
120
|
+
user_content = f"{data['input']}"
|
|
121
|
+
|
|
122
|
+
messages: list[ChatMessage] = [
|
|
123
|
+
{"role": "system", "content": system_content},
|
|
124
|
+
{"role": "user", "content": user_content},
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
litellm_requests.append(messages)
|
|
128
|
+
|
|
129
|
+
if isinstance(self.model, str):
|
|
130
|
+
responses = [
|
|
131
|
+
resp.choices[0].message.content.strip()
|
|
132
|
+
for resp in self.litellm.batch_completion(
|
|
133
|
+
model=self.model,
|
|
134
|
+
messages=litellm_requests,
|
|
135
|
+
max_workers=self.max_litellm_workers,
|
|
136
|
+
**self.litellm_batch_completion_kwargs,
|
|
137
|
+
)
|
|
138
|
+
]
|
|
139
|
+
else:
|
|
140
|
+
responses = [self.model(messages) for messages in litellm_requests]
|
|
141
|
+
|
|
142
|
+
for data, assistant_response in zip(batch, responses, strict=True):
|
|
143
|
+
eval_result = self.evaluator(data, assistant_response)
|
|
144
|
+
score = eval_result.score
|
|
145
|
+
feedback = eval_result.feedback
|
|
146
|
+
obj_scores = eval_result.objective_scores
|
|
147
|
+
|
|
148
|
+
output: DefaultRolloutOutput = {"full_assistant_response": assistant_response}
|
|
149
|
+
|
|
150
|
+
outputs.append(output)
|
|
151
|
+
scores.append(score)
|
|
152
|
+
objective_scores.append(obj_scores)
|
|
153
|
+
|
|
154
|
+
if trajectories is not None:
|
|
155
|
+
trajectories.append(
|
|
156
|
+
{
|
|
157
|
+
"data": data,
|
|
158
|
+
"full_assistant_response": assistant_response,
|
|
159
|
+
"feedback": feedback,
|
|
160
|
+
}
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
objective_scores_arg: list[dict[str, float]] | None = None
|
|
164
|
+
if objective_scores:
|
|
165
|
+
all_none = all(x is None for x in objective_scores)
|
|
166
|
+
all_not_none = all(x is not None for x in objective_scores)
|
|
167
|
+
if not (all_none or all_not_none):
|
|
168
|
+
raise ValueError("Objective scores must either be all None or all not None.")
|
|
169
|
+
if all_not_none:
|
|
170
|
+
objective_scores_arg = cast(list[dict[str, float]], objective_scores)
|
|
171
|
+
|
|
172
|
+
return EvaluationBatch(
|
|
173
|
+
outputs=outputs,
|
|
174
|
+
scores=scores,
|
|
175
|
+
trajectories=trajectories,
|
|
176
|
+
objective_scores=objective_scores_arg,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def make_reflective_dataset(
|
|
180
|
+
self,
|
|
181
|
+
candidate: dict[str, str],
|
|
182
|
+
eval_batch: EvaluationBatch[DefaultTrajectory, DefaultRolloutOutput],
|
|
183
|
+
components_to_update: list[str],
|
|
184
|
+
) -> Mapping[str, Sequence[Mapping[str, Any]]]:
|
|
185
|
+
ret_d: dict[str, list[DefaultReflectiveRecord]] = {}
|
|
186
|
+
|
|
187
|
+
assert len(components_to_update) == 1
|
|
188
|
+
comp = components_to_update[0]
|
|
189
|
+
|
|
190
|
+
trajectories = eval_batch.trajectories
|
|
191
|
+
assert trajectories is not None, "Trajectories are required to build a reflective dataset."
|
|
192
|
+
|
|
193
|
+
items: list[DefaultReflectiveRecord] = []
|
|
194
|
+
|
|
195
|
+
for traj in trajectories:
|
|
196
|
+
d: DefaultReflectiveRecord = {
|
|
197
|
+
"Inputs": traj["data"]["input"],
|
|
198
|
+
"Generated Outputs": traj["full_assistant_response"],
|
|
199
|
+
"Feedback": traj["feedback"],
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
items.append(d)
|
|
203
|
+
|
|
204
|
+
ret_d[comp] = items
|
|
205
|
+
|
|
206
|
+
if len(items) == 0:
|
|
207
|
+
raise Exception("No valid predictions found for any module.")
|
|
208
|
+
|
|
209
|
+
return ret_d
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# DSPy <> GEPA
|
|
2
|
+
|
|
3
|
+
[dspy_adapter.py](dspy_adapter.py) provides an example adapter to allow GEPA to optimize the signature instructions of any DSPy program. The most up-to-date version of this adapter is live in the DSPy repository at [gepa_utils.py](https://github.com/stanfordnlp/dspy/blob/main/dspy/teleprompt/gepa/gepa_utils.py).
|
|
4
|
+
|
|
5
|
+
> The best way to use this adapter is from within DSPy itself.
|
|
6
|
+
|
|
7
|
+
Extensive tutorials on using GEPA with DSPy are available at [dspy.GEPA tutorials](https://dspy.ai/tutorials/gepa_ai_program/).
|
|
File without changes
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file provides an example adapter allowing GEPA to optimize text components of DSPy programs (instructions and prompts).
|
|
3
|
+
The most up-to-date version of this file is in the DSPy repository: https://github.com/stanfordnlp/dspy/blob/main/dspy/teleprompt/gepa/gepa_utils.py
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import random
|
|
8
|
+
from typing import Any, Callable, Protocol
|
|
9
|
+
|
|
10
|
+
from dspy.adapters.chat_adapter import ChatAdapter
|
|
11
|
+
from dspy.adapters.types import History
|
|
12
|
+
from dspy.evaluate import Evaluate
|
|
13
|
+
from dspy.primitives import Example, Prediction
|
|
14
|
+
from dspy.teleprompt.bootstrap_trace import TraceData
|
|
15
|
+
|
|
16
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import EvaluationBatch, GEPAAdapter
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LoggerAdapter:
|
|
20
|
+
def __init__(self, logger: logging.Logger):
|
|
21
|
+
self.logger = logger
|
|
22
|
+
|
|
23
|
+
def log(self, x: str):
|
|
24
|
+
self.logger.info(x)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
DSPyTrace = list[tuple[Any, dict[str, Any], Prediction]]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ScoreWithFeedback(Prediction):
|
|
31
|
+
score: float
|
|
32
|
+
feedback: str | None = None
|
|
33
|
+
subscores: dict[str, float] | None = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PredictorFeedbackFn(Protocol):
|
|
37
|
+
def __call__(
|
|
38
|
+
self,
|
|
39
|
+
predictor_output: dict[str, Any],
|
|
40
|
+
predictor_inputs: dict[str, Any],
|
|
41
|
+
module_inputs: Example,
|
|
42
|
+
module_outputs: Prediction,
|
|
43
|
+
captured_trace: DSPyTrace,
|
|
44
|
+
) -> ScoreWithFeedback:
|
|
45
|
+
"""
|
|
46
|
+
This function is used to provide feedback to a specific predictor.
|
|
47
|
+
The function is called with the following arguments:
|
|
48
|
+
- predictor_output: The output of the predictor.
|
|
49
|
+
- predictor_inputs: The inputs to the predictor.
|
|
50
|
+
- module_inputs: The inputs to the whole program --- `Example`.
|
|
51
|
+
- module_outputs: The outputs of the whole program --- `Prediction`.
|
|
52
|
+
- captured_trace: The trace of the module's execution.
|
|
53
|
+
# Shape of trace is: [predictor_invocation_idx -> Tuple[Predictor, PredictorInputs, Prediction]]
|
|
54
|
+
# Each trace is a tuple of (Predictor, PredictorInputs, Prediction)
|
|
55
|
+
|
|
56
|
+
The function should return a `ScoreWithFeedback` object.
|
|
57
|
+
The feedback is a string that is used to guide the evolution of the predictor.
|
|
58
|
+
"""
|
|
59
|
+
...
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DspyAdapter(GEPAAdapter[Example, TraceData, Prediction]):
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
student_module,
|
|
66
|
+
metric_fn: Callable,
|
|
67
|
+
feedback_map: dict[str, Callable],
|
|
68
|
+
failure_score=0.0,
|
|
69
|
+
num_threads: int | None = None,
|
|
70
|
+
add_format_failure_as_feedback: bool = False,
|
|
71
|
+
rng: random.Random | None = None,
|
|
72
|
+
):
|
|
73
|
+
self.student = student_module
|
|
74
|
+
self.metric_fn = metric_fn
|
|
75
|
+
self.feedback_map = feedback_map
|
|
76
|
+
self.failure_score = failure_score
|
|
77
|
+
self.num_threads = num_threads
|
|
78
|
+
self.add_format_failure_as_feedback = add_format_failure_as_feedback
|
|
79
|
+
self.rng = rng or random.Random(0)
|
|
80
|
+
|
|
81
|
+
# Cache predictor names/signatures
|
|
82
|
+
self.named_predictors = list(self.student.named_predictors())
|
|
83
|
+
|
|
84
|
+
def build_program(self, candidate: dict[str, str]):
|
|
85
|
+
new_prog = self.student.deepcopy()
|
|
86
|
+
for name, pred in new_prog.named_predictors():
|
|
87
|
+
if name in candidate:
|
|
88
|
+
pred.signature = pred.signature.with_instructions(candidate[name])
|
|
89
|
+
return new_prog
|
|
90
|
+
|
|
91
|
+
def evaluate(self, batch, candidate, capture_traces=False):
|
|
92
|
+
program = self.build_program(candidate)
|
|
93
|
+
|
|
94
|
+
outputs: list[Prediction] = []
|
|
95
|
+
scores: list[float] = []
|
|
96
|
+
subscores: list[dict[str, float]] = []
|
|
97
|
+
trajs: list[TraceData] | None = None
|
|
98
|
+
|
|
99
|
+
if capture_traces:
|
|
100
|
+
# bootstrap_trace_data-like flow with trace capture
|
|
101
|
+
from dspy.teleprompt.bootstrap_trace import bootstrap_trace_data
|
|
102
|
+
|
|
103
|
+
trajs = bootstrap_trace_data(
|
|
104
|
+
program=program,
|
|
105
|
+
dataset=batch,
|
|
106
|
+
metric=self.metric_fn,
|
|
107
|
+
num_threads=self.num_threads,
|
|
108
|
+
raise_on_error=False,
|
|
109
|
+
capture_failed_parses=True,
|
|
110
|
+
failure_score=self.failure_score,
|
|
111
|
+
format_failure_score=self.failure_score,
|
|
112
|
+
)
|
|
113
|
+
for t in trajs:
|
|
114
|
+
outputs.append(t["prediction"])
|
|
115
|
+
score_val, subscore_dict = self._extract_score_and_subscores(t.get("score"))
|
|
116
|
+
if score_val is None:
|
|
117
|
+
score_val = self.failure_score
|
|
118
|
+
scores.append(score_val)
|
|
119
|
+
subscores.append(subscore_dict)
|
|
120
|
+
else:
|
|
121
|
+
evaluator = Evaluate(
|
|
122
|
+
devset=batch,
|
|
123
|
+
metric=self.metric_fn,
|
|
124
|
+
num_threads=self.num_threads,
|
|
125
|
+
return_all_scores=True,
|
|
126
|
+
return_outputs=True,
|
|
127
|
+
failure_score=self.failure_score,
|
|
128
|
+
provide_traceback=True,
|
|
129
|
+
max_errors=len(batch) * 100,
|
|
130
|
+
)
|
|
131
|
+
res = evaluator(program)
|
|
132
|
+
outputs = [r[1] for r in res.results]
|
|
133
|
+
raw_scores = [r[2] for r in res.results]
|
|
134
|
+
for raw_score in raw_scores:
|
|
135
|
+
score_val, subscore_dict = self._extract_score_and_subscores(raw_score)
|
|
136
|
+
if score_val is None:
|
|
137
|
+
score_val = self.failure_score
|
|
138
|
+
scores.append(score_val)
|
|
139
|
+
subscores.append(subscore_dict)
|
|
140
|
+
|
|
141
|
+
has_subscores = any(subscores)
|
|
142
|
+
# Map DSPy "subscores" into GEPA objective score payloads.
|
|
143
|
+
return EvaluationBatch(
|
|
144
|
+
outputs=outputs,
|
|
145
|
+
scores=scores,
|
|
146
|
+
trajectories=trajs,
|
|
147
|
+
objective_scores=subscores if has_subscores else None,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def make_reflective_dataset(self, candidate, eval_batch, components_to_update):
|
|
151
|
+
from dspy.teleprompt.bootstrap_trace import FailedPrediction
|
|
152
|
+
|
|
153
|
+
program = self.build_program(candidate)
|
|
154
|
+
|
|
155
|
+
ret_d: dict[str, list[dict[str, Any]]] = {}
|
|
156
|
+
for pred_name in components_to_update:
|
|
157
|
+
module = None
|
|
158
|
+
for name, m in program.named_predictors():
|
|
159
|
+
if name == pred_name:
|
|
160
|
+
module = m
|
|
161
|
+
break
|
|
162
|
+
assert module is not None
|
|
163
|
+
|
|
164
|
+
items: list[dict[str, Any]] = []
|
|
165
|
+
for data in eval_batch.trajectories or []:
|
|
166
|
+
trace = data["trace"]
|
|
167
|
+
example = data["example"]
|
|
168
|
+
prediction = data["prediction"]
|
|
169
|
+
module_score_obj = data.get("score")
|
|
170
|
+
module_score, _ = self._extract_score_and_subscores(module_score_obj)
|
|
171
|
+
|
|
172
|
+
trace_instances = [t for t in trace if t[0].signature.equals(module.signature)]
|
|
173
|
+
if not self.add_format_failure_as_feedback:
|
|
174
|
+
trace_instances = [t for t in trace_instances if not isinstance(t[2], FailedPrediction)]
|
|
175
|
+
if len(trace_instances) == 0:
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
selected = None
|
|
179
|
+
for t in trace_instances:
|
|
180
|
+
if isinstance(t[2], FailedPrediction):
|
|
181
|
+
selected = t
|
|
182
|
+
break
|
|
183
|
+
|
|
184
|
+
if selected is None:
|
|
185
|
+
if isinstance(prediction, FailedPrediction):
|
|
186
|
+
continue
|
|
187
|
+
selected = self.rng.choice(trace_instances)
|
|
188
|
+
|
|
189
|
+
inputs = selected[1]
|
|
190
|
+
outputs = selected[2]
|
|
191
|
+
|
|
192
|
+
new_inputs = {}
|
|
193
|
+
new_outputs = {}
|
|
194
|
+
|
|
195
|
+
contains_history = False
|
|
196
|
+
history_key_name = None
|
|
197
|
+
for input_key, input_val in inputs.items():
|
|
198
|
+
if isinstance(input_val, History):
|
|
199
|
+
contains_history = True
|
|
200
|
+
assert history_key_name is None
|
|
201
|
+
history_key_name = input_key
|
|
202
|
+
|
|
203
|
+
if contains_history:
|
|
204
|
+
s = "```json\n"
|
|
205
|
+
for i, message in enumerate(inputs[history_key_name].messages):
|
|
206
|
+
s += f" {i}: {message}\n"
|
|
207
|
+
s += "```"
|
|
208
|
+
new_inputs["Context"] = s
|
|
209
|
+
|
|
210
|
+
for input_key, input_val in inputs.items():
|
|
211
|
+
if contains_history and input_key == history_key_name:
|
|
212
|
+
continue
|
|
213
|
+
new_inputs[input_key] = str(input_val)
|
|
214
|
+
|
|
215
|
+
if isinstance(outputs, FailedPrediction):
|
|
216
|
+
s = "Couldn't parse the output as per the expected output format. The model's raw response was:\n"
|
|
217
|
+
s += "```\n"
|
|
218
|
+
s += outputs.completion_text + "\n"
|
|
219
|
+
s += "```\n\n"
|
|
220
|
+
new_outputs = s
|
|
221
|
+
else:
|
|
222
|
+
for output_key, output_val in outputs.items():
|
|
223
|
+
new_outputs[output_key] = str(output_val)
|
|
224
|
+
|
|
225
|
+
d = {"Inputs": new_inputs, "Generated Outputs": new_outputs}
|
|
226
|
+
if isinstance(outputs, FailedPrediction):
|
|
227
|
+
adapter = ChatAdapter()
|
|
228
|
+
structure_instruction = ""
|
|
229
|
+
for dd in adapter.format(module.signature, [], {}):
|
|
230
|
+
structure_instruction += dd["role"] + ": " + dd["content"] + "\n"
|
|
231
|
+
d["Feedback"] = "Your output failed to parse. Follow this structure:\n" + structure_instruction
|
|
232
|
+
# d['score'] = self.failure_score
|
|
233
|
+
else:
|
|
234
|
+
feedback_fn = self.feedback_map[pred_name]
|
|
235
|
+
fb = feedback_fn(
|
|
236
|
+
predictor_output=outputs,
|
|
237
|
+
predictor_inputs=inputs,
|
|
238
|
+
module_inputs=example,
|
|
239
|
+
module_outputs=prediction,
|
|
240
|
+
captured_trace=trace,
|
|
241
|
+
)
|
|
242
|
+
if isinstance(fb, dict):
|
|
243
|
+
feedback_score = fb.get("score")
|
|
244
|
+
feedback_text = fb.get("feedback", "")
|
|
245
|
+
else:
|
|
246
|
+
feedback_score = getattr(fb, "score", None)
|
|
247
|
+
feedback_text = getattr(fb, "feedback", "")
|
|
248
|
+
d["Feedback"] = feedback_text
|
|
249
|
+
if module_score is not None and feedback_score is not None:
|
|
250
|
+
assert abs(feedback_score - module_score) < 1e-8, (
|
|
251
|
+
"Currently, GEPA only supports feedback functions that return the same score as the module's score. "
|
|
252
|
+
f"However, the module-level score is {module_score} and the feedback score is {feedback_score}."
|
|
253
|
+
)
|
|
254
|
+
items.append(d)
|
|
255
|
+
|
|
256
|
+
if len(items) == 0:
|
|
257
|
+
# raise Exception(f"No valid predictions found for module {module.signature}.")
|
|
258
|
+
continue
|
|
259
|
+
ret_d[pred_name] = items
|
|
260
|
+
|
|
261
|
+
if len(ret_d) == 0:
|
|
262
|
+
raise Exception("No valid predictions found for any module.")
|
|
263
|
+
|
|
264
|
+
return ret_d
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def _extract_score_and_subscores(score_obj: Any) -> tuple[float | None, dict[str, float]]:
|
|
268
|
+
if score_obj is None:
|
|
269
|
+
return None, {}
|
|
270
|
+
if isinstance(score_obj, dict):
|
|
271
|
+
score_val = score_obj.get("score")
|
|
272
|
+
subscores = score_obj.get("subscores") or {}
|
|
273
|
+
return score_val, dict(subscores)
|
|
274
|
+
if hasattr(score_obj, "score"):
|
|
275
|
+
score_val = getattr(score_obj, "score", None)
|
|
276
|
+
subscores = getattr(score_obj, "subscores", None) or {}
|
|
277
|
+
return score_val, dict(subscores)
|
|
278
|
+
try:
|
|
279
|
+
return float(score_obj), {}
|
|
280
|
+
except (TypeError, ValueError):
|
|
281
|
+
return None, {}
|
|
282
|
+
|
|
283
|
+
# TODO: The current DSPyAdapter implementation uses the GEPA default propose_new_texts.
|
|
284
|
+
# We can potentially override this, to use the instruction proposal similar to MIPROv2.
|
|
285
|
+
|
|
286
|
+
# def propose_new_texts(
|
|
287
|
+
# self,
|
|
288
|
+
# candidate: Dict[str, str],
|
|
289
|
+
# reflective_dataset: Dict[str, List[Dict[str, Any]]],
|
|
290
|
+
# components_to_update: List[str]
|
|
291
|
+
# ) -> Dict[str, str]:
|
|
292
|
+
# if self.adapter.propose_new_texts is not None:
|
|
293
|
+
# return self.adapter.propose_new_texts(candidate, reflective_dataset, components_to_update)
|
|
294
|
+
|
|
295
|
+
# from .instruction_proposal import InstructionProposalSignature
|
|
296
|
+
# new_texts: Dict[str, str] = {}
|
|
297
|
+
# for name in components_to_update:
|
|
298
|
+
# base_instruction = candidate[name]
|
|
299
|
+
# dataset_with_feedback = reflective_dataset[name]
|
|
300
|
+
# new_texts[name] = InstructionProposalSignature.run(
|
|
301
|
+
# lm=self.reflection_lm,
|
|
302
|
+
# input_dict={
|
|
303
|
+
# "current_instruction_doc": base_instruction,
|
|
304
|
+
# "dataset_with_feedback": dataset_with_feedback
|
|
305
|
+
# }
|
|
306
|
+
# )['new_instruction']
|
|
307
|
+
# return new_texts
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# DSPy Full Program Adapter
|
|
2
|
+
|
|
3
|
+
This adapter lets GEPA evolve entire DSPy programs—including signatures, modules, and control flow—not just prompts or instructions.
|
|
4
|
+
|
|
5
|
+
## Usage
|
|
6
|
+
|
|
7
|
+
First, install DSPy (version 3.0.2 or higher) with: `pip install 'dspy>=3.0.2'`.
|
|
8
|
+
|
|
9
|
+
Now, let's use GEPA to generate a DSPy program to solve MATH (benchmark). We start with a very very simple DSPy program `dspy.ChainOfThought("question -> answer")`:
|
|
10
|
+
```python
|
|
11
|
+
from gepa import optimize
|
|
12
|
+
from gepa.adapters.dspy_full_program_adapter.full_program_adapter import DspyAdapter
|
|
13
|
+
import dspy
|
|
14
|
+
|
|
15
|
+
# Standard DSPy metric function
|
|
16
|
+
def metric_fn(example, pred, trace=None):
|
|
17
|
+
...
|
|
18
|
+
|
|
19
|
+
# Start with a basic program. This code block must export a `program` that shows how the task should be performed
|
|
20
|
+
seed_program = """import dspy
|
|
21
|
+
program = dspy.ChainOfThought("question -> answer")"""
|
|
22
|
+
|
|
23
|
+
# Run optimization
|
|
24
|
+
reflection_lm = dspy.LM(model="openai/gpt-4.1", max_tokens=32000) # <-- This LM will only be used to propose new DSPy programs
|
|
25
|
+
result = optimize(
|
|
26
|
+
seed_candidate={"program": seed_program},
|
|
27
|
+
trainset=train_data,
|
|
28
|
+
valset=val_data,
|
|
29
|
+
adapter=DspyAdapter(
|
|
30
|
+
task_lm=dspy.LM(model="openai/gpt-4.1-nano", max_tokens=32000), # <-- This LM will be used for the downstream task
|
|
31
|
+
metric_fn=metric_fn,
|
|
32
|
+
reflection_lm=lambda x: reflection_lm(x)[0],
|
|
33
|
+
),
|
|
34
|
+
max_metric_calls=2000,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Get the evolved program
|
|
38
|
+
optimized_program_code = result.best_candidate["program"]
|
|
39
|
+
print(optimized_program_code)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
Using dspy.ChainOfThought with GPT-4.1 Nano achieves a score of **67%**, while the following GEPA-optimized program boosts performance to **93%**!
|
|
43
|
+
```
|
|
44
|
+
import dspy
|
|
45
|
+
from typing import Optional
|
|
46
|
+
|
|
47
|
+
class MathQAReasoningSignature(dspy.Signature):
|
|
48
|
+
"""
|
|
49
|
+
Solve the given math word problem step by step, showing all necessary reasoning and calculations.
|
|
50
|
+
- First, provide a clear, detailed, and logically ordered reasoning chain, using equations and algebraic steps as needed.
|
|
51
|
+
- Then, extract the final answer in the required format, strictly following these rules:
|
|
52
|
+
* If the answer should be a number, output only the number (no units, unless explicitly requested).
|
|
53
|
+
* If the answer should be an algebraic expression, output it in LaTeX math mode (e.g., \frac{h^2}{m}).
|
|
54
|
+
* Do not include explanatory text, units, or extra formatting in the answer field unless the question explicitly requests it.
|
|
55
|
+
Common pitfalls:
|
|
56
|
+
- Including units when not required.
|
|
57
|
+
- Restating the answer with extra words or formatting.
|
|
58
|
+
- Failing to simplify expressions or extract the final answer.
|
|
59
|
+
Edge cases:
|
|
60
|
+
- If the answer is a sum or list, output only the final value(s) as required.
|
|
61
|
+
- If the answer is an expression, ensure it is fully simplified.
|
|
62
|
+
Successful strategies:
|
|
63
|
+
- Use step-by-step algebraic manipulation.
|
|
64
|
+
- Double-check the final answer for correct format and content.
|
|
65
|
+
"""
|
|
66
|
+
question: str = dspy.InputField(desc="A math word problem to solve.")
|
|
67
|
+
reasoning: str = dspy.OutputField(desc="Step-by-step solution, with equations and logic.")
|
|
68
|
+
answer: str = dspy.OutputField(desc="Final answer, strictly in the required format (see instructions).")
|
|
69
|
+
|
|
70
|
+
class MathQAExtractSignature(dspy.Signature):
|
|
71
|
+
"""
|
|
72
|
+
Given a math word problem and a detailed step-by-step solution, extract ONLY the final answer in the required format.
|
|
73
|
+
- If the answer should be a number, output only the number (no units, unless explicitly requested).
|
|
74
|
+
- If the answer should be an algebraic expression, output it in LaTeX math mode (e.g., \frac{h^2}{m}).
|
|
75
|
+
- Do not include explanatory text, units, or extra formatting in the answer field unless the question explicitly requests it.
|
|
76
|
+
- If the answer is a sum or list, output only the final value(s) as required.
|
|
77
|
+
"""
|
|
78
|
+
question: str = dspy.InputField(desc="The original math word problem.")
|
|
79
|
+
reasoning: str = dspy.InputField(desc="A detailed, step-by-step solution to the problem.")
|
|
80
|
+
answer: str = dspy.OutputField(desc="Final answer, strictly in the required format.")
|
|
81
|
+
|
|
82
|
+
class MathQAModule(dspy.Module):
|
|
83
|
+
def __init__(self):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.reasoner = dspy.ChainOfThought(MathQAReasoningSignature)
|
|
86
|
+
self.extractor = dspy.Predict(MathQAExtractSignature)
|
|
87
|
+
|
|
88
|
+
def forward(self, question: str):
|
|
89
|
+
reasoning_pred = self.reasoner(question=question)
|
|
90
|
+
extract_pred = self.extractor(question=question, reasoning=reasoning_pred.reasoning)
|
|
91
|
+
return dspy.Prediction(
|
|
92
|
+
reasoning=reasoning_pred.reasoning,
|
|
93
|
+
answer=extract_pred.answer
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
program = MathQAModule()
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
A fully executable notebook to run this example is in [src/gepa/examples/dspy_full_program_evolution/example.ipynb](../../examples/dspy_full_program_evolution/example.ipynb)
|