dspy 0.28.1 → 0.29.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +2 -3
- data/lib/dspy/callbacks.rb +222 -0
- data/lib/dspy/chain_of_thought.rb +2 -1
- data/lib/dspy/code_act.rb +14 -1
- data/lib/dspy/datasets/ade.rb +90 -0
- data/lib/dspy/datasets.rb +8 -0
- data/lib/dspy/lm.rb +9 -12
- data/lib/dspy/mixins/struct_builder.rb +17 -25
- data/lib/dspy/module.rb +45 -1
- data/lib/dspy/observability/async_span_processor.rb +67 -93
- data/lib/dspy/observability.rb +43 -1
- data/lib/dspy/predict.rb +17 -0
- data/lib/dspy/prompt.rb +90 -20
- data/lib/dspy/propose/dataset_summary_generator.rb +210 -0
- data/lib/dspy/propose/grounded_proposer.rb +320 -66
- data/lib/dspy/re_act.rb +13 -0
- data/lib/dspy/reflection_lm.rb +36 -0
- data/lib/dspy/teleprompt/bootstrap_strategy.rb +26 -0
- data/lib/dspy/teleprompt/gepa.rb +448 -2803
- data/lib/dspy/teleprompt/mipro_v2.rb +624 -100
- data/lib/dspy/teleprompt/utils.rb +349 -42
- data/lib/dspy/version.rb +2 -2
- data/lib/dspy.rb +4 -2
- data/lib/gepa/api.rb +61 -0
- data/lib/gepa/core/engine.rb +226 -0
- data/lib/gepa/core/evaluation_batch.rb +26 -0
- data/lib/gepa/core/result.rb +92 -0
- data/lib/gepa/core/state.rb +231 -0
- data/lib/gepa/logging/experiment_tracker.rb +54 -0
- data/lib/gepa/logging/logger.rb +57 -0
- data/lib/gepa/logging.rb +9 -0
- data/lib/gepa/proposer/base.rb +27 -0
- data/lib/gepa/proposer/merge_proposer.rb +424 -0
- data/lib/gepa/proposer/reflective_mutation/base.rb +48 -0
- data/lib/gepa/proposer/reflective_mutation/reflective_mutation.rb +188 -0
- data/lib/gepa/strategies/batch_sampler.rb +91 -0
- data/lib/gepa/strategies/candidate_selector.rb +97 -0
- data/lib/gepa/strategies/component_selector.rb +57 -0
- data/lib/gepa/strategies/instruction_proposal.rb +120 -0
- data/lib/gepa/telemetry.rb +122 -0
- data/lib/gepa/utils/pareto.rb +119 -0
- data/lib/gepa.rb +21 -0
- metadata +59 -4
- data/lib/dspy/teleprompt/simple_optimizer.rb +0 -497
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
|
5
|
+
module GEPA
|
6
|
+
module Proposer
|
7
|
+
class CandidateProposal < T::Struct
|
8
|
+
extend T::Sig
|
9
|
+
|
10
|
+
const :candidate, T::Hash[String, String]
|
11
|
+
const :parent_program_ids, T::Array[Integer]
|
12
|
+
const :subsample_indices, T.nilable(T::Array[Integer]), default: nil
|
13
|
+
const :subsample_scores_before, T.nilable(T::Array[Float]), default: nil
|
14
|
+
const :subsample_scores_after, T.nilable(T::Array[Float]), default: nil
|
15
|
+
const :tag, String, default: 'reflective_mutation'
|
16
|
+
const :metadata, T::Hash[Symbol, T.untyped], default: {}
|
17
|
+
end
|
18
|
+
|
19
|
+
module ProposeNewCandidate
|
20
|
+
extend T::Sig
|
21
|
+
|
22
|
+
sig { abstract.params(state: GEPA::Core::State).returns(T.nilable(CandidateProposal)) }
|
23
|
+
def propose(state); end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
@@ -0,0 +1,424 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'set'
|
4
|
+
require 'sorbet-runtime'
|
5
|
+
|
6
|
+
require_relative 'base'
|
7
|
+
require_relative '../utils/pareto'
|
8
|
+
require_relative '../telemetry'
|
9
|
+
|
10
|
+
module GEPA
|
11
|
+
module Proposer
|
12
|
+
# Port of the Python GEPA merge proposer. It fuses two descendants that share
|
13
|
+
# a common ancestor by recombining their component instructions and then
|
14
|
+
# evaluates the merged program on a Pareto-informed subsample.
|
15
|
+
class MergeProposer
|
16
|
+
extend T::Sig
|
17
|
+
include ProposeNewCandidate
|
18
|
+
|
19
|
+
CandidateTriplet = T.type_alias { [Integer, Integer, Integer] }
|
20
|
+
MergeAttempt = T.type_alias { [Integer, Integer, T::Array[Integer]] }
|
21
|
+
|
22
|
+
sig do
|
23
|
+
params(
|
24
|
+
logger: T.untyped,
|
25
|
+
valset: T::Array[T.untyped],
|
26
|
+
evaluator: T.proc.params(dataset: T::Array[T.untyped], candidate: T::Hash[String, String])
|
27
|
+
.returns([T::Array[T.untyped], T::Array[Float]]),
|
28
|
+
use_merge: T::Boolean,
|
29
|
+
max_merge_invocations: Integer,
|
30
|
+
rng: T.nilable(Random),
|
31
|
+
telemetry: T.nilable(T.untyped)
|
32
|
+
).void
|
33
|
+
end
|
34
|
+
def initialize(logger:, valset:, evaluator:, use_merge:, max_merge_invocations:, rng: nil, telemetry: nil)
|
35
|
+
@logger = logger
|
36
|
+
@valset = valset
|
37
|
+
@evaluator = evaluator
|
38
|
+
@use_merge = use_merge
|
39
|
+
@max_merge_invocations = max_merge_invocations
|
40
|
+
@rng = rng || Random.new(0)
|
41
|
+
@telemetry = telemetry || GEPA::Telemetry
|
42
|
+
|
43
|
+
@merges_due = 0
|
44
|
+
@total_merges_tested = 0
|
45
|
+
@last_iter_found_new_program = false
|
46
|
+
@merges_performed = [[], []]
|
47
|
+
end
|
48
|
+
|
49
|
+
sig { returns(Integer) }
|
50
|
+
attr_accessor :merges_due
|
51
|
+
|
52
|
+
sig { returns(Integer) }
|
53
|
+
attr_accessor :total_merges_tested
|
54
|
+
|
55
|
+
sig { returns(T::Boolean) }
|
56
|
+
attr_accessor :last_iter_found_new_program
|
57
|
+
|
58
|
+
sig { returns(Integer) }
|
59
|
+
attr_reader :max_merge_invocations
|
60
|
+
|
61
|
+
sig { returns(T::Boolean) }
|
62
|
+
attr_reader :use_merge
|
63
|
+
|
64
|
+
sig { void }
|
65
|
+
def schedule_if_needed
|
66
|
+
return unless @use_merge
|
67
|
+
return unless @total_merges_tested < @max_merge_invocations
|
68
|
+
|
69
|
+
@merges_due += 1
|
70
|
+
end
|
71
|
+
|
72
|
+
sig do
|
73
|
+
params(
|
74
|
+
scores1: T::Array[Float],
|
75
|
+
scores2: T::Array[Float],
|
76
|
+
num_subsample_ids: Integer
|
77
|
+
).returns(T::Array[Integer])
|
78
|
+
end
|
79
|
+
def select_eval_subsample_for_merged_program(scores1, scores2, num_subsample_ids: 5)
|
80
|
+
all_indices = (0...[scores1.length, scores2.length].min).to_a
|
81
|
+
p1 = []
|
82
|
+
p2 = []
|
83
|
+
p3 = []
|
84
|
+
|
85
|
+
all_indices.each do |index|
|
86
|
+
s1 = scores1[index]
|
87
|
+
s2 = scores2[index]
|
88
|
+
if s1 > s2
|
89
|
+
p1 << index
|
90
|
+
elsif s2 > s1
|
91
|
+
p2 << index
|
92
|
+
else
|
93
|
+
p3 << index
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
n_each = (num_subsample_ids / 3.0).ceil
|
98
|
+
selected = []
|
99
|
+
selected.concat(sample_from(p1, [n_each, p1.length].min))
|
100
|
+
selected.concat(sample_from(p2, [n_each, p2.length].min))
|
101
|
+
|
102
|
+
remaining_slots = num_subsample_ids - selected.length
|
103
|
+
selected.concat(sample_from(p3, [remaining_slots, p3.length].min))
|
104
|
+
|
105
|
+
remaining_slots = num_subsample_ids - selected.length
|
106
|
+
unused = all_indices - selected
|
107
|
+
if remaining_slots.positive?
|
108
|
+
if unused.length >= remaining_slots
|
109
|
+
selected.concat(sample_from(unused, remaining_slots))
|
110
|
+
else
|
111
|
+
selected.concat(sample_with_replacement(all_indices, remaining_slots))
|
112
|
+
end
|
113
|
+
end
|
114
|
+
|
115
|
+
selected.take(num_subsample_ids)
|
116
|
+
end
|
117
|
+
|
118
|
+
sig { override.params(state: GEPA::Core::State).returns(T.nilable(CandidateProposal)) }
|
119
|
+
def propose(state)
|
120
|
+
iteration = state.i + 1
|
121
|
+
ensure_trace_slot(state)
|
122
|
+
state.full_program_trace.last[:invoked_merge] = true
|
123
|
+
|
124
|
+
unless eligible_for_proposal?
|
125
|
+
@logger.log("Iteration #{iteration}: No merge candidates scheduled")
|
126
|
+
return nil
|
127
|
+
end
|
128
|
+
|
129
|
+
merge_candidates = GEPA::Utils::Pareto.find_dominator_programs(
|
130
|
+
state.program_at_pareto_front_valset,
|
131
|
+
state.per_program_tracked_scores.each_with_index.to_h { |score, idx| [idx, score] }
|
132
|
+
)
|
133
|
+
|
134
|
+
success, new_program, id1, id2, ancestor = sample_and_attempt_merge_programs_by_common_predictors(
|
135
|
+
state,
|
136
|
+
merge_candidates
|
137
|
+
)
|
138
|
+
|
139
|
+
unless success
|
140
|
+
@logger.log("Iteration #{iteration}: No merge candidates found")
|
141
|
+
return nil
|
142
|
+
end
|
143
|
+
|
144
|
+
state.full_program_trace.last[:merged] = true
|
145
|
+
state.full_program_trace.last[:merged_entities] = [id1, id2, ancestor]
|
146
|
+
@merges_performed[0] << [id1, id2, ancestor]
|
147
|
+
|
148
|
+
@logger.log("Iteration #{iteration}: Merged programs #{id1} and #{id2} via ancestor #{ancestor}")
|
149
|
+
|
150
|
+
subsample_ids = select_eval_subsample_for_merged_program(
|
151
|
+
state.prog_candidate_val_subscores[id1],
|
152
|
+
state.prog_candidate_val_subscores[id2]
|
153
|
+
)
|
154
|
+
|
155
|
+
mini_valset = subsample_ids.map { |idx| @valset[idx] }
|
156
|
+
id1_sub_scores = subsample_ids.map { |idx| state.prog_candidate_val_subscores[id1][idx] }
|
157
|
+
id2_sub_scores = subsample_ids.map { |idx| state.prog_candidate_val_subscores[id2][idx] }
|
158
|
+
|
159
|
+
state.full_program_trace.last[:subsample_ids] = subsample_ids
|
160
|
+
state.full_program_trace.last[:id1_subsample_scores] = id1_sub_scores
|
161
|
+
state.full_program_trace.last[:id2_subsample_scores] = id2_sub_scores
|
162
|
+
|
163
|
+
_, new_sub_scores = @evaluator.call(mini_valset, new_program)
|
164
|
+
state.full_program_trace.last[:new_program_subsample_scores] = new_sub_scores
|
165
|
+
|
166
|
+
state.total_num_evals += subsample_ids.length
|
167
|
+
|
168
|
+
CandidateProposal.new(
|
169
|
+
candidate: new_program,
|
170
|
+
parent_program_ids: [id1, id2],
|
171
|
+
subsample_indices: subsample_ids,
|
172
|
+
subsample_scores_before: [id1_sub_scores.sum, id2_sub_scores.sum],
|
173
|
+
subsample_scores_after: new_sub_scores,
|
174
|
+
tag: 'merge',
|
175
|
+
metadata: { ancestor: ancestor }
|
176
|
+
)
|
177
|
+
end
|
178
|
+
|
179
|
+
private
|
180
|
+
|
181
|
+
attr_reader :logger
|
182
|
+
|
183
|
+
sig { returns(T::Boolean) }
|
184
|
+
def eligible_for_proposal?
|
185
|
+
@use_merge && @last_iter_found_new_program && @merges_due.positive?
|
186
|
+
end
|
187
|
+
|
188
|
+
sig do
|
189
|
+
params(state: GEPA::Core::State, merge_candidates: T::Array[Integer])
|
190
|
+
.returns([T::Boolean, T.nilable(T::Hash[String, String]), T.nilable(Integer), T.nilable(Integer), T.nilable(Integer)])
|
191
|
+
end
|
192
|
+
def sample_and_attempt_merge_programs_by_common_predictors(state, merge_candidates)
|
193
|
+
return [false, nil, nil, nil, nil] if merge_candidates.length < 2
|
194
|
+
return [false, nil, nil, nil, nil] if state.parent_program_for_candidate.length < 3
|
195
|
+
|
196
|
+
10.times do
|
197
|
+
ids_to_merge = find_common_ancestor_pair(
|
198
|
+
state.parent_program_for_candidate,
|
199
|
+
merge_candidates,
|
200
|
+
state.per_program_tracked_scores,
|
201
|
+
state.program_candidates
|
202
|
+
)
|
203
|
+
next unless ids_to_merge
|
204
|
+
|
205
|
+
id1, id2, ancestor = ids_to_merge
|
206
|
+
return [false, nil, nil, nil, nil] unless id1 && id2 && ancestor
|
207
|
+
|
208
|
+
new_program, new_prog_desc = build_merged_program(
|
209
|
+
state.program_candidates,
|
210
|
+
id1,
|
211
|
+
id2,
|
212
|
+
ancestor,
|
213
|
+
state.per_program_tracked_scores
|
214
|
+
)
|
215
|
+
|
216
|
+
next unless new_program
|
217
|
+
|
218
|
+
if @merges_performed[1].include?([id1, id2, new_prog_desc])
|
219
|
+
next
|
220
|
+
end
|
221
|
+
|
222
|
+
@merges_performed[1] << [id1, id2, new_prog_desc]
|
223
|
+
return [true, new_program, id1, id2, ancestor]
|
224
|
+
end
|
225
|
+
|
226
|
+
[false, nil, nil, nil, nil]
|
227
|
+
end
|
228
|
+
|
229
|
+
sig do
|
230
|
+
params(
|
231
|
+
parent_list: T::Array[T::Array[T.nilable(Integer)]],
|
232
|
+
merge_candidates: T::Array[Integer],
|
233
|
+
agg_scores: T::Array[Float],
|
234
|
+
program_candidates: T::Array[T::Hash[String, String]]
|
235
|
+
).returns(T.nilable(CandidateTriplet))
|
236
|
+
end
|
237
|
+
def find_common_ancestor_pair(parent_list, merge_candidates, agg_scores, program_candidates)
|
238
|
+
10.times do
|
239
|
+
return nil if merge_candidates.length < 2
|
240
|
+
|
241
|
+
id1, id2 = sample_distinct_pair(merge_candidates)
|
242
|
+
next unless id1 && id2
|
243
|
+
|
244
|
+
ancestors_i = collect_ancestors(parent_list, id1)
|
245
|
+
ancestors_j = collect_ancestors(parent_list, id2)
|
246
|
+
|
247
|
+
next if ancestors_i.include?(id2) || ancestors_j.include?(id1)
|
248
|
+
|
249
|
+
common = ancestors_i & ancestors_j
|
250
|
+
filtered = filter_ancestors(
|
251
|
+
id1,
|
252
|
+
id2,
|
253
|
+
common,
|
254
|
+
agg_scores,
|
255
|
+
program_candidates
|
256
|
+
)
|
257
|
+
next if filtered.empty?
|
258
|
+
|
259
|
+
weights = filtered.map { |ancestor| agg_scores[ancestor] }
|
260
|
+
ancestor = sample_with_weights(filtered, weights)
|
261
|
+
return [id1, id2, ancestor]
|
262
|
+
end
|
263
|
+
|
264
|
+
nil
|
265
|
+
end
|
266
|
+
|
267
|
+
sig do
|
268
|
+
params(
|
269
|
+
id1: Integer,
|
270
|
+
id2: Integer,
|
271
|
+
common_ancestors: T::Array[Integer],
|
272
|
+
agg_scores: T::Array[Float],
|
273
|
+
program_candidates: T::Array[T::Hash[String, String]]
|
274
|
+
).returns(T::Array[Integer])
|
275
|
+
end
|
276
|
+
def filter_ancestors(id1, id2, common_ancestors, agg_scores, program_candidates)
|
277
|
+
common_ancestors.each_with_object([]) do |ancestor, memo|
|
278
|
+
next if @merges_performed[0].include?([id1, id2, ancestor])
|
279
|
+
next if agg_scores[ancestor] > agg_scores[id1] || agg_scores[ancestor] > agg_scores[id2]
|
280
|
+
next unless desirable_predictors_triplet?(program_candidates, ancestor, id1, id2)
|
281
|
+
|
282
|
+
memo << ancestor
|
283
|
+
end
|
284
|
+
end
|
285
|
+
|
286
|
+
sig do
|
287
|
+
params(
|
288
|
+
program_candidates: T::Array[T::Hash[String, String]],
|
289
|
+
ancestor: Integer,
|
290
|
+
id1: Integer,
|
291
|
+
id2: Integer
|
292
|
+
).returns(T::Boolean)
|
293
|
+
end
|
294
|
+
def desirable_predictors_triplet?(program_candidates, ancestor, id1, id2)
|
295
|
+
ancestor_program = program_candidates[ancestor]
|
296
|
+
id1_program = program_candidates[id1]
|
297
|
+
id2_program = program_candidates[id2]
|
298
|
+
|
299
|
+
ancestor_program.keys.any? do |pred_name|
|
300
|
+
pred_anc = ancestor_program[pred_name]
|
301
|
+
pred_id1 = id1_program[pred_name]
|
302
|
+
pred_id2 = id2_program[pred_name]
|
303
|
+
|
304
|
+
((pred_anc == pred_id1) || (pred_anc == pred_id2)) &&
|
305
|
+
pred_id1 != pred_id2
|
306
|
+
end
|
307
|
+
end
|
308
|
+
|
309
|
+
sig do
|
310
|
+
params(
|
311
|
+
program_candidates: T::Array[T::Hash[String, String]],
|
312
|
+
id1: Integer,
|
313
|
+
id2: Integer,
|
314
|
+
ancestor: Integer,
|
315
|
+
agg_scores: T::Array[Float]
|
316
|
+
).returns([T.nilable(T::Hash[String, String]), T::Array[Integer]])
|
317
|
+
end
|
318
|
+
def build_merged_program(program_candidates, id1, id2, ancestor, agg_scores)
|
319
|
+
ancestor_program = program_candidates[ancestor]
|
320
|
+
id1_program = program_candidates[id1]
|
321
|
+
id2_program = program_candidates[id2]
|
322
|
+
|
323
|
+
new_program = ancestor_program.dup
|
324
|
+
descriptors = []
|
325
|
+
|
326
|
+
ancestor_program.each_key do |pred_name|
|
327
|
+
pred_anc = ancestor_program[pred_name]
|
328
|
+
pred_id1 = id1_program[pred_name]
|
329
|
+
pred_id2 = id2_program[pred_name]
|
330
|
+
|
331
|
+
if ((pred_anc == pred_id1) || (pred_anc == pred_id2)) && pred_id1 != pred_id2
|
332
|
+
replacement_idx = pred_anc == pred_id1 ? id2 : id1
|
333
|
+
new_program[pred_name] = program_candidates[replacement_idx][pred_name]
|
334
|
+
descriptors << replacement_idx
|
335
|
+
elsif pred_anc != pred_id1 && pred_anc != pred_id2
|
336
|
+
chosen_idx = if agg_scores[id1] > agg_scores[id2]
|
337
|
+
id1
|
338
|
+
elsif agg_scores[id2] > agg_scores[id1]
|
339
|
+
id2
|
340
|
+
else
|
341
|
+
@rng.rand(2).zero? ? id1 : id2
|
342
|
+
end
|
343
|
+
new_program[pred_name] = program_candidates[chosen_idx][pred_name]
|
344
|
+
descriptors << chosen_idx
|
345
|
+
elsif pred_id1 == pred_id2
|
346
|
+
new_program[pred_name] = pred_id1
|
347
|
+
descriptors << id1
|
348
|
+
else
|
349
|
+
raise 'Unexpected predictor merge case'
|
350
|
+
end
|
351
|
+
end
|
352
|
+
|
353
|
+
[new_program, descriptors]
|
354
|
+
end
|
355
|
+
|
356
|
+
sig { params(state: GEPA::Core::State).void }
|
357
|
+
def ensure_trace_slot(state)
|
358
|
+
state.full_program_trace << {} if state.full_program_trace.empty? || state.full_program_trace.last.nil?
|
359
|
+
end
|
360
|
+
|
361
|
+
sig { params(array: T::Array[Integer], count: Integer).returns(T::Array[Integer]) }
|
362
|
+
def sample_from(array, count)
|
363
|
+
return [] if count <= 0 || array.empty?
|
364
|
+
|
365
|
+
if array.length >= count
|
366
|
+
array.sample(count, random: @rng)
|
367
|
+
else
|
368
|
+
array.dup
|
369
|
+
end
|
370
|
+
end
|
371
|
+
|
372
|
+
sig { params(array: T::Array[Integer], count: Integer).returns(T::Array[Integer]) }
|
373
|
+
def sample_with_replacement(array, count)
|
374
|
+
count.times.map { array[@rng.rand(array.length)] }
|
375
|
+
end
|
376
|
+
|
377
|
+
sig { params(options: T::Array[Integer], weights: T::Array[Float]).returns(Integer) }
|
378
|
+
def sample_with_weights(options, weights)
|
379
|
+
total = weights.sum
|
380
|
+
return options.first if total.zero?
|
381
|
+
|
382
|
+
pick = @rng.rand * total
|
383
|
+
accumulator = 0.0
|
384
|
+
options.zip(weights).each do |option, weight|
|
385
|
+
accumulator += weight
|
386
|
+
return option if pick <= accumulator
|
387
|
+
end
|
388
|
+
options.last
|
389
|
+
end
|
390
|
+
|
391
|
+
sig { params(parent_list: T::Array[T::Array[T.nilable(Integer)]], node: Integer).returns(T::Array[Integer]) }
|
392
|
+
def collect_ancestors(parent_list, node)
|
393
|
+
visited = Set.new
|
394
|
+
traverse_ancestors(parent_list, node, visited)
|
395
|
+
visited.to_a
|
396
|
+
end
|
397
|
+
|
398
|
+
sig { params(parent_list: T::Array[T::Array[T.nilable(Integer)]], node: Integer, visited: Set).void }
|
399
|
+
def traverse_ancestors(parent_list, node, visited)
|
400
|
+
parent_list[node].each do |parent|
|
401
|
+
next if parent.nil? || visited.include?(parent)
|
402
|
+
|
403
|
+
visited.add(parent)
|
404
|
+
traverse_ancestors(parent_list, parent, visited)
|
405
|
+
end
|
406
|
+
end
|
407
|
+
|
408
|
+
sig { params(candidates: T::Array[Integer]).returns([T.nilable(Integer), T.nilable(Integer)]) }
|
409
|
+
def sample_distinct_pair(candidates)
|
410
|
+
return [nil, nil] if candidates.length < 2
|
411
|
+
|
412
|
+
first = candidates[@rng.rand(candidates.length)]
|
413
|
+
second = candidates[@rng.rand(candidates.length)]
|
414
|
+
second = candidates[@rng.rand(candidates.length)] while second == first && candidates.length > 1
|
415
|
+
|
416
|
+
if first && second && second < first
|
417
|
+
[second, first]
|
418
|
+
else
|
419
|
+
[first, second]
|
420
|
+
end
|
421
|
+
end
|
422
|
+
end
|
423
|
+
end
|
424
|
+
end
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
|
5
|
+
module GEPA
|
6
|
+
module Proposer
|
7
|
+
module ReflectiveMutation
|
8
|
+
extend T::Sig
|
9
|
+
|
10
|
+
CandidateSelector = T.type_alias { T.proc.params(state: GEPA::Core::State).returns(Integer) }
|
11
|
+
|
12
|
+
ComponentSelector = T.type_alias do
|
13
|
+
T.proc.params(
|
14
|
+
state: GEPA::Core::State,
|
15
|
+
trajectories: T::Array[T.untyped],
|
16
|
+
subsample_scores: T::Array[Float],
|
17
|
+
candidate_idx: Integer,
|
18
|
+
candidate: T::Hash[String, String]
|
19
|
+
).returns(T::Array[String])
|
20
|
+
end
|
21
|
+
|
22
|
+
BatchSampler = T.type_alias do
|
23
|
+
T.proc.params(trainset_size: Integer, iteration: Integer).returns(T::Array[Integer])
|
24
|
+
end
|
25
|
+
|
26
|
+
LanguageModel = T.type_alias { T.proc.params(prompt: String).returns(String) }
|
27
|
+
|
28
|
+
class Signature < T::Struct
|
29
|
+
extend T::Sig
|
30
|
+
|
31
|
+
const :prompt_template, String
|
32
|
+
const :input_keys, T::Array[String]
|
33
|
+
const :output_keys, T::Array[String]
|
34
|
+
const :prompt_renderer, T.proc.params(arg0: T::Hash[String, T.untyped]).returns(String)
|
35
|
+
const :output_extractor, T.proc.params(arg0: String).returns(T::Hash[String, String])
|
36
|
+
|
37
|
+
sig do
|
38
|
+
params(lm: LanguageModel, input_dict: T::Hash[String, T.untyped]).returns(T::Hash[String, String])
|
39
|
+
end
|
40
|
+
def self.run(lm, input_dict)
|
41
|
+
full_prompt = prompt_renderer.call(input_dict)
|
42
|
+
output = lm.call(full_prompt).strip
|
43
|
+
output_extractor.call(output)
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
@@ -0,0 +1,188 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
|
5
|
+
require_relative '../base'
|
6
|
+
require_relative 'base'
|
7
|
+
|
8
|
+
module GEPA
|
9
|
+
module Proposer
|
10
|
+
class ReflectiveMutationProposer
|
11
|
+
extend T::Sig
|
12
|
+
include ProposeNewCandidate
|
13
|
+
|
14
|
+
sig do
|
15
|
+
params(
|
16
|
+
logger: T.untyped,
|
17
|
+
trainset: T::Array[T.untyped],
|
18
|
+
adapter: T.untyped,
|
19
|
+
candidate_selector: T.untyped,
|
20
|
+
module_selector: T.untyped,
|
21
|
+
batch_sampler: T.untyped,
|
22
|
+
perfect_score: Float,
|
23
|
+
skip_perfect_score: T::Boolean,
|
24
|
+
experiment_tracker: T.untyped,
|
25
|
+
reflection_lm: T.nilable(T.proc.params(prompt: String).returns(String)),
|
26
|
+
telemetry: T.nilable(T.untyped)
|
27
|
+
).void
|
28
|
+
end
|
29
|
+
def initialize(
|
30
|
+
logger:,
|
31
|
+
trainset:,
|
32
|
+
adapter:,
|
33
|
+
candidate_selector:,
|
34
|
+
module_selector:,
|
35
|
+
batch_sampler:,
|
36
|
+
perfect_score:,
|
37
|
+
skip_perfect_score:,
|
38
|
+
experiment_tracker:,
|
39
|
+
reflection_lm: nil,
|
40
|
+
telemetry: nil
|
41
|
+
)
|
42
|
+
@logger = logger
|
43
|
+
@trainset = trainset
|
44
|
+
@adapter = adapter
|
45
|
+
@candidate_selector = candidate_selector
|
46
|
+
@module_selector = module_selector
|
47
|
+
@batch_sampler = batch_sampler
|
48
|
+
@perfect_score = perfect_score
|
49
|
+
@skip_perfect_score = skip_perfect_score
|
50
|
+
@experiment_tracker = experiment_tracker
|
51
|
+
@reflection_lm = reflection_lm
|
52
|
+
@telemetry = telemetry || GEPA::Telemetry
|
53
|
+
end
|
54
|
+
|
55
|
+
sig { override.params(state: GEPA::Core::State).returns(T.nilable(CandidateProposal)) }
|
56
|
+
def propose(state)
|
57
|
+
iteration = state.i + 1
|
58
|
+
|
59
|
+
with_span('gepa.proposer.reflective_mutation.propose', iteration: iteration) do
|
60
|
+
proposal_for_iteration(state, iteration)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
private
|
65
|
+
|
66
|
+
def proposal_for_iteration(state, iteration)
|
67
|
+
curr_prog_id = @candidate_selector.select_candidate_idx(state)
|
68
|
+
curr_prog = state.program_candidates[curr_prog_id]
|
69
|
+
ensure_trace_slot(state)
|
70
|
+
state.full_program_trace.last[:selected_program_candidate] = curr_prog_id
|
71
|
+
|
72
|
+
@logger.log("Iteration #{iteration}: Selected program #{curr_prog_id} score: #{state.per_program_tracked_scores[curr_prog_id]}")
|
73
|
+
@experiment_tracker.log_metrics({ iteration: iteration, selected_program_candidate: curr_prog_id }, step: iteration)
|
74
|
+
|
75
|
+
subsample_ids = @batch_sampler.next_minibatch_indices(@trainset.length, iteration - 1)
|
76
|
+
state.full_program_trace.last[:subsample_ids] = subsample_ids
|
77
|
+
minibatch = subsample_ids.map { |idx| @trainset[idx] }
|
78
|
+
|
79
|
+
eval_curr = with_span('gepa.proposer.evaluate_current', iteration: iteration) do
|
80
|
+
@adapter.evaluate(minibatch, curr_prog, capture_traces: true)
|
81
|
+
end
|
82
|
+
|
83
|
+
unless eval_curr.trajectories && !eval_curr.trajectories.empty?
|
84
|
+
@logger.log("Iteration #{iteration}: No trajectories captured. Skipping.")
|
85
|
+
return nil
|
86
|
+
end
|
87
|
+
|
88
|
+
state.total_num_evals += subsample_ids.length
|
89
|
+
state.full_program_trace.last[:subsample_scores] = eval_curr.scores
|
90
|
+
|
91
|
+
if @skip_perfect_score && eval_curr.scores.all? { |score| score >= @perfect_score }
|
92
|
+
@logger.log("Iteration #{iteration}: All subsample scores perfect. Skipping.")
|
93
|
+
return nil
|
94
|
+
end
|
95
|
+
|
96
|
+
@experiment_tracker.log_metrics({ subsample_score: eval_curr.scores.sum }, step: iteration)
|
97
|
+
|
98
|
+
predictor_names = @module_selector.select_modules(
|
99
|
+
state,
|
100
|
+
eval_curr.trajectories,
|
101
|
+
eval_curr.scores,
|
102
|
+
curr_prog_id,
|
103
|
+
curr_prog
|
104
|
+
)
|
105
|
+
|
106
|
+
reflective_dataset = nil
|
107
|
+
new_texts = nil
|
108
|
+
|
109
|
+
with_span('gepa.proposer.build_reflective_dataset', iteration: iteration) do
|
110
|
+
reflective_dataset = @adapter.make_reflective_dataset(curr_prog, eval_curr, predictor_names)
|
111
|
+
end
|
112
|
+
|
113
|
+
begin
|
114
|
+
new_texts = with_span('gepa.proposer.propose_texts', iteration: iteration) do
|
115
|
+
propose_new_texts(curr_prog, reflective_dataset, predictor_names)
|
116
|
+
end
|
117
|
+
|
118
|
+
new_texts.each do |name, text|
|
119
|
+
@logger.log("Iteration #{iteration}: Proposed new text for #{name}: #{text}")
|
120
|
+
end
|
121
|
+
@experiment_tracker.log_metrics(new_texts.transform_keys { |name| "new_instruction_#{name}" }, step: iteration)
|
122
|
+
rescue StandardError => e
|
123
|
+
@logger.log("Iteration #{iteration}: Exception during reflection/proposal: #{e}")
|
124
|
+
@logger.log(e.backtrace&.join("\n"))
|
125
|
+
return nil
|
126
|
+
end
|
127
|
+
|
128
|
+
new_candidate = curr_prog.dup
|
129
|
+
new_texts.each do |name, text|
|
130
|
+
raise ArgumentError, "Missing component #{name}" unless new_candidate.key?(name)
|
131
|
+
new_candidate[name] = text
|
132
|
+
end
|
133
|
+
|
134
|
+
eval_new = with_span('gepa.proposer.evaluate_new_candidate', iteration: iteration) do
|
135
|
+
@adapter.evaluate(minibatch, new_candidate, capture_traces: false)
|
136
|
+
end
|
137
|
+
|
138
|
+
state.total_num_evals += subsample_ids.length
|
139
|
+
state.full_program_trace.last[:new_subsample_scores] = eval_new.scores
|
140
|
+
@experiment_tracker.log_metrics({ new_subsample_score: eval_new.scores.sum }, step: iteration)
|
141
|
+
|
142
|
+
CandidateProposal.new(
|
143
|
+
candidate: new_candidate,
|
144
|
+
parent_program_ids: [curr_prog_id],
|
145
|
+
subsample_indices: subsample_ids,
|
146
|
+
subsample_scores_before: eval_curr.scores,
|
147
|
+
subsample_scores_after: eval_new.scores,
|
148
|
+
metadata: { iteration: iteration }
|
149
|
+
)
|
150
|
+
end
|
151
|
+
|
152
|
+
sig do
|
153
|
+
params(
|
154
|
+
candidate: T::Hash[String, String],
|
155
|
+
reflective_dataset: T::Hash[String, T::Array[T::Hash[String, T.untyped]]],
|
156
|
+
components_to_update: T::Array[String]
|
157
|
+
).returns(T::Hash[String, String])
|
158
|
+
end
|
159
|
+
def propose_new_texts(candidate, reflective_dataset, components_to_update)
|
160
|
+
if @adapter.respond_to?(:propose_new_texts)
|
161
|
+
return @adapter.propose_new_texts(candidate, reflective_dataset, components_to_update)
|
162
|
+
end
|
163
|
+
|
164
|
+
raise ArgumentError, 'reflection_lm is required when adapter lacks propose_new_texts' unless @reflection_lm
|
165
|
+
|
166
|
+
components_to_update.each_with_object({}) do |name, acc|
|
167
|
+
signature_input = {
|
168
|
+
'current_instruction_doc' => candidate[name],
|
169
|
+
'dataset_with_feedback' => reflective_dataset.fetch(name)
|
170
|
+
}
|
171
|
+
acc[name] = GEPA::Strategies::InstructionProposalSignature.run(@reflection_lm, signature_input)['new_instruction']
|
172
|
+
end
|
173
|
+
end
|
174
|
+
|
175
|
+
sig { params(state: GEPA::Core::State).void }
|
176
|
+
def ensure_trace_slot(state)
|
177
|
+
state.full_program_trace << {} if state.full_program_trace.empty? || state.full_program_trace.last.nil?
|
178
|
+
end
|
179
|
+
|
180
|
+
sig do
|
181
|
+
params(operation: String, attrs: T::Hash[Symbol, T.untyped], block: T.proc.returns(T.untyped)).returns(T.untyped)
|
182
|
+
end
|
183
|
+
def with_span(operation, attrs = {}, &block)
|
184
|
+
@telemetry.with_span(operation, attrs, &block)
|
185
|
+
end
|
186
|
+
end
|
187
|
+
end
|
188
|
+
end
|