mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
import traceback
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from typing import Generic
|
|
7
|
+
|
|
8
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import DataInst, GEPAAdapter, RolloutOutput, Trajectory
|
|
9
|
+
from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataLoader, ensure_loader
|
|
10
|
+
from mantisdk.algorithm.gepa.lib.core.state import EvaluationCache, FrontierType, GEPAState, ValsetEvaluation, initialize_gepa_state
|
|
11
|
+
from mantisdk.algorithm.gepa.lib.logging.experiment_tracker import ExperimentTracker
|
|
12
|
+
from mantisdk.algorithm.gepa.lib.logging.logger import LoggerProtocol
|
|
13
|
+
from mantisdk.algorithm.gepa.lib.logging.utils import log_detailed_metrics_after_discovering_new_program
|
|
14
|
+
from mantisdk.algorithm.gepa.lib.proposer.merge import MergeProposer
|
|
15
|
+
from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.reflective_mutation import (
|
|
16
|
+
ReflectiveMutationProposer,
|
|
17
|
+
)
|
|
18
|
+
from mantisdk.algorithm.gepa.lib.strategies.eval_policy import EvaluationPolicy, FullEvaluationPolicy
|
|
19
|
+
from mantisdk.algorithm.gepa.lib.utils import StopperProtocol
|
|
20
|
+
|
|
21
|
+
# Import tqdm for progress bar functionality
|
|
22
|
+
try:
|
|
23
|
+
from tqdm import tqdm
|
|
24
|
+
except ImportError:
|
|
25
|
+
tqdm = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GEPAEngine(Generic[DataId, DataInst, Trajectory, RolloutOutput]):
|
|
29
|
+
"""Orchestrates the optimization loop using pluggable candidate proposers."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
adapter: GEPAAdapter[DataInst, Trajectory, RolloutOutput],
|
|
34
|
+
run_dir: str | None,
|
|
35
|
+
valset: list[DataInst] | DataLoader[DataId, DataInst] | None,
|
|
36
|
+
seed_candidate: dict[str, str],
|
|
37
|
+
# Controls
|
|
38
|
+
perfect_score: float,
|
|
39
|
+
seed: int,
|
|
40
|
+
# Strategies and helpers
|
|
41
|
+
reflective_proposer: ReflectiveMutationProposer,
|
|
42
|
+
merge_proposer: MergeProposer | None,
|
|
43
|
+
frontier_type: FrontierType,
|
|
44
|
+
# Logging
|
|
45
|
+
logger: LoggerProtocol,
|
|
46
|
+
experiment_tracker: ExperimentTracker,
|
|
47
|
+
# Optional parameters
|
|
48
|
+
track_best_outputs: bool = False,
|
|
49
|
+
display_progress_bar: bool = False,
|
|
50
|
+
raise_on_exception: bool = True,
|
|
51
|
+
use_cloudpickle: bool = False,
|
|
52
|
+
# Budget and Stop Condition
|
|
53
|
+
stop_callback: StopperProtocol | None = None,
|
|
54
|
+
val_evaluation_policy: EvaluationPolicy[DataId, DataInst] | None = None,
|
|
55
|
+
# Evaluation caching (stored in state, passed here for initialization)
|
|
56
|
+
evaluation_cache: EvaluationCache[RolloutOutput, DataId] | None = None,
|
|
57
|
+
):
|
|
58
|
+
self.logger = logger
|
|
59
|
+
self.run_dir = run_dir
|
|
60
|
+
|
|
61
|
+
# Graceful stopping mechanism
|
|
62
|
+
self._stop_requested = False
|
|
63
|
+
|
|
64
|
+
# Set up stopping mechanism
|
|
65
|
+
self.stop_callback = stop_callback
|
|
66
|
+
self.adapter = adapter
|
|
67
|
+
|
|
68
|
+
# Store cache reference for state initialization (actual cache lives in GEPAState)
|
|
69
|
+
self._initial_evaluation_cache = evaluation_cache
|
|
70
|
+
|
|
71
|
+
def evaluator(
|
|
72
|
+
batch: list[DataInst], program: dict[str, str]
|
|
73
|
+
) -> tuple[list[RolloutOutput], list[float], Sequence[dict[str, float]] | None]:
|
|
74
|
+
eval_result = adapter.evaluate(batch, program, capture_traces=False)
|
|
75
|
+
return eval_result.outputs, eval_result.scores, eval_result.objective_scores
|
|
76
|
+
|
|
77
|
+
self.evaluator = evaluator
|
|
78
|
+
|
|
79
|
+
self.valset = ensure_loader(valset) if valset is not None else None
|
|
80
|
+
self.seed_candidate = seed_candidate
|
|
81
|
+
|
|
82
|
+
self.perfect_score = perfect_score
|
|
83
|
+
self.seed = seed
|
|
84
|
+
self.experiment_tracker = experiment_tracker
|
|
85
|
+
|
|
86
|
+
self.reflective_proposer = reflective_proposer
|
|
87
|
+
self.merge_proposer = merge_proposer
|
|
88
|
+
self.frontier_type: FrontierType = frontier_type
|
|
89
|
+
|
|
90
|
+
# Merge scheduling flags (mirroring previous behavior)
|
|
91
|
+
if self.merge_proposer is not None:
|
|
92
|
+
self.merge_proposer.last_iter_found_new_program = False
|
|
93
|
+
|
|
94
|
+
self.track_best_outputs = track_best_outputs
|
|
95
|
+
self.display_progress_bar = display_progress_bar
|
|
96
|
+
self.use_cloudpickle = use_cloudpickle
|
|
97
|
+
|
|
98
|
+
self.raise_on_exception = raise_on_exception
|
|
99
|
+
self.val_evaluation_policy: EvaluationPolicy[DataId, DataInst] = (
|
|
100
|
+
val_evaluation_policy if val_evaluation_policy is not None else FullEvaluationPolicy()
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def _evaluate_on_valset(
|
|
104
|
+
self,
|
|
105
|
+
program: dict[str, str],
|
|
106
|
+
state: GEPAState[RolloutOutput, DataId],
|
|
107
|
+
) -> ValsetEvaluation[RolloutOutput, DataId]:
|
|
108
|
+
valset = self.valset
|
|
109
|
+
assert valset is not None
|
|
110
|
+
|
|
111
|
+
val_ids = self.val_evaluation_policy.get_eval_batch(valset, state)
|
|
112
|
+
|
|
113
|
+
outputs_by_val_idx, scores_by_val_idx, objective_by_val_idx, num_actual_evals = state.cached_evaluate_full(
|
|
114
|
+
program, list(val_ids), valset.fetch, self.evaluator
|
|
115
|
+
)
|
|
116
|
+
state.total_num_evals += num_actual_evals
|
|
117
|
+
|
|
118
|
+
return ValsetEvaluation(
|
|
119
|
+
outputs_by_val_id=outputs_by_val_idx,
|
|
120
|
+
scores_by_val_id=scores_by_val_idx,
|
|
121
|
+
objective_scores_by_val_id=objective_by_val_idx,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def _run_full_eval_and_add(
|
|
125
|
+
self,
|
|
126
|
+
new_program: dict[str, str],
|
|
127
|
+
state: GEPAState[RolloutOutput, DataId],
|
|
128
|
+
parent_program_idx: list[int],
|
|
129
|
+
) -> tuple[int, int]:
|
|
130
|
+
num_metric_calls_by_discovery = state.total_num_evals
|
|
131
|
+
valset_evaluation = self._evaluate_on_valset(new_program, state)
|
|
132
|
+
state.num_full_ds_evals += 1
|
|
133
|
+
|
|
134
|
+
new_program_idx = state.update_state_with_new_program(
|
|
135
|
+
parent_program_idx=parent_program_idx,
|
|
136
|
+
new_program=new_program,
|
|
137
|
+
valset_evaluation=valset_evaluation,
|
|
138
|
+
run_dir=self.run_dir,
|
|
139
|
+
num_metric_calls_by_discovery_of_new_program=num_metric_calls_by_discovery,
|
|
140
|
+
)
|
|
141
|
+
state.full_program_trace[-1]["new_program_idx"] = new_program_idx
|
|
142
|
+
state.full_program_trace[-1]["evaluated_val_indices"] = sorted(valset_evaluation.scores_by_val_id.keys())
|
|
143
|
+
|
|
144
|
+
valset_score = self.val_evaluation_policy.get_valset_score(new_program_idx, state)
|
|
145
|
+
|
|
146
|
+
linear_pareto_front_program_idx = self.val_evaluation_policy.get_best_program(state)
|
|
147
|
+
if new_program_idx == linear_pareto_front_program_idx:
|
|
148
|
+
self.logger.log(f"Iteration {state.i + 1}: Found a better program on the valset with score {valset_score}.")
|
|
149
|
+
|
|
150
|
+
valset = self.valset
|
|
151
|
+
assert valset is not None
|
|
152
|
+
|
|
153
|
+
log_detailed_metrics_after_discovering_new_program(
|
|
154
|
+
logger=self.logger,
|
|
155
|
+
gepa_state=state,
|
|
156
|
+
new_program_idx=new_program_idx,
|
|
157
|
+
valset_evaluation=valset_evaluation,
|
|
158
|
+
objective_scores=state.prog_candidate_objective_scores[new_program_idx],
|
|
159
|
+
experiment_tracker=self.experiment_tracker,
|
|
160
|
+
linear_pareto_front_program_idx=linear_pareto_front_program_idx,
|
|
161
|
+
valset_size=len(valset),
|
|
162
|
+
val_evaluation_policy=self.val_evaluation_policy,
|
|
163
|
+
)
|
|
164
|
+
return new_program_idx, linear_pareto_front_program_idx
|
|
165
|
+
|
|
166
|
+
def run(self) -> GEPAState[RolloutOutput, DataId]:
|
|
167
|
+
# Check tqdm availability if progress bar is enabled
|
|
168
|
+
progress_bar = None
|
|
169
|
+
if self.display_progress_bar:
|
|
170
|
+
if tqdm is None:
|
|
171
|
+
raise ImportError("tqdm must be installed when display_progress_bar is enabled")
|
|
172
|
+
|
|
173
|
+
# Check if stop_callback contains MaxMetricCallsStopper
|
|
174
|
+
total_calls: int | None = None
|
|
175
|
+
stop_cb = self.stop_callback
|
|
176
|
+
if stop_cb is not None:
|
|
177
|
+
max_calls_attr = getattr(stop_cb, "max_metric_calls", None)
|
|
178
|
+
if isinstance(max_calls_attr, int):
|
|
179
|
+
# Direct MaxMetricCallsStopper
|
|
180
|
+
total_calls = max_calls_attr
|
|
181
|
+
else:
|
|
182
|
+
stoppers = getattr(stop_cb, "stoppers", None)
|
|
183
|
+
if stoppers is not None:
|
|
184
|
+
# CompositeStopper - iterate to find MaxMetricCallsStopper
|
|
185
|
+
for stopper in stoppers:
|
|
186
|
+
stopper_max = getattr(stopper, "max_metric_calls", None)
|
|
187
|
+
if isinstance(stopper_max, int):
|
|
188
|
+
total_calls = stopper_max
|
|
189
|
+
break
|
|
190
|
+
|
|
191
|
+
if total_calls is not None:
|
|
192
|
+
progress_bar = tqdm(total=total_calls, desc="GEPA Optimization", unit="rollouts")
|
|
193
|
+
else:
|
|
194
|
+
progress_bar = tqdm(desc="GEPA Optimization", unit="rollouts")
|
|
195
|
+
progress_bar.update(0)
|
|
196
|
+
|
|
197
|
+
# Prepare valset
|
|
198
|
+
valset = self.valset
|
|
199
|
+
if valset is None:
|
|
200
|
+
raise ValueError("valset must be provided to GEPAEngine.run()")
|
|
201
|
+
|
|
202
|
+
def valset_evaluator(
|
|
203
|
+
program: dict[str, str],
|
|
204
|
+
) -> ValsetEvaluation[RolloutOutput, DataId]:
|
|
205
|
+
all_ids = list(valset.all_ids())
|
|
206
|
+
outputs, scores, objective_scores = self.evaluator(valset.fetch(all_ids), program)
|
|
207
|
+
outputs_dict = dict(zip(all_ids, outputs, strict=False))
|
|
208
|
+
scores_dict = dict(zip(all_ids, scores, strict=False))
|
|
209
|
+
objective_scores_dict = (
|
|
210
|
+
dict(zip(all_ids, objective_scores, strict=False)) if objective_scores is not None else None
|
|
211
|
+
)
|
|
212
|
+
return ValsetEvaluation(
|
|
213
|
+
outputs_by_val_id=outputs_dict,
|
|
214
|
+
scores_by_val_id=scores_dict,
|
|
215
|
+
objective_scores_by_val_id=objective_scores_dict,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Initialize state
|
|
219
|
+
state = initialize_gepa_state(
|
|
220
|
+
run_dir=self.run_dir,
|
|
221
|
+
logger=self.logger,
|
|
222
|
+
seed_candidate=self.seed_candidate,
|
|
223
|
+
valset_evaluator=valset_evaluator,
|
|
224
|
+
track_best_outputs=self.track_best_outputs,
|
|
225
|
+
frontier_type=self.frontier_type,
|
|
226
|
+
evaluation_cache=self._initial_evaluation_cache,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Log base program score
|
|
230
|
+
base_val_avg, base_val_coverage = state.get_program_average_val_subset(0)
|
|
231
|
+
self.experiment_tracker.log_metrics(
|
|
232
|
+
{
|
|
233
|
+
"base_program_full_valset_score": base_val_avg,
|
|
234
|
+
"base_program_val_coverage": base_val_coverage,
|
|
235
|
+
"iteration": state.i + 1,
|
|
236
|
+
},
|
|
237
|
+
step=state.i + 1,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self.logger.log(
|
|
241
|
+
f"Iteration {state.i + 1}: Base program full valset score: {base_val_avg} "
|
|
242
|
+
f"over {base_val_coverage} / {len(valset)} examples"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Merge scheduling
|
|
246
|
+
if self.merge_proposer is not None:
|
|
247
|
+
self.merge_proposer.last_iter_found_new_program = False
|
|
248
|
+
|
|
249
|
+
# Main loop
|
|
250
|
+
last_pbar_val = 0
|
|
251
|
+
while not self._should_stop(state):
|
|
252
|
+
if self.display_progress_bar and progress_bar is not None:
|
|
253
|
+
delta = state.total_num_evals - last_pbar_val
|
|
254
|
+
progress_bar.update(delta)
|
|
255
|
+
last_pbar_val = state.total_num_evals
|
|
256
|
+
|
|
257
|
+
assert state.is_consistent()
|
|
258
|
+
try:
|
|
259
|
+
state.save(self.run_dir, use_cloudpickle=self.use_cloudpickle)
|
|
260
|
+
state.i += 1
|
|
261
|
+
state.full_program_trace.append({"i": state.i})
|
|
262
|
+
|
|
263
|
+
# 1) Attempt merge first if scheduled and last iter found new program
|
|
264
|
+
if self.merge_proposer is not None and self.merge_proposer.use_merge:
|
|
265
|
+
if self.merge_proposer.merges_due > 0 and self.merge_proposer.last_iter_found_new_program:
|
|
266
|
+
proposal = self.merge_proposer.propose(state)
|
|
267
|
+
self.merge_proposer.last_iter_found_new_program = False # old behavior
|
|
268
|
+
|
|
269
|
+
if proposal is not None and proposal.tag == "merge":
|
|
270
|
+
parent_sums = proposal.subsample_scores_before or [
|
|
271
|
+
float("-inf"),
|
|
272
|
+
float("-inf"),
|
|
273
|
+
]
|
|
274
|
+
new_sum = sum(proposal.subsample_scores_after or [])
|
|
275
|
+
|
|
276
|
+
if new_sum >= max(parent_sums):
|
|
277
|
+
# ACCEPTED: consume one merge attempt and record it
|
|
278
|
+
self._run_full_eval_and_add(
|
|
279
|
+
new_program=proposal.candidate,
|
|
280
|
+
state=state,
|
|
281
|
+
parent_program_idx=proposal.parent_program_ids,
|
|
282
|
+
)
|
|
283
|
+
self.merge_proposer.merges_due -= 1
|
|
284
|
+
self.merge_proposer.total_merges_tested += 1
|
|
285
|
+
continue # skip reflective this iteration
|
|
286
|
+
else:
|
|
287
|
+
# REJECTED: do NOT consume merges_due or total_merges_tested
|
|
288
|
+
self.logger.log(
|
|
289
|
+
f"Iteration {state.i + 1}: New program subsample score {new_sum} "
|
|
290
|
+
f"is worse than both parents {parent_sums}, skipping merge"
|
|
291
|
+
)
|
|
292
|
+
# Skip reflective this iteration (old behavior)
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
# Old behavior: regardless of whether we attempted, clear the flag before reflective
|
|
296
|
+
self.merge_proposer.last_iter_found_new_program = False
|
|
297
|
+
|
|
298
|
+
# 2) Reflective mutation proposer
|
|
299
|
+
proposal = self.reflective_proposer.propose(state)
|
|
300
|
+
if proposal is None:
|
|
301
|
+
self.logger.log(f"Iteration {state.i + 1}: Reflective mutation did not propose a new candidate")
|
|
302
|
+
continue
|
|
303
|
+
|
|
304
|
+
# Acceptance: require strict improvement on subsample
|
|
305
|
+
old_sum = sum(proposal.subsample_scores_before or [])
|
|
306
|
+
new_sum = sum(proposal.subsample_scores_after or [])
|
|
307
|
+
if new_sum <= old_sum:
|
|
308
|
+
self.logger.log(
|
|
309
|
+
f"Iteration {state.i + 1}: New subsample score {new_sum} is not better than old score {old_sum}, skipping"
|
|
310
|
+
)
|
|
311
|
+
continue
|
|
312
|
+
else:
|
|
313
|
+
self.logger.log(
|
|
314
|
+
f"Iteration {state.i + 1}: New subsample score {new_sum} is better than old score {old_sum}. Continue to full eval and add to candidate pool."
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Accept: full eval + add
|
|
318
|
+
self._run_full_eval_and_add(
|
|
319
|
+
new_program=proposal.candidate,
|
|
320
|
+
state=state,
|
|
321
|
+
parent_program_idx=proposal.parent_program_ids,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Schedule merge attempts like original behavior
|
|
325
|
+
if self.merge_proposer is not None:
|
|
326
|
+
self.merge_proposer.last_iter_found_new_program = True
|
|
327
|
+
if self.merge_proposer.total_merges_tested < self.merge_proposer.max_merge_invocations:
|
|
328
|
+
self.merge_proposer.merges_due += 1
|
|
329
|
+
|
|
330
|
+
except Exception as e:
|
|
331
|
+
self.logger.log(f"Iteration {state.i + 1}: Exception during optimization: {e}")
|
|
332
|
+
self.logger.log(traceback.format_exc())
|
|
333
|
+
if self.raise_on_exception:
|
|
334
|
+
raise e
|
|
335
|
+
else:
|
|
336
|
+
continue
|
|
337
|
+
|
|
338
|
+
# Close progress bar if it exists
|
|
339
|
+
if self.display_progress_bar and progress_bar is not None:
|
|
340
|
+
progress_bar.close()
|
|
341
|
+
|
|
342
|
+
state.save(self.run_dir)
|
|
343
|
+
return state
|
|
344
|
+
|
|
345
|
+
def _should_stop(self, state: GEPAState[RolloutOutput, DataId]) -> bool:
|
|
346
|
+
"""Check if the optimization should stop."""
|
|
347
|
+
if self._stop_requested:
|
|
348
|
+
return True
|
|
349
|
+
if self.stop_callback and self.stop_callback(state):
|
|
350
|
+
return True
|
|
351
|
+
return False
|
|
352
|
+
|
|
353
|
+
def request_stop(self) -> None:
|
|
354
|
+
"""Manually request the optimization to stop gracefully."""
|
|
355
|
+
self.logger.log("Stop requested manually. Initiating graceful shutdown...")
|
|
356
|
+
self._stop_requested = True
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Generic
|
|
6
|
+
|
|
7
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import RolloutOutput
|
|
8
|
+
from mantisdk.algorithm.gepa.lib.core.data_loader import DataId
|
|
9
|
+
from mantisdk.algorithm.gepa.lib.core.state import ProgramIdx
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from mantisdk.algorithm.gepa.lib.core.state import GEPAState
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class GEPAResult(Generic[RolloutOutput, DataId]):
|
|
17
|
+
"""
|
|
18
|
+
Immutable snapshot of a GEPA run with convenience accessors.
|
|
19
|
+
|
|
20
|
+
- candidates: list of proposed candidates (component_name -> component_text)
|
|
21
|
+
- parents: lineage info; for each candidate i, parents[i] is a list of parent indices or None
|
|
22
|
+
- val_aggregate_scores: per-candidate aggregate score on the validation set (higher is better)
|
|
23
|
+
- val_subscores: per-candidate mapping from validation id to score on the validation set (sparse dict)
|
|
24
|
+
- val_aggregate_subscores: optional per-candidate aggregate subscores across objectives
|
|
25
|
+
- per_val_instance_best_candidates: for each val instance t, a set of candidate indices achieving the current best score on t
|
|
26
|
+
- per_objective_best_candidates: optional per-objective set of candidate indices achieving best aggregate subscore
|
|
27
|
+
- discovery_eval_counts: number of metric calls accumulated up to the discovery of each candidate
|
|
28
|
+
|
|
29
|
+
Optional fields:
|
|
30
|
+
- best_outputs_valset: per-task best outputs on the validation set. [task_idx -> [(program_idx_1, output_1), (program_idx_2, output_2), ...]]
|
|
31
|
+
|
|
32
|
+
Run-level metadata:
|
|
33
|
+
- total_metric_calls: total number of metric calls made across the run
|
|
34
|
+
- num_full_val_evals: number of full validation evaluations performed
|
|
35
|
+
- run_dir: where artifacts were written (if any)
|
|
36
|
+
- seed: RNG seed for reproducibility (if known)
|
|
37
|
+
- tracked_scores: optional tracked aggregate scores (if different from val_aggregate_scores)
|
|
38
|
+
|
|
39
|
+
Convenience:
|
|
40
|
+
- best_idx: candidate index with the highest val_aggregate_scores
|
|
41
|
+
- best_candidate: the program text mapping for best_idx
|
|
42
|
+
- non_dominated_indices(): candidate indices that are not dominated across per-instance pareto fronts
|
|
43
|
+
- lineage(idx): parent chain from base to idx
|
|
44
|
+
- diff(parent_idx, child_idx, only_changed=True): component-wise diff between two candidates
|
|
45
|
+
- best_k(k): top-k candidates by aggregate val score
|
|
46
|
+
- instance_winners(t): set of candidates on the pareto front for val instance t
|
|
47
|
+
- to_dict(...), save_json(...): serialization helpers
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
# Core data
|
|
51
|
+
candidates: list[dict[str, str]]
|
|
52
|
+
parents: list[list[ProgramIdx | None]]
|
|
53
|
+
val_aggregate_scores: list[float]
|
|
54
|
+
val_subscores: list[dict[DataId, float]]
|
|
55
|
+
per_val_instance_best_candidates: dict[DataId, set[ProgramIdx]]
|
|
56
|
+
discovery_eval_counts: list[int]
|
|
57
|
+
val_aggregate_subscores: list[dict[str, float]] | None = None
|
|
58
|
+
per_objective_best_candidates: dict[str, set[ProgramIdx]] | None = None
|
|
59
|
+
objective_pareto_front: dict[str, float] | None = None
|
|
60
|
+
|
|
61
|
+
# Optional data
|
|
62
|
+
best_outputs_valset: dict[DataId, list[tuple[ProgramIdx, RolloutOutput]]] | None = None
|
|
63
|
+
|
|
64
|
+
# Run metadata (optional)
|
|
65
|
+
total_metric_calls: int | None = None
|
|
66
|
+
num_full_val_evals: int | None = None
|
|
67
|
+
run_dir: str | None = None
|
|
68
|
+
seed: int | None = None
|
|
69
|
+
|
|
70
|
+
_VALIDATION_SCHEMA_VERSION: ClassVar[int] = 2
|
|
71
|
+
|
|
72
|
+
# -------- Convenience properties --------
|
|
73
|
+
@property
|
|
74
|
+
def num_candidates(self) -> int:
|
|
75
|
+
return len(self.candidates)
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def num_val_instances(self) -> int:
|
|
79
|
+
return len(self.per_val_instance_best_candidates)
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def best_idx(self) -> int:
|
|
83
|
+
scores = self.val_aggregate_scores
|
|
84
|
+
return max(range(len(scores)), key=lambda i: scores[i])
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def best_candidate(self) -> dict[str, str]:
|
|
88
|
+
return self.candidates[self.best_idx]
|
|
89
|
+
|
|
90
|
+
def to_dict(self) -> dict[str, Any]:
|
|
91
|
+
cands = [dict(cand.items()) for cand in self.candidates]
|
|
92
|
+
|
|
93
|
+
return {
|
|
94
|
+
"candidates": cands,
|
|
95
|
+
"parents": self.parents,
|
|
96
|
+
"val_aggregate_scores": self.val_aggregate_scores,
|
|
97
|
+
"val_subscores": self.val_subscores,
|
|
98
|
+
"best_outputs_valset": self.best_outputs_valset,
|
|
99
|
+
"per_val_instance_best_candidates": {
|
|
100
|
+
val_id: list(front) for val_id, front in self.per_val_instance_best_candidates.items()
|
|
101
|
+
},
|
|
102
|
+
"val_aggregate_subscores": self.val_aggregate_subscores,
|
|
103
|
+
"per_objective_best_candidates": (
|
|
104
|
+
{k: list(v) for k, v in self.per_objective_best_candidates.items()}
|
|
105
|
+
if self.per_objective_best_candidates is not None
|
|
106
|
+
else None
|
|
107
|
+
),
|
|
108
|
+
"objective_pareto_front": self.objective_pareto_front,
|
|
109
|
+
"discovery_eval_counts": self.discovery_eval_counts,
|
|
110
|
+
"total_metric_calls": self.total_metric_calls,
|
|
111
|
+
"num_full_val_evals": self.num_full_val_evals,
|
|
112
|
+
"run_dir": self.run_dir,
|
|
113
|
+
"seed": self.seed,
|
|
114
|
+
"best_idx": self.best_idx,
|
|
115
|
+
"validation_schema_version": GEPAResult._VALIDATION_SCHEMA_VERSION,
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def from_dict(d: dict[str, Any]) -> "GEPAResult[RolloutOutput, DataId]":
|
|
120
|
+
version = d.get("validation_schema_version") or 0
|
|
121
|
+
if version > GEPAResult._VALIDATION_SCHEMA_VERSION:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Unsupported GEPAResult validation schema version {version}; "
|
|
124
|
+
f"max supported is {GEPAResult._VALIDATION_SCHEMA_VERSION}"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if version <= 1:
|
|
128
|
+
return GEPAResult._migrate_from_dict_v0(d)
|
|
129
|
+
|
|
130
|
+
return GEPAResult._from_dict_v2(d)
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def _common_kwargs_from_dict(d: dict[str, Any]) -> dict[str, Any]:
|
|
134
|
+
return {
|
|
135
|
+
"candidates": [dict(candidate) for candidate in d.get("candidates", [])],
|
|
136
|
+
"parents": [list(parent_row) for parent_row in d.get("parents", [])],
|
|
137
|
+
"val_aggregate_scores": list(d.get("val_aggregate_scores", [])),
|
|
138
|
+
"discovery_eval_counts": list(d.get("discovery_eval_counts", [])),
|
|
139
|
+
"total_metric_calls": d.get("total_metric_calls"),
|
|
140
|
+
"num_full_val_evals": d.get("num_full_val_evals"),
|
|
141
|
+
"run_dir": d.get("run_dir"),
|
|
142
|
+
"seed": d.get("seed"),
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def _migrate_from_dict_v0(d: dict[str, Any]) -> "GEPAResult[RolloutOutput, DataId]":
|
|
147
|
+
kwargs = GEPAResult._common_kwargs_from_dict(d)
|
|
148
|
+
kwargs["val_subscores"] = [
|
|
149
|
+
{idx: score for idx, score in enumerate(scores)} for scores in d.get("val_subscores", [])
|
|
150
|
+
]
|
|
151
|
+
kwargs["per_val_instance_best_candidates"] = {
|
|
152
|
+
idx: set(front) for idx, front in enumerate(d.get("per_val_instance_best_candidates", []))
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
best_outputs_valset = d.get("best_outputs_valset")
|
|
156
|
+
if best_outputs_valset is not None:
|
|
157
|
+
kwargs["best_outputs_valset"] = {
|
|
158
|
+
idx: [(program_idx, output) for program_idx, output in outputs]
|
|
159
|
+
for idx, outputs in enumerate(best_outputs_valset)
|
|
160
|
+
}
|
|
161
|
+
else:
|
|
162
|
+
kwargs["best_outputs_valset"] = None
|
|
163
|
+
return GEPAResult(**kwargs)
|
|
164
|
+
|
|
165
|
+
@staticmethod
|
|
166
|
+
def _from_dict_v2(d: dict[str, Any]) -> "GEPAResult[RolloutOutput, DataId]":
|
|
167
|
+
kwargs = GEPAResult._common_kwargs_from_dict(d)
|
|
168
|
+
kwargs["val_subscores"] = [dict(scores) for scores in d.get("val_subscores", [])]
|
|
169
|
+
per_val_instance_best_candidates_data = d.get("per_val_instance_best_candidates", {})
|
|
170
|
+
kwargs["per_val_instance_best_candidates"] = {
|
|
171
|
+
val_id: set(candidates_on_front)
|
|
172
|
+
for val_id, candidates_on_front in per_val_instance_best_candidates_data.items()
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
best_outputs_valset = d.get("best_outputs_valset")
|
|
176
|
+
if best_outputs_valset is not None:
|
|
177
|
+
kwargs["best_outputs_valset"] = {
|
|
178
|
+
val_id: [(program_idx, output) for program_idx, output in outputs]
|
|
179
|
+
for val_id, outputs in best_outputs_valset.items()
|
|
180
|
+
}
|
|
181
|
+
else:
|
|
182
|
+
kwargs["best_outputs_valset"] = None
|
|
183
|
+
|
|
184
|
+
val_aggregate_subscores = d.get("val_aggregate_subscores")
|
|
185
|
+
kwargs["val_aggregate_subscores"] = (
|
|
186
|
+
[dict(scores) for scores in val_aggregate_subscores] if val_aggregate_subscores is not None else None
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
per_objective_best_candidates = d.get("per_objective_best_candidates")
|
|
190
|
+
if per_objective_best_candidates is not None:
|
|
191
|
+
kwargs["per_objective_best_candidates"] = {
|
|
192
|
+
objective: set(program_indices) for objective, program_indices in per_objective_best_candidates.items()
|
|
193
|
+
}
|
|
194
|
+
else:
|
|
195
|
+
kwargs["per_objective_best_candidates"] = None
|
|
196
|
+
|
|
197
|
+
objective_pareto_front = d.get("objective_pareto_front")
|
|
198
|
+
kwargs["objective_pareto_front"] = dict(objective_pareto_front) if objective_pareto_front is not None else None
|
|
199
|
+
|
|
200
|
+
return GEPAResult(**kwargs)
|
|
201
|
+
|
|
202
|
+
@staticmethod
|
|
203
|
+
def from_state(
|
|
204
|
+
state: "GEPAState[RolloutOutput, DataId]",
|
|
205
|
+
run_dir: str | None = None,
|
|
206
|
+
seed: int | None = None,
|
|
207
|
+
) -> "GEPAResult[RolloutOutput, DataId]":
|
|
208
|
+
"""Build a GEPAResult from a GEPAState."""
|
|
209
|
+
objective_scores_list = [dict(scores) for scores in state.prog_candidate_objective_scores]
|
|
210
|
+
has_objective_scores = any(obj for obj in objective_scores_list)
|
|
211
|
+
per_objective_best = {
|
|
212
|
+
objective: set(front) for objective, front in state.program_at_pareto_front_objectives.items()
|
|
213
|
+
}
|
|
214
|
+
objective_front = dict(state.objective_pareto_front)
|
|
215
|
+
|
|
216
|
+
return GEPAResult(
|
|
217
|
+
candidates=list(state.program_candidates),
|
|
218
|
+
parents=list(state.parent_program_for_candidate),
|
|
219
|
+
val_aggregate_scores=list(state.program_full_scores_val_set),
|
|
220
|
+
best_outputs_valset=getattr(state, "best_outputs_valset", None),
|
|
221
|
+
val_subscores=[dict(scores) for scores in state.prog_candidate_val_subscores],
|
|
222
|
+
per_val_instance_best_candidates={
|
|
223
|
+
val_id: set(front) for val_id, front in state.program_at_pareto_front_valset.items()
|
|
224
|
+
},
|
|
225
|
+
val_aggregate_subscores=(objective_scores_list if has_objective_scores else None),
|
|
226
|
+
per_objective_best_candidates=(per_objective_best if per_objective_best else None),
|
|
227
|
+
objective_pareto_front=objective_front if objective_front else None,
|
|
228
|
+
discovery_eval_counts=list(state.num_metric_calls_by_discovery),
|
|
229
|
+
total_metric_calls=getattr(state, "total_num_evals", None),
|
|
230
|
+
num_full_val_evals=getattr(state, "num_full_ds_evals", None),
|
|
231
|
+
run_dir=run_dir,
|
|
232
|
+
seed=seed,
|
|
233
|
+
)
|