mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,636 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
import hashlib
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from collections.abc import Callable, Sequence
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any, ClassVar, Generic, Literal, TypeAlias
|
|
11
|
+
|
|
12
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import RolloutOutput
|
|
13
|
+
from mantisdk.algorithm.gepa.lib.core.data_loader import DataId
|
|
14
|
+
from mantisdk.algorithm.gepa.lib.gepa_utils import json_default
|
|
15
|
+
from mantisdk.algorithm.gepa.lib.logging.logger import LoggerProtocol
|
|
16
|
+
|
|
17
|
+
# Types for GEPAState
|
|
18
|
+
ProgramIdx = int
|
|
19
|
+
|
|
20
|
+
# Type aliases
|
|
21
|
+
ObjectiveScores: TypeAlias = dict[str, float]
|
|
22
|
+
FrontierType: TypeAlias = Literal["instance", "objective", "hybrid", "cartesian"]
|
|
23
|
+
"""Strategy for tracking Pareto frontiers: 'instance' (per validation example), 'objective' (per objective metric), 'hybrid' (both), or 'cartesian' (per example × objective)."""
|
|
24
|
+
FrontierKey: TypeAlias = DataId | str | tuple[str, DataId] | tuple[str, DataId, str]
|
|
25
|
+
"""Key type for frontier mappings depending on frontier_type."""
|
|
26
|
+
|
|
27
|
+
CandidateHash: TypeAlias = str
|
|
28
|
+
CacheKey: TypeAlias = tuple[CandidateHash, DataId]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _candidate_hash(candidate: dict[str, str]) -> CandidateHash:
|
|
32
|
+
"""Compute a deterministic hash of a candidate dictionary."""
|
|
33
|
+
return hashlib.sha256(json.dumps(sorted(candidate.items())).encode()).hexdigest()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class CachedEvaluation(Generic[RolloutOutput]):
|
|
38
|
+
"""Cached evaluation result for a (candidate, example) pair."""
|
|
39
|
+
|
|
40
|
+
output: RolloutOutput
|
|
41
|
+
score: float
|
|
42
|
+
objective_scores: ObjectiveScores | None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class EvaluationCache(Generic[RolloutOutput, DataId]):
|
|
47
|
+
"""Cache for storing evaluation results of (candidate, example) pairs."""
|
|
48
|
+
|
|
49
|
+
_cache: dict[CacheKey, CachedEvaluation[RolloutOutput]] = field(default_factory=dict)
|
|
50
|
+
|
|
51
|
+
def get(self, candidate: dict[str, str], example_id: DataId) -> CachedEvaluation[RolloutOutput] | None:
|
|
52
|
+
"""Retrieve cached evaluation result if it exists."""
|
|
53
|
+
return self._cache.get((_candidate_hash(candidate), example_id))
|
|
54
|
+
|
|
55
|
+
def put(
|
|
56
|
+
self,
|
|
57
|
+
candidate: dict[str, str],
|
|
58
|
+
example_id: DataId,
|
|
59
|
+
output: RolloutOutput,
|
|
60
|
+
score: float,
|
|
61
|
+
objective_scores: ObjectiveScores | None = None,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Store an evaluation result in the cache."""
|
|
64
|
+
self._cache[(_candidate_hash(candidate), example_id)] = CachedEvaluation(output, score, objective_scores)
|
|
65
|
+
|
|
66
|
+
def get_batch(
|
|
67
|
+
self, candidate: dict[str, str], example_ids: list[DataId]
|
|
68
|
+
) -> tuple[dict[DataId, CachedEvaluation[RolloutOutput]], list[DataId]]:
|
|
69
|
+
"""Look up cached results for a batch. Returns (cached_results, uncached_ids)."""
|
|
70
|
+
h = _candidate_hash(candidate)
|
|
71
|
+
cached, uncached = {}, []
|
|
72
|
+
for eid in example_ids:
|
|
73
|
+
if entry := self._cache.get((h, eid)):
|
|
74
|
+
cached[eid] = entry
|
|
75
|
+
else:
|
|
76
|
+
uncached.append(eid)
|
|
77
|
+
return cached, uncached
|
|
78
|
+
|
|
79
|
+
def put_batch(
|
|
80
|
+
self,
|
|
81
|
+
candidate: dict[str, str],
|
|
82
|
+
example_ids: list[DataId],
|
|
83
|
+
outputs: list[RolloutOutput],
|
|
84
|
+
scores: list[float],
|
|
85
|
+
objective_scores_list: Sequence[ObjectiveScores] | None = None,
|
|
86
|
+
) -> None:
|
|
87
|
+
"""Store evaluation results for a batch of examples."""
|
|
88
|
+
h = _candidate_hash(candidate)
|
|
89
|
+
for i, eid in enumerate(example_ids):
|
|
90
|
+
self._cache[(h, eid)] = CachedEvaluation(
|
|
91
|
+
outputs[i], scores[i], objective_scores_list[i] if objective_scores_list else None
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def evaluate_with_cache_full(
|
|
95
|
+
self,
|
|
96
|
+
candidate: dict[str, str],
|
|
97
|
+
example_ids: list[DataId],
|
|
98
|
+
fetcher: Callable[[list[DataId]], Any],
|
|
99
|
+
evaluator: Callable[[Any, dict[str, str]], tuple[Any, list[float], Sequence[ObjectiveScores] | None]],
|
|
100
|
+
) -> tuple[dict[DataId, RolloutOutput], dict[DataId, float], dict[DataId, ObjectiveScores] | None, int]:
|
|
101
|
+
"""
|
|
102
|
+
Evaluate using cache, returning full results.
|
|
103
|
+
|
|
104
|
+
Returns (outputs_by_id, scores_by_id, objective_scores_by_id, num_actual_evals).
|
|
105
|
+
"""
|
|
106
|
+
cached, uncached_ids = self.get_batch(candidate, example_ids)
|
|
107
|
+
|
|
108
|
+
outputs_by_id: dict[DataId, RolloutOutput] = {eid: c.output for eid, c in cached.items()}
|
|
109
|
+
scores_by_id: dict[DataId, float] = {eid: c.score for eid, c in cached.items()}
|
|
110
|
+
objective_by_id: dict[DataId, ObjectiveScores] | None = None
|
|
111
|
+
|
|
112
|
+
# Populate objective scores from cache
|
|
113
|
+
for eid, c in cached.items():
|
|
114
|
+
if c.objective_scores is not None:
|
|
115
|
+
objective_by_id = objective_by_id or {}
|
|
116
|
+
objective_by_id[eid] = c.objective_scores
|
|
117
|
+
|
|
118
|
+
# Evaluate uncached examples
|
|
119
|
+
if uncached_ids:
|
|
120
|
+
batch = fetcher(uncached_ids)
|
|
121
|
+
outputs, scores, obj_scores = evaluator(batch, candidate)
|
|
122
|
+
for idx, eid in enumerate(uncached_ids):
|
|
123
|
+
outputs_by_id[eid] = outputs[idx]
|
|
124
|
+
scores_by_id[eid] = scores[idx]
|
|
125
|
+
if obj_scores is not None:
|
|
126
|
+
objective_by_id = objective_by_id or {}
|
|
127
|
+
objective_by_id[eid] = obj_scores[idx]
|
|
128
|
+
self.put_batch(candidate, uncached_ids, outputs, scores, obj_scores)
|
|
129
|
+
|
|
130
|
+
return outputs_by_id, scores_by_id, objective_by_id, len(uncached_ids)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@dataclass(slots=True)
|
|
134
|
+
class ValsetEvaluation(Generic[RolloutOutput, DataId]):
|
|
135
|
+
"""Container for evaluation results on a validation set batch."""
|
|
136
|
+
|
|
137
|
+
outputs_by_val_id: dict[DataId, RolloutOutput]
|
|
138
|
+
scores_by_val_id: dict[DataId, float]
|
|
139
|
+
objective_scores_by_val_id: dict[DataId, ObjectiveScores] | None = None
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class GEPAState(Generic[RolloutOutput, DataId]):
|
|
143
|
+
"""Persistent optimizer state tracking candidates, sparse validation coverage, and objective frontiers."""
|
|
144
|
+
|
|
145
|
+
_VALIDATION_SCHEMA_VERSION: ClassVar[int] = 4
|
|
146
|
+
|
|
147
|
+
program_candidates: list[dict[str, str]]
|
|
148
|
+
parent_program_for_candidate: list[list[ProgramIdx | None]]
|
|
149
|
+
prog_candidate_val_subscores: list[dict[DataId, float]]
|
|
150
|
+
prog_candidate_objective_scores: list[ObjectiveScores]
|
|
151
|
+
|
|
152
|
+
pareto_front_valset: dict[DataId, float]
|
|
153
|
+
program_at_pareto_front_valset: dict[DataId, set[ProgramIdx]]
|
|
154
|
+
objective_pareto_front: ObjectiveScores
|
|
155
|
+
program_at_pareto_front_objectives: dict[str, set[ProgramIdx]]
|
|
156
|
+
pareto_front_cartesian: dict[tuple[DataId, str], float]
|
|
157
|
+
program_at_pareto_front_cartesian: dict[tuple[DataId, str], set[ProgramIdx]]
|
|
158
|
+
|
|
159
|
+
list_of_named_predictors: list[str]
|
|
160
|
+
named_predictor_id_to_update_next_for_program_candidate: list[int]
|
|
161
|
+
|
|
162
|
+
i: int
|
|
163
|
+
num_full_ds_evals: int
|
|
164
|
+
|
|
165
|
+
total_num_evals: int
|
|
166
|
+
|
|
167
|
+
num_metric_calls_by_discovery: list[int]
|
|
168
|
+
|
|
169
|
+
full_program_trace: list[dict[str, Any]]
|
|
170
|
+
best_outputs_valset: dict[DataId, list[tuple[ProgramIdx, RolloutOutput]]] | None
|
|
171
|
+
|
|
172
|
+
validation_schema_version: int
|
|
173
|
+
|
|
174
|
+
# Optional evaluation cache for (candidate, example) pairs
|
|
175
|
+
evaluation_cache: "EvaluationCache[RolloutOutput, DataId] | None"
|
|
176
|
+
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
seed_candidate: dict[str, str],
|
|
180
|
+
base_evaluation: ValsetEvaluation[RolloutOutput, DataId],
|
|
181
|
+
track_best_outputs: bool = False,
|
|
182
|
+
frontier_type: FrontierType = "instance",
|
|
183
|
+
evaluation_cache: "EvaluationCache[RolloutOutput, DataId] | None" = None,
|
|
184
|
+
):
|
|
185
|
+
self.program_candidates = [dict(seed_candidate)]
|
|
186
|
+
self.prog_candidate_val_subscores = [dict(base_evaluation.scores_by_val_id)]
|
|
187
|
+
|
|
188
|
+
base_objective_aggregates = self._aggregate_objective_scores(base_evaluation.objective_scores_by_val_id)
|
|
189
|
+
self.prog_candidate_objective_scores = [base_objective_aggregates]
|
|
190
|
+
|
|
191
|
+
self.parent_program_for_candidate = [[None]]
|
|
192
|
+
|
|
193
|
+
self.frontier_type: FrontierType = frontier_type
|
|
194
|
+
self.pareto_front_valset = {val_id: score for val_id, score in base_evaluation.scores_by_val_id.items()}
|
|
195
|
+
self.program_at_pareto_front_valset = {val_id: {0} for val_id in base_evaluation.scores_by_val_id.keys()}
|
|
196
|
+
self.objective_pareto_front = dict(base_objective_aggregates)
|
|
197
|
+
self.program_at_pareto_front_objectives = {objective: {0} for objective in base_objective_aggregates.keys()}
|
|
198
|
+
|
|
199
|
+
# Validate that objective scores are provided for frontier types that require them
|
|
200
|
+
if frontier_type in ("objective", "hybrid", "cartesian"):
|
|
201
|
+
if not base_evaluation.objective_scores_by_val_id:
|
|
202
|
+
raise ValueError(
|
|
203
|
+
f"frontier_type='{frontier_type}' requires objective_scores to be provided by the evaluator, "
|
|
204
|
+
f"but none were found. Use an evaluator that returns objective_scores or use frontier_type='instance'."
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Cartesian frontier will be base_evaluation.objective_scores_by_val_id
|
|
208
|
+
if frontier_type == "cartesian":
|
|
209
|
+
assert base_evaluation.objective_scores_by_val_id is not None # Already validated above
|
|
210
|
+
self.pareto_front_cartesian = {
|
|
211
|
+
(val_id, objective): objective_score
|
|
212
|
+
for val_id, objective_scores in base_evaluation.objective_scores_by_val_id.items()
|
|
213
|
+
for objective, objective_score in objective_scores.items()
|
|
214
|
+
}
|
|
215
|
+
self.program_at_pareto_front_cartesian = {
|
|
216
|
+
(val_id, objective): {0}
|
|
217
|
+
for val_id, objective_scores in base_evaluation.objective_scores_by_val_id.items()
|
|
218
|
+
for objective in objective_scores.keys()
|
|
219
|
+
}
|
|
220
|
+
else:
|
|
221
|
+
self.pareto_front_cartesian = {}
|
|
222
|
+
self.program_at_pareto_front_cartesian = {}
|
|
223
|
+
|
|
224
|
+
self.list_of_named_predictors = list(seed_candidate.keys())
|
|
225
|
+
self.named_predictor_id_to_update_next_for_program_candidate = [0]
|
|
226
|
+
self.i = -1
|
|
227
|
+
|
|
228
|
+
self.num_metric_calls_by_discovery = [0]
|
|
229
|
+
|
|
230
|
+
if track_best_outputs:
|
|
231
|
+
self.best_outputs_valset = {
|
|
232
|
+
val_id: [(0, output)] for val_id, output in base_evaluation.outputs_by_val_id.items()
|
|
233
|
+
}
|
|
234
|
+
else:
|
|
235
|
+
self.best_outputs_valset = None
|
|
236
|
+
|
|
237
|
+
self.full_program_trace = []
|
|
238
|
+
self.validation_schema_version = self._VALIDATION_SCHEMA_VERSION
|
|
239
|
+
self.evaluation_cache = evaluation_cache
|
|
240
|
+
|
|
241
|
+
def is_consistent(self) -> bool:
|
|
242
|
+
assert len(self.program_candidates) == len(self.parent_program_for_candidate)
|
|
243
|
+
assert len(self.program_candidates) == len(self.named_predictor_id_to_update_next_for_program_candidate)
|
|
244
|
+
assert len(self.program_candidates) == len(self.prog_candidate_val_subscores)
|
|
245
|
+
assert len(self.program_candidates) == len(self.prog_candidate_objective_scores)
|
|
246
|
+
assert len(self.program_candidates) == len(self.num_metric_calls_by_discovery)
|
|
247
|
+
|
|
248
|
+
assert len(self.pareto_front_valset) == len(self.program_at_pareto_front_valset)
|
|
249
|
+
assert set(self.pareto_front_valset.keys()) == set(self.program_at_pareto_front_valset.keys())
|
|
250
|
+
assert set(self.objective_pareto_front.keys()) == set(self.program_at_pareto_front_objectives.keys())
|
|
251
|
+
|
|
252
|
+
for front in self.program_at_pareto_front_valset.values():
|
|
253
|
+
for prog_idx in front:
|
|
254
|
+
assert prog_idx < len(self.program_candidates), (
|
|
255
|
+
"Program index in valset pareto front exceeds number of program candidates"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return True
|
|
259
|
+
|
|
260
|
+
def save(self, run_dir: str | None, *, use_cloudpickle: bool = False) -> None:
|
|
261
|
+
if run_dir is None:
|
|
262
|
+
return
|
|
263
|
+
with open(os.path.join(run_dir, "gepa_state.bin"), "wb") as f:
|
|
264
|
+
if use_cloudpickle:
|
|
265
|
+
import cloudpickle as pickle # type: ignore[import-not-found]
|
|
266
|
+
else:
|
|
267
|
+
import pickle
|
|
268
|
+
serialized = dict(self.__dict__.items())
|
|
269
|
+
serialized["validation_schema_version"] = GEPAState._VALIDATION_SCHEMA_VERSION
|
|
270
|
+
pickle.dump(serialized, f)
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def load(run_dir: str) -> "GEPAState[RolloutOutput, DataId]":
|
|
274
|
+
with open(os.path.join(run_dir, "gepa_state.bin"), "rb") as f:
|
|
275
|
+
import pickle
|
|
276
|
+
|
|
277
|
+
data = pickle.load(f)
|
|
278
|
+
|
|
279
|
+
# handle schema migration
|
|
280
|
+
version = data.get("validation_schema_version")
|
|
281
|
+
if version is None or version < 2:
|
|
282
|
+
GEPAState._migrate_from_legacy_state_v0(data)
|
|
283
|
+
version = data.get("validation_schema_version")
|
|
284
|
+
if version is None or version < GEPAState._VALIDATION_SCHEMA_VERSION:
|
|
285
|
+
GEPAState._upgrade_state_dict(data)
|
|
286
|
+
|
|
287
|
+
state = GEPAState.__new__(GEPAState)
|
|
288
|
+
state.__dict__.update(data)
|
|
289
|
+
|
|
290
|
+
state.validation_schema_version = GEPAState._VALIDATION_SCHEMA_VERSION
|
|
291
|
+
assert len(state.program_candidates) == len(state.prog_candidate_val_subscores)
|
|
292
|
+
assert len(state.program_candidates) == len(state.prog_candidate_objective_scores)
|
|
293
|
+
assert len(state.program_candidates) == len(state.num_metric_calls_by_discovery)
|
|
294
|
+
assert len(state.program_candidates) == len(state.parent_program_for_candidate)
|
|
295
|
+
assert len(state.program_candidates) == len(state.named_predictor_id_to_update_next_for_program_candidate)
|
|
296
|
+
assert len(state.pareto_front_valset) == len(state.program_at_pareto_front_valset)
|
|
297
|
+
assert set(state.pareto_front_valset.keys()) == set(state.program_at_pareto_front_valset.keys())
|
|
298
|
+
assert set(state.objective_pareto_front.keys()) == set(state.program_at_pareto_front_objectives.keys())
|
|
299
|
+
return state
|
|
300
|
+
|
|
301
|
+
@staticmethod
|
|
302
|
+
def _migrate_from_legacy_state_v0(d: dict[str, Any]) -> None:
|
|
303
|
+
assert isinstance(d, dict)
|
|
304
|
+
assert "prog_candidate_val_subscores" in d
|
|
305
|
+
assert isinstance(d["prog_candidate_val_subscores"], list)
|
|
306
|
+
assert all(isinstance(scores, list) for scores in d["prog_candidate_val_subscores"])
|
|
307
|
+
legacy_scores: list[list[float]] = d.pop("prog_candidate_val_subscores", [])
|
|
308
|
+
d["prog_candidate_val_subscores"] = [
|
|
309
|
+
{idx: score for idx, score in enumerate(scores)} for scores in legacy_scores
|
|
310
|
+
]
|
|
311
|
+
|
|
312
|
+
pareto_front = d.get("pareto_front_valset")
|
|
313
|
+
if isinstance(pareto_front, list):
|
|
314
|
+
d["pareto_front_valset"] = {idx: score for idx, score in enumerate(pareto_front)}
|
|
315
|
+
|
|
316
|
+
program_at_front = d.get("program_at_pareto_front_valset")
|
|
317
|
+
if isinstance(program_at_front, list):
|
|
318
|
+
d["program_at_pareto_front_valset"] = {idx: set(front) for idx, front in enumerate(program_at_front)}
|
|
319
|
+
|
|
320
|
+
best_outputs = d.get("best_outputs_valset")
|
|
321
|
+
if isinstance(best_outputs, list):
|
|
322
|
+
d["best_outputs_valset"] = {idx: list(outputs) for idx, outputs in enumerate(best_outputs)}
|
|
323
|
+
|
|
324
|
+
d["validation_schema_version"] = 2
|
|
325
|
+
|
|
326
|
+
@staticmethod
|
|
327
|
+
def _upgrade_state_dict(d: dict[str, Any]) -> None:
|
|
328
|
+
num_candidates = len(d.get("program_candidates", []))
|
|
329
|
+
if "prog_candidate_objective_scores" not in d:
|
|
330
|
+
d["prog_candidate_objective_scores"] = [{} for _ in range(num_candidates)]
|
|
331
|
+
if "objective_pareto_front" not in d:
|
|
332
|
+
d["objective_pareto_front"] = {}
|
|
333
|
+
if "program_at_pareto_front_objectives" not in d:
|
|
334
|
+
d["program_at_pareto_front_objectives"] = {}
|
|
335
|
+
if "frontier_type" not in d:
|
|
336
|
+
d["frontier_type"] = "instance"
|
|
337
|
+
# Since frontier_type instance does not require "pareto_front_cartesian" and "program_at_pareto_front_cartesian", we can safely set them to empty dicts.
|
|
338
|
+
d["pareto_front_cartesian"] = {}
|
|
339
|
+
d["program_at_pareto_front_cartesian"] = {}
|
|
340
|
+
# evaluation_cache is not persisted across runs by default; initialize to None if missing
|
|
341
|
+
if "evaluation_cache" not in d:
|
|
342
|
+
d["evaluation_cache"] = None
|
|
343
|
+
d["validation_schema_version"] = GEPAState._VALIDATION_SCHEMA_VERSION
|
|
344
|
+
|
|
345
|
+
@staticmethod
|
|
346
|
+
def _aggregate_objective_scores(
|
|
347
|
+
val_objective_scores: dict[DataId, ObjectiveScores] | None,
|
|
348
|
+
) -> ObjectiveScores:
|
|
349
|
+
if not val_objective_scores:
|
|
350
|
+
return {}
|
|
351
|
+
totals: dict[str, float] = {}
|
|
352
|
+
counts: dict[str, int] = {}
|
|
353
|
+
for objective_dict in val_objective_scores.values():
|
|
354
|
+
for objective, score in objective_dict.items():
|
|
355
|
+
totals[objective] = totals.get(objective, 0.0) + score
|
|
356
|
+
counts[objective] = counts.get(objective, 0) + 1
|
|
357
|
+
return {
|
|
358
|
+
objective: totals[objective] / counts[objective] for objective in totals.keys() if counts[objective] > 0
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
def get_program_average_val_subset(self, program_idx: int) -> tuple[float, int]:
|
|
362
|
+
# TODO: This should be only used/handled by the val_evaluation_policy, and never used directly.
|
|
363
|
+
scores = self.prog_candidate_val_subscores[program_idx]
|
|
364
|
+
if not scores:
|
|
365
|
+
return float("-inf"), 0
|
|
366
|
+
num_samples = len(scores)
|
|
367
|
+
avg = sum(scores.values()) / num_samples
|
|
368
|
+
return avg, num_samples
|
|
369
|
+
|
|
370
|
+
@property
|
|
371
|
+
def valset_evaluations(self) -> dict[DataId, list[ProgramIdx]]:
|
|
372
|
+
"""
|
|
373
|
+
Valset examples by id and programs that have evaluated them. Keys include only validation
|
|
374
|
+
ids that have been scored at least once.
|
|
375
|
+
"""
|
|
376
|
+
result: dict[DataId, list[ProgramIdx]] = defaultdict(list)
|
|
377
|
+
for program_idx, val_scores in enumerate(self.prog_candidate_val_subscores):
|
|
378
|
+
for val_id in val_scores.keys():
|
|
379
|
+
result[val_id].append(program_idx)
|
|
380
|
+
return result
|
|
381
|
+
|
|
382
|
+
@property
|
|
383
|
+
def program_full_scores_val_set(self) -> list[float]:
|
|
384
|
+
# TODO: This should be using the val_evaluation_policy instead of the get_program_average_val_subset method to calculate the scores.
|
|
385
|
+
return [
|
|
386
|
+
self.get_program_average_val_subset(program_idx)[0]
|
|
387
|
+
for program_idx in range(len(self.prog_candidate_val_subscores))
|
|
388
|
+
]
|
|
389
|
+
|
|
390
|
+
@property
|
|
391
|
+
def per_program_tracked_scores(self) -> list[float]:
|
|
392
|
+
return [
|
|
393
|
+
self.get_program_average_val_subset(program_idx)[0]
|
|
394
|
+
for program_idx in range(len(self.prog_candidate_val_subscores))
|
|
395
|
+
]
|
|
396
|
+
|
|
397
|
+
def _update_objective_pareto_front(self, objective_scores: ObjectiveScores, program_idx: ProgramIdx) -> None:
|
|
398
|
+
if not objective_scores:
|
|
399
|
+
return
|
|
400
|
+
for objective, score in objective_scores.items():
|
|
401
|
+
prev_score = self.objective_pareto_front.get(objective, float("-inf"))
|
|
402
|
+
if score > prev_score:
|
|
403
|
+
self.objective_pareto_front[objective] = score
|
|
404
|
+
self.program_at_pareto_front_objectives[objective] = {program_idx}
|
|
405
|
+
elif score == prev_score:
|
|
406
|
+
front = self.program_at_pareto_front_objectives.setdefault(objective, set())
|
|
407
|
+
front.add(program_idx)
|
|
408
|
+
|
|
409
|
+
def _update_pareto_front_for_val_id(
|
|
410
|
+
self,
|
|
411
|
+
val_id: DataId,
|
|
412
|
+
score: float,
|
|
413
|
+
program_idx: ProgramIdx,
|
|
414
|
+
output: RolloutOutput | None,
|
|
415
|
+
run_dir: str | None,
|
|
416
|
+
iteration: int,
|
|
417
|
+
) -> None:
|
|
418
|
+
prev_score = self.pareto_front_valset.get(val_id, float("-inf"))
|
|
419
|
+
if score > prev_score:
|
|
420
|
+
self.pareto_front_valset[val_id] = score
|
|
421
|
+
self.program_at_pareto_front_valset[val_id] = {program_idx}
|
|
422
|
+
if self.best_outputs_valset is not None and output is not None:
|
|
423
|
+
self.best_outputs_valset[val_id] = [(program_idx, output)]
|
|
424
|
+
if run_dir is not None:
|
|
425
|
+
task_dir = os.path.join(run_dir, "generated_best_outputs_valset", f"task_{val_id}")
|
|
426
|
+
os.makedirs(task_dir, exist_ok=True)
|
|
427
|
+
with open(os.path.join(task_dir, f"iter_{iteration}_prog_{program_idx}.json"), "w") as fout:
|
|
428
|
+
json.dump(output, fout, indent=4, default=json_default)
|
|
429
|
+
elif score == prev_score:
|
|
430
|
+
pareto_front = self.program_at_pareto_front_valset.setdefault(val_id, set())
|
|
431
|
+
pareto_front.add(program_idx)
|
|
432
|
+
if self.best_outputs_valset is not None and output is not None:
|
|
433
|
+
self.best_outputs_valset[val_id].append((program_idx, output))
|
|
434
|
+
|
|
435
|
+
def _update_pareto_front_for_cartesian(
|
|
436
|
+
self,
|
|
437
|
+
val_id: DataId,
|
|
438
|
+
objective: str,
|
|
439
|
+
objective_score: float,
|
|
440
|
+
program_idx: ProgramIdx,
|
|
441
|
+
) -> None:
|
|
442
|
+
prev_score = self.pareto_front_cartesian.get((val_id, objective), float("-inf"))
|
|
443
|
+
if objective_score > prev_score:
|
|
444
|
+
self.pareto_front_cartesian[(val_id, objective)] = objective_score
|
|
445
|
+
self.program_at_pareto_front_cartesian[(val_id, objective)] = {program_idx}
|
|
446
|
+
elif objective_score == prev_score:
|
|
447
|
+
front = self.program_at_pareto_front_cartesian.setdefault((val_id, objective), set())
|
|
448
|
+
front.add(program_idx)
|
|
449
|
+
|
|
450
|
+
def update_state_with_new_program(
|
|
451
|
+
self,
|
|
452
|
+
parent_program_idx: list[ProgramIdx],
|
|
453
|
+
new_program: dict[str, str],
|
|
454
|
+
valset_evaluation: ValsetEvaluation,
|
|
455
|
+
run_dir: str | None,
|
|
456
|
+
num_metric_calls_by_discovery_of_new_program: int,
|
|
457
|
+
) -> ProgramIdx:
|
|
458
|
+
new_program_idx = len(self.program_candidates)
|
|
459
|
+
self.program_candidates.append(dict(new_program))
|
|
460
|
+
self.num_metric_calls_by_discovery.append(num_metric_calls_by_discovery_of_new_program)
|
|
461
|
+
|
|
462
|
+
max_predictor_id = max(
|
|
463
|
+
[self.named_predictor_id_to_update_next_for_program_candidate[p] for p in parent_program_idx],
|
|
464
|
+
default=0,
|
|
465
|
+
)
|
|
466
|
+
self.named_predictor_id_to_update_next_for_program_candidate.append(max_predictor_id)
|
|
467
|
+
self.parent_program_for_candidate.append(list(parent_program_idx))
|
|
468
|
+
|
|
469
|
+
valset_scores = dict(valset_evaluation.scores_by_val_id)
|
|
470
|
+
self.prog_candidate_val_subscores.append(valset_scores)
|
|
471
|
+
objective_scores = self._aggregate_objective_scores(valset_evaluation.objective_scores_by_val_id)
|
|
472
|
+
self.prog_candidate_objective_scores.append(objective_scores)
|
|
473
|
+
|
|
474
|
+
for val_id, score in valset_scores.items():
|
|
475
|
+
output = valset_evaluation.outputs_by_val_id.get(val_id) if valset_evaluation.outputs_by_val_id else None
|
|
476
|
+
self._update_pareto_front_for_val_id(
|
|
477
|
+
val_id,
|
|
478
|
+
score,
|
|
479
|
+
new_program_idx,
|
|
480
|
+
output,
|
|
481
|
+
run_dir,
|
|
482
|
+
self.i + 1,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
self._update_objective_pareto_front(objective_scores, new_program_idx)
|
|
486
|
+
|
|
487
|
+
if self.frontier_type in ("objective", "hybrid", "cartesian"):
|
|
488
|
+
if not valset_evaluation.objective_scores_by_val_id:
|
|
489
|
+
raise ValueError(
|
|
490
|
+
f"frontier_type='{self.frontier_type}' requires objective_scores to be provided by the evaluator, "
|
|
491
|
+
f"but none were found in the evaluation result."
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
if self.frontier_type == "cartesian":
|
|
495
|
+
assert valset_evaluation.objective_scores_by_val_id is not None # Validated above
|
|
496
|
+
for val_id, objective_scores in valset_evaluation.objective_scores_by_val_id.items():
|
|
497
|
+
for objective, objective_score in objective_scores.items():
|
|
498
|
+
self._update_pareto_front_for_cartesian(
|
|
499
|
+
val_id,
|
|
500
|
+
objective,
|
|
501
|
+
objective_score,
|
|
502
|
+
new_program_idx,
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
return new_program_idx
|
|
506
|
+
|
|
507
|
+
def _get_pareto_front_mapping(self, frontier_type: FrontierType) -> dict[FrontierKey, set[ProgramIdx]]:
|
|
508
|
+
if frontier_type == "instance":
|
|
509
|
+
return {val_id: set(front) for val_id, front in self.program_at_pareto_front_valset.items()}
|
|
510
|
+
if frontier_type == "objective":
|
|
511
|
+
return {objective: set(front) for objective, front in self.program_at_pareto_front_objectives.items()}
|
|
512
|
+
if frontier_type == "hybrid":
|
|
513
|
+
combined: dict[FrontierKey, set[ProgramIdx]] = {
|
|
514
|
+
("val_id", val_id): set(front) for val_id, front in self.program_at_pareto_front_valset.items()
|
|
515
|
+
}
|
|
516
|
+
for objective, front in self.program_at_pareto_front_objectives.items():
|
|
517
|
+
combined[("objective", objective)] = set(front)
|
|
518
|
+
return combined
|
|
519
|
+
if frontier_type == "cartesian":
|
|
520
|
+
return {
|
|
521
|
+
("cartesian", val_id, objective): set(front)
|
|
522
|
+
for (val_id, objective), front in self.program_at_pareto_front_cartesian.items()
|
|
523
|
+
}
|
|
524
|
+
raise ValueError(f"Unknown frontier_type: {frontier_type}")
|
|
525
|
+
|
|
526
|
+
def get_pareto_front_mapping(self) -> dict[FrontierKey, set[ProgramIdx]]:
|
|
527
|
+
"""Return frontier key to best-program-indices mapping based on configured frontier_type."""
|
|
528
|
+
return self._get_pareto_front_mapping(self.frontier_type)
|
|
529
|
+
|
|
530
|
+
def cached_evaluate(
|
|
531
|
+
self,
|
|
532
|
+
candidate: dict[str, str],
|
|
533
|
+
example_ids: list[DataId],
|
|
534
|
+
fetcher: Callable[[list[DataId]], Any],
|
|
535
|
+
evaluator: Callable[[Any, dict[str, str]], tuple[Any, list[float], Sequence[ObjectiveScores] | None]],
|
|
536
|
+
) -> tuple[list[float], int]:
|
|
537
|
+
"""Evaluate with optional caching. Returns (scores, num_actual_evals)."""
|
|
538
|
+
_, scores_by_id, _, num_actual_evals = self.cached_evaluate_full(candidate, example_ids, fetcher, evaluator)
|
|
539
|
+
return [scores_by_id[eid] for eid in example_ids], num_actual_evals
|
|
540
|
+
|
|
541
|
+
def cached_evaluate_full(
|
|
542
|
+
self,
|
|
543
|
+
candidate: dict[str, str],
|
|
544
|
+
example_ids: list[DataId],
|
|
545
|
+
fetcher: Callable[[list[DataId]], Any],
|
|
546
|
+
evaluator: Callable[[Any, dict[str, str]], tuple[Any, list[float], Sequence[ObjectiveScores] | None]],
|
|
547
|
+
) -> tuple[dict[DataId, RolloutOutput], dict[DataId, float], dict[DataId, ObjectiveScores] | None, int]:
|
|
548
|
+
"""Evaluate with optional caching, returning full results."""
|
|
549
|
+
if self.evaluation_cache is not None:
|
|
550
|
+
return self.evaluation_cache.evaluate_with_cache_full(candidate, example_ids, fetcher, evaluator)
|
|
551
|
+
batch = fetcher(example_ids)
|
|
552
|
+
outputs, scores, objective_scores = evaluator(batch, candidate)
|
|
553
|
+
outputs_by_id = dict(zip(example_ids, outputs, strict=False))
|
|
554
|
+
scores_by_id = dict(zip(example_ids, scores, strict=False))
|
|
555
|
+
objective_by_id = dict(zip(example_ids, objective_scores, strict=False)) if objective_scores else None
|
|
556
|
+
return outputs_by_id, scores_by_id, objective_by_id, len(example_ids)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def write_eval_scores_to_directory(scores: dict[DataId, float], output_dir: str) -> None:
|
|
560
|
+
for val_id, score in scores.items():
|
|
561
|
+
task_dir = os.path.join(output_dir, f"task_{val_id}")
|
|
562
|
+
os.makedirs(task_dir, exist_ok=True)
|
|
563
|
+
with open(os.path.join(task_dir, f"iter_{0}_prog_0.json"), "w") as f:
|
|
564
|
+
json.dump(score, f, indent=4, default=json_default)
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def write_eval_outputs_to_directory(outputs, output_dir: str) -> None:
|
|
568
|
+
"""
|
|
569
|
+
Write generated rollout outputs (not scalar scores) to disk.
|
|
570
|
+
|
|
571
|
+
Structure:
|
|
572
|
+
{output_dir}/task_{val_id}/iter_0_prog_0.json
|
|
573
|
+
|
|
574
|
+
This directory is used to store best outputs for inspection/reuse.
|
|
575
|
+
"""
|
|
576
|
+
for val_id, output in outputs.items():
|
|
577
|
+
task_dir = os.path.join(output_dir, f"task_{val_id}")
|
|
578
|
+
os.makedirs(task_dir, exist_ok=True)
|
|
579
|
+
with open(os.path.join(task_dir, "iter_0_prog_0.json"), "w") as f:
|
|
580
|
+
json.dump(output, f, indent=4, default=json_default)
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def initialize_gepa_state(
|
|
584
|
+
run_dir: str | None,
|
|
585
|
+
logger: LoggerProtocol,
|
|
586
|
+
seed_candidate: dict[str, str],
|
|
587
|
+
valset_evaluator: Callable[
|
|
588
|
+
[dict[str, str]],
|
|
589
|
+
ValsetEvaluation[RolloutOutput, DataId],
|
|
590
|
+
],
|
|
591
|
+
track_best_outputs: bool = False,
|
|
592
|
+
frontier_type: FrontierType = "instance",
|
|
593
|
+
evaluation_cache: "EvaluationCache[RolloutOutput, DataId] | None" = None,
|
|
594
|
+
) -> GEPAState[RolloutOutput, DataId]:
|
|
595
|
+
if run_dir is not None and os.path.exists(os.path.join(run_dir, "gepa_state.bin")):
|
|
596
|
+
logger.log("Loading gepa state from run dir")
|
|
597
|
+
gepa_state = GEPAState.load(run_dir)
|
|
598
|
+
if gepa_state.frontier_type != frontier_type:
|
|
599
|
+
raise ValueError(
|
|
600
|
+
f"Frontier type mismatch: requested '{frontier_type}' but loaded state has '{gepa_state.frontier_type}'. "
|
|
601
|
+
f"Use a different run_dir or match the frontier_type parameter."
|
|
602
|
+
)
|
|
603
|
+
# Sync cache with current run's cache_evaluation setting:
|
|
604
|
+
# - If caching is disabled (evaluation_cache is None), clear any loaded cache
|
|
605
|
+
# to respect the current run's cache_evaluation=False setting
|
|
606
|
+
# - If caching is enabled and the loaded state has a cache, preserve it
|
|
607
|
+
# (allows resuming with cached results from previous run)
|
|
608
|
+
# - If caching is enabled but no cache exists in loaded state, use the new empty cache
|
|
609
|
+
if evaluation_cache is None:
|
|
610
|
+
gepa_state.evaluation_cache = None
|
|
611
|
+
elif gepa_state.evaluation_cache is None:
|
|
612
|
+
gepa_state.evaluation_cache = evaluation_cache
|
|
613
|
+
# else: keep the loaded cache (gepa_state.evaluation_cache is already set)
|
|
614
|
+
else:
|
|
615
|
+
num_evals_run = 0
|
|
616
|
+
|
|
617
|
+
eval_result = valset_evaluator(seed_candidate)
|
|
618
|
+
if run_dir is not None:
|
|
619
|
+
write_eval_outputs_to_directory(
|
|
620
|
+
eval_result.outputs_by_val_id, os.path.join(run_dir, "generated_best_outputs_valset")
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
num_evals_run += len(eval_result.scores_by_val_id)
|
|
624
|
+
|
|
625
|
+
gepa_state = GEPAState(
|
|
626
|
+
seed_candidate,
|
|
627
|
+
eval_result,
|
|
628
|
+
track_best_outputs=track_best_outputs,
|
|
629
|
+
frontier_type=frontier_type,
|
|
630
|
+
evaluation_cache=evaluation_cache,
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
gepa_state.num_full_ds_evals = 1
|
|
634
|
+
gepa_state.total_num_evals = num_evals_run
|
|
635
|
+
|
|
636
|
+
return gepa_state
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def init_dataset():
|
|
6
|
+
import random
|
|
7
|
+
|
|
8
|
+
from datasets import load_dataset
|
|
9
|
+
|
|
10
|
+
train_split = [
|
|
11
|
+
{"input": x["problem"], "additional_context": {"solution": x["solution"]}, "answer": "### " + str(x["answer"])}
|
|
12
|
+
for x in load_dataset("AI-MO/aimo-validation-aime")["train"]
|
|
13
|
+
]
|
|
14
|
+
random.Random(0).shuffle(train_split)
|
|
15
|
+
test_split = [
|
|
16
|
+
{"input": x["problem"], "answer": "### " + str(x["answer"])}
|
|
17
|
+
for x in load_dataset("MathArena/aime_2025")["train"]
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
trainset = train_split[: len(train_split) // 2]
|
|
21
|
+
valset = train_split[len(train_split) // 2 :]
|
|
22
|
+
testset = test_split * 5
|
|
23
|
+
|
|
24
|
+
return trainset, valset, testset
|