desiru 0.1.1 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.claude/settings.local.json +11 -0
- data/CHANGELOG.md +73 -0
- data/CLAUDE.local.md +3 -0
- data/CLAUDE.md +6 -1
- data/Gemfile.lock +1 -1
- data/README.md +7 -1
- data/desiru-development-swarm.yml +185 -0
- data/lib/desiru/core/compiler.rb +231 -0
- data/lib/desiru/core/example.rb +96 -0
- data/lib/desiru/core/prediction.rb +108 -0
- data/lib/desiru/core/trace.rb +330 -0
- data/lib/desiru/core/traceable.rb +61 -0
- data/lib/desiru/core.rb +12 -0
- data/lib/desiru/module.rb +8 -0
- data/lib/desiru/modules/best_of_n.rb +306 -0
- data/lib/desiru/modules/multi_chain_comparison.rb +72 -20
- data/lib/desiru/modules/predict.rb +7 -0
- data/lib/desiru/modules/program_of_thought.rb +227 -28
- data/lib/desiru/optimizers/base.rb +31 -1
- data/lib/desiru/optimizers/mipro_v2.rb +889 -0
- data/lib/desiru/persistence/repositories/base_repository.rb +1 -1
- data/lib/desiru/version.rb +1 -1
- data/lib/desiru.rb +10 -0
- metadata +13 -1
@@ -0,0 +1,306 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Desiru
|
4
|
+
module Modules
|
5
|
+
# BestOfN module that samples N outputs from a predictor and selects the best one
|
6
|
+
# based on configurable criteria (confidence, consistency, or external validation)
|
7
|
+
class BestOfN < Desiru::Module
|
8
|
+
SELECTION_CRITERIA = %i[confidence consistency llm_judge custom].freeze
|
9
|
+
|
10
|
+
DEFAULT_SIGNATURE = 'question: string -> answer: string'
|
11
|
+
|
12
|
+
def initialize(signature = nil, model: nil, **kwargs)
|
13
|
+
# Extract our specific options before passing to parent
|
14
|
+
@n_samples = kwargs.delete(:n_samples) || 5
|
15
|
+
@selection_criterion = validate_criterion(kwargs.delete(:selection_criterion) || :consistency)
|
16
|
+
@temperature = kwargs.delete(:temperature) || 0.7
|
17
|
+
@custom_selector = kwargs.delete(:custom_selector) # Proc that takes array of results
|
18
|
+
@base_module = kwargs.delete(:base_module) || Modules::Predict
|
19
|
+
@include_metadata = kwargs.delete(:include_metadata) || false
|
20
|
+
|
21
|
+
# Use default signature if none provided
|
22
|
+
signature ||= DEFAULT_SIGNATURE
|
23
|
+
|
24
|
+
# Pass remaining kwargs to parent (config, demos, metadata)
|
25
|
+
super
|
26
|
+
end
|
27
|
+
|
28
|
+
def forward(**inputs)
|
29
|
+
# Generate N samples
|
30
|
+
samples = generate_samples(inputs)
|
31
|
+
|
32
|
+
# Select the best sample based on criterion
|
33
|
+
best_sample = select_best(samples, inputs)
|
34
|
+
|
35
|
+
# Include metadata if requested
|
36
|
+
if @include_metadata || signature.output_fields.key?(:selection_metadata)
|
37
|
+
best_sample[:selection_metadata] = build_metadata(samples, best_sample)
|
38
|
+
end
|
39
|
+
|
40
|
+
# Clean up internal fields
|
41
|
+
best_sample.delete(:_confidence_score)
|
42
|
+
|
43
|
+
best_sample
|
44
|
+
rescue ArgumentError => e
|
45
|
+
# Re-raise ArgumentError for missing custom selector
|
46
|
+
raise e
|
47
|
+
rescue StandardError => e
|
48
|
+
Desiru.logger.error("BestOfN error: #{e.message}")
|
49
|
+
# Fallback to single sample
|
50
|
+
fallback_sample(inputs)
|
51
|
+
end
|
52
|
+
|
53
|
+
private
|
54
|
+
|
55
|
+
def validate_criterion(criterion)
|
56
|
+
unless SELECTION_CRITERIA.include?(criterion)
|
57
|
+
raise ArgumentError, "Invalid selection criterion: #{criterion}. " \
|
58
|
+
"Must be one of: #{SELECTION_CRITERIA.join(', ')}"
|
59
|
+
end
|
60
|
+
criterion
|
61
|
+
end
|
62
|
+
|
63
|
+
def generate_samples(inputs)
|
64
|
+
samples = []
|
65
|
+
|
66
|
+
# Create module instance for generation
|
67
|
+
generator = if @base_module.is_a?(Class)
|
68
|
+
@base_module.new(signature, model: model)
|
69
|
+
else
|
70
|
+
@base_module
|
71
|
+
end
|
72
|
+
|
73
|
+
@n_samples.times do |i|
|
74
|
+
# Add variation seed to inputs for diversity
|
75
|
+
sample_inputs = inputs.merge(_sample_index: i)
|
76
|
+
|
77
|
+
# Use higher temperature for diversity
|
78
|
+
original_temp = model.instance_variable_get(:@temperature) if model.respond_to?(:instance_variable_get)
|
79
|
+
|
80
|
+
begin
|
81
|
+
# Temporarily set temperature if possible
|
82
|
+
model.temperature = @temperature if model.respond_to?(:temperature=)
|
83
|
+
|
84
|
+
# Generate sample
|
85
|
+
sample = if generator.respond_to?(:forward)
|
86
|
+
generator.forward(**sample_inputs)
|
87
|
+
else
|
88
|
+
generator.call(**sample_inputs)
|
89
|
+
end
|
90
|
+
|
91
|
+
# Remove the sample index from results
|
92
|
+
sample.delete(:_sample_index)
|
93
|
+
samples << sample
|
94
|
+
ensure
|
95
|
+
# Restore original temperature
|
96
|
+
model.temperature = original_temp if model.respond_to?(:temperature=) && original_temp
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
samples
|
101
|
+
end
|
102
|
+
|
103
|
+
def select_best(samples, inputs)
|
104
|
+
case @selection_criterion
|
105
|
+
when :confidence
|
106
|
+
select_by_confidence(samples)
|
107
|
+
when :consistency
|
108
|
+
select_by_consistency(samples)
|
109
|
+
when :llm_judge
|
110
|
+
select_by_llm_judge(samples, inputs)
|
111
|
+
when :custom
|
112
|
+
select_by_custom(samples)
|
113
|
+
else
|
114
|
+
samples.first # Fallback
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
def select_by_confidence(samples)
|
119
|
+
# Ask model to rate confidence for each sample
|
120
|
+
samples_with_scores = samples.map do |sample|
|
121
|
+
confidence = calculate_confidence(sample)
|
122
|
+
sample.merge(_confidence_score: confidence)
|
123
|
+
end
|
124
|
+
|
125
|
+
# Return sample with highest confidence (keep score for metadata)
|
126
|
+
samples_with_scores.max_by { |s| s[:_confidence_score] }
|
127
|
+
end
|
128
|
+
|
129
|
+
def calculate_confidence(sample)
|
130
|
+
# Build confidence prompt
|
131
|
+
prompt = "Rate the confidence (0-100) for this response:\n\n"
|
132
|
+
|
133
|
+
sample.each do |key, value|
|
134
|
+
next if key.to_s.start_with?('_')
|
135
|
+
|
136
|
+
prompt += "#{key}: #{value}\n"
|
137
|
+
end
|
138
|
+
|
139
|
+
prompt += "\nProvide only a number between 0 and 100:"
|
140
|
+
|
141
|
+
response = model.complete(
|
142
|
+
messages: [{ role: 'user', content: prompt }],
|
143
|
+
temperature: 0.1
|
144
|
+
)
|
145
|
+
|
146
|
+
# Extract confidence score
|
147
|
+
score = response[:content].scan(/\d+/).first&.to_i || 50
|
148
|
+
score.clamp(0, 100)
|
149
|
+
end
|
150
|
+
|
151
|
+
def select_by_consistency(samples)
|
152
|
+
# Group samples by their main output values
|
153
|
+
output_groups = Hash.new { |h, k| h[k] = [] }
|
154
|
+
|
155
|
+
# Find the main output field (first non-metadata field)
|
156
|
+
main_field = signature.output_fields.keys.find do |k|
|
157
|
+
!k.to_s.start_with?('_') && k.to_s != 'selection_metadata'
|
158
|
+
end
|
159
|
+
|
160
|
+
return samples.first unless main_field
|
161
|
+
|
162
|
+
# Convert to symbol to match sample keys
|
163
|
+
field_sym = main_field.to_sym
|
164
|
+
|
165
|
+
# Group samples by their main output
|
166
|
+
samples.each do |sample|
|
167
|
+
if sample[field_sym]
|
168
|
+
key = normalize_output(sample[field_sym])
|
169
|
+
output_groups[key] << sample
|
170
|
+
end
|
171
|
+
end
|
172
|
+
|
173
|
+
# Select the most consistent group
|
174
|
+
largest_group = output_groups.values.max_by(&:length)
|
175
|
+
|
176
|
+
# From the largest group, select the "centroid" - the one most similar to others
|
177
|
+
select_centroid(largest_group)
|
178
|
+
end
|
179
|
+
|
180
|
+
def normalize_output(value)
|
181
|
+
case value
|
182
|
+
when String
|
183
|
+
value.downcase.strip.gsub(/[[:punct:]]/, '')
|
184
|
+
when Numeric
|
185
|
+
value.round(2)
|
186
|
+
when Array
|
187
|
+
value.map { |v| normalize_output(v) }.sort
|
188
|
+
when Hash
|
189
|
+
value.transform_values { |v| normalize_output(v) }
|
190
|
+
else
|
191
|
+
value.to_s
|
192
|
+
end
|
193
|
+
end
|
194
|
+
|
195
|
+
def select_centroid(group)
|
196
|
+
return group.first if group.length == 1
|
197
|
+
|
198
|
+
# For now, return the middle element (could be improved with similarity metrics)
|
199
|
+
group[group.length / 2]
|
200
|
+
end
|
201
|
+
|
202
|
+
def select_by_llm_judge(samples, inputs)
|
203
|
+
# Build judge prompt
|
204
|
+
judge_prompt = "Given the following input and multiple response options, " \
|
205
|
+
"select the best response:\n\n"
|
206
|
+
|
207
|
+
# Add original inputs
|
208
|
+
judge_prompt += "Input:\n"
|
209
|
+
inputs.each do |key, value|
|
210
|
+
judge_prompt += " #{key}: #{value}\n"
|
211
|
+
end
|
212
|
+
|
213
|
+
# Add all samples
|
214
|
+
judge_prompt += "\nResponse Options:\n"
|
215
|
+
samples.each_with_index do |sample, i|
|
216
|
+
judge_prompt += "\n--- Option #{i + 1} ---\n"
|
217
|
+
sample.each do |key, value|
|
218
|
+
next if key.to_s.start_with?('_')
|
219
|
+
|
220
|
+
judge_prompt += "#{key}: #{value}\n"
|
221
|
+
end
|
222
|
+
end
|
223
|
+
|
224
|
+
judge_prompt += "\nSelect the best option (1-#{samples.length}) and briefly explain why:"
|
225
|
+
|
226
|
+
response = model.complete(
|
227
|
+
messages: [{ role: 'user', content: judge_prompt }],
|
228
|
+
temperature: 0.1
|
229
|
+
)
|
230
|
+
|
231
|
+
# Extract selected index
|
232
|
+
selection_match = response[:content].match(/option\s*#?(\d+)/i)
|
233
|
+
selected_index = if selection_match
|
234
|
+
selection_match[1].to_i - 1
|
235
|
+
else
|
236
|
+
0
|
237
|
+
end
|
238
|
+
|
239
|
+
selected_index = selected_index.clamp(0, samples.length - 1)
|
240
|
+
samples[selected_index]
|
241
|
+
end
|
242
|
+
|
243
|
+
def select_by_custom(samples)
|
244
|
+
unless @custom_selector.respond_to?(:call)
|
245
|
+
raise ArgumentError, "Custom selector must be provided when using :custom criterion"
|
246
|
+
end
|
247
|
+
|
248
|
+
@custom_selector.call(samples) || samples.first
|
249
|
+
end
|
250
|
+
|
251
|
+
def build_metadata(samples, selected)
|
252
|
+
metadata = {
|
253
|
+
total_samples: samples.length,
|
254
|
+
selection_criterion: @selection_criterion,
|
255
|
+
temperature: @temperature
|
256
|
+
}
|
257
|
+
|
258
|
+
# Add criterion-specific metadata
|
259
|
+
case @selection_criterion
|
260
|
+
when :consistency
|
261
|
+
# Count how many samples agree with the selected one
|
262
|
+
main_field = signature.output_fields.keys.find do |k|
|
263
|
+
!k.to_s.start_with?('_') && k.to_s != 'selection_metadata'
|
264
|
+
end
|
265
|
+
|
266
|
+
if main_field
|
267
|
+
# Convert to symbol to match sample keys
|
268
|
+
field_sym = main_field.to_sym
|
269
|
+
if selected[field_sym]
|
270
|
+
selected_value = normalize_output(selected[field_sym])
|
271
|
+
agreement_count = samples.count do |s|
|
272
|
+
normalize_output(s[field_sym]) == selected_value
|
273
|
+
end
|
274
|
+
metadata[:agreement_rate] = agreement_count.to_f / samples.length
|
275
|
+
end
|
276
|
+
end
|
277
|
+
when :confidence
|
278
|
+
# Include confidence scores if available
|
279
|
+
metadata[:selected_confidence] = selected[:_confidence_score] if selected[:_confidence_score]
|
280
|
+
end
|
281
|
+
|
282
|
+
metadata
|
283
|
+
end
|
284
|
+
|
285
|
+
def fallback_sample(inputs)
|
286
|
+
# Generate a single sample as fallback
|
287
|
+
generator = if @base_module.is_a?(Class)
|
288
|
+
@base_module.new(signature, model: model)
|
289
|
+
else
|
290
|
+
@base_module
|
291
|
+
end
|
292
|
+
|
293
|
+
if generator.respond_to?(:forward)
|
294
|
+
generator.forward(**inputs)
|
295
|
+
else
|
296
|
+
generator.call(**inputs)
|
297
|
+
end
|
298
|
+
end
|
299
|
+
end
|
300
|
+
end
|
301
|
+
end
|
302
|
+
|
303
|
+
# Register in the main module namespace for convenience
|
304
|
+
module Desiru
|
305
|
+
BestOfN = Modules::BestOfN
|
306
|
+
end
|
@@ -5,14 +5,25 @@ module Desiru
|
|
5
5
|
# MultiChainComparison module that generates multiple chain-of-thought
|
6
6
|
# reasoning paths and compares them to produce the best answer
|
7
7
|
class MultiChainComparison < Desiru::Module
|
8
|
+
DEFAULT_SIGNATURE = 'question: string -> answer: string, reasoning: string'
|
9
|
+
|
8
10
|
def initialize(signature = nil, model: nil, **kwargs)
|
11
|
+
# Extract our specific options before passing to parent
|
12
|
+
@num_chains = kwargs.delete(:num_chains) || 3
|
13
|
+
@comparison_strategy = kwargs.delete(:comparison_strategy) || :vote
|
14
|
+
@temperature = kwargs.delete(:temperature) || 0.7
|
15
|
+
|
16
|
+
# Use default signature if none provided
|
17
|
+
signature ||= DEFAULT_SIGNATURE
|
18
|
+
|
19
|
+
# Pass remaining kwargs to parent (config, demos, metadata)
|
9
20
|
super
|
10
|
-
@num_chains = kwargs[:num_chains] || 3
|
11
|
-
@comparison_strategy = kwargs[:comparison_strategy] || :vote
|
12
|
-
@temperature = kwargs[:temperature] || 0.7
|
13
21
|
end
|
14
22
|
|
15
23
|
def forward(**inputs)
|
24
|
+
# Handle edge case of zero chains
|
25
|
+
return {} if @num_chains <= 0
|
26
|
+
|
16
27
|
# Generate multiple reasoning chains
|
17
28
|
chains = generate_chains(inputs)
|
18
29
|
|
@@ -25,11 +36,14 @@ module Desiru
|
|
25
36
|
when :confidence
|
26
37
|
select_by_confidence(chains)
|
27
38
|
else
|
28
|
-
chains.first # Fallback to first chain
|
39
|
+
chains.first || {} # Fallback to first chain or empty hash
|
29
40
|
end
|
30
41
|
|
42
|
+
# Ensure best_result is not nil
|
43
|
+
best_result ||= {}
|
44
|
+
|
31
45
|
# Include comparison metadata if requested
|
32
|
-
if signature.output_fields.key?(:comparison_data)
|
46
|
+
if signature.output_fields.key?('comparison_data') || signature.output_fields.key?(:comparison_data)
|
33
47
|
best_result[:comparison_data] = {
|
34
48
|
num_chains: chains.length,
|
35
49
|
strategy: @comparison_strategy,
|
@@ -77,7 +91,7 @@ module Desiru
|
|
77
91
|
if signature.output_fields.any?
|
78
92
|
prompt += "\nMake sure your answer includes:\n"
|
79
93
|
signature.output_fields.each do |name, field|
|
80
|
-
next if %
|
94
|
+
next if %w[reasoning comparison_data].include?(name.to_s)
|
81
95
|
|
82
96
|
prompt += "- #{name}: #{field.description || field.type}\n"
|
83
97
|
end
|
@@ -95,15 +109,33 @@ module Desiru
|
|
95
109
|
|
96
110
|
# Extract answer
|
97
111
|
answer_match = response.match(/ANSWER:\s*(.+)/mi)
|
98
|
-
answer_text = answer_match ? answer_match[1].strip : ""
|
99
112
|
|
100
|
-
|
101
|
-
|
102
|
-
|
113
|
+
if answer_match
|
114
|
+
answer_text = answer_match[1].strip
|
115
|
+
|
116
|
+
# Try to parse structured answer
|
117
|
+
if answer_text.include?(':') || answer_text.include?('{')
|
118
|
+
result.merge!(parse_structured_answer(answer_text))
|
119
|
+
elsif !answer_text.empty?
|
120
|
+
# Single value answer
|
121
|
+
main_output_field = signature.output_fields.keys.map(&:to_sym).find do |k|
|
122
|
+
!%i[reasoning comparison_data].include?(k)
|
123
|
+
end
|
124
|
+
result[main_output_field] = answer_text if main_output_field
|
125
|
+
end
|
103
126
|
else
|
104
|
-
#
|
105
|
-
|
106
|
-
|
127
|
+
# No ANSWER: section found - check if we should extract from reasoning
|
128
|
+
signature.output_fields.keys.map(&:to_sym).find do |k|
|
129
|
+
!%i[reasoning comparison_data].include?(k)
|
130
|
+
end
|
131
|
+
# Don't set the field if there's no clear answer
|
132
|
+
# result[main_output_field] = nil if main_output_field
|
133
|
+
end
|
134
|
+
|
135
|
+
# Parse any additional fields that might be in the response
|
136
|
+
response.scan(/(\w+):\s*([^\n]+)/).each do |key, value|
|
137
|
+
key_sym = key.downcase.to_sym
|
138
|
+
result[key_sym] = value.strip if signature.output_fields.key?(key_sym) && !result.key?(key_sym)
|
107
139
|
end
|
108
140
|
|
109
141
|
result
|
@@ -115,31 +147,42 @@ module Desiru
|
|
115
147
|
# Try to parse as key-value pairs
|
116
148
|
answer_text.scan(/(\w+):\s*([^\n,}]+)/).each do |key, value|
|
117
149
|
key_sym = key.downcase.to_sym
|
118
|
-
|
150
|
+
if signature.output_fields.key?(key_sym) || signature.output_fields.key?(key.downcase)
|
151
|
+
parsed[key_sym] =
|
152
|
+
value.strip
|
153
|
+
end
|
119
154
|
end
|
120
155
|
|
121
156
|
parsed
|
122
157
|
end
|
123
158
|
|
124
159
|
def vote_on_chains(chains)
|
160
|
+
return {} if chains.empty?
|
161
|
+
|
125
162
|
# Count votes for each unique answer
|
126
163
|
votes = Hash.new(0)
|
127
164
|
answer_to_chain = {}
|
128
165
|
|
129
166
|
chains.each do |chain|
|
130
167
|
# Get the main answer field (first non-metadata field)
|
131
|
-
answer_key = signature.output_fields.keys.find
|
168
|
+
answer_key = signature.output_fields.keys.map(&:to_sym).find do |k|
|
169
|
+
!%i[reasoning comparison_data].include?(k)
|
170
|
+
end
|
132
171
|
answer_value = chain[answer_key]
|
133
172
|
|
134
|
-
if answer_value
|
173
|
+
if answer_value && !answer_value.to_s.empty?
|
135
174
|
votes[answer_value] += 1
|
136
175
|
answer_to_chain[answer_value] ||= chain
|
137
176
|
end
|
138
177
|
end
|
139
178
|
|
140
179
|
# Return the chain with the most common answer
|
141
|
-
|
142
|
-
|
180
|
+
if votes.empty?
|
181
|
+
chains.first || {}
|
182
|
+
else
|
183
|
+
winning_answer = votes.max_by { |_, count| count }.first
|
184
|
+
answer_to_chain[winning_answer] || chains.first || {}
|
185
|
+
end
|
143
186
|
end
|
144
187
|
|
145
188
|
def llm_judge_chains(chains, original_inputs)
|
@@ -157,7 +200,9 @@ module Desiru
|
|
157
200
|
judge_prompt += "\n--- Attempt #{i + 1} ---\n"
|
158
201
|
judge_prompt += "Reasoning: #{chain[:reasoning]}\n"
|
159
202
|
|
160
|
-
answer_key = signature.output_fields.keys.find
|
203
|
+
answer_key = signature.output_fields.keys.map(&:to_sym).find do |k|
|
204
|
+
!%i[reasoning comparison_data].include?(k)
|
205
|
+
end
|
161
206
|
judge_prompt += "Answer: #{chain[answer_key]}\n" if chain[answer_key]
|
162
207
|
end
|
163
208
|
|
@@ -182,7 +227,9 @@ module Desiru
|
|
182
227
|
confidence_prompt = "Rate your confidence (0-100) in this reasoning and answer:\n"
|
183
228
|
confidence_prompt += "Reasoning: #{chain[:reasoning]}\n"
|
184
229
|
|
185
|
-
answer_key = signature.output_fields.keys.find
|
230
|
+
answer_key = signature.output_fields.keys.map(&:to_sym).find do |k|
|
231
|
+
!%i[reasoning comparison_data].include?(k)
|
232
|
+
end
|
186
233
|
confidence_prompt += "Answer: #{chain[answer_key]}\n" if chain[answer_key]
|
187
234
|
|
188
235
|
confidence_prompt += "\nRespond with just a number between 0 and 100:"
|
@@ -202,3 +249,8 @@ module Desiru
|
|
202
249
|
end
|
203
250
|
end
|
204
251
|
end
|
252
|
+
|
253
|
+
# Register in the main module namespace for convenience
|
254
|
+
module Desiru
|
255
|
+
MultiChainComparison = Modules::MultiChainComparison
|
256
|
+
end
|
@@ -4,6 +4,13 @@ module Desiru
|
|
4
4
|
module Modules
|
5
5
|
# Basic prediction module - the fundamental building block
|
6
6
|
class Predict < Module
|
7
|
+
DEFAULT_SIGNATURE = 'question: string -> answer: string'
|
8
|
+
|
9
|
+
def initialize(signature = nil, model: nil, **)
|
10
|
+
signature ||= DEFAULT_SIGNATURE
|
11
|
+
super
|
12
|
+
end
|
13
|
+
|
7
14
|
def forward(inputs)
|
8
15
|
prompt = build_prompt(inputs)
|
9
16
|
|