dspy 0.28.2 → 0.29.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +2 -3
- data/lib/dspy/code_act.rb +14 -1
- data/lib/dspy/datasets/ade.rb +90 -0
- data/lib/dspy/datasets.rb +8 -0
- data/lib/dspy/lm.rb +4 -8
- data/lib/dspy/mixins/struct_builder.rb +17 -25
- data/lib/dspy/module.rb +12 -1
- data/lib/dspy/observability/async_span_processor.rb +67 -93
- data/lib/dspy/observability.rb +43 -1
- data/lib/dspy/predict.rb +10 -0
- data/lib/dspy/propose/dataset_summary_generator.rb +36 -3
- data/lib/dspy/propose/grounded_proposer.rb +118 -11
- data/lib/dspy/re_act.rb +13 -0
- data/lib/dspy/reflection_lm.rb +36 -0
- data/lib/dspy/teleprompt/gepa.rb +448 -2803
- data/lib/dspy/teleprompt/mipro_v2.rb +564 -65
- data/lib/dspy/teleprompt/utils.rb +8 -3
- data/lib/dspy/version.rb +2 -2
- data/lib/dspy.rb +3 -2
- data/lib/gepa/api.rb +61 -0
- data/lib/gepa/core/engine.rb +226 -0
- data/lib/gepa/core/evaluation_batch.rb +26 -0
- data/lib/gepa/core/result.rb +92 -0
- data/lib/gepa/core/state.rb +231 -0
- data/lib/gepa/logging/experiment_tracker.rb +54 -0
- data/lib/gepa/logging/logger.rb +57 -0
- data/lib/gepa/logging.rb +9 -0
- data/lib/gepa/proposer/base.rb +27 -0
- data/lib/gepa/proposer/merge_proposer.rb +424 -0
- data/lib/gepa/proposer/reflective_mutation/base.rb +48 -0
- data/lib/gepa/proposer/reflective_mutation/reflective_mutation.rb +188 -0
- data/lib/gepa/strategies/batch_sampler.rb +91 -0
- data/lib/gepa/strategies/candidate_selector.rb +97 -0
- data/lib/gepa/strategies/component_selector.rb +57 -0
- data/lib/gepa/strategies/instruction_proposal.rb +120 -0
- data/lib/gepa/telemetry.rb +122 -0
- data/lib/gepa/utils/pareto.rb +119 -0
- data/lib/gepa.rb +21 -0
- metadata +42 -4
- data/lib/dspy/teleprompt/simple_optimizer.rb +0 -503
@@ -306,8 +306,13 @@ module DSPy
|
|
306
306
|
demo_candidates = Hash.new { |h, k| h[k] = [] }
|
307
307
|
rng = seed ? Random.new(seed) : Random.new
|
308
308
|
|
309
|
-
#
|
310
|
-
num_predictors =
|
309
|
+
# Determine number of predictors exposed by the student module
|
310
|
+
num_predictors = if student.respond_to?(:predictors)
|
311
|
+
predictors = Array(student.predictors)
|
312
|
+
predictors.empty? ? 1 : predictors.size
|
313
|
+
else
|
314
|
+
1
|
315
|
+
end
|
311
316
|
|
312
317
|
# Adjust for 3 special seeds (-3, -2, -1)
|
313
318
|
adjusted_num_sets = num_candidate_sets - 3
|
@@ -706,4 +711,4 @@ module DSPy
|
|
706
711
|
end
|
707
712
|
end
|
708
713
|
end
|
709
|
-
end
|
714
|
+
end
|
data/lib/dspy/version.rb
CHANGED
data/lib/dspy.rb
CHANGED
@@ -12,6 +12,7 @@ require_relative 'dspy/observability/observation_type'
|
|
12
12
|
require_relative 'dspy/context'
|
13
13
|
require_relative 'dspy/events'
|
14
14
|
require_relative 'dspy/events/types'
|
15
|
+
require_relative 'dspy/reflection_lm'
|
15
16
|
|
16
17
|
module DSPy
|
17
18
|
extend Dry::Configurable
|
@@ -198,6 +199,7 @@ require_relative 'dspy/signature'
|
|
198
199
|
require_relative 'dspy/few_shot_example'
|
199
200
|
require_relative 'dspy/prompt'
|
200
201
|
require_relative 'dspy/example'
|
202
|
+
require_relative 'dspy/datasets'
|
201
203
|
require_relative 'dspy/lm'
|
202
204
|
require_relative 'dspy/image'
|
203
205
|
require_relative 'dspy/prediction'
|
@@ -211,10 +213,9 @@ require_relative 'dspy/evaluate'
|
|
211
213
|
require_relative 'dspy/teleprompt/teleprompter'
|
212
214
|
require_relative 'dspy/teleprompt/utils'
|
213
215
|
require_relative 'dspy/teleprompt/data_handler'
|
216
|
+
require_relative 'dspy/teleprompt/gepa'
|
214
217
|
require_relative 'dspy/propose/grounded_proposer'
|
215
|
-
require_relative 'dspy/teleprompt/simple_optimizer'
|
216
218
|
require_relative 'dspy/teleprompt/mipro_v2'
|
217
|
-
require_relative 'dspy/teleprompt/gepa'
|
218
219
|
require_relative 'dspy/tools'
|
219
220
|
require_relative 'dspy/memory'
|
220
221
|
require_relative 'dspy/storage/program_storage'
|
data/lib/gepa/api.rb
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
|
5
|
+
require_relative 'core/engine'
|
6
|
+
require_relative 'core/result'
|
7
|
+
|
8
|
+
module GEPA
|
9
|
+
extend T::Sig
|
10
|
+
module_function
|
11
|
+
|
12
|
+
sig do
|
13
|
+
params(
|
14
|
+
seed_candidate: T::Hash[String, String],
|
15
|
+
trainset: T::Array[T.untyped],
|
16
|
+
valset: T::Array[T.untyped],
|
17
|
+
adapter: T.untyped,
|
18
|
+
reflective_proposer: T.untyped,
|
19
|
+
merge_proposer: T.nilable(T.untyped),
|
20
|
+
logger: T.untyped,
|
21
|
+
experiment_tracker: T.untyped,
|
22
|
+
max_metric_calls: Integer,
|
23
|
+
telemetry: T.nilable(T.untyped)
|
24
|
+
).returns(GEPA::Core::Result)
|
25
|
+
end
|
26
|
+
def optimize(
|
27
|
+
seed_candidate:,
|
28
|
+
trainset:,
|
29
|
+
valset:,
|
30
|
+
adapter:,
|
31
|
+
reflective_proposer:,
|
32
|
+
merge_proposer: nil,
|
33
|
+
logger:,
|
34
|
+
experiment_tracker:,
|
35
|
+
max_metric_calls:,
|
36
|
+
telemetry: nil
|
37
|
+
)
|
38
|
+
evaluator = proc { |dataset, candidate| adapter.evaluate(dataset, candidate) }
|
39
|
+
|
40
|
+
engine = GEPA::Core::Engine.new(
|
41
|
+
run_dir: nil,
|
42
|
+
evaluator: evaluator,
|
43
|
+
valset: valset,
|
44
|
+
seed_candidate: seed_candidate,
|
45
|
+
max_metric_calls: max_metric_calls,
|
46
|
+
perfect_score: Float::INFINITY,
|
47
|
+
seed: 0,
|
48
|
+
reflective_proposer: reflective_proposer,
|
49
|
+
merge_proposer: merge_proposer,
|
50
|
+
logger: logger,
|
51
|
+
experiment_tracker: experiment_tracker,
|
52
|
+
telemetry: telemetry || GEPA::Telemetry,
|
53
|
+
track_best_outputs: false,
|
54
|
+
display_progress_bar: false,
|
55
|
+
raise_on_exception: true
|
56
|
+
)
|
57
|
+
|
58
|
+
state = engine.run
|
59
|
+
GEPA::Core::Result.from_state(state)
|
60
|
+
end
|
61
|
+
end
|
@@ -0,0 +1,226 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
|
5
|
+
require_relative 'state'
|
6
|
+
require_relative 'result'
|
7
|
+
require_relative '../telemetry'
|
8
|
+
|
9
|
+
module GEPA
|
10
|
+
module Core
|
11
|
+
class Engine
|
12
|
+
extend T::Sig
|
13
|
+
|
14
|
+
sig do
|
15
|
+
params(
|
16
|
+
evaluator: T.proc.params(dataset: T::Array[T.untyped], candidate: T::Hash[String, String])
|
17
|
+
.returns([T::Array[T.untyped], T::Array[Float]]),
|
18
|
+
valset: T::Array[T.untyped],
|
19
|
+
seed_candidate: T::Hash[String, String],
|
20
|
+
max_metric_calls: Integer,
|
21
|
+
perfect_score: Float,
|
22
|
+
seed: Integer,
|
23
|
+
reflective_proposer: T.untyped,
|
24
|
+
logger: T.untyped,
|
25
|
+
experiment_tracker: T.untyped,
|
26
|
+
merge_proposer: T.nilable(T.untyped),
|
27
|
+
run_dir: T.nilable(String),
|
28
|
+
track_best_outputs: T::Boolean,
|
29
|
+
display_progress_bar: T::Boolean,
|
30
|
+
telemetry: T.nilable(T.untyped),
|
31
|
+
raise_on_exception: T::Boolean
|
32
|
+
).void
|
33
|
+
end
|
34
|
+
def initialize(
|
35
|
+
evaluator:,
|
36
|
+
valset:,
|
37
|
+
seed_candidate:,
|
38
|
+
max_metric_calls:,
|
39
|
+
perfect_score:,
|
40
|
+
seed:, # rubocop:disable Lint/UnusedMethodArgument -- kept for parity and future use
|
41
|
+
reflective_proposer:,
|
42
|
+
logger:,
|
43
|
+
experiment_tracker:,
|
44
|
+
merge_proposer: nil,
|
45
|
+
run_dir: nil,
|
46
|
+
track_best_outputs: false,
|
47
|
+
display_progress_bar: false,
|
48
|
+
telemetry: nil,
|
49
|
+
raise_on_exception: true
|
50
|
+
)
|
51
|
+
@run_dir = run_dir
|
52
|
+
@evaluator = evaluator
|
53
|
+
@valset = valset
|
54
|
+
@seed_candidate = seed_candidate
|
55
|
+
@max_metric_calls = max_metric_calls
|
56
|
+
@perfect_score = perfect_score
|
57
|
+
@reflective_proposer = reflective_proposer
|
58
|
+
@merge_proposer = merge_proposer
|
59
|
+
@logger = logger
|
60
|
+
@experiment_tracker = experiment_tracker
|
61
|
+
@track_best_outputs = track_best_outputs
|
62
|
+
@display_progress_bar = display_progress_bar
|
63
|
+
@telemetry = telemetry || GEPA::Telemetry
|
64
|
+
@raise_on_exception = raise_on_exception
|
65
|
+
end
|
66
|
+
|
67
|
+
sig { returns(GEPA::Core::State) }
|
68
|
+
def run
|
69
|
+
with_span('gepa.engine.run', max_metric_calls: @max_metric_calls) do
|
70
|
+
state = GEPA::Core::State.initialize_gepa_state(
|
71
|
+
run_dir: @run_dir,
|
72
|
+
logger: @logger,
|
73
|
+
seed_candidate: @seed_candidate,
|
74
|
+
valset_evaluator: ->(candidate) { full_evaluator(candidate) },
|
75
|
+
track_best_outputs: @track_best_outputs
|
76
|
+
)
|
77
|
+
|
78
|
+
@experiment_tracker.log_metrics({ base_program_full_valset_score: state.program_full_scores_val_set.first }, step: 0)
|
79
|
+
|
80
|
+
if @merge_proposer
|
81
|
+
@merge_proposer.last_iter_found_new_program = false
|
82
|
+
end
|
83
|
+
|
84
|
+
while state.total_num_evals < @max_metric_calls
|
85
|
+
break unless iteration_step(state)
|
86
|
+
end
|
87
|
+
|
88
|
+
state.save(@run_dir)
|
89
|
+
state
|
90
|
+
end
|
91
|
+
end
|
92
|
+
|
93
|
+
private
|
94
|
+
|
95
|
+
sig { params(state: GEPA::Core::State).returns(T::Boolean) }
|
96
|
+
def iteration_step(state)
|
97
|
+
state.i += 1
|
98
|
+
trace_entry = { iteration: state.i }
|
99
|
+
state.full_program_trace << trace_entry
|
100
|
+
|
101
|
+
progress = false
|
102
|
+
|
103
|
+
with_span('gepa.engine.iteration', iteration: state.i) do
|
104
|
+
merge_result = process_merge_iteration(state)
|
105
|
+
case merge_result
|
106
|
+
when :accepted
|
107
|
+
return true
|
108
|
+
when :attempted
|
109
|
+
return false
|
110
|
+
end
|
111
|
+
|
112
|
+
reflective_result = process_reflective_iteration(state)
|
113
|
+
return false if reflective_result == :no_candidate
|
114
|
+
progress = true if reflective_result == :accepted
|
115
|
+
end
|
116
|
+
|
117
|
+
progress
|
118
|
+
rescue StandardError => e
|
119
|
+
@logger.log("Iteration #{state.i}: Exception during optimization: #{e}")
|
120
|
+
@logger.log(e.backtrace&.join("\n"))
|
121
|
+
raise e if @raise_on_exception
|
122
|
+
true
|
123
|
+
end
|
124
|
+
|
125
|
+
sig { params(state: GEPA::Core::State).returns(Symbol) }
|
126
|
+
def process_merge_iteration(state)
|
127
|
+
return :skipped unless @merge_proposer && @merge_proposer.use_merge
|
128
|
+
|
129
|
+
if @merge_proposer.merges_due.positive? && @merge_proposer.last_iter_found_new_program
|
130
|
+
proposal = @merge_proposer.propose(state)
|
131
|
+
@merge_proposer.last_iter_found_new_program = false
|
132
|
+
|
133
|
+
if proposal&.tag == 'merge'
|
134
|
+
parent_sums = Array(proposal.subsample_scores_before).map(&:to_f)
|
135
|
+
new_sum = Array(proposal.subsample_scores_after).map(&:to_f).sum
|
136
|
+
|
137
|
+
if parent_sums.empty?
|
138
|
+
@logger.log("Iteration #{state.i}: Missing parent subscores for merge proposal, skipping")
|
139
|
+
return :handled
|
140
|
+
end
|
141
|
+
|
142
|
+
if new_sum >= parent_sums.max
|
143
|
+
with_span('gepa.engine.full_evaluation', iteration: state.i) do
|
144
|
+
run_full_evaluation(state, proposal.candidate, proposal.parent_program_ids)
|
145
|
+
end
|
146
|
+
@merge_proposer.merges_due -= 1
|
147
|
+
@merge_proposer.total_merges_tested += 1
|
148
|
+
return :accepted
|
149
|
+
else
|
150
|
+
@logger.log(
|
151
|
+
"Iteration #{state.i}: Merge subsample score #{new_sum.round(4)} "\
|
152
|
+
"did not beat parents #{parent_sums.map { |v| v.round(4) }}, skipping"
|
153
|
+
)
|
154
|
+
return :attempted
|
155
|
+
end
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
@merge_proposer.last_iter_found_new_program = false
|
160
|
+
:skipped
|
161
|
+
end
|
162
|
+
|
163
|
+
sig { params(state: GEPA::Core::State).void }
|
164
|
+
def process_reflective_iteration(state)
|
165
|
+
proposal = @reflective_proposer.propose(state)
|
166
|
+
unless proposal
|
167
|
+
@logger.log("Iteration #{state.i}: Reflective mutation did not propose a new candidate")
|
168
|
+
return :no_candidate
|
169
|
+
end
|
170
|
+
|
171
|
+
before = Array(proposal.subsample_scores_before).map(&:to_f)
|
172
|
+
after = Array(proposal.subsample_scores_after).map(&:to_f)
|
173
|
+
if after.empty? || after.sum <= before.sum
|
174
|
+
@logger.log("Iteration #{state.i}: New subsample score is not better, skipping")
|
175
|
+
return :skipped
|
176
|
+
end
|
177
|
+
|
178
|
+
with_span('gepa.engine.full_evaluation', iteration: state.i) do
|
179
|
+
run_full_evaluation(state, proposal.candidate, proposal.parent_program_ids)
|
180
|
+
end
|
181
|
+
|
182
|
+
if @merge_proposer&.use_merge
|
183
|
+
@merge_proposer.last_iter_found_new_program = true
|
184
|
+
@merge_proposer.schedule_if_needed
|
185
|
+
end
|
186
|
+
|
187
|
+
:accepted
|
188
|
+
end
|
189
|
+
|
190
|
+
sig do
|
191
|
+
params(state: GEPA::Core::State, new_program: T::Hash[String, String], parents: T::Array[Integer]).void
|
192
|
+
end
|
193
|
+
def run_full_evaluation(state, new_program, parents)
|
194
|
+
outputs, scores = full_evaluator(new_program)
|
195
|
+
avg_score = scores.sum / scores.length.to_f
|
196
|
+
|
197
|
+
state.num_full_ds_evals += 1
|
198
|
+
state.total_num_evals += scores.length
|
199
|
+
|
200
|
+
state.update_state_with_new_program(
|
201
|
+
parents,
|
202
|
+
new_program,
|
203
|
+
avg_score,
|
204
|
+
outputs,
|
205
|
+
scores,
|
206
|
+
@run_dir,
|
207
|
+
state.total_num_evals
|
208
|
+
)
|
209
|
+
|
210
|
+
@experiment_tracker.log_metrics({ new_program_full_score: avg_score }, step: state.i)
|
211
|
+
end
|
212
|
+
|
213
|
+
sig { params(candidate: T::Hash[String, String]).returns([T::Array[T.untyped], T::Array[Float]]) }
|
214
|
+
def full_evaluator(candidate)
|
215
|
+
@evaluator.call(@valset, candidate)
|
216
|
+
end
|
217
|
+
|
218
|
+
sig do
|
219
|
+
params(operation: String, attrs: T::Hash[Symbol, T.untyped], block: T.proc.returns(T.untyped)).returns(T.untyped)
|
220
|
+
end
|
221
|
+
def with_span(operation, attrs = {}, &block)
|
222
|
+
@telemetry.with_span(operation, attrs, &block)
|
223
|
+
end
|
224
|
+
end
|
225
|
+
end
|
226
|
+
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
|
5
|
+
module GEPA
|
6
|
+
module Core
|
7
|
+
# Container for evaluating a candidate on a batch.
|
8
|
+
class EvaluationBatch < T::Struct
|
9
|
+
extend T::Sig
|
10
|
+
|
11
|
+
const :outputs, T::Array[T.untyped]
|
12
|
+
const :scores, T::Array[Float]
|
13
|
+
const :trajectories, T.nilable(T::Array[T.untyped])
|
14
|
+
|
15
|
+
sig { override.params(args: T.untyped, kwargs: T.untyped).void }
|
16
|
+
def initialize(*args, **kwargs)
|
17
|
+
super
|
18
|
+
raise ArgumentError, 'outputs and scores length mismatch' unless outputs.length == scores.length
|
19
|
+
|
20
|
+
if trajectories
|
21
|
+
raise ArgumentError, 'trajectories length mismatch' unless trajectories.length == outputs.length
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'json'
|
4
|
+
require 'set'
|
5
|
+
require 'sorbet-runtime'
|
6
|
+
|
7
|
+
module GEPA
|
8
|
+
module Core
|
9
|
+
# Snapshot of GEPA optimization output with helpers for common queries.
|
10
|
+
class Result < T::Struct
|
11
|
+
extend T::Sig
|
12
|
+
|
13
|
+
const :candidates, T::Array[T::Hash[String, String]]
|
14
|
+
const :parents, T::Array[T::Array[T.nilable(Integer)]]
|
15
|
+
const :val_aggregate_scores, T::Array[Float]
|
16
|
+
const :val_subscores, T::Array[T::Array[Float]]
|
17
|
+
const :per_val_instance_best_candidates, T::Array[T::Array[Integer]]
|
18
|
+
const :discovery_eval_counts, T::Array[Integer]
|
19
|
+
const :best_outputs_valset, T.nilable(T::Array[T::Array[T::Array[T.untyped]]]), default: nil
|
20
|
+
const :total_metric_calls, T.nilable(Integer), default: nil
|
21
|
+
const :num_full_val_evals, T.nilable(Integer), default: nil
|
22
|
+
const :run_dir, T.nilable(String), default: nil
|
23
|
+
const :seed, T.nilable(Integer), default: nil
|
24
|
+
|
25
|
+
sig { returns(Integer) }
|
26
|
+
def num_candidates
|
27
|
+
candidates.length
|
28
|
+
end
|
29
|
+
|
30
|
+
sig { returns(Integer) }
|
31
|
+
def num_val_instances
|
32
|
+
per_val_instance_best_candidates.length
|
33
|
+
end
|
34
|
+
|
35
|
+
sig { returns(Integer) }
|
36
|
+
def best_idx
|
37
|
+
val_aggregate_scores.each_with_index.max_by { |score, _i| score }&.last || 0
|
38
|
+
end
|
39
|
+
|
40
|
+
sig { returns(T::Hash[String, String]) }
|
41
|
+
def best_candidate
|
42
|
+
candidates.fetch(best_idx)
|
43
|
+
end
|
44
|
+
|
45
|
+
sig { returns(T::Hash[Symbol, T.untyped]) }
|
46
|
+
def to_h
|
47
|
+
{
|
48
|
+
candidates: candidates.map(&:dup),
|
49
|
+
parents: parents.map(&:dup),
|
50
|
+
val_aggregate_scores: val_aggregate_scores.dup,
|
51
|
+
val_subscores: val_subscores.map(&:dup),
|
52
|
+
best_outputs_valset: best_outputs_valset&.map { |arr| arr.map(&:dup) },
|
53
|
+
per_val_instance_best_candidates: per_val_instance_best_candidates.map(&:dup),
|
54
|
+
discovery_eval_counts: discovery_eval_counts.dup,
|
55
|
+
total_metric_calls: total_metric_calls,
|
56
|
+
num_full_val_evals: num_full_val_evals,
|
57
|
+
run_dir: run_dir,
|
58
|
+
seed: seed,
|
59
|
+
best_idx: best_idx
|
60
|
+
}
|
61
|
+
end
|
62
|
+
|
63
|
+
sig { returns(String) }
|
64
|
+
def to_json(*_args)
|
65
|
+
JSON.pretty_generate(to_h)
|
66
|
+
end
|
67
|
+
|
68
|
+
sig do
|
69
|
+
params(
|
70
|
+
state: T.untyped,
|
71
|
+
run_dir: T.nilable(String),
|
72
|
+
seed: T.nilable(Integer)
|
73
|
+
).returns(Result)
|
74
|
+
end
|
75
|
+
def self.from_state(state, run_dir: nil, seed: nil)
|
76
|
+
new(
|
77
|
+
candidates: state.program_candidates.map(&:dup),
|
78
|
+
parents: state.parent_program_for_candidate.map(&:dup),
|
79
|
+
val_aggregate_scores: state.program_full_scores_val_set.map(&:to_f),
|
80
|
+
best_outputs_valset: state.respond_to?(:best_outputs_valset) ? state.best_outputs_valset&.map(&:dup) : nil,
|
81
|
+
val_subscores: state.prog_candidate_val_subscores.map { |scores| scores.map(&:to_f) },
|
82
|
+
per_val_instance_best_candidates: state.program_at_pareto_front_valset.map { |set| set.to_a },
|
83
|
+
discovery_eval_counts: state.num_metric_calls_by_discovery.map(&:to_i),
|
84
|
+
total_metric_calls: state.respond_to?(:total_num_evals) ? state.total_num_evals : nil,
|
85
|
+
num_full_val_evals: state.respond_to?(:num_full_ds_evals) ? state.num_full_ds_evals : nil,
|
86
|
+
run_dir: run_dir,
|
87
|
+
seed: seed
|
88
|
+
)
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
@@ -0,0 +1,231 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'fileutils'
|
4
|
+
require 'json'
|
5
|
+
require 'set'
|
6
|
+
require 'sorbet-runtime'
|
7
|
+
|
8
|
+
require_relative '../utils/pareto'
|
9
|
+
require_relative '../telemetry'
|
10
|
+
|
11
|
+
module GEPA
|
12
|
+
module Core
|
13
|
+
class State
|
14
|
+
extend T::Sig
|
15
|
+
|
16
|
+
attr_accessor :i, :num_full_ds_evals, :total_num_evals
|
17
|
+
attr_reader :program_candidates,
|
18
|
+
:parent_program_for_candidate,
|
19
|
+
:program_full_scores_val_set,
|
20
|
+
:program_at_pareto_front_valset,
|
21
|
+
:prog_candidate_val_subscores,
|
22
|
+
:list_of_named_predictors,
|
23
|
+
:named_predictor_id_to_update_next_for_program_candidate,
|
24
|
+
:num_metric_calls_by_discovery,
|
25
|
+
:full_program_trace,
|
26
|
+
:per_program_tracked_scores,
|
27
|
+
:pareto_front_valset,
|
28
|
+
:best_outputs_valset
|
29
|
+
|
30
|
+
sig do
|
31
|
+
params(
|
32
|
+
seed_candidate: T::Hash[String, String],
|
33
|
+
base_valset_eval_output: [T::Array[T.untyped], T::Array[Float]],
|
34
|
+
track_best_outputs: T::Boolean
|
35
|
+
).void
|
36
|
+
end
|
37
|
+
def initialize(seed_candidate, base_valset_eval_output, track_best_outputs: false)
|
38
|
+
outputs, scores = base_valset_eval_output
|
39
|
+
raise ArgumentError, 'validation scores must not be empty' if scores.empty?
|
40
|
+
|
41
|
+
valset_base_score = scores.sum / scores.length.to_f
|
42
|
+
|
43
|
+
@program_candidates = [seed_candidate.dup]
|
44
|
+
@program_full_scores_val_set = [valset_base_score]
|
45
|
+
@per_program_tracked_scores = [valset_base_score]
|
46
|
+
|
47
|
+
@pareto_front_valset = scores.dup
|
48
|
+
@parent_program_for_candidate = [[nil]]
|
49
|
+
@program_at_pareto_front_valset = Array.new(scores.length) { Set.new([0]) }
|
50
|
+
|
51
|
+
@list_of_named_predictors = seed_candidate.keys
|
52
|
+
@named_predictor_id_to_update_next_for_program_candidate = [0]
|
53
|
+
|
54
|
+
@prog_candidate_val_subscores = [scores.dup]
|
55
|
+
@num_metric_calls_by_discovery = [0]
|
56
|
+
|
57
|
+
@best_outputs_valset = if track_best_outputs
|
58
|
+
outputs.map { |output| [[0, output]] }
|
59
|
+
end
|
60
|
+
|
61
|
+
@full_program_trace = []
|
62
|
+
@i = -1
|
63
|
+
@num_full_ds_evals = 0
|
64
|
+
@total_num_evals = 0
|
65
|
+
end
|
66
|
+
|
67
|
+
sig { returns(T::Boolean) }
|
68
|
+
def consistent?
|
69
|
+
size = @program_candidates.length
|
70
|
+
raise 'program_full_scores_val_set mismatch' unless @program_full_scores_val_set.length == size
|
71
|
+
raise 'per_program_tracked_scores mismatch' unless @per_program_tracked_scores.length == size
|
72
|
+
raise 'parent_program_for_candidate mismatch' unless @parent_program_for_candidate.length == size
|
73
|
+
raise 'named_predictor_id_to_update mismatch' unless @named_predictor_id_to_update_next_for_program_candidate.length == size
|
74
|
+
raise 'prog_candidate_val_subscores mismatch' unless @prog_candidate_val_subscores.length == size
|
75
|
+
raise 'num_metric_calls mismatch' unless @num_metric_calls_by_discovery.length == size
|
76
|
+
raise 'pareto fronts length mismatch' unless @pareto_front_valset.length == @program_at_pareto_front_valset.length
|
77
|
+
|
78
|
+
@program_at_pareto_front_valset.each do |front|
|
79
|
+
front.each do |idx|
|
80
|
+
raise 'pareto index out of range' unless idx < size
|
81
|
+
end
|
82
|
+
end
|
83
|
+
true
|
84
|
+
end
|
85
|
+
|
86
|
+
sig { params(run_dir: T.nilable(String)).void }
|
87
|
+
def save(run_dir)
|
88
|
+
return if run_dir.nil?
|
89
|
+
|
90
|
+
FileUtils.mkdir_p(run_dir)
|
91
|
+
File.open(File.join(run_dir, 'gepa_state.bin'), 'wb') do |file|
|
92
|
+
data = instance_variables.each_with_object({}) do |ivar, acc|
|
93
|
+
acc[ivar.to_s.delete('@')] = instance_variable_get(ivar)
|
94
|
+
end
|
95
|
+
Marshal.dump(data, file)
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
99
|
+
sig { params(run_dir: String).returns(State) }
|
100
|
+
def self.load(run_dir)
|
101
|
+
File.open(File.join(run_dir, 'gepa_state.bin'), 'rb') do |file|
|
102
|
+
data = Marshal.load(file)
|
103
|
+
state = allocate
|
104
|
+
data.each { |key, value| state.instance_variable_set("@#{key}", value) }
|
105
|
+
state.consistent?
|
106
|
+
state
|
107
|
+
end
|
108
|
+
end
|
109
|
+
|
110
|
+
sig do
|
111
|
+
params(
|
112
|
+
parent_program_idx: T::Array[Integer],
|
113
|
+
new_program: T::Hash[String, String],
|
114
|
+
valset_score: Float,
|
115
|
+
valset_outputs: T::Array[T.untyped],
|
116
|
+
valset_subscores: T::Array[Float],
|
117
|
+
run_dir: T.nilable(String),
|
118
|
+
num_metric_calls: Integer
|
119
|
+
).returns([Integer, Integer])
|
120
|
+
end
|
121
|
+
def update_state_with_new_program(
|
122
|
+
parent_program_idx,
|
123
|
+
new_program,
|
124
|
+
valset_score,
|
125
|
+
valset_outputs,
|
126
|
+
valset_subscores,
|
127
|
+
run_dir,
|
128
|
+
num_metric_calls
|
129
|
+
)
|
130
|
+
new_program_idx = @program_candidates.length
|
131
|
+
@program_candidates << new_program.dup
|
132
|
+
@num_metric_calls_by_discovery << num_metric_calls
|
133
|
+
|
134
|
+
max_predictor_id = parent_program_idx.map { |idx| @named_predictor_id_to_update_next_for_program_candidate[idx] }.compact.max
|
135
|
+
@named_predictor_id_to_update_next_for_program_candidate << (max_predictor_id || 0)
|
136
|
+
@parent_program_for_candidate << parent_program_idx.dup
|
137
|
+
|
138
|
+
@prog_candidate_val_subscores << valset_subscores.dup
|
139
|
+
@program_full_scores_val_set << valset_score.to_f
|
140
|
+
|
141
|
+
valset_subscores.each_with_index do |new_score, task_idx|
|
142
|
+
old_score = @pareto_front_valset[task_idx]
|
143
|
+
if new_score > old_score
|
144
|
+
@pareto_front_valset[task_idx] = new_score
|
145
|
+
@program_at_pareto_front_valset[task_idx] = Set.new([new_program_idx])
|
146
|
+
if @best_outputs_valset
|
147
|
+
@best_outputs_valset[task_idx] = [[new_program_idx, valset_outputs[task_idx]]]
|
148
|
+
end
|
149
|
+
write_best_output(run_dir, task_idx, new_program_idx, valset_outputs[task_idx])
|
150
|
+
elsif new_score == old_score
|
151
|
+
@program_at_pareto_front_valset[task_idx].add(new_program_idx)
|
152
|
+
if @best_outputs_valset
|
153
|
+
@best_outputs_valset[task_idx] << [new_program_idx, valset_outputs[task_idx]]
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
raise 'valset subscores length mismatch' unless valset_subscores.length == @program_at_pareto_front_valset.length
|
159
|
+
|
160
|
+
@per_program_tracked_scores = @program_full_scores_val_set.dup
|
161
|
+
linear_idx = GEPA::Utils::Pareto.idxmax(@per_program_tracked_scores)
|
162
|
+
|
163
|
+
[new_program_idx, linear_idx]
|
164
|
+
end
|
165
|
+
|
166
|
+
sig do
|
167
|
+
params(
|
168
|
+
eval_output: [T::Array[T.untyped], T::Array[Float]],
|
169
|
+
output_dir: String
|
170
|
+
).void
|
171
|
+
end
|
172
|
+
def self.write_eval_output_to_directory(eval_output, output_dir)
|
173
|
+
_, scores = eval_output
|
174
|
+
scores.each_with_index do |_score, task_idx|
|
175
|
+
dir = File.join(output_dir, "task_#{task_idx}")
|
176
|
+
FileUtils.mkdir_p(dir)
|
177
|
+
path = File.join(dir, 'iter_0_prog_0.json')
|
178
|
+
File.write(path, JSON.pretty_generate(scores[task_idx]))
|
179
|
+
end
|
180
|
+
end
|
181
|
+
|
182
|
+
sig do
|
183
|
+
params(
|
184
|
+
run_dir: T.nilable(String),
|
185
|
+
logger: T.untyped,
|
186
|
+
seed_candidate: T::Hash[String, String],
|
187
|
+
valset_evaluator: T.proc.params(arg0: T::Hash[String, String]).returns([T::Array[T.untyped], T::Array[Float]]),
|
188
|
+
track_best_outputs: T::Boolean
|
189
|
+
).returns(State)
|
190
|
+
end
|
191
|
+
def self.initialize_gepa_state(run_dir:, logger:, seed_candidate:, valset_evaluator:, track_best_outputs: false)
|
192
|
+
if run_dir && File.exist?(File.join(run_dir, 'gepa_state.bin')) && File.exist?(File.join(run_dir, 'prog_candidates'))
|
193
|
+
logger.log('Loading gepa state from run dir')
|
194
|
+
return load(run_dir)
|
195
|
+
end
|
196
|
+
|
197
|
+
valset_out = valset_evaluator.call(seed_candidate)
|
198
|
+
if run_dir
|
199
|
+
write_eval_output_to_directory(valset_out, File.join(run_dir, 'generated_best_outputs_valset'))
|
200
|
+
end
|
201
|
+
|
202
|
+
state = new(seed_candidate, valset_out, track_best_outputs: track_best_outputs)
|
203
|
+
state.num_full_ds_evals = 1
|
204
|
+
state.total_num_evals = valset_out.last.length
|
205
|
+
state
|
206
|
+
end
|
207
|
+
|
208
|
+
private
|
209
|
+
|
210
|
+
sig do
|
211
|
+
params(run_dir: T.nilable(String), task_idx: Integer, program_idx: Integer, output: T.untyped).void
|
212
|
+
end
|
213
|
+
def write_best_output(run_dir, task_idx, program_idx, output)
|
214
|
+
return if run_dir.nil?
|
215
|
+
|
216
|
+
dir = File.join(run_dir, 'generated_best_outputs_valset', "task_#{task_idx}")
|
217
|
+
FileUtils.mkdir_p(dir)
|
218
|
+
payload = ensure_jsonable(output)
|
219
|
+
File.write(File.join(dir, "iter_#{@i + 1}_prog_#{program_idx}.json"), JSON.pretty_generate(payload))
|
220
|
+
end
|
221
|
+
|
222
|
+
sig { params(value: T.untyped).returns(T.untyped) }
|
223
|
+
def ensure_jsonable(value)
|
224
|
+
JSON.parse(JSON.generate(value))
|
225
|
+
rescue StandardError
|
226
|
+
GEPA::Utils::Pareto.json_default(value)
|
227
|
+
end
|
228
|
+
end
|
229
|
+
end
|
230
|
+
end
|
231
|
+
|