desiru 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.claude/settings.local.json +11 -0
- data/.env.example +34 -0
- data/.rubocop.yml +7 -4
- data/.ruby-version +1 -0
- data/CHANGELOG.md +73 -0
- data/CLAUDE.local.md +3 -0
- data/CLAUDE.md +10 -1
- data/Gemfile +21 -2
- data/Gemfile.lock +88 -13
- data/README.md +301 -2
- data/Rakefile +1 -0
- data/db/migrations/001_create_initial_tables.rb +96 -0
- data/db/migrations/002_create_job_results.rb +39 -0
- data/desiru-development-swarm.yml +185 -0
- data/desiru.db +0 -0
- data/desiru.gemspec +2 -5
- data/docs/background_processing_roadmap.md +87 -0
- data/docs/job_scheduling.md +167 -0
- data/dspy-analysis-swarm.yml +60 -0
- data/dspy-feature-analysis.md +121 -0
- data/examples/README.md +69 -0
- data/examples/api_with_persistence.rb +122 -0
- data/examples/assertions_example.rb +232 -0
- data/examples/async_processing.rb +2 -0
- data/examples/few_shot_learning.rb +1 -2
- data/examples/graphql_api.rb +4 -2
- data/examples/graphql_integration.rb +3 -3
- data/examples/graphql_optimization_summary.md +143 -0
- data/examples/graphql_performance_benchmark.rb +247 -0
- data/examples/persistence_example.rb +102 -0
- data/examples/react_agent.rb +203 -0
- data/examples/rest_api.rb +173 -0
- data/examples/rest_api_advanced.rb +333 -0
- data/examples/scheduled_job_example.rb +116 -0
- data/examples/simple_qa.rb +1 -2
- data/examples/sinatra_api.rb +109 -0
- data/examples/typed_signatures.rb +1 -2
- data/graphql_optimization_summary.md +53 -0
- data/lib/desiru/api/grape_integration.rb +284 -0
- data/lib/desiru/api/persistence_middleware.rb +148 -0
- data/lib/desiru/api/sinatra_integration.rb +217 -0
- data/lib/desiru/api.rb +42 -0
- data/lib/desiru/assertions.rb +74 -0
- data/lib/desiru/async_status.rb +65 -0
- data/lib/desiru/cache.rb +1 -1
- data/lib/desiru/configuration.rb +2 -1
- data/lib/desiru/core/compiler.rb +231 -0
- data/lib/desiru/core/example.rb +96 -0
- data/lib/desiru/core/prediction.rb +108 -0
- data/lib/desiru/core/trace.rb +330 -0
- data/lib/desiru/core/traceable.rb +61 -0
- data/lib/desiru/core.rb +12 -0
- data/lib/desiru/errors.rb +160 -0
- data/lib/desiru/field.rb +17 -14
- data/lib/desiru/graphql/batch_loader.rb +85 -0
- data/lib/desiru/graphql/data_loader.rb +242 -75
- data/lib/desiru/graphql/enum_builder.rb +75 -0
- data/lib/desiru/graphql/executor.rb +37 -4
- data/lib/desiru/graphql/schema_generator.rb +62 -158
- data/lib/desiru/graphql/type_builder.rb +138 -0
- data/lib/desiru/graphql/type_cache_warmer.rb +91 -0
- data/lib/desiru/jobs/async_predict.rb +1 -1
- data/lib/desiru/jobs/base.rb +67 -0
- data/lib/desiru/jobs/batch_processor.rb +6 -6
- data/lib/desiru/jobs/retriable.rb +119 -0
- data/lib/desiru/jobs/retry_strategies.rb +169 -0
- data/lib/desiru/jobs/scheduler.rb +219 -0
- data/lib/desiru/jobs/webhook_notifier.rb +242 -0
- data/lib/desiru/models/anthropic.rb +164 -0
- data/lib/desiru/models/base.rb +37 -3
- data/lib/desiru/models/open_ai.rb +151 -0
- data/lib/desiru/models/open_router.rb +161 -0
- data/lib/desiru/module.rb +67 -9
- data/lib/desiru/modules/best_of_n.rb +306 -0
- data/lib/desiru/modules/chain_of_thought.rb +3 -3
- data/lib/desiru/modules/majority.rb +51 -0
- data/lib/desiru/modules/multi_chain_comparison.rb +256 -0
- data/lib/desiru/modules/predict.rb +15 -1
- data/lib/desiru/modules/program_of_thought.rb +338 -0
- data/lib/desiru/modules/react.rb +273 -0
- data/lib/desiru/modules/retrieve.rb +4 -2
- data/lib/desiru/optimizers/base.rb +32 -4
- data/lib/desiru/optimizers/bootstrap_few_shot.rb +2 -2
- data/lib/desiru/optimizers/copro.rb +268 -0
- data/lib/desiru/optimizers/knn_few_shot.rb +185 -0
- data/lib/desiru/optimizers/mipro_v2.rb +889 -0
- data/lib/desiru/persistence/database.rb +71 -0
- data/lib/desiru/persistence/models/api_request.rb +38 -0
- data/lib/desiru/persistence/models/job_result.rb +138 -0
- data/lib/desiru/persistence/models/module_execution.rb +37 -0
- data/lib/desiru/persistence/models/optimization_result.rb +28 -0
- data/lib/desiru/persistence/models/training_example.rb +25 -0
- data/lib/desiru/persistence/models.rb +11 -0
- data/lib/desiru/persistence/repositories/api_request_repository.rb +98 -0
- data/lib/desiru/persistence/repositories/base_repository.rb +77 -0
- data/lib/desiru/persistence/repositories/job_result_repository.rb +116 -0
- data/lib/desiru/persistence/repositories/module_execution_repository.rb +85 -0
- data/lib/desiru/persistence/repositories/optimization_result_repository.rb +67 -0
- data/lib/desiru/persistence/repositories/training_example_repository.rb +102 -0
- data/lib/desiru/persistence/repository.rb +29 -0
- data/lib/desiru/persistence/setup.rb +77 -0
- data/lib/desiru/persistence.rb +49 -0
- data/lib/desiru/registry.rb +3 -5
- data/lib/desiru/signature.rb +91 -24
- data/lib/desiru/version.rb +1 -1
- data/lib/desiru.rb +33 -8
- data/missing-features-analysis.md +192 -0
- metadata +75 -45
- data/lib/desiru/models/raix_adapter.rb +0 -210
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'open_router'
|
|
4
|
+
|
|
5
|
+
module Desiru
|
|
6
|
+
module Models
|
|
7
|
+
# OpenRouter model adapter - provides access to multiple models through a single API
|
|
8
|
+
class OpenRouter < Base
|
|
9
|
+
DEFAULT_MODEL = 'anthropic/claude-3-haiku'
|
|
10
|
+
|
|
11
|
+
def initialize(config = {})
|
|
12
|
+
super
|
|
13
|
+
@api_key = config[:api_key] || ENV.fetch('OPENROUTER_API_KEY', nil)
|
|
14
|
+
raise ArgumentError, 'OpenRouter API key is required' unless @api_key
|
|
15
|
+
|
|
16
|
+
# Configure OpenRouter client
|
|
17
|
+
::OpenRouter.configure do |c|
|
|
18
|
+
c.access_token = @api_key
|
|
19
|
+
c.site_name = config[:site_name] || 'Desiru'
|
|
20
|
+
c.site_url = config[:site_url] || 'https://github.com/obie/desiru'
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
@client = ::OpenRouter::Client.new
|
|
24
|
+
@models_cache = nil
|
|
25
|
+
@models_fetched_at = nil
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def models
|
|
29
|
+
# Cache models for 1 hour
|
|
30
|
+
fetch_models if @models_cache.nil? || @models_fetched_at.nil? || (Time.now - @models_fetched_at) > 3600
|
|
31
|
+
@models_cache
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
protected
|
|
35
|
+
|
|
36
|
+
def perform_completion(messages, options)
|
|
37
|
+
model = options[:model] || @config[:model] || DEFAULT_MODEL
|
|
38
|
+
temperature = options[:temperature] || @config[:temperature] || 0.7
|
|
39
|
+
max_tokens = options[:max_tokens] || @config[:max_tokens] || 4096
|
|
40
|
+
|
|
41
|
+
# Prepare request parameters
|
|
42
|
+
params = {
|
|
43
|
+
model: model,
|
|
44
|
+
messages: messages,
|
|
45
|
+
temperature: temperature,
|
|
46
|
+
max_tokens: max_tokens
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
# Add provider-specific options if needed
|
|
50
|
+
params[:provider] = options[:provider] if options[:provider]
|
|
51
|
+
|
|
52
|
+
# Add response format if specified
|
|
53
|
+
params[:response_format] = options[:response_format] if options[:response_format]
|
|
54
|
+
|
|
55
|
+
# Add tools if provided (for models that support function calling)
|
|
56
|
+
if options[:tools]
|
|
57
|
+
params[:tools] = options[:tools]
|
|
58
|
+
params[:tool_choice] = options[:tool_choice] if options[:tool_choice]
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
# Make API call
|
|
62
|
+
response = @client.complete(params)
|
|
63
|
+
|
|
64
|
+
# Format response
|
|
65
|
+
format_response(response, model)
|
|
66
|
+
rescue StandardError => e
|
|
67
|
+
handle_api_error(e)
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
def stream_complete(prompt, **options, &block)
|
|
71
|
+
messages = prepare_messages(prompt, options[:messages])
|
|
72
|
+
model = options[:model] || @config[:model] || DEFAULT_MODEL
|
|
73
|
+
temperature = options[:temperature] || @config[:temperature] || 0.7
|
|
74
|
+
max_tokens = options[:max_tokens] || @config[:max_tokens] || 4096
|
|
75
|
+
|
|
76
|
+
# Prepare streaming request
|
|
77
|
+
params = {
|
|
78
|
+
model: model,
|
|
79
|
+
messages: messages,
|
|
80
|
+
temperature: temperature,
|
|
81
|
+
max_tokens: max_tokens,
|
|
82
|
+
stream: true
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
# Stream response
|
|
86
|
+
@client.complete(params) do |chunk|
|
|
87
|
+
if chunk.dig('choices', 0, 'delta', 'content')
|
|
88
|
+
content = chunk.dig('choices', 0, 'delta', 'content')
|
|
89
|
+
block.call(content) if block_given?
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
rescue StandardError => e
|
|
93
|
+
handle_api_error(e)
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
private
|
|
97
|
+
|
|
98
|
+
def fetch_models
|
|
99
|
+
# OpenRouter provides models at https://openrouter.ai/api/v1/models
|
|
100
|
+
response = @client.models
|
|
101
|
+
|
|
102
|
+
@models_cache = {}
|
|
103
|
+
response['data'].each do |model|
|
|
104
|
+
@models_cache[model['id']] = {
|
|
105
|
+
name: model['name'] || model['id'],
|
|
106
|
+
context_length: model['context_length'],
|
|
107
|
+
pricing: model['pricing'],
|
|
108
|
+
top_provider: model['top_provider']
|
|
109
|
+
}
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
@models_fetched_at = Time.now
|
|
113
|
+
@models_cache
|
|
114
|
+
rescue StandardError => e
|
|
115
|
+
Desiru.logger.warn("Failed to fetch OpenRouter models: #{e.message}")
|
|
116
|
+
# Fallback to commonly used models
|
|
117
|
+
@models_cache = {
|
|
118
|
+
'anthropic/claude-3-haiku' => { name: 'Claude 3 Haiku' },
|
|
119
|
+
'anthropic/claude-3-sonnet' => { name: 'Claude 3 Sonnet' },
|
|
120
|
+
'openai/gpt-4o-mini' => { name: 'GPT-4o Mini' },
|
|
121
|
+
'openai/gpt-4o' => { name: 'GPT-4o' },
|
|
122
|
+
'google/gemini-pro' => { name: 'Gemini Pro' }
|
|
123
|
+
}
|
|
124
|
+
@models_fetched_at = Time.now
|
|
125
|
+
@models_cache
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
def format_response(response, model)
|
|
129
|
+
# OpenRouter uses OpenAI-compatible response format
|
|
130
|
+
content = response.dig('choices', 0, 'message', 'content') || ''
|
|
131
|
+
usage = response['usage'] || {}
|
|
132
|
+
|
|
133
|
+
{
|
|
134
|
+
content: content,
|
|
135
|
+
raw: response,
|
|
136
|
+
model: model,
|
|
137
|
+
usage: {
|
|
138
|
+
prompt_tokens: usage['prompt_tokens'] || 0,
|
|
139
|
+
completion_tokens: usage['completion_tokens'] || 0,
|
|
140
|
+
total_tokens: usage['total_tokens'] || 0
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
def handle_api_error(error)
|
|
146
|
+
case error
|
|
147
|
+
when ::Faraday::UnauthorizedError
|
|
148
|
+
raise AuthenticationError, 'Invalid OpenRouter API key'
|
|
149
|
+
when ::Faraday::BadRequestError
|
|
150
|
+
raise InvalidRequestError, "Invalid request: #{error.message}"
|
|
151
|
+
when ::Faraday::TooManyRequestsError
|
|
152
|
+
raise RateLimitError, 'OpenRouter API rate limit exceeded'
|
|
153
|
+
when ::Faraday::PaymentRequiredError
|
|
154
|
+
raise APIError, 'OpenRouter payment required - check your account balance'
|
|
155
|
+
else
|
|
156
|
+
raise APIError, "OpenRouter API error: #{error.message}"
|
|
157
|
+
end
|
|
158
|
+
end
|
|
159
|
+
end
|
|
160
|
+
end
|
|
161
|
+
end
|
data/lib/desiru/module.rb
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
require_relative 'async_capable'
|
|
4
|
+
require_relative 'assertions'
|
|
5
|
+
require_relative 'core/traceable'
|
|
4
6
|
|
|
5
7
|
module Desiru
|
|
6
8
|
# Base class for all Desiru modules
|
|
7
9
|
# Implements the core module pattern with service-oriented design
|
|
8
10
|
class Module
|
|
9
11
|
extend Forwardable
|
|
10
|
-
|
|
12
|
+
include AsyncCapable
|
|
13
|
+
prepend Core::Traceable
|
|
11
14
|
|
|
12
15
|
attr_reader :signature, :model, :config, :demos, :metadata
|
|
13
16
|
|
|
@@ -40,23 +43,23 @@ module Desiru
|
|
|
40
43
|
|
|
41
44
|
begin
|
|
42
45
|
# Validate inputs first, then coerce
|
|
43
|
-
signature.
|
|
46
|
+
signature.valid_inputs?(inputs)
|
|
44
47
|
coerced_inputs = signature.coerce_inputs(inputs)
|
|
45
48
|
|
|
46
49
|
# Execute the module logic
|
|
47
50
|
result = forward(**coerced_inputs)
|
|
48
51
|
|
|
49
52
|
# Validate outputs first, then coerce
|
|
50
|
-
signature.
|
|
53
|
+
signature.valid_outputs?(result)
|
|
51
54
|
coerced_outputs = signature.coerce_outputs(result)
|
|
52
55
|
|
|
53
56
|
# Return result object
|
|
54
57
|
ModuleResult.new(coerced_outputs, metadata: execution_metadata)
|
|
55
58
|
rescue StandardError => e
|
|
56
|
-
if
|
|
59
|
+
if should_retry?(e)
|
|
57
60
|
@retry_count += 1
|
|
58
|
-
|
|
59
|
-
sleep(
|
|
61
|
+
log_retry(e)
|
|
62
|
+
sleep(retry_delay_for(e))
|
|
60
63
|
retry
|
|
61
64
|
else
|
|
62
65
|
handle_error(e)
|
|
@@ -110,6 +113,44 @@ module Desiru
|
|
|
110
113
|
|
|
111
114
|
private
|
|
112
115
|
|
|
116
|
+
def should_retry?(error)
|
|
117
|
+
return false unless config[:retry_on_failure]
|
|
118
|
+
|
|
119
|
+
# Handle assertion errors specifically
|
|
120
|
+
return error.retriable? && @retry_count < max_retries_for(error) if error.is_a?(Assertions::AssertionError)
|
|
121
|
+
|
|
122
|
+
# Default retry logic for other errors
|
|
123
|
+
@retry_count < Desiru.configuration.max_retries
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
def max_retries_for(error)
|
|
127
|
+
if error.is_a?(Assertions::AssertionError)
|
|
128
|
+
Assertions.configuration.max_assertion_retries
|
|
129
|
+
else
|
|
130
|
+
Desiru.configuration.max_retries
|
|
131
|
+
end
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
def retry_delay_for(error)
|
|
135
|
+
if error.is_a?(Assertions::AssertionError)
|
|
136
|
+
Assertions.configuration.assertion_retry_delay
|
|
137
|
+
else
|
|
138
|
+
Desiru.configuration.retry_delay
|
|
139
|
+
end
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
def log_retry(error)
|
|
143
|
+
if error.is_a?(Assertions::AssertionError)
|
|
144
|
+
Desiru.configuration.logger&.warn(
|
|
145
|
+
"[ASSERTION RETRY] #{error.message} (attempt #{@retry_count}/#{max_retries_for(error)})"
|
|
146
|
+
)
|
|
147
|
+
else
|
|
148
|
+
Desiru.configuration.logger&.warn(
|
|
149
|
+
"Retrying module execution (attempt #{@retry_count}/#{Desiru.configuration.max_retries})"
|
|
150
|
+
)
|
|
151
|
+
end
|
|
152
|
+
end
|
|
153
|
+
|
|
113
154
|
def validate_model!
|
|
114
155
|
return if model.nil? # Will use default
|
|
115
156
|
|
|
@@ -133,8 +174,19 @@ module Desiru
|
|
|
133
174
|
end
|
|
134
175
|
|
|
135
176
|
def handle_error(error)
|
|
136
|
-
|
|
137
|
-
|
|
177
|
+
if error.is_a?(Assertions::AssertionError)
|
|
178
|
+
# Update the assertion error with module context
|
|
179
|
+
error.instance_variable_set(:@module_name, self.class.name)
|
|
180
|
+
error.instance_variable_set(:@retry_count, @retry_count)
|
|
181
|
+
|
|
182
|
+
Desiru.configuration.logger&.error(
|
|
183
|
+
"[ASSERTION FAILED] #{error.message} in #{self.class.name} after #{@retry_count} retries"
|
|
184
|
+
)
|
|
185
|
+
raise error
|
|
186
|
+
else
|
|
187
|
+
Desiru.configuration.logger&.error("Module execution failed: #{error.message}")
|
|
188
|
+
raise ModuleError, "Module execution failed: #{error.message}"
|
|
189
|
+
end
|
|
138
190
|
end
|
|
139
191
|
end
|
|
140
192
|
|
|
@@ -166,6 +218,12 @@ module Desiru
|
|
|
166
218
|
end
|
|
167
219
|
end
|
|
168
220
|
|
|
221
|
+
def key?(key)
|
|
222
|
+
@data.key?(key.to_sym) || @data.key?(key.to_s)
|
|
223
|
+
end
|
|
224
|
+
|
|
225
|
+
alias has_key? key?
|
|
226
|
+
|
|
169
227
|
def method_missing(method_name, *args, &)
|
|
170
228
|
method_str = method_name.to_s
|
|
171
229
|
if method_str.end_with?('?')
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Desiru
|
|
4
|
+
module Modules
|
|
5
|
+
# BestOfN module that samples N outputs from a predictor and selects the best one
|
|
6
|
+
# based on configurable criteria (confidence, consistency, or external validation)
|
|
7
|
+
class BestOfN < Desiru::Module
|
|
8
|
+
SELECTION_CRITERIA = %i[confidence consistency llm_judge custom].freeze
|
|
9
|
+
|
|
10
|
+
DEFAULT_SIGNATURE = 'question: string -> answer: string'
|
|
11
|
+
|
|
12
|
+
def initialize(signature = nil, model: nil, **kwargs)
|
|
13
|
+
# Extract our specific options before passing to parent
|
|
14
|
+
@n_samples = kwargs.delete(:n_samples) || 5
|
|
15
|
+
@selection_criterion = validate_criterion(kwargs.delete(:selection_criterion) || :consistency)
|
|
16
|
+
@temperature = kwargs.delete(:temperature) || 0.7
|
|
17
|
+
@custom_selector = kwargs.delete(:custom_selector) # Proc that takes array of results
|
|
18
|
+
@base_module = kwargs.delete(:base_module) || Modules::Predict
|
|
19
|
+
@include_metadata = kwargs.delete(:include_metadata) || false
|
|
20
|
+
|
|
21
|
+
# Use default signature if none provided
|
|
22
|
+
signature ||= DEFAULT_SIGNATURE
|
|
23
|
+
|
|
24
|
+
# Pass remaining kwargs to parent (config, demos, metadata)
|
|
25
|
+
super
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def forward(**inputs)
|
|
29
|
+
# Generate N samples
|
|
30
|
+
samples = generate_samples(inputs)
|
|
31
|
+
|
|
32
|
+
# Select the best sample based on criterion
|
|
33
|
+
best_sample = select_best(samples, inputs)
|
|
34
|
+
|
|
35
|
+
# Include metadata if requested
|
|
36
|
+
if @include_metadata || signature.output_fields.key?(:selection_metadata)
|
|
37
|
+
best_sample[:selection_metadata] = build_metadata(samples, best_sample)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# Clean up internal fields
|
|
41
|
+
best_sample.delete(:_confidence_score)
|
|
42
|
+
|
|
43
|
+
best_sample
|
|
44
|
+
rescue ArgumentError => e
|
|
45
|
+
# Re-raise ArgumentError for missing custom selector
|
|
46
|
+
raise e
|
|
47
|
+
rescue StandardError => e
|
|
48
|
+
Desiru.logger.error("BestOfN error: #{e.message}")
|
|
49
|
+
# Fallback to single sample
|
|
50
|
+
fallback_sample(inputs)
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
private
|
|
54
|
+
|
|
55
|
+
def validate_criterion(criterion)
|
|
56
|
+
unless SELECTION_CRITERIA.include?(criterion)
|
|
57
|
+
raise ArgumentError, "Invalid selection criterion: #{criterion}. " \
|
|
58
|
+
"Must be one of: #{SELECTION_CRITERIA.join(', ')}"
|
|
59
|
+
end
|
|
60
|
+
criterion
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
def generate_samples(inputs)
|
|
64
|
+
samples = []
|
|
65
|
+
|
|
66
|
+
# Create module instance for generation
|
|
67
|
+
generator = if @base_module.is_a?(Class)
|
|
68
|
+
@base_module.new(signature, model: model)
|
|
69
|
+
else
|
|
70
|
+
@base_module
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
@n_samples.times do |i|
|
|
74
|
+
# Add variation seed to inputs for diversity
|
|
75
|
+
sample_inputs = inputs.merge(_sample_index: i)
|
|
76
|
+
|
|
77
|
+
# Use higher temperature for diversity
|
|
78
|
+
original_temp = model.instance_variable_get(:@temperature) if model.respond_to?(:instance_variable_get)
|
|
79
|
+
|
|
80
|
+
begin
|
|
81
|
+
# Temporarily set temperature if possible
|
|
82
|
+
model.temperature = @temperature if model.respond_to?(:temperature=)
|
|
83
|
+
|
|
84
|
+
# Generate sample
|
|
85
|
+
sample = if generator.respond_to?(:forward)
|
|
86
|
+
generator.forward(**sample_inputs)
|
|
87
|
+
else
|
|
88
|
+
generator.call(**sample_inputs)
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
# Remove the sample index from results
|
|
92
|
+
sample.delete(:_sample_index)
|
|
93
|
+
samples << sample
|
|
94
|
+
ensure
|
|
95
|
+
# Restore original temperature
|
|
96
|
+
model.temperature = original_temp if model.respond_to?(:temperature=) && original_temp
|
|
97
|
+
end
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
samples
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
def select_best(samples, inputs)
|
|
104
|
+
case @selection_criterion
|
|
105
|
+
when :confidence
|
|
106
|
+
select_by_confidence(samples)
|
|
107
|
+
when :consistency
|
|
108
|
+
select_by_consistency(samples)
|
|
109
|
+
when :llm_judge
|
|
110
|
+
select_by_llm_judge(samples, inputs)
|
|
111
|
+
when :custom
|
|
112
|
+
select_by_custom(samples)
|
|
113
|
+
else
|
|
114
|
+
samples.first # Fallback
|
|
115
|
+
end
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
def select_by_confidence(samples)
|
|
119
|
+
# Ask model to rate confidence for each sample
|
|
120
|
+
samples_with_scores = samples.map do |sample|
|
|
121
|
+
confidence = calculate_confidence(sample)
|
|
122
|
+
sample.merge(_confidence_score: confidence)
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
# Return sample with highest confidence (keep score for metadata)
|
|
126
|
+
samples_with_scores.max_by { |s| s[:_confidence_score] }
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
def calculate_confidence(sample)
|
|
130
|
+
# Build confidence prompt
|
|
131
|
+
prompt = "Rate the confidence (0-100) for this response:\n\n"
|
|
132
|
+
|
|
133
|
+
sample.each do |key, value|
|
|
134
|
+
next if key.to_s.start_with?('_')
|
|
135
|
+
|
|
136
|
+
prompt += "#{key}: #{value}\n"
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
prompt += "\nProvide only a number between 0 and 100:"
|
|
140
|
+
|
|
141
|
+
response = model.complete(
|
|
142
|
+
messages: [{ role: 'user', content: prompt }],
|
|
143
|
+
temperature: 0.1
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Extract confidence score
|
|
147
|
+
score = response[:content].scan(/\d+/).first&.to_i || 50
|
|
148
|
+
score.clamp(0, 100)
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
def select_by_consistency(samples)
|
|
152
|
+
# Group samples by their main output values
|
|
153
|
+
output_groups = Hash.new { |h, k| h[k] = [] }
|
|
154
|
+
|
|
155
|
+
# Find the main output field (first non-metadata field)
|
|
156
|
+
main_field = signature.output_fields.keys.find do |k|
|
|
157
|
+
!k.to_s.start_with?('_') && k.to_s != 'selection_metadata'
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
return samples.first unless main_field
|
|
161
|
+
|
|
162
|
+
# Convert to symbol to match sample keys
|
|
163
|
+
field_sym = main_field.to_sym
|
|
164
|
+
|
|
165
|
+
# Group samples by their main output
|
|
166
|
+
samples.each do |sample|
|
|
167
|
+
if sample[field_sym]
|
|
168
|
+
key = normalize_output(sample[field_sym])
|
|
169
|
+
output_groups[key] << sample
|
|
170
|
+
end
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
# Select the most consistent group
|
|
174
|
+
largest_group = output_groups.values.max_by(&:length)
|
|
175
|
+
|
|
176
|
+
# From the largest group, select the "centroid" - the one most similar to others
|
|
177
|
+
select_centroid(largest_group)
|
|
178
|
+
end
|
|
179
|
+
|
|
180
|
+
def normalize_output(value)
|
|
181
|
+
case value
|
|
182
|
+
when String
|
|
183
|
+
value.downcase.strip.gsub(/[[:punct:]]/, '')
|
|
184
|
+
when Numeric
|
|
185
|
+
value.round(2)
|
|
186
|
+
when Array
|
|
187
|
+
value.map { |v| normalize_output(v) }.sort
|
|
188
|
+
when Hash
|
|
189
|
+
value.transform_values { |v| normalize_output(v) }
|
|
190
|
+
else
|
|
191
|
+
value.to_s
|
|
192
|
+
end
|
|
193
|
+
end
|
|
194
|
+
|
|
195
|
+
def select_centroid(group)
|
|
196
|
+
return group.first if group.length == 1
|
|
197
|
+
|
|
198
|
+
# For now, return the middle element (could be improved with similarity metrics)
|
|
199
|
+
group[group.length / 2]
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
def select_by_llm_judge(samples, inputs)
|
|
203
|
+
# Build judge prompt
|
|
204
|
+
judge_prompt = "Given the following input and multiple response options, " \
|
|
205
|
+
"select the best response:\n\n"
|
|
206
|
+
|
|
207
|
+
# Add original inputs
|
|
208
|
+
judge_prompt += "Input:\n"
|
|
209
|
+
inputs.each do |key, value|
|
|
210
|
+
judge_prompt += " #{key}: #{value}\n"
|
|
211
|
+
end
|
|
212
|
+
|
|
213
|
+
# Add all samples
|
|
214
|
+
judge_prompt += "\nResponse Options:\n"
|
|
215
|
+
samples.each_with_index do |sample, i|
|
|
216
|
+
judge_prompt += "\n--- Option #{i + 1} ---\n"
|
|
217
|
+
sample.each do |key, value|
|
|
218
|
+
next if key.to_s.start_with?('_')
|
|
219
|
+
|
|
220
|
+
judge_prompt += "#{key}: #{value}\n"
|
|
221
|
+
end
|
|
222
|
+
end
|
|
223
|
+
|
|
224
|
+
judge_prompt += "\nSelect the best option (1-#{samples.length}) and briefly explain why:"
|
|
225
|
+
|
|
226
|
+
response = model.complete(
|
|
227
|
+
messages: [{ role: 'user', content: judge_prompt }],
|
|
228
|
+
temperature: 0.1
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Extract selected index
|
|
232
|
+
selection_match = response[:content].match(/option\s*#?(\d+)/i)
|
|
233
|
+
selected_index = if selection_match
|
|
234
|
+
selection_match[1].to_i - 1
|
|
235
|
+
else
|
|
236
|
+
0
|
|
237
|
+
end
|
|
238
|
+
|
|
239
|
+
selected_index = selected_index.clamp(0, samples.length - 1)
|
|
240
|
+
samples[selected_index]
|
|
241
|
+
end
|
|
242
|
+
|
|
243
|
+
def select_by_custom(samples)
|
|
244
|
+
unless @custom_selector.respond_to?(:call)
|
|
245
|
+
raise ArgumentError, "Custom selector must be provided when using :custom criterion"
|
|
246
|
+
end
|
|
247
|
+
|
|
248
|
+
@custom_selector.call(samples) || samples.first
|
|
249
|
+
end
|
|
250
|
+
|
|
251
|
+
def build_metadata(samples, selected)
|
|
252
|
+
metadata = {
|
|
253
|
+
total_samples: samples.length,
|
|
254
|
+
selection_criterion: @selection_criterion,
|
|
255
|
+
temperature: @temperature
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
# Add criterion-specific metadata
|
|
259
|
+
case @selection_criterion
|
|
260
|
+
when :consistency
|
|
261
|
+
# Count how many samples agree with the selected one
|
|
262
|
+
main_field = signature.output_fields.keys.find do |k|
|
|
263
|
+
!k.to_s.start_with?('_') && k.to_s != 'selection_metadata'
|
|
264
|
+
end
|
|
265
|
+
|
|
266
|
+
if main_field
|
|
267
|
+
# Convert to symbol to match sample keys
|
|
268
|
+
field_sym = main_field.to_sym
|
|
269
|
+
if selected[field_sym]
|
|
270
|
+
selected_value = normalize_output(selected[field_sym])
|
|
271
|
+
agreement_count = samples.count do |s|
|
|
272
|
+
normalize_output(s[field_sym]) == selected_value
|
|
273
|
+
end
|
|
274
|
+
metadata[:agreement_rate] = agreement_count.to_f / samples.length
|
|
275
|
+
end
|
|
276
|
+
end
|
|
277
|
+
when :confidence
|
|
278
|
+
# Include confidence scores if available
|
|
279
|
+
metadata[:selected_confidence] = selected[:_confidence_score] if selected[:_confidence_score]
|
|
280
|
+
end
|
|
281
|
+
|
|
282
|
+
metadata
|
|
283
|
+
end
|
|
284
|
+
|
|
285
|
+
def fallback_sample(inputs)
|
|
286
|
+
# Generate a single sample as fallback
|
|
287
|
+
generator = if @base_module.is_a?(Class)
|
|
288
|
+
@base_module.new(signature, model: model)
|
|
289
|
+
else
|
|
290
|
+
@base_module
|
|
291
|
+
end
|
|
292
|
+
|
|
293
|
+
if generator.respond_to?(:forward)
|
|
294
|
+
generator.forward(**inputs)
|
|
295
|
+
else
|
|
296
|
+
generator.call(**inputs)
|
|
297
|
+
end
|
|
298
|
+
end
|
|
299
|
+
end
|
|
300
|
+
end
|
|
301
|
+
end
|
|
302
|
+
|
|
303
|
+
# Register in the main module namespace for convenience
|
|
304
|
+
module Desiru
|
|
305
|
+
BestOfN = Modules::BestOfN
|
|
306
|
+
end
|
|
@@ -21,9 +21,9 @@ module Desiru
|
|
|
21
21
|
|
|
22
22
|
Before providing the final answer, you must show your reasoning process. Think through the problem step by step.
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
reasoning:
|
|
26
|
-
|
|
24
|
+
Always format your response with each field on its own line like this:
|
|
25
|
+
reasoning: Your step-by-step thought process here
|
|
26
|
+
#{@original_signature.output_fields.keys.map { |field| "#{field}: Your #{field} here" }.join("\n")}
|
|
27
27
|
|
|
28
28
|
#{format_descriptions}
|
|
29
29
|
PROMPT
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Desiru
|
|
4
|
+
module Modules
|
|
5
|
+
# Function-style module for majority voting
|
|
6
|
+
# Returns the most common response from multiple completions
|
|
7
|
+
def self.majority(module_instance, **inputs)
|
|
8
|
+
raise ArgumentError, "First argument must be a Desiru module instance" unless module_instance.respond_to?(:call)
|
|
9
|
+
|
|
10
|
+
# Number of completions to generate
|
|
11
|
+
num_completions = inputs.delete(:num_completions) || 5
|
|
12
|
+
|
|
13
|
+
# Generate multiple completions
|
|
14
|
+
results = []
|
|
15
|
+
num_completions.times do
|
|
16
|
+
result = module_instance.call(**inputs)
|
|
17
|
+
results << result
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Find the majority answer
|
|
21
|
+
# For simplicity, we'll compare the first output field
|
|
22
|
+
output_fields = module_instance.signature.output_fields.keys
|
|
23
|
+
main_field = output_fields.first
|
|
24
|
+
|
|
25
|
+
# Count occurrences of each answer
|
|
26
|
+
answer_counts = Hash.new(0)
|
|
27
|
+
answer_to_result = {}
|
|
28
|
+
|
|
29
|
+
results.each do |result|
|
|
30
|
+
answer = result[main_field]
|
|
31
|
+
answer_counts[answer] += 1
|
|
32
|
+
answer_to_result[answer] ||= result
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Return the result with the most common answer
|
|
36
|
+
majority_answer = answer_counts.max_by { |_, count| count }&.first
|
|
37
|
+
winning_result = answer_to_result[majority_answer] || results.first
|
|
38
|
+
|
|
39
|
+
# Add voting metadata if requested
|
|
40
|
+
if output_fields.include?(:voting_data)
|
|
41
|
+
winning_result[:voting_data] = {
|
|
42
|
+
votes: answer_counts,
|
|
43
|
+
num_completions: num_completions,
|
|
44
|
+
consensus_rate: answer_counts[majority_answer].to_f / num_completions
|
|
45
|
+
}
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
winning_result
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
end
|