desiru 0.1.1 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,306 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Desiru
4
+ module Modules
5
+ # BestOfN module that samples N outputs from a predictor and selects the best one
6
+ # based on configurable criteria (confidence, consistency, or external validation)
7
+ class BestOfN < Desiru::Module
8
+ SELECTION_CRITERIA = %i[confidence consistency llm_judge custom].freeze
9
+
10
+ DEFAULT_SIGNATURE = 'question: string -> answer: string'
11
+
12
+ def initialize(signature = nil, model: nil, **kwargs)
13
+ # Extract our specific options before passing to parent
14
+ @n_samples = kwargs.delete(:n_samples) || 5
15
+ @selection_criterion = validate_criterion(kwargs.delete(:selection_criterion) || :consistency)
16
+ @temperature = kwargs.delete(:temperature) || 0.7
17
+ @custom_selector = kwargs.delete(:custom_selector) # Proc that takes array of results
18
+ @base_module = kwargs.delete(:base_module) || Modules::Predict
19
+ @include_metadata = kwargs.delete(:include_metadata) || false
20
+
21
+ # Use default signature if none provided
22
+ signature ||= DEFAULT_SIGNATURE
23
+
24
+ # Pass remaining kwargs to parent (config, demos, metadata)
25
+ super
26
+ end
27
+
28
+ def forward(**inputs)
29
+ # Generate N samples
30
+ samples = generate_samples(inputs)
31
+
32
+ # Select the best sample based on criterion
33
+ best_sample = select_best(samples, inputs)
34
+
35
+ # Include metadata if requested
36
+ if @include_metadata || signature.output_fields.key?(:selection_metadata)
37
+ best_sample[:selection_metadata] = build_metadata(samples, best_sample)
38
+ end
39
+
40
+ # Clean up internal fields
41
+ best_sample.delete(:_confidence_score)
42
+
43
+ best_sample
44
+ rescue ArgumentError => e
45
+ # Re-raise ArgumentError for missing custom selector
46
+ raise e
47
+ rescue StandardError => e
48
+ Desiru.logger.error("BestOfN error: #{e.message}")
49
+ # Fallback to single sample
50
+ fallback_sample(inputs)
51
+ end
52
+
53
+ private
54
+
55
+ def validate_criterion(criterion)
56
+ unless SELECTION_CRITERIA.include?(criterion)
57
+ raise ArgumentError, "Invalid selection criterion: #{criterion}. " \
58
+ "Must be one of: #{SELECTION_CRITERIA.join(', ')}"
59
+ end
60
+ criterion
61
+ end
62
+
63
+ def generate_samples(inputs)
64
+ samples = []
65
+
66
+ # Create module instance for generation
67
+ generator = if @base_module.is_a?(Class)
68
+ @base_module.new(signature, model: model)
69
+ else
70
+ @base_module
71
+ end
72
+
73
+ @n_samples.times do |i|
74
+ # Add variation seed to inputs for diversity
75
+ sample_inputs = inputs.merge(_sample_index: i)
76
+
77
+ # Use higher temperature for diversity
78
+ original_temp = model.instance_variable_get(:@temperature) if model.respond_to?(:instance_variable_get)
79
+
80
+ begin
81
+ # Temporarily set temperature if possible
82
+ model.temperature = @temperature if model.respond_to?(:temperature=)
83
+
84
+ # Generate sample
85
+ sample = if generator.respond_to?(:forward)
86
+ generator.forward(**sample_inputs)
87
+ else
88
+ generator.call(**sample_inputs)
89
+ end
90
+
91
+ # Remove the sample index from results
92
+ sample.delete(:_sample_index)
93
+ samples << sample
94
+ ensure
95
+ # Restore original temperature
96
+ model.temperature = original_temp if model.respond_to?(:temperature=) && original_temp
97
+ end
98
+ end
99
+
100
+ samples
101
+ end
102
+
103
+ def select_best(samples, inputs)
104
+ case @selection_criterion
105
+ when :confidence
106
+ select_by_confidence(samples)
107
+ when :consistency
108
+ select_by_consistency(samples)
109
+ when :llm_judge
110
+ select_by_llm_judge(samples, inputs)
111
+ when :custom
112
+ select_by_custom(samples)
113
+ else
114
+ samples.first # Fallback
115
+ end
116
+ end
117
+
118
+ def select_by_confidence(samples)
119
+ # Ask model to rate confidence for each sample
120
+ samples_with_scores = samples.map do |sample|
121
+ confidence = calculate_confidence(sample)
122
+ sample.merge(_confidence_score: confidence)
123
+ end
124
+
125
+ # Return sample with highest confidence (keep score for metadata)
126
+ samples_with_scores.max_by { |s| s[:_confidence_score] }
127
+ end
128
+
129
+ def calculate_confidence(sample)
130
+ # Build confidence prompt
131
+ prompt = "Rate the confidence (0-100) for this response:\n\n"
132
+
133
+ sample.each do |key, value|
134
+ next if key.to_s.start_with?('_')
135
+
136
+ prompt += "#{key}: #{value}\n"
137
+ end
138
+
139
+ prompt += "\nProvide only a number between 0 and 100:"
140
+
141
+ response = model.complete(
142
+ messages: [{ role: 'user', content: prompt }],
143
+ temperature: 0.1
144
+ )
145
+
146
+ # Extract confidence score
147
+ score = response[:content].scan(/\d+/).first&.to_i || 50
148
+ score.clamp(0, 100)
149
+ end
150
+
151
+ def select_by_consistency(samples)
152
+ # Group samples by their main output values
153
+ output_groups = Hash.new { |h, k| h[k] = [] }
154
+
155
+ # Find the main output field (first non-metadata field)
156
+ main_field = signature.output_fields.keys.find do |k|
157
+ !k.to_s.start_with?('_') && k.to_s != 'selection_metadata'
158
+ end
159
+
160
+ return samples.first unless main_field
161
+
162
+ # Convert to symbol to match sample keys
163
+ field_sym = main_field.to_sym
164
+
165
+ # Group samples by their main output
166
+ samples.each do |sample|
167
+ if sample[field_sym]
168
+ key = normalize_output(sample[field_sym])
169
+ output_groups[key] << sample
170
+ end
171
+ end
172
+
173
+ # Select the most consistent group
174
+ largest_group = output_groups.values.max_by(&:length)
175
+
176
+ # From the largest group, select the "centroid" - the one most similar to others
177
+ select_centroid(largest_group)
178
+ end
179
+
180
+ def normalize_output(value)
181
+ case value
182
+ when String
183
+ value.downcase.strip.gsub(/[[:punct:]]/, '')
184
+ when Numeric
185
+ value.round(2)
186
+ when Array
187
+ value.map { |v| normalize_output(v) }.sort
188
+ when Hash
189
+ value.transform_values { |v| normalize_output(v) }
190
+ else
191
+ value.to_s
192
+ end
193
+ end
194
+
195
+ def select_centroid(group)
196
+ return group.first if group.length == 1
197
+
198
+ # For now, return the middle element (could be improved with similarity metrics)
199
+ group[group.length / 2]
200
+ end
201
+
202
+ def select_by_llm_judge(samples, inputs)
203
+ # Build judge prompt
204
+ judge_prompt = "Given the following input and multiple response options, " \
205
+ "select the best response:\n\n"
206
+
207
+ # Add original inputs
208
+ judge_prompt += "Input:\n"
209
+ inputs.each do |key, value|
210
+ judge_prompt += " #{key}: #{value}\n"
211
+ end
212
+
213
+ # Add all samples
214
+ judge_prompt += "\nResponse Options:\n"
215
+ samples.each_with_index do |sample, i|
216
+ judge_prompt += "\n--- Option #{i + 1} ---\n"
217
+ sample.each do |key, value|
218
+ next if key.to_s.start_with?('_')
219
+
220
+ judge_prompt += "#{key}: #{value}\n"
221
+ end
222
+ end
223
+
224
+ judge_prompt += "\nSelect the best option (1-#{samples.length}) and briefly explain why:"
225
+
226
+ response = model.complete(
227
+ messages: [{ role: 'user', content: judge_prompt }],
228
+ temperature: 0.1
229
+ )
230
+
231
+ # Extract selected index
232
+ selection_match = response[:content].match(/option\s*#?(\d+)/i)
233
+ selected_index = if selection_match
234
+ selection_match[1].to_i - 1
235
+ else
236
+ 0
237
+ end
238
+
239
+ selected_index = selected_index.clamp(0, samples.length - 1)
240
+ samples[selected_index]
241
+ end
242
+
243
+ def select_by_custom(samples)
244
+ unless @custom_selector.respond_to?(:call)
245
+ raise ArgumentError, "Custom selector must be provided when using :custom criterion"
246
+ end
247
+
248
+ @custom_selector.call(samples) || samples.first
249
+ end
250
+
251
+ def build_metadata(samples, selected)
252
+ metadata = {
253
+ total_samples: samples.length,
254
+ selection_criterion: @selection_criterion,
255
+ temperature: @temperature
256
+ }
257
+
258
+ # Add criterion-specific metadata
259
+ case @selection_criterion
260
+ when :consistency
261
+ # Count how many samples agree with the selected one
262
+ main_field = signature.output_fields.keys.find do |k|
263
+ !k.to_s.start_with?('_') && k.to_s != 'selection_metadata'
264
+ end
265
+
266
+ if main_field
267
+ # Convert to symbol to match sample keys
268
+ field_sym = main_field.to_sym
269
+ if selected[field_sym]
270
+ selected_value = normalize_output(selected[field_sym])
271
+ agreement_count = samples.count do |s|
272
+ normalize_output(s[field_sym]) == selected_value
273
+ end
274
+ metadata[:agreement_rate] = agreement_count.to_f / samples.length
275
+ end
276
+ end
277
+ when :confidence
278
+ # Include confidence scores if available
279
+ metadata[:selected_confidence] = selected[:_confidence_score] if selected[:_confidence_score]
280
+ end
281
+
282
+ metadata
283
+ end
284
+
285
+ def fallback_sample(inputs)
286
+ # Generate a single sample as fallback
287
+ generator = if @base_module.is_a?(Class)
288
+ @base_module.new(signature, model: model)
289
+ else
290
+ @base_module
291
+ end
292
+
293
+ if generator.respond_to?(:forward)
294
+ generator.forward(**inputs)
295
+ else
296
+ generator.call(**inputs)
297
+ end
298
+ end
299
+ end
300
+ end
301
+ end
302
+
303
+ # Register in the main module namespace for convenience
304
+ module Desiru
305
+ BestOfN = Modules::BestOfN
306
+ end
@@ -5,14 +5,25 @@ module Desiru
5
5
  # MultiChainComparison module that generates multiple chain-of-thought
6
6
  # reasoning paths and compares them to produce the best answer
7
7
  class MultiChainComparison < Desiru::Module
8
+ DEFAULT_SIGNATURE = 'question: string -> answer: string, reasoning: string'
9
+
8
10
  def initialize(signature = nil, model: nil, **kwargs)
11
+ # Extract our specific options before passing to parent
12
+ @num_chains = kwargs.delete(:num_chains) || 3
13
+ @comparison_strategy = kwargs.delete(:comparison_strategy) || :vote
14
+ @temperature = kwargs.delete(:temperature) || 0.7
15
+
16
+ # Use default signature if none provided
17
+ signature ||= DEFAULT_SIGNATURE
18
+
19
+ # Pass remaining kwargs to parent (config, demos, metadata)
9
20
  super
10
- @num_chains = kwargs[:num_chains] || 3
11
- @comparison_strategy = kwargs[:comparison_strategy] || :vote
12
- @temperature = kwargs[:temperature] || 0.7
13
21
  end
14
22
 
15
23
  def forward(**inputs)
24
+ # Handle edge case of zero chains
25
+ return {} if @num_chains <= 0
26
+
16
27
  # Generate multiple reasoning chains
17
28
  chains = generate_chains(inputs)
18
29
 
@@ -25,11 +36,14 @@ module Desiru
25
36
  when :confidence
26
37
  select_by_confidence(chains)
27
38
  else
28
- chains.first # Fallback to first chain
39
+ chains.first || {} # Fallback to first chain or empty hash
29
40
  end
30
41
 
42
+ # Ensure best_result is not nil
43
+ best_result ||= {}
44
+
31
45
  # Include comparison metadata if requested
32
- if signature.output_fields.key?(:comparison_data)
46
+ if signature.output_fields.key?('comparison_data') || signature.output_fields.key?(:comparison_data)
33
47
  best_result[:comparison_data] = {
34
48
  num_chains: chains.length,
35
49
  strategy: @comparison_strategy,
@@ -77,7 +91,7 @@ module Desiru
77
91
  if signature.output_fields.any?
78
92
  prompt += "\nMake sure your answer includes:\n"
79
93
  signature.output_fields.each do |name, field|
80
- next if %i[reasoning comparison_data].include?(name)
94
+ next if %w[reasoning comparison_data].include?(name.to_s)
81
95
 
82
96
  prompt += "- #{name}: #{field.description || field.type}\n"
83
97
  end
@@ -95,15 +109,33 @@ module Desiru
95
109
 
96
110
  # Extract answer
97
111
  answer_match = response.match(/ANSWER:\s*(.+)/mi)
98
- answer_text = answer_match ? answer_match[1].strip : ""
99
112
 
100
- # Try to parse structured answer
101
- if answer_text.include?(':') || answer_text.include?('{')
102
- result.merge!(parse_structured_answer(answer_text))
113
+ if answer_match
114
+ answer_text = answer_match[1].strip
115
+
116
+ # Try to parse structured answer
117
+ if answer_text.include?(':') || answer_text.include?('{')
118
+ result.merge!(parse_structured_answer(answer_text))
119
+ elsif !answer_text.empty?
120
+ # Single value answer
121
+ main_output_field = signature.output_fields.keys.map(&:to_sym).find do |k|
122
+ !%i[reasoning comparison_data].include?(k)
123
+ end
124
+ result[main_output_field] = answer_text if main_output_field
125
+ end
103
126
  else
104
- # Single value answer
105
- main_output_field = signature.output_fields.keys.find { |k| !%i[reasoning comparison_data].include?(k) }
106
- result[main_output_field] = answer_text if main_output_field
127
+ # No ANSWER: section found - check if we should extract from reasoning
128
+ signature.output_fields.keys.map(&:to_sym).find do |k|
129
+ !%i[reasoning comparison_data].include?(k)
130
+ end
131
+ # Don't set the field if there's no clear answer
132
+ # result[main_output_field] = nil if main_output_field
133
+ end
134
+
135
+ # Parse any additional fields that might be in the response
136
+ response.scan(/(\w+):\s*([^\n]+)/).each do |key, value|
137
+ key_sym = key.downcase.to_sym
138
+ result[key_sym] = value.strip if signature.output_fields.key?(key_sym) && !result.key?(key_sym)
107
139
  end
108
140
 
109
141
  result
@@ -115,31 +147,42 @@ module Desiru
115
147
  # Try to parse as key-value pairs
116
148
  answer_text.scan(/(\w+):\s*([^\n,}]+)/).each do |key, value|
117
149
  key_sym = key.downcase.to_sym
118
- parsed[key_sym] = value.strip if signature.output_fields.key?(key_sym)
150
+ if signature.output_fields.key?(key_sym) || signature.output_fields.key?(key.downcase)
151
+ parsed[key_sym] =
152
+ value.strip
153
+ end
119
154
  end
120
155
 
121
156
  parsed
122
157
  end
123
158
 
124
159
  def vote_on_chains(chains)
160
+ return {} if chains.empty?
161
+
125
162
  # Count votes for each unique answer
126
163
  votes = Hash.new(0)
127
164
  answer_to_chain = {}
128
165
 
129
166
  chains.each do |chain|
130
167
  # Get the main answer field (first non-metadata field)
131
- answer_key = signature.output_fields.keys.find { |k| !%i[reasoning comparison_data].include?(k) }
168
+ answer_key = signature.output_fields.keys.map(&:to_sym).find do |k|
169
+ !%i[reasoning comparison_data].include?(k)
170
+ end
132
171
  answer_value = chain[answer_key]
133
172
 
134
- if answer_value
173
+ if answer_value && !answer_value.to_s.empty?
135
174
  votes[answer_value] += 1
136
175
  answer_to_chain[answer_value] ||= chain
137
176
  end
138
177
  end
139
178
 
140
179
  # Return the chain with the most common answer
141
- winning_answer = votes.max_by { |_, count| count }&.first
142
- answer_to_chain[winning_answer] || chains.first
180
+ if votes.empty?
181
+ chains.first || {}
182
+ else
183
+ winning_answer = votes.max_by { |_, count| count }.first
184
+ answer_to_chain[winning_answer] || chains.first || {}
185
+ end
143
186
  end
144
187
 
145
188
  def llm_judge_chains(chains, original_inputs)
@@ -157,7 +200,9 @@ module Desiru
157
200
  judge_prompt += "\n--- Attempt #{i + 1} ---\n"
158
201
  judge_prompt += "Reasoning: #{chain[:reasoning]}\n"
159
202
 
160
- answer_key = signature.output_fields.keys.find { |k| !%i[reasoning comparison_data].include?(k) }
203
+ answer_key = signature.output_fields.keys.map(&:to_sym).find do |k|
204
+ !%i[reasoning comparison_data].include?(k)
205
+ end
161
206
  judge_prompt += "Answer: #{chain[answer_key]}\n" if chain[answer_key]
162
207
  end
163
208
 
@@ -182,7 +227,9 @@ module Desiru
182
227
  confidence_prompt = "Rate your confidence (0-100) in this reasoning and answer:\n"
183
228
  confidence_prompt += "Reasoning: #{chain[:reasoning]}\n"
184
229
 
185
- answer_key = signature.output_fields.keys.find { |k| !%i[reasoning comparison_data].include?(k) }
230
+ answer_key = signature.output_fields.keys.map(&:to_sym).find do |k|
231
+ !%i[reasoning comparison_data].include?(k)
232
+ end
186
233
  confidence_prompt += "Answer: #{chain[answer_key]}\n" if chain[answer_key]
187
234
 
188
235
  confidence_prompt += "\nRespond with just a number between 0 and 100:"
@@ -202,3 +249,8 @@ module Desiru
202
249
  end
203
250
  end
204
251
  end
252
+
253
+ # Register in the main module namespace for convenience
254
+ module Desiru
255
+ MultiChainComparison = Modules::MultiChainComparison
256
+ end
@@ -4,6 +4,13 @@ module Desiru
4
4
  module Modules
5
5
  # Basic prediction module - the fundamental building block
6
6
  class Predict < Module
7
+ DEFAULT_SIGNATURE = 'question: string -> answer: string'
8
+
9
+ def initialize(signature = nil, model: nil, **)
10
+ signature ||= DEFAULT_SIGNATURE
11
+ super
12
+ end
13
+
7
14
  def forward(inputs)
8
15
  prompt = build_prompt(inputs)
9
16