dspy 0.28.2 → 0.29.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +2 -3
  3. data/lib/dspy/code_act.rb +14 -1
  4. data/lib/dspy/datasets/ade.rb +90 -0
  5. data/lib/dspy/datasets.rb +8 -0
  6. data/lib/dspy/lm.rb +4 -8
  7. data/lib/dspy/mixins/struct_builder.rb +17 -25
  8. data/lib/dspy/module.rb +12 -1
  9. data/lib/dspy/observability/async_span_processor.rb +67 -93
  10. data/lib/dspy/observability.rb +43 -1
  11. data/lib/dspy/predict.rb +10 -0
  12. data/lib/dspy/propose/dataset_summary_generator.rb +36 -3
  13. data/lib/dspy/propose/grounded_proposer.rb +118 -11
  14. data/lib/dspy/re_act.rb +13 -0
  15. data/lib/dspy/reflection_lm.rb +36 -0
  16. data/lib/dspy/teleprompt/gepa.rb +448 -2803
  17. data/lib/dspy/teleprompt/mipro_v2.rb +564 -65
  18. data/lib/dspy/teleprompt/utils.rb +8 -3
  19. data/lib/dspy/version.rb +2 -2
  20. data/lib/dspy.rb +3 -2
  21. data/lib/gepa/api.rb +61 -0
  22. data/lib/gepa/core/engine.rb +226 -0
  23. data/lib/gepa/core/evaluation_batch.rb +26 -0
  24. data/lib/gepa/core/result.rb +92 -0
  25. data/lib/gepa/core/state.rb +231 -0
  26. data/lib/gepa/logging/experiment_tracker.rb +54 -0
  27. data/lib/gepa/logging/logger.rb +57 -0
  28. data/lib/gepa/logging.rb +9 -0
  29. data/lib/gepa/proposer/base.rb +27 -0
  30. data/lib/gepa/proposer/merge_proposer.rb +424 -0
  31. data/lib/gepa/proposer/reflective_mutation/base.rb +48 -0
  32. data/lib/gepa/proposer/reflective_mutation/reflective_mutation.rb +188 -0
  33. data/lib/gepa/strategies/batch_sampler.rb +91 -0
  34. data/lib/gepa/strategies/candidate_selector.rb +97 -0
  35. data/lib/gepa/strategies/component_selector.rb +57 -0
  36. data/lib/gepa/strategies/instruction_proposal.rb +120 -0
  37. data/lib/gepa/telemetry.rb +122 -0
  38. data/lib/gepa/utils/pareto.rb +119 -0
  39. data/lib/gepa.rb +21 -0
  40. metadata +42 -4
  41. data/lib/dspy/teleprompt/simple_optimizer.rb +0 -503
@@ -306,8 +306,13 @@ module DSPy
306
306
  demo_candidates = Hash.new { |h, k| h[k] = [] }
307
307
  rng = seed ? Random.new(seed) : Random.new
308
308
 
309
- # Get number of predictors (simplified: assume single predictor)
310
- num_predictors = 1
309
+ # Determine number of predictors exposed by the student module
310
+ num_predictors = if student.respond_to?(:predictors)
311
+ predictors = Array(student.predictors)
312
+ predictors.empty? ? 1 : predictors.size
313
+ else
314
+ 1
315
+ end
311
316
 
312
317
  # Adjust for 3 special seeds (-3, -2, -1)
313
318
  adjusted_num_sets = num_candidate_sets - 3
@@ -706,4 +711,4 @@ module DSPy
706
711
  end
707
712
  end
708
713
  end
709
- end
714
+ end
data/lib/dspy/version.rb CHANGED
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module DSPy
4
- VERSION = "0.28.2"
5
- end
4
+ VERSION = "0.29.0"
5
+ end
data/lib/dspy.rb CHANGED
@@ -12,6 +12,7 @@ require_relative 'dspy/observability/observation_type'
12
12
  require_relative 'dspy/context'
