dspy 0.28.1 → 0.29.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +2 -3
- data/lib/dspy/callbacks.rb +222 -0
- data/lib/dspy/chain_of_thought.rb +2 -1
- data/lib/dspy/code_act.rb +14 -1
- data/lib/dspy/datasets/ade.rb +90 -0
- data/lib/dspy/datasets.rb +8 -0
- data/lib/dspy/lm.rb +9 -12
- data/lib/dspy/mixins/struct_builder.rb +17 -25
- data/lib/dspy/module.rb +45 -1
- data/lib/dspy/observability/async_span_processor.rb +67 -93
- data/lib/dspy/observability.rb +43 -1
- data/lib/dspy/predict.rb +17 -0
- data/lib/dspy/prompt.rb +90 -20
- data/lib/dspy/propose/dataset_summary_generator.rb +210 -0
- data/lib/dspy/propose/grounded_proposer.rb +320 -66
- data/lib/dspy/re_act.rb +13 -0
- data/lib/dspy/reflection_lm.rb +36 -0
- data/lib/dspy/teleprompt/bootstrap_strategy.rb +26 -0
- data/lib/dspy/teleprompt/gepa.rb +448 -2803
- data/lib/dspy/teleprompt/mipro_v2.rb +624 -100
- data/lib/dspy/teleprompt/utils.rb +349 -42
- data/lib/dspy/version.rb +2 -2
- data/lib/dspy.rb +4 -2
- data/lib/gepa/api.rb +61 -0
- data/lib/gepa/core/engine.rb +226 -0
- data/lib/gepa/core/evaluation_batch.rb +26 -0
- data/lib/gepa/core/result.rb +92 -0
- data/lib/gepa/core/state.rb +231 -0
- data/lib/gepa/logging/experiment_tracker.rb +54 -0
- data/lib/gepa/logging/logger.rb +57 -0
- data/lib/gepa/logging.rb +9 -0
- data/lib/gepa/proposer/base.rb +27 -0
- data/lib/gepa/proposer/merge_proposer.rb +424 -0
- data/lib/gepa/proposer/reflective_mutation/base.rb +48 -0
- data/lib/gepa/proposer/reflective_mutation/reflective_mutation.rb +188 -0
- data/lib/gepa/strategies/batch_sampler.rb +91 -0
- data/lib/gepa/strategies/candidate_selector.rb +97 -0
- data/lib/gepa/strategies/component_selector.rb +57 -0
- data/lib/gepa/strategies/instruction_proposal.rb +120 -0
- data/lib/gepa/telemetry.rb +122 -0
- data/lib/gepa/utils/pareto.rb +119 -0
- data/lib/gepa.rb +21 -0
- metadata +59 -4
- data/lib/dspy/teleprompt/simple_optimizer.rb +0 -497
@@ -0,0 +1,226 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
|
5
|
+
require_relative 'state'
|
6
|
+
require_relative 'result'
|
7
|
+
require_relative '../telemetry'
|
8
|
+
|
9
|
+
module GEPA
|
10
|
+
module Core
|
11
|
+
class Engine
|
12
|
+
extend T::Sig
|
13
|
+
|
14
|
+
sig do
|
15
|
+
params(
|
16
|
+
evaluator: T.proc.params(dataset: T::Array[T.untyped], candidate: T::Hash[String, String])
|
17
|
+
.returns([T::Array[T.untyped], T::Array[Float]]),
|
18
|
+
valset: T::Array[T.untyped],
|
19
|
+
seed_candidate: T::Hash[String, String],
|
20
|
+
max_metric_calls: Integer,
|
21
|
+
perfect_score: Float,
|
22
|
+
seed: Integer,
|
23
|
+
reflective_proposer: T.untyped,
|
24
|
+
logger: T.untyped,
|
25
|
+
experiment_tracker: T.untyped,
|
26
|
+
merge_proposer: T.nilable(T.untyped),
|
27
|
+
run_dir: T.nilable(String),
|
28
|
+
track_best_outputs: T::Boolean,
|
29
|
+
display_progress_bar: T::Boolean,
|
30
|
+
telemetry: T.nilable(T.untyped),
|
31
|
+
raise_on_exception: T::Boolean
|
32
|
+
).void
|
33
|
+
end
|
34
|
+
def initialize(
|
35
|
+
evaluator:,
|
36
|
+
valset:,
|
37
|
+
seed_candidate:,
|
38
|
+
max_metric_calls:,
|
39
|
+
perfect_score:,
|
40
|
+
seed:, # rubocop:disable Lint/UnusedMethodArgument -- kept for parity and future use
|
41
|
+
reflective_proposer:,
|
42
|
+
logger:,
|
43
|
+
experiment_tracker:,
|
44
|
+
merge_proposer: nil,
|
45
|
+
run_dir: nil,
|
46
|
+
track_best_outputs: false,
|
47
|
+
display_progress_bar: false,
|
48
|
+
telemetry: nil,
|
49
|
+
raise_on_exception: true
|
50
|
+
)
|
51
|
+
@run_dir = run_dir
|
52
|
+
@evaluator = evaluator
|
53
|
+
@valset = valset
|
54
|
+
@seed_candidate = seed_candidate
|
55
|
+
@max_metric_calls = max_metric_calls
|
56
|
+
@perfect_score = perfect_score
|
57
|
+
@reflective_proposer = reflective_proposer
|
58
|
+
@merge_proposer = merge_proposer
|
59
|
+
@logger = logger
|
60
|
+
@experiment_tracker = experiment_tracker
|
61
|
+
@track_best_outputs = track_best_outputs
|
62
|
+
@display_progress_bar = display_progress_bar
|
63
|
+
@telemetry = telemetry || GEPA::Telemetry
|
64
|
+
@raise_on_exception = raise_on_exception
|
65
|
+
end
|
66
|
+
|
67
|
+
sig { returns(GEPA::Core::State) }
|
68
|
+
def run
|
69
|
+
with_span('gepa.engine.run', max_metric_calls: @max_metric_calls) do
|
70
|
+
state = GEPA::Core::State.initialize_gepa_state(
|
71
|
+
run_dir: @run_dir,
|
72
|
+
logger: @logger,
|
73
|
+
seed_candidate: @seed_candidate,
|
74
|
+
valset_evaluator: ->(candidate) { full_evaluator(candidate) },
|
75
|
+
track_best_outputs: @track_best_outputs
|
76
|
+
)
|
77
|
+
|
78
|
+
@experiment_tracker.log_metrics({ base_program_full_valset_score: state.program_full_scores_val_set.first }, step: 0)
|
79
|
+
|
80
|
+
if @merge_proposer
|
81
|
+
@merge_proposer.last_iter_found_new_program = false
|
82
|
+
end
|
83
|
+
|
84
|
+
while state.total_num_evals < @max_metric_calls
|
85
|
+
break unless iteration_step(state)
|
86
|
+
end
|
87
|
+
|
88
|
+
state.save(@run_dir)
|
89
|
+
state
|
90
|
+
end
|
91
|
+
end
|
92
|
+
|
93
|
+
private
|
94
|
+
|
95
|
+
sig { params(state: GEPA::Core::State).returns(T::Boolean) }
|
96
|
+
def iteration_step(state)
|
97
|
+
state.i += 1
|
98
|
+
trace_entry = { iteration: state.i }
|
99
|
+
state.full_program_trace << trace_entry
|
100
|
+
|
101
|
+
progress = false
|
102
|
+
|
103
|
+
with_span('gepa.engine.iteration', iteration: state.i) do
|
104
|
+
merge_result = process_merge_iteration(state)
|
105
|
+
case merge_result
|
106
|
+
when :accepted
|
107
|
+
return true
|
108
|
+
when :attempted
|
109
|
+
return false
|
110
|
+
end
|
111
|
+
|
112
|
+
reflective_result = process_reflective_iteration(state)
|
113
|
+
return false if reflective_result == :no_candidate
|
114
|
+
progress = true if reflective_result == :accepted
|
115
|
+
end
|
116
|
+
|
117
|
+
progress
|
118
|
+
rescue StandardError => e
|
119
|
+
@logger.log("Iteration #{state.i}: Exception during optimization: #{e}")
|
120
|
+
@logger.log(e.backtrace&.join("\n"))
|
121
|
+
raise e if @raise_on_exception
|
122
|
+
true
|
123
|
+
end
|
124
|
+
|
125
|
+
sig { params(state: GEPA::Core::State).returns(Symbol) }
|
126
|
+
def process_merge_iteration(state)
|
127
|
+
return :skipped unless @merge_proposer && @merge_proposer.use_merge
|
128
|
+
|
129
|
+
if @merge_proposer.merges_due.positive? && @merge_proposer.last_iter_found_new_program
|
130
|
+
proposal = @merge_proposer.propose(state)
|
131
|
+
@merge_proposer.last_iter_found_new_program = false
|
132
|
+
|
133
|
+
if proposal&.tag == 'merge'
|
134
|
+
parent_sums = Array(proposal.subsample_scores_before).map(&:to_f)
|
135
|
+
new_sum = Array(proposal.subsample_scores_after).map(&:to_f).sum
|
136
|
+
|
137
|
+
if parent_sums.empty?
|
138
|
+
@logger.log("Iteration #{state.i}: Missing parent subscores for merge proposal, skipping")
|
139
|
+
return :handled
|
140
|
+
end
|
141
|
+
|
142
|
+
if new_sum >= parent_sums.max
|
143
|
+
with_span('gepa.engine.full_evaluation', iteration: state.i) do
|
144
|
+
run_full_evaluation(state, proposal.candidate, proposal.parent_program_ids)
|
145
|
+
end
|
146
|
+
@merge_proposer.merges_due -= 1
|
147
|
+
@merge_proposer.total_merges_tested += 1
|
148
|
+
return :accepted
|
149
|
+
else
|
150
|
+
@logger.log(
|
151
|
+
"Iteration #{state.i}: Merge subsample score #{new_sum.round(4)} "\
|
152
|
+
"did not beat parents #{parent_sums.map { |v| v.round(4) }}, skipping"
|
153
|
+
)
|
154
|
+
return :attempted
|
155
|
+
end
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
@merge_proposer.last_iter_found_new_program = false
|
160
|
+
:skipped
|
161
|
+
end
|
162
|
+
|
163
|
+
sig { params(state: GEPA::Core::State).void }
|
164
|
+
def process_reflective_iteration(state)
|
165
|
+
proposal = @reflective_proposer.propose(state)
|
166
|
+
unless proposal
|
167
|
+
@logger.log("Iteration #{state.i}: Reflective mutation did not propose a new candidate")
|
168
|
+
return :no_candidate
|
169
|
+
end
|
170
|
+
|
171
|
+
before = Array(proposal.subsample_scores_before).map(&:to_f)
|
172
|
+
after = Array(proposal.subsample_scores_after).map(&:to_f)
|
173
|
+
if after.empty? || after.sum <= before.sum
|
174
|
+
@logger.log("Iteration #{state.i}: New subsample score is not better, skipping")
|
175
|
+
return :skipped
|
176
|
+
end
|
177
|
+
|
178
|
+
with_span('gepa.engine.full_evaluation', iteration: state.i) do
|
179
|
+
run_full_evaluation(state, proposal.candidate, proposal.parent_program_ids)
|
180
|
+
end
|
181
|
+
|
182
|
+
if @merge_proposer&.use_merge
|
183
|
+
@merge_proposer.last_iter_found_new_program = true
|
184
|
+
@merge_proposer.schedule_if_needed
|
185
|
+
end
|
186
|
+
|
187
|
+
:accepted
|
188
|
+
end
|
189
|
+
|
190
|
+
sig do
|
191
|
+
params(state: GEPA::Core::State, new_program: T::Hash[String, String], parents: T::Array[Integer]).void
|
192
|
+
end
|
193
|
+
def run_full_evaluation(state, new_program, parents)
|
194
|
+
outputs, scores = full_evaluator(new_program)
|
195
|
+
avg_score = scores.sum / scores.length.to_f
|
196
|
+
|
197
|
+
state.num_full_ds_evals += 1
|
198
|
+
state.total_num_evals += scores.length
|
199
|
+
|
200
|
+
state.update_state_with_new_program(
|
201
|
+
parents,
|
202
|
+
new_program,
|
203
|
+
avg_score,
|
204
|
+
outputs,
|
205
|
+
scores,
|
206
|
+
@run_dir,
|
207
|
+
state.total_num_evals
|
208
|
+
)
|
209
|
+
|
210
|
+
@experiment_tracker.log_metrics({ new_program_full_score: avg_score }, step: state.i)
|
211
|
+
end
|
212
|
+
|
213
|
+
sig { params(candidate: T::Hash[String, String]).returns([T::Array[T.untyped], T::Array[Float]]) }
|
214
|
+
def full_evaluator(candidate)
|
215
|
+
@evaluator.call(@valset, candidate)
|
216
|
+
end
|
217
|
+
|
218
|
+
sig do
|
219
|
+
params(operation: String, attrs: T::Hash[Symbol, T.untyped], block: T.proc.returns(T.untyped)).returns(T.untyped)
|
220
|
+
end
|
221
|
+
def with_span(operation, attrs = {}, &block)
|
222
|
+
@telemetry.with_span(operation, attrs, &block)
|
223
|
+
end
|
224
|
+
end
|
225
|
+
end
|
226
|
+
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
|
5
|
+
module GEPA
|
6
|
+
module Core
|
7
|
+
# Container for evaluating a candidate on a batch.
|
8
|
+
class EvaluationBatch < T::Struct
|
9
|
+
extend T::Sig
|
10
|
+
|
11
|
+
const :outputs, T::Array[T.untyped]
|
12
|
+
const :scores, T::Array[Float]
|
13
|
+
const :trajectories, T.nilable(T::Array[T.untyped])
|
14
|
+
|
15
|
+
sig { override.params(args: T.untyped, kwargs: T.untyped).void }
|
16
|
+
def initialize(*args, **kwargs)
|
17
|
+
super
|
18
|
+
raise ArgumentError, 'outputs and scores length mismatch' unless outputs.length == scores.length
|
19
|
+
|
20
|
+
if trajectories
|
21
|
+
raise ArgumentError, 'trajectories length mismatch' unless trajectories.length == outputs.length
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'json'
|
4
|
+
require 'set'
|
5
|
+
require 'sorbet-runtime'
|
6
|
+
|
7
|
+
module GEPA
|
8
|
+
module Core
|
9
|
+
# Snapshot of GEPA optimization output with helpers for common queries.
|
10
|
+
class Result < T::Struct
|
11
|
+
extend T::Sig
|
12
|
+
|
13
|
+
const :candidates, T::Array[T::Hash[String, String]]
|
14
|
+
const :parents, T::Array[T::Array[T.nilable(Integer)]]
|
15
|
+
const :val_aggregate_scores, T::Array[Float]
|
16
|
+
const :val_subscores, T::Array[T::Array[Float]]
|
17
|
+
const :per_val_instance_best_candidates, T::Array[T::Array[Integer]]
|
18
|
+
const :discovery_eval_counts, T::Array[Integer]
|
19
|
+
const :best_outputs_valset, T.nilable(T::Array[T::Array[T::Array[T.untyped]]]), default: nil
|
20
|
+
const :total_metric_calls, T.nilable(Integer), default: nil
|
21
|
+
const :num_full_val_evals, T.nilable(Integer), default: nil
|
22
|
+
const :run_dir, T.nilable(String), default: nil
|
23
|
+
const :seed, T.nilable(Integer), default: nil
|
24
|
+
|
25
|
+
sig { returns(Integer) }
|
26
|
+
def num_candidates
|
27
|
+
candidates.length
|
28
|
+
end
|
29
|
+
|
30
|
+
sig { returns(Integer) }
|
31
|
+
def num_val_instances
|
32
|
+
per_val_instance_best_candidates.length
|
33
|
+
end
|
34
|
+
|
35
|
+
sig { returns(Integer) }
|
36
|
+
def best_idx
|
37
|
+
val_aggregate_scores.each_with_index.max_by { |score, _i| score }&.last || 0
|
38
|
+
end
|
39
|
+
|
40
|
+
sig { returns(T::Hash[String, String]) }
|
41
|
+
def best_candidate
|
42
|
+
candidates.fetch(best_idx)
|
43
|
+
end
|
44
|
+
|
45
|
+
sig { returns(T::Hash[Symbol, T.untyped]) }
|
46
|
+
def to_h
|
47
|
+
{
|
48
|
+
candidates: candidates.map(&:dup),
|
49
|
+
parents: parents.map(&:dup),
|
50
|
+
val_aggregate_scores: val_aggregate_scores.dup,
|
51
|
+
val_subscores: val_subscores.map(&:dup),
|
52
|
+
best_outputs_valset: best_outputs_valset&.map { |arr| arr.map(&:dup) },
|
53
|
+
per_val_instance_best_candidates: per_val_instance_best_candidates.map(&:dup),
|
54
|
+
discovery_eval_counts: discovery_eval_counts.dup,
|
55
|
+
total_metric_calls: total_metric_calls,
|
56
|
+
num_full_val_evals: num_full_val_evals,
|
57
|
+
run_dir: run_dir,
|
58
|
+
seed: seed,
|
59
|
+
best_idx: best_idx
|
60
|
+
}
|
61
|
+
end
|
62
|
+
|
63
|
+
sig { returns(String) }
|
64
|
+
def to_json(*_args)
|
65
|
+
JSON.pretty_generate(to_h)
|
66
|
+
end
|
67
|
+
|
68
|
+
sig do
|
69
|
+
params(
|
70
|
+
state: T.untyped,
|
71
|
+
run_dir: T.nilable(String),
|
72
|
+
seed: T.nilable(Integer)
|
73
|
+
).returns(Result)
|
74
|
+
end
|
75
|
+
def self.from_state(state, run_dir: nil, seed: nil)
|
76
|
+
new(
|
77
|
+
candidates: state.program_candidates.map(&:dup),
|
78
|
+
parents: state.parent_program_for_candidate.map(&:dup),
|
79
|
+
val_aggregate_scores: state.program_full_scores_val_set.map(&:to_f),
|
80
|
+
best_outputs_valset: state.respond_to?(:best_outputs_valset) ? state.best_outputs_valset&.map(&:dup) : nil,
|
81
|
+
val_subscores: state.prog_candidate_val_subscores.map { |scores| scores.map(&:to_f) },
|
82
|
+
per_val_instance_best_candidates: state.program_at_pareto_front_valset.map { |set| set.to_a },
|
83
|
+
discovery_eval_counts: state.num_metric_calls_by_discovery.map(&:to_i),
|
84
|
+
total_metric_calls: state.respond_to?(:total_num_evals) ? state.total_num_evals : nil,
|
85
|
+
num_full_val_evals: state.respond_to?(:num_full_ds_evals) ? state.num_full_ds_evals : nil,
|
86
|
+
run_dir: run_dir,
|
87
|
+
seed: seed
|
88
|
+
)
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
@@ -0,0 +1,231 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'fileutils'
|
4
|
+
require 'json'
|
5
|
+
require 'set'
|
6
|
+
require 'sorbet-runtime'
|
7
|
+
|
8
|
+
require_relative '../utils/pareto'
|
9
|
+
require_relative '../telemetry'
|
10
|
+
|
11
|
+
module GEPA
|
12
|
+
module Core
|
13
|
+
class State
|
14
|
+
extend T::Sig
|
15
|
+
|
16
|
+
attr_accessor :i, :num_full_ds_evals, :total_num_evals
|
17
|
+
attr_reader :program_candidates,
|
18
|
+
:parent_program_for_candidate,
|
19
|
+
:program_full_scores_val_set,
|
20
|
+
:program_at_pareto_front_valset,
|
21
|
+
:prog_candidate_val_subscores,
|
22
|
+
:list_of_named_predictors,
|
23
|
+
:named_predictor_id_to_update_next_for_program_candidate,
|
24
|
+
:num_metric_calls_by_discovery,
|
25
|
+
:full_program_trace,
|
26
|
+
:per_program_tracked_scores,
|
27
|
+
:pareto_front_valset,
|
28
|
+
:best_outputs_valset
|
29
|
+
|
30
|
+
sig do
|
31
|
+
params(
|
32
|
+
seed_candidate: T::Hash[String, String],
|
33
|
+
base_valset_eval_output: [T::Array[T.untyped], T::Array[Float]],
|
34
|
+
track_best_outputs: T::Boolean
|
35
|
+
).void
|
36
|
+
end
|
37
|
+
def initialize(seed_candidate, base_valset_eval_output, track_best_outputs: false)
|
38
|
+
outputs, scores = base_valset_eval_output
|
39
|
+
raise ArgumentError, 'validation scores must not be empty' if scores.empty?
|
40
|
+
|
41
|
+
valset_base_score = scores.sum / scores.length.to_f
|
42
|
+
|
43
|
+
@program_candidates = [seed_candidate.dup]
|
44
|
+
@program_full_scores_val_set = [valset_base_score]
|
45
|
+
@per_program_tracked_scores = [valset_base_score]
|
46
|
+
|
47
|
+
@pareto_front_valset = scores.dup
|
48
|
+
@parent_program_for_candidate = [[nil]]
|
49
|
+
@program_at_pareto_front_valset = Array.new(scores.length) { Set.new([0]) }
|
50
|
+
|
51
|
+
@list_of_named_predictors = seed_candidate.keys
|
52
|
+
@named_predictor_id_to_update_next_for_program_candidate = [0]
|
53
|
+
|
54
|
+
@prog_candidate_val_subscores = [scores.dup]
|
55
|
+
@num_metric_calls_by_discovery = [0]
|
56
|
+
|
57
|
+
@best_outputs_valset = if track_best_outputs
|
58
|
+
outputs.map { |output| [[0, output]] }
|
59
|
+
end
|
60
|
+
|
61
|
+
@full_program_trace = []
|
62
|
+
@i = -1
|
63
|
+
@num_full_ds_evals = 0
|
64
|
+
@total_num_evals = 0
|
65
|
+
end
|
66
|
+
|
67
|
+
sig { returns(T::Boolean) }
|
68
|
+
def consistent?
|
69
|
+
size = @program_candidates.length
|
70
|
+
raise 'program_full_scores_val_set mismatch' unless @program_full_scores_val_set.length == size
|
71
|
+
raise 'per_program_tracked_scores mismatch' unless @per_program_tracked_scores.length == size
|
72
|
+
raise 'parent_program_for_candidate mismatch' unless @parent_program_for_candidate.length == size
|
73
|
+
raise 'named_predictor_id_to_update mismatch' unless @named_predictor_id_to_update_next_for_program_candidate.length == size
|
74
|
+
raise 'prog_candidate_val_subscores mismatch' unless @prog_candidate_val_subscores.length == size
|
75
|
+
raise 'num_metric_calls mismatch' unless @num_metric_calls_by_discovery.length == size
|
76
|
+
raise 'pareto fronts length mismatch' unless @pareto_front_valset.length == @program_at_pareto_front_valset.length
|
77
|
+
|
78
|
+
@program_at_pareto_front_valset.each do |front|
|
79
|
+
front.each do |idx|
|
80
|
+
raise 'pareto index out of range' unless idx < size
|
81
|
+
end
|
82
|
+
end
|
83
|
+
true
|
84
|
+
end
|
85
|
+
|
86
|
+
sig { params(run_dir: T.nilable(String)).void }
|
87
|
+
def save(run_dir)
|
88
|
+
return if run_dir.nil?
|
89
|
+
|
90
|
+
FileUtils.mkdir_p(run_dir)
|
91
|
+
File.open(File.join(run_dir, 'gepa_state.bin'), 'wb') do |file|
|
92
|
+
data = instance_variables.each_with_object({}) do |ivar, acc|
|
93
|
+
acc[ivar.to_s.delete('@')] = instance_variable_get(ivar)
|
94
|
+
end
|
95
|
+
Marshal.dump(data, file)
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
99
|
+
sig { params(run_dir: String).returns(State) }
|
100
|
+
def self.load(run_dir)
|
101
|
+
File.open(File.join(run_dir, 'gepa_state.bin'), 'rb') do |file|
|
102
|
+
data = Marshal.load(file)
|
103
|
+
state = allocate
|
104
|
+
data.each { |key, value| state.instance_variable_set("@#{key}", value) }
|
105
|
+
state.consistent?
|
106
|
+
state
|
107
|
+
end
|
108
|
+
end
|
109
|
+
|
110
|
+
sig do
|
111
|
+
params(
|
112
|
+
parent_program_idx: T::Array[Integer],
|
113
|
+
new_program: T::Hash[String, String],
|
114
|
+
valset_score: Float,
|
115
|
+
valset_outputs: T::Array[T.untyped],
|
116
|
+
valset_subscores: T::Array[Float],
|
117
|
+
run_dir: T.nilable(String),
|
118
|
+
num_metric_calls: Integer
|
119
|
+
).returns([Integer, Integer])
|
120
|
+
end
|
121
|
+
def update_state_with_new_program(
|
122
|
+
parent_program_idx,
|
123
|
+
new_program,
|
124
|
+
valset_score,
|
125
|
+
valset_outputs,
|
126
|
+
valset_subscores,
|
127
|
+
run_dir,
|
128
|
+
num_metric_calls
|
129
|
+
)
|
130
|
+
new_program_idx = @program_candidates.length
|
131
|
+
@program_candidates << new_program.dup
|
132
|
+
@num_metric_calls_by_discovery << num_metric_calls
|
133
|
+
|
134
|
+
max_predictor_id = parent_program_idx.map { |idx| @named_predictor_id_to_update_next_for_program_candidate[idx] }.compact.max
|
135
|
+
@named_predictor_id_to_update_next_for_program_candidate << (max_predictor_id || 0)
|
136
|
+
@parent_program_for_candidate << parent_program_idx.dup
|
137
|
+
|
138
|
+
@prog_candidate_val_subscores << valset_subscores.dup
|
139
|
+
@program_full_scores_val_set << valset_score.to_f
|
140
|
+
|
141
|
+
valset_subscores.each_with_index do |new_score, task_idx|
|
142
|
+
old_score = @pareto_front_valset[task_idx]
|
143
|
+
if new_score > old_score
|
144
|
+
@pareto_front_valset[task_idx] = new_score
|
145
|
+
@program_at_pareto_front_valset[task_idx] = Set.new([new_program_idx])
|
146
|
+
if @best_outputs_valset
|
147
|
+
@best_outputs_valset[task_idx] = [[new_program_idx, valset_outputs[task_idx]]]
|
148
|
+
end
|
149
|
+
write_best_output(run_dir, task_idx, new_program_idx, valset_outputs[task_idx])
|
150
|
+
elsif new_score == old_score
|
151
|
+
@program_at_pareto_front_valset[task_idx].add(new_program_idx)
|
152
|
+
if @best_outputs_valset
|
153
|
+
@best_outputs_valset[task_idx] << [new_program_idx, valset_outputs[task_idx]]
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
raise 'valset subscores length mismatch' unless valset_subscores.length == @program_at_pareto_front_valset.length
|
159
|
+
|
160
|
+
@per_program_tracked_scores = @program_full_scores_val_set.dup
|
161
|
+
linear_idx = GEPA::Utils::Pareto.idxmax(@per_program_tracked_scores)
|
162
|
+
|
163
|
+
[new_program_idx, linear_idx]
|
164
|
+
end
|
165
|
+
|
166
|
+
sig do
|
167
|
+
params(
|
168
|
+
eval_output: [T::Array[T.untyped], T::Array[Float]],
|
169
|
+
output_dir: String
|
170
|
+
).void
|
171
|
+
end
|
172
|
+
def self.write_eval_output_to_directory(eval_output, output_dir)
|
173
|
+
_, scores = eval_output
|
174
|
+
scores.each_with_index do |_score, task_idx|
|
175
|
+
dir = File.join(output_dir, "task_#{task_idx}")
|
176
|
+
FileUtils.mkdir_p(dir)
|
177
|
+
path = File.join(dir, 'iter_0_prog_0.json')
|
178
|
+
File.write(path, JSON.pretty_generate(scores[task_idx]))
|
179
|
+
end
|
180
|
+
end
|
181
|
+
|
182
|
+
sig do
|
183
|
+
params(
|
184
|
+
run_dir: T.nilable(String),
|
185
|
+
logger: T.untyped,
|
186
|
+
seed_candidate: T::Hash[String, String],
|
187
|
+
valset_evaluator: T.proc.params(arg0: T::Hash[String, String]).returns([T::Array[T.untyped], T::Array[Float]]),
|
188
|
+
track_best_outputs: T::Boolean
|
189
|
+
).returns(State)
|
190
|
+
end
|
191
|
+
def self.initialize_gepa_state(run_dir:, logger:, seed_candidate:, valset_evaluator:, track_best_outputs: false)
|
192
|
+
if run_dir && File.exist?(File.join(run_dir, 'gepa_state.bin')) && File.exist?(File.join(run_dir, 'prog_candidates'))
|
193
|
+
logger.log('Loading gepa state from run dir')
|
194
|
+
return load(run_dir)
|
195
|
+
end
|
196
|
+
|
197
|
+
valset_out = valset_evaluator.call(seed_candidate)
|
198
|
+
if run_dir
|
199
|
+
write_eval_output_to_directory(valset_out, File.join(run_dir, 'generated_best_outputs_valset'))
|
200
|
+
end
|
201
|
+
|
202
|
+
state = new(seed_candidate, valset_out, track_best_outputs: track_best_outputs)
|
203
|
+
state.num_full_ds_evals = 1
|
204
|
+
state.total_num_evals = valset_out.last.length
|
205
|
+
state
|
206
|
+
end
|
207
|
+
|
208
|
+
private
|
209
|
+
|
210
|
+
sig do
|
211
|
+
params(run_dir: T.nilable(String), task_idx: Integer, program_idx: Integer, output: T.untyped).void
|
212
|
+
end
|
213
|
+
def write_best_output(run_dir, task_idx, program_idx, output)
|
214
|
+
return if run_dir.nil?
|
215
|
+
|
216
|
+
dir = File.join(run_dir, 'generated_best_outputs_valset', "task_#{task_idx}")
|
217
|
+
FileUtils.mkdir_p(dir)
|
218
|
+
payload = ensure_jsonable(output)
|
219
|
+
File.write(File.join(dir, "iter_#{@i + 1}_prog_#{program_idx}.json"), JSON.pretty_generate(payload))
|
220
|
+
end
|
221
|
+
|
222
|
+
sig { params(value: T.untyped).returns(T.untyped) }
|
223
|
+
def ensure_jsonable(value)
|
224
|
+
JSON.parse(JSON.generate(value))
|
225
|
+
rescue StandardError
|
226
|
+
GEPA::Utils::Pareto.json_default(value)
|
227
|
+
end
|
228
|
+
end
|
229
|
+
end
|
230
|
+
end
|
231
|
+
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module GEPA
|
4
|
+
module Logging
|
5
|
+
# Lightweight experiment tracker that records metrics locally and can fan out to user hooks.
|
6
|
+
class ExperimentTracker
|
7
|
+
attr_reader :events
|
8
|
+
|
9
|
+
def initialize(subscribers: [])
|
10
|
+
@subscribers = Array(subscribers)
|
11
|
+
@events = []
|
12
|
+
end
|
13
|
+
|
14
|
+
def with_subscriber(proc = nil, &block)
|
15
|
+
@subscribers << (proc || block)
|
16
|
+
self
|
17
|
+
end
|
18
|
+
|
19
|
+
def initialize_backends; end
|
20
|
+
|
21
|
+
def start_run; end
|
22
|
+
|
23
|
+
def log_metrics(metrics, step: nil)
|
24
|
+
entry = { metrics: symbolize_keys(metrics), step: step }
|
25
|
+
@events << entry
|
26
|
+
|
27
|
+
@subscribers.each do |subscriber|
|
28
|
+
subscriber.call(entry)
|
29
|
+
rescue StandardError => e
|
30
|
+
DSPy.log('gepa.experiment_tracker.error', error: e.message)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
def end_run; end
|
35
|
+
|
36
|
+
def active?
|
37
|
+
!@events.empty?
|
38
|
+
end
|
39
|
+
|
40
|
+
def each_event(&block)
|
41
|
+
@events.each(&block)
|
42
|
+
end
|
43
|
+
|
44
|
+
private
|
45
|
+
|
46
|
+
def symbolize_keys(hash)
|
47
|
+
hash.each_with_object({}) do |(k, v), memo|
|
48
|
+
memo[k.to_sym] = v
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'forwardable'
|
4
|
+
|
5
|
+
module GEPA
|
6
|
+
module Logging
|
7
|
+
# Minimal logger interface used across GEPA components.
|
8
|
+
class Logger
|
9
|
+
extend Forwardable
|
10
|
+
|
11
|
+
def initialize(io: $stdout)
|
12
|
+
@io = io
|
13
|
+
end
|
14
|
+
|
15
|
+
def log(message)
|
16
|
+
write(message)
|
17
|
+
end
|
18
|
+
|
19
|
+
private
|
20
|
+
|
21
|
+
attr_reader :io
|
22
|
+
|
23
|
+
def write(message)
|
24
|
+
io.puts(message)
|
25
|
+
io.flush if io.respond_to?(:flush)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
# Logger that fans out messages to multiple IO streams.
|
30
|
+
class CompositeLogger < Logger
|
31
|
+
def initialize(*ios)
|
32
|
+
@ios = ios.flatten
|
33
|
+
end
|
34
|
+
|
35
|
+
def log(message)
|
36
|
+
@ios.each do |io|
|
37
|
+
io.puts(message)
|
38
|
+
io.flush if io.respond_to?(:flush)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
# Logger that captures messages into memory (handy for tests).
|
44
|
+
class BufferingLogger < Logger
|
45
|
+
attr_reader :messages
|
46
|
+
|
47
|
+
def initialize
|
48
|
+
@messages = []
|
49
|
+
end
|
50
|
+
|
51
|
+
def log(message)
|
52
|
+
@messages << message
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|