mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import DataInst
|
|
6
|
+
from mantisdk.algorithm.gepa.lib.core.data_loader import DataId
|
|
7
|
+
from mantisdk.algorithm.gepa.lib.core.state import GEPAState, ValsetEvaluation
|
|
8
|
+
from mantisdk.algorithm.gepa.lib.strategies.eval_policy import EvaluationPolicy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def log_detailed_metrics_after_discovering_new_program(
|
|
12
|
+
logger,
|
|
13
|
+
gepa_state: GEPAState,
|
|
14
|
+
new_program_idx,
|
|
15
|
+
valset_evaluation: ValsetEvaluation,
|
|
16
|
+
objective_scores,
|
|
17
|
+
experiment_tracker,
|
|
18
|
+
linear_pareto_front_program_idx,
|
|
19
|
+
valset_size: int,
|
|
20
|
+
val_evaluation_policy: EvaluationPolicy[DataId, DataInst],
|
|
21
|
+
log_individual_valset_scores_and_programs: bool = False,
|
|
22
|
+
):
|
|
23
|
+
# best_prog_per_agg_val_score = idxmax(gepa_state.program_full_scores_val_set)
|
|
24
|
+
best_prog_per_agg_val_score = val_evaluation_policy.get_best_program(gepa_state)
|
|
25
|
+
best_score_on_valset = val_evaluation_policy.get_valset_score(best_prog_per_agg_val_score, gepa_state)
|
|
26
|
+
|
|
27
|
+
# avg, coverage = gepa_state.get_program_average_val_subset(new_program_idx)
|
|
28
|
+
valset_score = val_evaluation_policy.get_valset_score(new_program_idx, gepa_state)
|
|
29
|
+
valset_scores = valset_evaluation.scores_by_val_id
|
|
30
|
+
coverage = len(valset_scores)
|
|
31
|
+
logger.log(
|
|
32
|
+
f"Iteration {gepa_state.i + 1}: Valset score for new program: {valset_score}"
|
|
33
|
+
f" (coverage {coverage} / {valset_size})"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
agg_valset_score_new_program = val_evaluation_policy.get_valset_score(new_program_idx, gepa_state)
|
|
37
|
+
|
|
38
|
+
logger.log(f"Iteration {gepa_state.i + 1}: Val aggregate for new program: {agg_valset_score_new_program}")
|
|
39
|
+
logger.log(f"Iteration {gepa_state.i + 1}: Individual valset scores for new program: {valset_scores}")
|
|
40
|
+
if objective_scores:
|
|
41
|
+
logger.log(f"Iteration {gepa_state.i + 1}: Objective aggregate scores for new program: {objective_scores}")
|
|
42
|
+
logger.log(f"Iteration {gepa_state.i + 1}: New valset pareto front scores: {gepa_state.pareto_front_valset}")
|
|
43
|
+
if gepa_state.objective_pareto_front:
|
|
44
|
+
logger.log(f"Iteration {gepa_state.i + 1}: Objective pareto front scores: {gepa_state.objective_pareto_front}")
|
|
45
|
+
|
|
46
|
+
pareto_scores = list(gepa_state.pareto_front_valset.values())
|
|
47
|
+
assert all(score > float("-inf") for score in pareto_scores), (
|
|
48
|
+
"Should have at least one valid score per validation example"
|
|
49
|
+
)
|
|
50
|
+
assert len(pareto_scores) > 0
|
|
51
|
+
pareto_avg = sum(pareto_scores) / len(pareto_scores)
|
|
52
|
+
|
|
53
|
+
logger.log(f"Iteration {gepa_state.i + 1}: Valset pareto front aggregate score: {pareto_avg}")
|
|
54
|
+
logger.log(
|
|
55
|
+
f"Iteration {gepa_state.i + 1}: Updated valset pareto front programs: {gepa_state.program_at_pareto_front_valset}"
|
|
56
|
+
)
|
|
57
|
+
if gepa_state.program_at_pareto_front_objectives:
|
|
58
|
+
logger.log(
|
|
59
|
+
f"Iteration {gepa_state.i + 1}: Updated objective pareto front programs: {gepa_state.program_at_pareto_front_objectives}"
|
|
60
|
+
)
|
|
61
|
+
logger.log(
|
|
62
|
+
f"Iteration {gepa_state.i + 1}: Best valset aggregate score so far: {max(gepa_state.program_full_scores_val_set)}"
|
|
63
|
+
)
|
|
64
|
+
logger.log(
|
|
65
|
+
f"Iteration {gepa_state.i + 1}: Best program as per aggregate score on valset: {best_prog_per_agg_val_score}"
|
|
66
|
+
)
|
|
67
|
+
logger.log(f"Iteration {gepa_state.i + 1}: Best score on valset: {best_score_on_valset}")
|
|
68
|
+
logger.log(f"Iteration {gepa_state.i + 1}: Linear pareto front program index: {linear_pareto_front_program_idx}")
|
|
69
|
+
logger.log(f"Iteration {gepa_state.i + 1}: New program candidate index: {new_program_idx}")
|
|
70
|
+
|
|
71
|
+
metrics = {
|
|
72
|
+
"iteration": gepa_state.i + 1,
|
|
73
|
+
"new_program_idx": new_program_idx,
|
|
74
|
+
"valset_pareto_front_agg": pareto_avg,
|
|
75
|
+
"valset_pareto_front_programs": {k: list(v) for k, v in gepa_state.program_at_pareto_front_valset.items()},
|
|
76
|
+
"best_valset_agg_score": best_score_on_valset,
|
|
77
|
+
"linear_pareto_front_program_idx": linear_pareto_front_program_idx,
|
|
78
|
+
"best_program_as_per_agg_score_valset": best_prog_per_agg_val_score,
|
|
79
|
+
"best_score_on_valset": best_score_on_valset,
|
|
80
|
+
"val_evaluated_count_new_program": coverage,
|
|
81
|
+
"val_total_count": valset_size,
|
|
82
|
+
"val_program_average": valset_score,
|
|
83
|
+
}
|
|
84
|
+
if log_individual_valset_scores_and_programs:
|
|
85
|
+
metrics.update(
|
|
86
|
+
{
|
|
87
|
+
"valset_pareto_front_scores": dict(gepa_state.pareto_front_valset),
|
|
88
|
+
"individual_valset_score_new_program": dict(valset_scores),
|
|
89
|
+
}
|
|
90
|
+
)
|
|
91
|
+
if objective_scores:
|
|
92
|
+
metrics["objective_scores_new_program"] = dict(objective_scores)
|
|
93
|
+
if valset_evaluation.objective_scores_by_val_id:
|
|
94
|
+
metrics["objective_scores_by_val_new_program"] = {
|
|
95
|
+
val_id: dict(scores) for val_id, scores in valset_evaluation.objective_scores_by_val_id.items()
|
|
96
|
+
}
|
|
97
|
+
if gepa_state.objective_pareto_front:
|
|
98
|
+
metrics["objective_pareto_front_scores"] = dict(gepa_state.objective_pareto_front)
|
|
99
|
+
metrics["objective_pareto_front_programs"] = {
|
|
100
|
+
k: list(v) for k, v in gepa_state.program_at_pareto_front_objectives.items()
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
experiment_tracker.log_metrics(metrics, step=gepa_state.i + 1)
|
|
File without changes
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any, Generic, Protocol
|
|
6
|
+
|
|
7
|
+
from mantisdk.algorithm.gepa.lib.core.data_loader import DataId
|
|
8
|
+
from mantisdk.algorithm.gepa.lib.core.state import GEPAState
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class CandidateProposal(Generic[DataId]):
|
|
13
|
+
candidate: dict[str, str]
|
|
14
|
+
parent_program_ids: list[int]
|
|
15
|
+
# Optional mini-batch / subsample info
|
|
16
|
+
subsample_indices: list[DataId] | None = None
|
|
17
|
+
subsample_scores_before: list[float] | None = None
|
|
18
|
+
subsample_scores_after: list[float] | None = None
|
|
19
|
+
# Free-form metadata for logging/trace
|
|
20
|
+
tag: str = ""
|
|
21
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ProposeNewCandidate(Protocol[DataId]):
|
|
25
|
+
"""
|
|
26
|
+
Strategy that receives the current optimizer state and proposes a new candidate or returns None.
|
|
27
|
+
It may compute subsample evaluations, set trace fields in state, etc.
|
|
28
|
+
The engine will handle acceptance and full eval unless the strategy already did those and encoded in metadata.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def propose(self, state: GEPAState[Any, DataId]) -> CandidateProposal | None: ...
|
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
import random
|
|
6
|
+
from collections.abc import Callable, Iterable, Sequence
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
|
|
9
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import Candidate, DataInst, RolloutOutput
|
|
10
|
+
from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataLoader
|
|
11
|
+
from mantisdk.algorithm.gepa.lib.core.state import GEPAState, ObjectiveScores, ProgramIdx
|
|
12
|
+
from mantisdk.algorithm.gepa.lib.gepa_utils import find_dominator_programs
|
|
13
|
+
from mantisdk.algorithm.gepa.lib.logging.logger import LoggerProtocol
|
|
14
|
+
from mantisdk.algorithm.gepa.lib.proposer.base import CandidateProposal, ProposeNewCandidate
|
|
15
|
+
|
|
16
|
+
AncestorLog = tuple[int, int, int]
|
|
17
|
+
MergeDescription = tuple[int, int, tuple[int, ...]]
|
|
18
|
+
MergeAttempt = tuple[Candidate, ProgramIdx, ProgramIdx, ProgramIdx] | None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def does_triplet_have_desirable_predictors(
|
|
22
|
+
program_candidates: Sequence[Candidate],
|
|
23
|
+
ancestor: ProgramIdx,
|
|
24
|
+
id1: ProgramIdx,
|
|
25
|
+
id2: ProgramIdx,
|
|
26
|
+
) -> bool:
|
|
27
|
+
found_predictors: list[tuple[int, int]] = []
|
|
28
|
+
pred_names = list(program_candidates[ancestor].keys())
|
|
29
|
+
for pred_idx, pred_name in enumerate(pred_names):
|
|
30
|
+
pred_anc = program_candidates[ancestor][pred_name]
|
|
31
|
+
pred_id1 = program_candidates[id1][pred_name]
|
|
32
|
+
pred_id2 = program_candidates[id2][pred_name]
|
|
33
|
+
if (pred_anc == pred_id1 or pred_anc == pred_id2) and pred_id1 != pred_id2:
|
|
34
|
+
same_as_ancestor_id = 1 if pred_anc == pred_id1 else 2
|
|
35
|
+
found_predictors.append((pred_idx, same_as_ancestor_id))
|
|
36
|
+
|
|
37
|
+
return len(found_predictors) > 0
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def filter_ancestors(
|
|
41
|
+
i: ProgramIdx,
|
|
42
|
+
j: ProgramIdx,
|
|
43
|
+
common_ancestors: Iterable[ProgramIdx],
|
|
44
|
+
merges_performed: tuple[list[AncestorLog], list[MergeDescription]],
|
|
45
|
+
agg_scores: Sequence[float],
|
|
46
|
+
program_candidates: Sequence[Candidate],
|
|
47
|
+
) -> list[ProgramIdx]:
|
|
48
|
+
filtered_ancestors: list[ProgramIdx] = []
|
|
49
|
+
for ancestor in common_ancestors:
|
|
50
|
+
if (i, j, ancestor) in merges_performed[0]:
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
if agg_scores[ancestor] > agg_scores[i] or agg_scores[ancestor] > agg_scores[j]:
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
if not does_triplet_have_desirable_predictors(program_candidates, ancestor, i, j):
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
filtered_ancestors.append(ancestor)
|
|
60
|
+
return filtered_ancestors
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def find_common_ancestor_pair(
|
|
64
|
+
rng: random.Random,
|
|
65
|
+
parent_list: Sequence[Sequence[int | None]],
|
|
66
|
+
program_indexes: Sequence[int],
|
|
67
|
+
merges_performed: tuple[list[AncestorLog], list[MergeDescription]],
|
|
68
|
+
agg_scores: Sequence[float],
|
|
69
|
+
program_candidates: Sequence[Candidate],
|
|
70
|
+
max_attempts: int = 10,
|
|
71
|
+
) -> tuple[int, int, int] | None:
|
|
72
|
+
def get_ancestors(node: int, ancestors_found: set[int]) -> list[int]:
|
|
73
|
+
parents = parent_list[node]
|
|
74
|
+
for parent in parents:
|
|
75
|
+
if parent is not None and parent not in ancestors_found:
|
|
76
|
+
ancestors_found.add(parent)
|
|
77
|
+
get_ancestors(parent, ancestors_found)
|
|
78
|
+
|
|
79
|
+
return list(ancestors_found)
|
|
80
|
+
|
|
81
|
+
for _ in range(max_attempts):
|
|
82
|
+
if len(program_indexes) < 2:
|
|
83
|
+
return None
|
|
84
|
+
i, j = rng.sample(list(program_indexes), 2)
|
|
85
|
+
if i == j:
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
if j < i:
|
|
89
|
+
i, j = j, i
|
|
90
|
+
|
|
91
|
+
ancestors_i = get_ancestors(i, set())
|
|
92
|
+
ancestors_j = get_ancestors(j, set())
|
|
93
|
+
|
|
94
|
+
if j in ancestors_i or i in ancestors_j:
|
|
95
|
+
# If one is an ancestor of the other, we cannot merge them
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
common_ancestors = set(ancestors_i) & set(ancestors_j)
|
|
99
|
+
common_ancestors = filter_ancestors(i, j, common_ancestors, merges_performed, agg_scores, program_candidates)
|
|
100
|
+
if common_ancestors:
|
|
101
|
+
# Select a random common ancestor
|
|
102
|
+
common_ancestor = rng.choices(
|
|
103
|
+
list(common_ancestors),
|
|
104
|
+
k=1,
|
|
105
|
+
weights=[agg_scores[ancestor] for ancestor in common_ancestors],
|
|
106
|
+
)[0]
|
|
107
|
+
return (i, j, common_ancestor)
|
|
108
|
+
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def sample_and_attempt_merge_programs_by_common_predictors(
|
|
113
|
+
agg_scores: Sequence[float],
|
|
114
|
+
rng: random.Random,
|
|
115
|
+
merge_candidates: Sequence[int],
|
|
116
|
+
merges_performed: tuple[list[AncestorLog], list[MergeDescription]],
|
|
117
|
+
program_candidates: Sequence[Candidate],
|
|
118
|
+
parent_program_for_candidate: Sequence[Sequence[int | None]],
|
|
119
|
+
has_val_support_overlap: Callable[[ProgramIdx, ProgramIdx], bool] | None = None,
|
|
120
|
+
max_attempts: int = 10,
|
|
121
|
+
) -> MergeAttempt:
|
|
122
|
+
if len(merge_candidates) < 2:
|
|
123
|
+
return None
|
|
124
|
+
if len(parent_program_for_candidate) < 3:
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
for _ in range(max_attempts):
|
|
128
|
+
ids_to_merge = find_common_ancestor_pair(
|
|
129
|
+
rng,
|
|
130
|
+
parent_program_for_candidate,
|
|
131
|
+
list(merge_candidates),
|
|
132
|
+
merges_performed=merges_performed,
|
|
133
|
+
agg_scores=agg_scores,
|
|
134
|
+
program_candidates=program_candidates,
|
|
135
|
+
max_attempts=max_attempts,
|
|
136
|
+
)
|
|
137
|
+
if ids_to_merge is None:
|
|
138
|
+
continue
|
|
139
|
+
id1, id2, ancestor = ids_to_merge
|
|
140
|
+
|
|
141
|
+
if (id1, id2, ancestor) in merges_performed[0]:
|
|
142
|
+
continue
|
|
143
|
+
assert agg_scores[ancestor] <= agg_scores[id1], "Ancestor should not be better than its descendants"
|
|
144
|
+
assert agg_scores[ancestor] <= agg_scores[id2], "Ancestor should not be better than its descendants"
|
|
145
|
+
assert id1 != id2, "Cannot merge the same program"
|
|
146
|
+
|
|
147
|
+
# Now we have a common ancestor, which is outperformed by both its descendants
|
|
148
|
+
|
|
149
|
+
new_program: Candidate = deepcopy(program_candidates[ancestor])
|
|
150
|
+
|
|
151
|
+
new_prog_desc: tuple[ProgramIdx, ...] = ()
|
|
152
|
+
|
|
153
|
+
pred_names = set(program_candidates[ancestor].keys())
|
|
154
|
+
assert pred_names == set(program_candidates[id1].keys()) == set(program_candidates[id2].keys()), (
|
|
155
|
+
"Predictors should be the same across all programs"
|
|
156
|
+
)
|
|
157
|
+
for pred_name in pred_names:
|
|
158
|
+
pred_anc = program_candidates[ancestor][pred_name]
|
|
159
|
+
pred_id1 = program_candidates[id1][pred_name]
|
|
160
|
+
pred_id2 = program_candidates[id2][pred_name]
|
|
161
|
+
if (pred_anc == pred_id1 or pred_anc == pred_id2) and pred_id1 != pred_id2:
|
|
162
|
+
# We have a predictor that is the same as one of its ancestors, so we can update it with the other
|
|
163
|
+
same_as_ancestor_id = 1 if pred_anc == pred_id1 else 2
|
|
164
|
+
new_value_idx = id2 if same_as_ancestor_id == 1 else id1
|
|
165
|
+
new_program[pred_name] = program_candidates[new_value_idx][pred_name]
|
|
166
|
+
new_prog_desc = (*new_prog_desc, new_value_idx)
|
|
167
|
+
elif pred_anc != pred_id1 and pred_anc != pred_id2:
|
|
168
|
+
# Both predictors are different from the ancestor, and it is difficult to decide which one gives the benefits
|
|
169
|
+
# We randomly select one of the descendants to update the predictor
|
|
170
|
+
# The probability of selecting is proportional to the agg_scores of the descendants
|
|
171
|
+
# prog_to_get_instruction_from = id1 if (rng.random() < (agg_scores[id1] / (agg_scores[id1] + agg_scores[id2]))) else id2
|
|
172
|
+
prog_to_get_instruction_from = (
|
|
173
|
+
id1
|
|
174
|
+
if agg_scores[id1] > agg_scores[id2]
|
|
175
|
+
else (id2 if agg_scores[id2] > agg_scores[id1] else rng.choice([id1, id2]))
|
|
176
|
+
)
|
|
177
|
+
new_program[pred_name] = program_candidates[prog_to_get_instruction_from][pred_name]
|
|
178
|
+
new_prog_desc = (*new_prog_desc, prog_to_get_instruction_from)
|
|
179
|
+
elif pred_id1 == pred_id2:
|
|
180
|
+
# Either both predictors are the same, or both are different from the ancestor
|
|
181
|
+
# If both are different from the ancestor, we should use the new predictor, so selecting either one of the descendants is fine
|
|
182
|
+
# If both are same as the ancesor, again selecting any one of the descendants is fine
|
|
183
|
+
# So let's select id1
|
|
184
|
+
new_program[pred_name] = program_candidates[id1][pred_name]
|
|
185
|
+
new_prog_desc = (*new_prog_desc, id1)
|
|
186
|
+
else: # pragma: no cover - defensive
|
|
187
|
+
raise AssertionError("Unexpected case in predictor merging logic")
|
|
188
|
+
|
|
189
|
+
if (id1, id2, new_prog_desc) in merges_performed[1]:
|
|
190
|
+
# This triplet has already been merged, so we skip it
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
if has_val_support_overlap and not has_val_support_overlap(id1, id2):
|
|
194
|
+
# Not enough overlapping validation support for candidates
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
merges_performed[1].append((id1, id2, new_prog_desc))
|
|
198
|
+
|
|
199
|
+
return new_program, id1, id2, ancestor
|
|
200
|
+
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class MergeProposer(ProposeNewCandidate[DataId]):
|
|
205
|
+
"""
|
|
206
|
+
Implements merge flow that combines compatible descendants of a common ancestor.
|
|
207
|
+
|
|
208
|
+
- Find merge candidates among Pareto front dominators
|
|
209
|
+
- Attempt a merge via sample_and_attempt_merge_programs_by_common_predictors
|
|
210
|
+
- Subsample eval on valset-driven selected indices
|
|
211
|
+
- Return proposal if merge's subsample score >= max(parents)
|
|
212
|
+
The engine handles full eval + adding to state.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
logger: LoggerProtocol,
|
|
218
|
+
valset: DataLoader[DataId, DataInst],
|
|
219
|
+
evaluator: Callable[
|
|
220
|
+
[list[DataInst], dict[str, str]],
|
|
221
|
+
tuple[list[RolloutOutput], list[float], Sequence[ObjectiveScores] | None],
|
|
222
|
+
],
|
|
223
|
+
use_merge: bool,
|
|
224
|
+
max_merge_invocations: int,
|
|
225
|
+
val_overlap_floor: int = 5,
|
|
226
|
+
rng: random.Random | None = None,
|
|
227
|
+
):
|
|
228
|
+
self.logger = logger
|
|
229
|
+
self.valset = valset
|
|
230
|
+
self.evaluator = evaluator
|
|
231
|
+
self.use_merge = use_merge
|
|
232
|
+
self.max_merge_invocations = max_merge_invocations
|
|
233
|
+
self.rng = rng if rng is not None else random.Random(0)
|
|
234
|
+
|
|
235
|
+
if val_overlap_floor <= 0:
|
|
236
|
+
raise ValueError("val_overlap_floor should be a positive integer")
|
|
237
|
+
self.val_overlap_floor = val_overlap_floor
|
|
238
|
+
# Internal counters matching original behavior
|
|
239
|
+
self.merges_due = 0
|
|
240
|
+
self.total_merges_tested = 0
|
|
241
|
+
self.merges_performed: tuple[list[AncestorLog], list[MergeDescription]] = ([], [])
|
|
242
|
+
|
|
243
|
+
# Toggle controlled by engine: set True when last iter found new program
|
|
244
|
+
self.last_iter_found_new_program = False
|
|
245
|
+
|
|
246
|
+
def schedule_if_needed(self) -> None:
|
|
247
|
+
if self.use_merge and self.total_merges_tested < self.max_merge_invocations:
|
|
248
|
+
self.merges_due += 1
|
|
249
|
+
|
|
250
|
+
def select_eval_subsample_for_merged_program(
|
|
251
|
+
self,
|
|
252
|
+
scores1: dict[DataId, float],
|
|
253
|
+
scores2: dict[DataId, float],
|
|
254
|
+
num_subsample_ids: int = 5,
|
|
255
|
+
) -> list[DataId]:
|
|
256
|
+
common_ids = list(set(scores1.keys()) & set(scores2.keys()))
|
|
257
|
+
|
|
258
|
+
p1 = [idx for idx in common_ids if scores1[idx] > scores2[idx]]
|
|
259
|
+
p2 = [idx for idx in common_ids if scores2[idx] > scores1[idx]]
|
|
260
|
+
p3 = [idx for idx in common_ids if idx not in p1 and idx not in p2]
|
|
261
|
+
|
|
262
|
+
n_each = max(1, math.ceil(num_subsample_ids / 3))
|
|
263
|
+
selected: list[DataId] = []
|
|
264
|
+
for bucket in (p1, p2, p3):
|
|
265
|
+
if len(selected) >= num_subsample_ids:
|
|
266
|
+
break
|
|
267
|
+
available = [idx for idx in bucket if idx not in selected]
|
|
268
|
+
take = min(len(available), n_each, num_subsample_ids - len(selected))
|
|
269
|
+
if take > 0:
|
|
270
|
+
selected += self.rng.sample(available, k=take)
|
|
271
|
+
|
|
272
|
+
remaining = num_subsample_ids - len(selected)
|
|
273
|
+
if remaining > 0:
|
|
274
|
+
unused = [idx for idx in common_ids if idx not in selected]
|
|
275
|
+
if len(unused) >= remaining:
|
|
276
|
+
selected += self.rng.sample(unused, k=remaining)
|
|
277
|
+
elif common_ids:
|
|
278
|
+
selected += self.rng.choices(common_ids, k=remaining)
|
|
279
|
+
|
|
280
|
+
return selected[:num_subsample_ids]
|
|
281
|
+
|
|
282
|
+
def propose(self, state: GEPAState[RolloutOutput, DataId]) -> CandidateProposal[DataId] | None:
|
|
283
|
+
i = state.i + 1
|
|
284
|
+
state.full_program_trace[-1]["invoked_merge"] = True
|
|
285
|
+
|
|
286
|
+
# Only attempt when scheduled by engine and after a new program in last iteration
|
|
287
|
+
if not (self.use_merge and self.last_iter_found_new_program and self.merges_due > 0):
|
|
288
|
+
self.logger.log(f"Iteration {i}: No merge candidates scheduled")
|
|
289
|
+
return None
|
|
290
|
+
|
|
291
|
+
pareto_front_programs = state.get_pareto_front_mapping()
|
|
292
|
+
|
|
293
|
+
tracked_scores: Sequence[float] = getattr(
|
|
294
|
+
state, "per_program_tracked_scores", state.program_full_scores_val_set
|
|
295
|
+
)
|
|
296
|
+
merge_candidates = find_dominator_programs(pareto_front_programs, list(tracked_scores))
|
|
297
|
+
|
|
298
|
+
def has_val_support_overlap(id1: ProgramIdx, id2: ProgramIdx) -> bool:
|
|
299
|
+
common_ids = set(state.prog_candidate_val_subscores[id1].keys()) & set(
|
|
300
|
+
state.prog_candidate_val_subscores[id2].keys()
|
|
301
|
+
)
|
|
302
|
+
return len(common_ids) >= self.val_overlap_floor
|
|
303
|
+
|
|
304
|
+
merge_output = sample_and_attempt_merge_programs_by_common_predictors(
|
|
305
|
+
agg_scores=list(tracked_scores),
|
|
306
|
+
rng=self.rng,
|
|
307
|
+
merge_candidates=merge_candidates,
|
|
308
|
+
merges_performed=self.merges_performed,
|
|
309
|
+
program_candidates=state.program_candidates,
|
|
310
|
+
parent_program_for_candidate=state.parent_program_for_candidate,
|
|
311
|
+
has_val_support_overlap=has_val_support_overlap,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
if merge_output is None:
|
|
315
|
+
self.logger.log(f"Iteration {i}: No merge candidates found")
|
|
316
|
+
return None
|
|
317
|
+
|
|
318
|
+
new_program, id1, id2, ancestor = merge_output
|
|
319
|
+
state.full_program_trace[-1]["merged"] = True
|
|
320
|
+
state.full_program_trace[-1]["merged_entities"] = (id1, id2, ancestor)
|
|
321
|
+
self.merges_performed[0].append((id1, id2, ancestor))
|
|
322
|
+
self.logger.log(f"Iteration {i}: Merged programs {id1} and {id2} via ancestor {ancestor}")
|
|
323
|
+
|
|
324
|
+
subsample_ids = self.select_eval_subsample_for_merged_program(
|
|
325
|
+
state.prog_candidate_val_subscores[id1],
|
|
326
|
+
state.prog_candidate_val_subscores[id2],
|
|
327
|
+
)
|
|
328
|
+
if not subsample_ids:
|
|
329
|
+
self.logger.log(
|
|
330
|
+
f"Iteration {i}: Skipping merge of {id1} and {id2} due to insufficient overlapping val coverage"
|
|
331
|
+
)
|
|
332
|
+
return None
|
|
333
|
+
|
|
334
|
+
assert set(subsample_ids).issubset(state.prog_candidate_val_subscores[id1].keys())
|
|
335
|
+
assert set(subsample_ids).issubset(state.prog_candidate_val_subscores[id2].keys())
|
|
336
|
+
id1_sub_scores = [state.prog_candidate_val_subscores[id1][k] for k in subsample_ids]
|
|
337
|
+
id2_sub_scores = [state.prog_candidate_val_subscores[id2][k] for k in subsample_ids]
|
|
338
|
+
state.full_program_trace[-1]["subsample_ids"] = subsample_ids
|
|
339
|
+
|
|
340
|
+
new_sub_scores, actual_evals_count = state.cached_evaluate(
|
|
341
|
+
new_program, subsample_ids, self.valset.fetch, self.evaluator
|
|
342
|
+
)
|
|
343
|
+
state.full_program_trace[-1]["id1_subsample_scores"] = id1_sub_scores
|
|
344
|
+
state.full_program_trace[-1]["id2_subsample_scores"] = id2_sub_scores
|
|
345
|
+
state.full_program_trace[-1]["new_program_subsample_scores"] = new_sub_scores
|
|
346
|
+
state.total_num_evals += actual_evals_count
|
|
347
|
+
|
|
348
|
+
# Acceptance will be evaluated by engine (>= max(parents))
|
|
349
|
+
return CandidateProposal(
|
|
350
|
+
candidate=new_program,
|
|
351
|
+
parent_program_ids=[id1, id2],
|
|
352
|
+
subsample_indices=subsample_ids,
|
|
353
|
+
subsample_scores_before=[sum(id1_sub_scores), sum(id2_sub_scores)],
|
|
354
|
+
subsample_scores_after=new_sub_scores,
|
|
355
|
+
tag="merge",
|
|
356
|
+
metadata={"ancestor": ancestor},
|
|
357
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, ClassVar, Mapping, Protocol, runtime_checkable
|
|
6
|
+
|
|
7
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import Trajectory
|
|
8
|
+
from mantisdk.algorithm.gepa.lib.core.state import GEPAState
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@runtime_checkable
|
|
12
|
+
class CandidateSelector(Protocol):
|
|
13
|
+
def select_candidate_idx(self, state: GEPAState) -> int: ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ReflectionComponentSelector(Protocol):
|
|
17
|
+
def __call__(
|
|
18
|
+
self,
|
|
19
|
+
state: GEPAState,
|
|
20
|
+
trajectories: list[Trajectory],
|
|
21
|
+
subsample_scores: list[float],
|
|
22
|
+
candidate_idx: int,
|
|
23
|
+
candidate: dict[str, str],
|
|
24
|
+
) -> list[str]: ...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LanguageModel(Protocol):
|
|
28
|
+
def __call__(self, prompt: str) -> str: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class Signature:
|
|
33
|
+
prompt_template: ClassVar[str]
|
|
34
|
+
input_keys: ClassVar[list[str]]
|
|
35
|
+
output_keys: ClassVar[list[str]]
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def prompt_renderer(cls, input_dict: Mapping[str, Any]) -> str:
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def output_extractor(cls, lm_out: str) -> dict[str, str]:
|
|
43
|
+
raise NotImplementedError
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def run(cls, lm: LanguageModel, input_dict: Mapping[str, Any]) -> dict[str, str]:
|
|
47
|
+
full_prompt = cls.prompt_renderer(input_dict)
|
|
48
|
+
lm_out = lm(full_prompt).strip()
|
|
49
|
+
return cls.output_extractor(lm_out)
|