13
13
  require_relative 'dspy/events'
14
14
  require_relative 'dspy/events/types'
15
+ require_relative 'dspy/reflection_lm'
15
16
 
16
17
  module DSPy
17
18
  extend Dry::Configurable
@@ -198,6 +199,7 @@ require_relative 'dspy/signature'
198
199
  require_relative 'dspy/few_shot_example'
199
200
  require_relative 'dspy/prompt'
200
201
  require_relative 'dspy/example'
202
+ require_relative 'dspy/datasets'
201
203
  require_relative 'dspy/lm'
202
204
  require_relative 'dspy/image'
203
205
  require_relative 'dspy/prediction'
@@ -211,10 +213,9 @@ require_relative 'dspy/evaluate'
211
213
  require_relative 'dspy/teleprompt/teleprompter'
212
214
  require_relative 'dspy/teleprompt/utils'
213
215
  require_relative 'dspy/teleprompt/data_handler'
216
+ require_relative 'dspy/teleprompt/gepa'
214
217
  require_relative 'dspy/propose/grounded_proposer'
215
- require_relative 'dspy/teleprompt/simple_optimizer'
216
218
  require_relative 'dspy/teleprompt/mipro_v2'
217
- require_relative 'dspy/teleprompt/gepa'
218
219
  require_relative 'dspy/tools'
219
220
  require_relative 'dspy/memory'
220
221
  require_relative 'dspy/storage/program_storage'
data/lib/gepa/api.rb ADDED
@@ -0,0 +1,61 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'sorbet-runtime'
4
+
5
+ require_relative 'core/engine'
6
+ require_relative 'core/result'
7
+
8
+ module GEPA
9
+ extend T::Sig
10
+ module_function
11
+
12
+ sig do
13
+ params(
14
+ seed_candidate: T::Hash[String, String],
15
+ trainset: T::Array[T.untyped],
16
+ valset: T::Array[T.untyped],
17
+ adapter: T.untyped,
18
+ reflective_proposer: T.untyped,
19
+ merge_proposer: T.nilable(T.untyped),
20
+ logger: T.untyped,
21
+ experiment_tracker: T.untyped,
22
+ max_metric_calls: Integer,
23
+ telemetry: T.nilable(T.untyped)
24
+ ).returns(GEPA::Core::Result)
25
+ end
26
+ def optimize(
27
+ seed_candidate:,
28
+ trainset:,
29
+ valset:,
30
+ adapter:,
31
+ reflective_proposer:,
32
+ merge_proposer: nil,
33
+ logger:,
34
+ experiment_tracker:,
35
+ max_metric_calls:,
36
+ telemetry: nil
37
+ )
38
+ evaluator = proc { |dataset, candidate| adapter.evaluate(dataset, candidate) }
39
+
40
+ engine = GEPA::Core::Engine.new(
41
+ run_dir: nil,
42
+ evaluator: evaluator,
43
+ valset: valset,
44
+ seed_candidate: seed_candidate,
45
+ max_metric_calls: max_metric_calls,
46
+ perfect_score: Float::INFINITY,
47
+ seed: 0,
48
+ reflective_proposer: reflective_proposer,
49
+ merge_proposer: merge_proposer,
50
+ logger: logger,
51
+ experiment_tracker: experiment_tracker,
52
+ telemetry: telemetry || GEPA::Telemetry,
53
+ track_best_outputs: false,
54
+ display_progress_bar: false,
55
+ raise_on_exception: true
56
+ )
57
+
58
+ state = engine.run
59
+ GEPA::Core::Result.from_state(state)
60
+ end
61
+ end
@@ -0,0 +1,226 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'sorbet-runtime'
4
+
5
+ require_relative 'state'
6
+ require_relative 'result'
7
+ require_relative '../telemetry'
8
+
9
+ module GEPA
10
+ module Core
11
+ class Engine
12
+ extend T::Sig
13
+
14
+ sig do
15
+ params(
16
+ evaluator: T.proc.params(dataset: T::Array[T.untyped], candidate: T::Hash[String, String])
17
+ .returns([T::Array[T.untyped], T::Array[Float]]),
18
+ valset: T::Array[T.untyped],
19
+ seed_candidate: T::Hash[String, String],
20
+ max_metric_calls: Integer,
21
+ perfect_score: Float,
22
+ seed: Integer,
23
+ reflective_proposer: T.untyped,
24
+ logger: T.untyped,
25
+ experiment_tracker: T.untyped,
26
+ merge_proposer: T.nilable(T.untyped),
27
+ run_dir: T.nilable(String),
28
+ track_best_outputs: T::Boolean,
29
+ display_progress_bar: T::Boolean,
30
+ telemetry: T.nilable(T.untyped),
31
+ raise_on_exception: T::Boolean
32
+ ).void
33
+ end
34
+ def initialize(
35
+ evaluator:,
36
+ valset:,
37
+ seed_candidate:,
38
+ max_metric_calls:,
39
+ perfect_score:,
40
+ seed:, # rubocop:disable Lint/UnusedMethodArgument -- kept for parity and future use
41
+ reflective_proposer:,
42
+ logger:,
43
+ experiment_tracker:,
44
+ merge_proposer: nil,
45
+ run_dir: nil,
46
+ track_best_outputs: false,
47
+ display_progress_bar: false,
48
+ telemetry: nil,
49
+ raise_on_exception: true
50
+ )
51
+ @run_dir = run_dir
52
+ @evaluator = evaluator
53
+ @valset = valset
54
+ @seed_candidate = seed_candidate
55
+ @max_metric_calls = max_metric_calls
56
+ @perfect_score = perfect_score
57
+ @reflective_proposer = reflective_proposer
58
+ @merge_proposer = merge_proposer
59
+ @logger = logger
60
+ @experiment_tracker = experiment_tracker
61
+ @track_best_outputs = track_best_outputs
62
+ @display_progress_bar = display_progress_bar
63
+ @telemetry = telemetry || GEPA::Telemetry
64
+ @raise_on_exception = raise_on_exception
65
+ end
66
+
67
+ sig { returns(GEPA::Core::State) }
68
+ def run
69
+ with_span('gepa.engine.run', max_metric_calls: @max_metric_calls) do
70
+ state = GEPA::Core::State.initialize_gepa_state(
71
+ run_dir: @run_dir,
72
+ logger: @logger,
73
+ seed_candidate: @seed_candidate,
74
+ valset_evaluator: ->(candidate) { full_evaluator(candidate) },
75
+ track_best_outputs: @track_best_outputs
76
+ )
77
+
78
+ @experiment_tracker.log_metrics({ base_program_full_valset_score: state.program_full_scores_val_set.first }, step: 0)
79
+
80
+ if @merge_proposer
81
+ @merge_proposer.last_iter_found_new_program = false
82
+ end
83
+
84
+ while state.total_num_evals < @max_metric_calls
85
+ break unless iteration_step(state)
86
+ end
87
+
88
+ state.save(@run_dir)
89
+ state
90
+ end
91
+ end
92
+
93
+ private
94
+
95
+ sig { params(state: GEPA::Core::State).returns(T::Boolean) }
96
+ def iteration_step(state)
97
+ state.i += 1
98
+ trace_entry = { iteration: state.i }
99
+ state.full_program_trace << trace_entry
100
+
101
+ progress = false
102
+
103
+ with_span('gepa.engine.iteration', iteration: state.i) do
104
+ merge_result = process_merge_iteration(state)
105
+ case merge_result
106
+ when :accepted
107
+ return true
108
+ when :attempted
109
+ return false
110
+ end
111
+
112
+ reflective_result = process_reflective_iteration(state)
113
+ return false if reflective_result == :no_candidate
114
+ progress = true if reflective_result == :accepted
115
+ end
116
+
117
+ progress
118
+ rescue StandardError => e
119
+ @logger.log("Iteration #{state.i}: Exception during optimization: #{e}")
120
+ @logger.log(e.backtrace&.join("\n"))
121
+ raise e if @raise_on_exception
122
+ true
123
+ end
124
+
125
+ sig { params(state: GEPA::Core::State).returns(Symbol) }
126
+ def process_merge_iteration(state)
127
+ return :skipped unless @merge_proposer && @merge_proposer.use_merge
128
+
129
+ if @merge_proposer.merges_due.positive? && @merge_proposer.last_iter_found_new_program
130
+ proposal = @merge_proposer.propose(state)
131
+ @merge_proposer.last_iter_found_new_program = false
132
+
133
+ if proposal&.tag == 'merge'
134
+ parent_sums = Array(proposal.subsample_scores_before).map(&:to_f)
135
+ new_sum = Array(proposal.subsample_scores_after).map(&:to_f).sum
136
+
137
+ if parent_sums.empty?
138
+ @logger.log("Iteration #{state.i}: Missing parent subscores for merge proposal, skipping")
139
+ return :handled
140
+ end
141
+
142
+ if new_sum >= parent_sums.max
143
+ with_span('gepa.engine.full_evaluation', iteration: state.i) do
144
+ run_full_evaluation(state, proposal.candidate, proposal.parent_program_ids)
145
+ end
146
+ @merge_proposer.merges_due -= 1
147
+ @merge_proposer.total_merges_tested += 1
148
+ return :accepted
149
+ else
150
+ @logger.log(
151
+ "Iteration #{state.i}: Merge subsample score #{new_sum.round(4)} "\
152
+ "did not beat parents #{parent_sums.map { |v| v.round(4) }}, skipping"
153
+ )
154
+ return :attempted
155
+ end
156
+ end
157
+ end
158
+
159
+ @merge_proposer.last_iter_found_new_program = false
160
+ :skipped
161
+ end
162
+
163
+ sig { params(state: GEPA::Core::State).void }
164
+ def process_reflective_iteration(state)
165
+ proposal = @reflective_proposer.propose(state)
166
+ unless proposal
167
+ @logger.log("Iteration #{state.i}: Reflective mutation did not propose a new candidate")
168
+ return :no_candidate
169
+ end
170
+
171
+ before = Array(proposal.subsample_scores_before).map(&:to_f)
172
+ after = Array(proposal.subsample_scores_after).map(&:to_f)
173
+ if after.empty? || after.sum <= before.sum
174
+ @logger.log("Iteration #{state.i}: New subsample score is not better, skipping")
175
+ return :skipped
176
+ end
177
+
178
+ with_span('gepa.engine.full_evaluation', iteration: state.i) do
179
+ run_full_evaluation(state, proposal.candidate, proposal.parent_program_ids)
180
+ end
181
+
182
+ if @merge_proposer&.use_merge
183
+ @merge_proposer.last_iter_found_new_program = true
184
+ @merge_proposer.schedule_if_needed
185
+ end
186
+
187
+ :accepted
188
+ end
189
+
190
+ sig do
191
+ params(state: GEPA::Core::State, new_program: T::Hash[String, String], parents: T::Array[Integer]).void
192
+ end
193
+ def run_full_evaluation(state, new_program, parents)
194
+ outputs, scores = full_evaluator(new_program)
195
+ avg_score = scores.sum / scores.length.to_f
196
+
197
+ state.num_full_ds_evals += 1
198
+ state.total_num_evals += scores.length
199
+
200
+ state.update_state_with_new_program(
201
+ parents,
202
+ new_program,
203
+ avg_score,
204
+ outputs,
205
+ scores,
206
+ @run_dir,
207
+ state.total_num_evals
208
+ )
209
+
210
+ @experiment_tracker.log_metrics({ new_program_full_score: avg_score }, step: state.i)
211
+ end
212
+
213
+ sig { params(candidate: T::Hash[String, String]).returns([T::Array[T.untyped], T::Array[Float]]) }
214
+ def full_evaluator(candidate)
215
+ @evaluator.call(@valset, candidate)
216
+ end
217
+
218
+ sig do
219
+ params(operation: String, attrs: T::Hash[Symbol, T.untyped], block: T.proc.returns(T.untyped)).returns(T.untyped)
220
+ end
221
+ def with_span(operation, attrs = {}, &block)
222
+ @telemetry.with_span(operation, attrs, &block)
223
+ end
224
+ end
225
+ end
226
+ end
@@ -0,0 +1,26 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'sorbet-runtime'
4
+
5
+ module GEPA
6
+ module Core
7
+ # Container for evaluating a candidate on a batch.
8
+ class EvaluationBatch < T::Struct
9
+ extend T::Sig
10
+
11
+ const :outputs, T::Array[T.untyped]
12
+ const :scores, T::Array[Float]
13
+ const :trajectories, T.nilable(T::Array[T.untyped])
14
+
15
+ sig { override.params(args: T.untyped, kwargs: T.untyped).void }
16
+ def initialize(*args, **kwargs)
17
+ super
18
+ raise ArgumentError, 'outputs and scores length mismatch' unless outputs.length == scores.length
19
+
20
+ if trajectories
21
+ raise ArgumentError, 'trajectories length mismatch' unless trajectories.length == outputs.length
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,92 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'json'
4
+ require 'set'
5
+ require 'sorbet-runtime'
6
+
7
+ module GEPA
8
+ module Core
9
+ # Snapshot of GEPA optimization output with helpers for common queries.
10
+ class Result < T::Struct
11
+ extend T::Sig
12
+
13
+ const :candidates, T::Array[T::Hash[String, String]]
14
+ const :parents, T::Array[T::Array[T.nilable(Integer)]]
15
+ const :val_aggregate_scores, T::Array[Float]
16
+ const :val_subscores, T::Array[T::Array[Float]]
17
+ const :per_val_instance_best_candidates, T::Array[T::Array[Integer]]
18
+ const :discovery_eval_counts, T::Array[Integer]
19
+ const :best_outputs_valset, T.nilable(T::Array[T::Array[T::Array[T.untyped]]]), default: nil
20
+ const :total_metric_calls, T.nilable(Integer), default: nil
21
+ const :num_full_val_evals, T.nilable(Integer), default: nil
22
+ const :run_dir, T.nilable(String), default: nil
23
+ const :seed, T.nilable(Integer), default: nil
24
+
25
+ sig { returns(Integer) }
26
+ def num_candidates
27
+ candidates.length
28
+ end
29
+
30
+ sig { returns(Integer) }
31
+ def num_val_instances
32
+ per_val_instance_best_candidates.length
33
+ end
34
+
35
+ sig { returns(Integer) }
36
+ def best_idx
37
+ val_aggregate_scores.each_with_index.max_by { |score, _i| score }&.last || 0
38
+ end
39
+
40
+ sig { returns(T::Hash[String, String]) }
41
+ def best_candidate
42
+ candidates.fetch(best_idx)
43
+ end
44
+
45
+ sig { returns(T::Hash[Symbol, T.untyped]) }
46
+ def to_h
47
+ {
48
+ candidates: candidates.map(&:dup),
49
+ parents: parents.map(&:dup),
50
+ val_aggregate_scores: val_aggregate_scores.dup,
51
+ val_subscores: val_subscores.map(&:dup),
52
+ best_outputs_valset: best_outputs_valset&.map { |arr| arr.map(&:dup) },
53
+ per_val_instance_best_candidates: per_val_instance_best_candidates.map(&:dup),
54
+ discovery_eval_counts: discovery_eval_counts.dup,
55
+ total_metric_calls: total_metric_calls,
56
+ num_full_val_evals: num_full_val_evals,
57
+ run_dir: run_dir,
58
+ seed: seed,
59
+ best_idx: best_idx
60
+ }
61
+ end
62
+
63
+ sig { returns(String) }
64
+ def to_json(*_args)
65
+ JSON.pretty_generate(to_h)
66
+ end
67
+
68
+ sig do
69
+ params(
70
+ state: T.untyped,
71
+ run_dir: T.nilable(String),
72
+ seed: T.nilable(Integer)
73
+ ).returns(Result)
74
+ end
75
+ def self.from_state(state, run_dir: nil, seed: nil)
76
+ new(
77
+ candidates: state.program_candidates.map(&:dup),
78
+ parents: state.parent_program_for_candidate.map(&:dup),
79
+ val_aggregate_scores: state.program_full_scores_val_set.map(&:to_f),
80
+ best_outputs_valset: state.respond_to?(:best_outputs_valset) ? state.best_outputs_valset&.map(&:dup) : nil,
81
+ val_subscores: state.prog_candidate_val_subscores.map { |scores| scores.map(&:to_f) },
82
+ per_val_instance_best_candidates: state.program_at_pareto_front_valset.map { |set| set.to_a },
83
+ discovery_eval_counts: state.num_metric_calls_by_discovery.map(&:to_i),
84
+ total_metric_calls: state.respond_to?(:total_num_evals) ? state.total_num_evals : nil,
85
+ num_full_val_evals: state.respond_to?(:num_full_ds_evals) ? state.num_full_ds_evals : nil,
86
+ run_dir: run_dir,
87
+ seed: seed
88
+ )
89
+ end
90
+ end
91
+ end
92
+ end
@@ -0,0 +1,231 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'fileutils'
4
+ require 'json'
5
+ require 'set'
6
+ require 'sorbet-runtime'
7
+
8
+ require_relative '../utils/pareto'
9
+ require_relative '../telemetry'
10
+
11
+ module GEPA
12
+ module Core
13
+ class State
14
+ extend T::Sig
15
+
16
+ attr_accessor :i, :num_full_ds_evals, :total_num_evals
17
+ attr_reader :program_candidates,
18
+ :parent_program_for_candidate,
19
+ :program_full_scores_val_set,
20
+ :program_at_pareto_front_valset,
21
+ :prog_candidate_val_subscores,
22
+ :list_of_named_predictors,
23
+ :named_predictor_id_to_update_next_for_program_candidate,
24
+ :num_metric_calls_by_discovery,
25
+ :full_program_trace,
26
+ :per_program_tracked_scores,
27
+ :pareto_front_valset,
28
+ :best_outputs_valset
29
+
30
+ sig do
31
+ params(
32
+ seed_candidate: T::Hash[String, String],
33
+ base_valset_eval_output: [T::Array[T.untyped], T::Array[Float]],
34
+ track_best_outputs: T::Boolean
35
+ ).void
36
+ end
37
+ def initialize(seed_candidate, base_valset_eval_output, track_best_outputs: false)
38
+ outputs, scores = base_valset_eval_output
39
+ raise ArgumentError, 'validation scores must not be empty' if scores.empty?
40
+
41
+ valset_base_score = scores.sum / scores.length.to_f
42
+
43
+ @program_candidates = [seed_candidate.dup]
44
+ @program_full_scores_val_set = [valset_base_score]
45
+ @per_program_tracked_scores = [valset_base_score]
46
+
47
+ @pareto_front_valset = scores.dup
48
+ @parent_program_for_candidate = [[nil]]
49
+ @program_at_pareto_front_valset = Array.new(scores.length) { Set.new([0]) }
50
+
51
+ @list_of_named_predictors = seed_candidate.keys
52
+ @named_predictor_id_to_update_next_for_program_candidate = [0]
53
+
54
+ @prog_candidate_val_subscores = [scores.dup]
55
+ @num_metric_calls_by_discovery = [0]
56
+
57
+ @best_outputs_valset = if track_best_outputs
58
+ outputs.map { |output| [[0, output]] }
59
+ end
60
+
61
+ @full_program_trace = []
62
+ @i = -1
63
+ @num_full_ds_evals = 0
64
+ @total_num_evals = 0
65
+ end
66
+
67
+ sig { returns(T::Boolean) }
68
+ def consistent?
69
+ size = @program_candidates.length
70
+ raise 'program_full_scores_val_set mismatch' unless @program_full_scores_val_set.length == size
71
+ raise 'per_program_tracked_scores mismatch' unless @per_program_tracked_scores.length == size
72
+ raise 'parent_program_for_candidate mismatch' unless @parent_program_for_candidate.length == size
73
+ raise 'named_predictor_id_to_update mismatch' unless @named_predictor_id_to_update_next_for_program_candidate.length == size
74
+ raise 'prog_candidate_val_subscores mismatch' unless @prog_candidate_val_subscores.length == size
75
+ raise 'num_metric_calls mismatch' unless @num_metric_calls_by_discovery.length == size
76
+ raise 'pareto fronts length mismatch' unless @pareto_front_valset.length == @program_at_pareto_front_valset.length
77
+
78
+ @program_at_pareto_front_valset.each do |front|
79
+ front.each do |idx|
80
+ raise 'pareto index out of range' unless idx < size
81
+ end
82
+ end
83
+ true
84
+ end
85
+
86
+ sig { params(run_dir: T.nilable(String)).void }
87
+ def save(run_dir)
88
+ return if run_dir.nil?
89
+
90
+ FileUtils.mkdir_p(run_dir)
91
+ File.open(File.join(run_dir, 'gepa_state.bin'), 'wb') do |file|
92
+ data = instance_variables.each_with_object({}) do |ivar, acc|
93
+ acc[ivar.to_s.delete('@')] = instance_variable_get(ivar)
94
+ end
95
+ Marshal.dump(data, file)
96
+ end
97
+ end
98
+
99
+ sig { params(run_dir: String).returns(State) }
100
+ def self.load(run_dir)
101
+ File.open(File.join(run_dir, 'gepa_state.bin'), 'rb') do |file|
102
+ data = Marshal.load(file)
103
+ state = allocate
104
+ data.each { |key, value| state.instance_variable_set("@#{key}", value) }
105
+ state.consistent?
106
+ state
107
+ end
108
+ end
109
+
110
+ sig do
111
+ params(
112
+ parent_program_idx: T::Array[Integer],
113
+ new_program: T::Hash[String, String],
114
+ valset_score: Float,
115
+ valset_outputs: T::Array[T.untyped],
116
+ valset_subscores: T::Array[Float],
117
+ run_dir: T.nilable(String),
118
+ num_metric_calls: Integer
119
+ ).returns([Integer, Integer])
120
+ end
121
+ def update_state_with_new_program(
122
+ parent_program_idx,
123
+ new_program,
124
+ valset_score,
125
+ valset_outputs,
126
+ valset_subscores,
127
+ run_dir,
128
+ num_metric_calls
129
+ )
130
+ new_program_idx = @program_candidates.length
131
+ @program_candidates << new_program.dup
132
+ @num_metric_calls_by_discovery << num_metric_calls
133
+
134
+ max_predictor_id = parent_program_idx.map { |idx| @named_predictor_id_to_update_next_for_program_candidate[idx] }.compact.max
135
+ @named_predictor_id_to_update_next_for_program_candidate << (max_predictor_id || 0)
136
+ @parent_program_for_candidate << parent_program_idx.dup
137
+
138
+ @prog_candidate_val_subscores << valset_subscores.dup
139
+ @program_full_scores_val_set << valset_score.to_f
140
+
141
+ valset_subscores.each_with_index do |new_score, task_idx|
142
+ old_score = @pareto_front_valset[task_idx]
143
+ if new_score > old_score
144
+ @pareto_front_valset[task_idx] = new_score
145
+ @program_at_pareto_front_valset[task_idx] = Set.new([new_program_idx])
146
+ if @best_outputs_valset
147
+ @best_outputs_valset[task_idx] = [[new_program_idx, valset_outputs[task_idx]]]
148
+ end
149
+ write_best_output(run_dir, task_idx, new_program_idx, valset_outputs[task_idx])
150
+ elsif new_score == old_score
151
+ @program_at_pareto_front_valset[task_idx].add(new_program_idx)
152
+ if @best_outputs_valset
153
+ @best_outputs_valset[task_idx] << [new_program_idx, valset_outputs[task_idx]]
154
+ end
155
+ end
156
+ end
157
+
158
+ raise 'valset subscores length mismatch' unless valset_subscores.length == @program_at_pareto_front_valset.length
159
+
160
+ @per_program_tracked_scores = @program_full_scores_val_set.dup
161
+ linear_idx = GEPA::Utils::Pareto.idxmax(@per_program_tracked_scores)
162
+
163
+ [new_program_idx, linear_idx]
164
+ end
165
+
166
+ sig do
167
+ params(
168
+ eval_output: [T::Array[T.untyped], T::Array[Float]],
169
+ output_dir: String
170
+ ).void
171
+ end
172
+ def self.write_eval_output_to_directory(eval_output, output_dir)
173
+ _, scores = eval_output
174
+ scores.each_with_index do |_score, task_idx|
175
+ dir = File.join(output_dir, "task_#{task_idx}")
176
+ FileUtils.mkdir_p(dir)
177
+ path = File.join(dir, 'iter_0_prog_0.json')
178
+ File.write(path, JSON.pretty_generate(scores[task_idx]))
179
+ end
180
+ end
181
+
182
+ sig do
183
+ params(
184
+ run_dir: T.nilable(String),
185
+ logger: T.untyped,
186
+ seed_candidate: T::Hash[String, String],
187
+ valset_evaluator: T.proc.params(arg0: T::Hash[String, String]).returns([T::Array[T.untyped], T::Array[Float]]),
188
+ track_best_outputs: T::Boolean
189
+ ).returns(State)
190
+ end
191
+ def self.initialize_gepa_state(run_dir:, logger:, seed_candidate:, valset_evaluator:, track_best_outputs: false)
192
+ if run_dir && File.exist?(File.join(run_dir, 'gepa_state.bin')) && File.exist?(File.join(run_dir, 'prog_candidates'))
193
+ logger.log('Loading gepa state from run dir')
194
+ return load(run_dir)
195
+ end
196
+
197
+ valset_out = valset_evaluator.call(seed_candidate)
198
+ if run_dir
199
+ write_eval_output_to_directory(valset_out, File.join(run_dir, 'generated_best_outputs_valset'))
200
+ end
201
+
202
+ state = new(seed_candidate, valset_out, track_best_outputs: track_best_outputs)
203
+ state.num_full_ds_evals = 1
204
+ state.total_num_evals = valset_out.last.length
205
+ state
206
+ end
207
+
208
+ private
209
+
210
+ sig do
211
+ params(run_dir: T.nilable(String), task_idx: Integer, program_idx: Integer, output: T.untyped).void
212
+ end
213
+ def write_best_output(run_dir, task_idx, program_idx, output)
214
+ return if run_dir.nil?
215
+
216
+ dir = File.join(run_dir, 'generated_best_outputs_valset', "task_#{task_idx}")
217
+ FileUtils.mkdir_p(dir)
218
+ payload = ensure_jsonable(output)
219
+ File.write(File.join(dir, "iter_#{@i + 1}_prog_#{program_idx}.json"), JSON.pretty_generate(payload))
220
+ end
221
+
222
+ sig { params(value: T.untyped).returns(T.untyped) }
223
+ def ensure_jsonable(value)
224
+ JSON.parse(JSON.generate(value))
225
+ rescue StandardError
226
+ GEPA::Utils::Pareto.json_default(value)
227
+ end
228
+ end
229
+ end
230
+ end
231
+