desiru 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +1 -0
- data/.rubocop.yml +55 -0
- data/CLAUDE.md +22 -0
- data/Gemfile +36 -0
- data/Gemfile.lock +255 -0
- data/LICENSE +21 -0
- data/README.md +343 -0
- data/Rakefile +18 -0
- data/desiru.gemspec +44 -0
- data/examples/README.md +55 -0
- data/examples/async_processing.rb +135 -0
- data/examples/few_shot_learning.rb +66 -0
- data/examples/graphql_api.rb +190 -0
- data/examples/graphql_integration.rb +114 -0
- data/examples/rag_retrieval.rb +80 -0
- data/examples/simple_qa.rb +31 -0
- data/examples/typed_signatures.rb +45 -0
- data/lib/desiru/async_capable.rb +170 -0
- data/lib/desiru/cache.rb +116 -0
- data/lib/desiru/configuration.rb +40 -0
- data/lib/desiru/field.rb +171 -0
- data/lib/desiru/graphql/data_loader.rb +210 -0
- data/lib/desiru/graphql/executor.rb +115 -0
- data/lib/desiru/graphql/schema_generator.rb +301 -0
- data/lib/desiru/jobs/async_predict.rb +52 -0
- data/lib/desiru/jobs/base.rb +53 -0
- data/lib/desiru/jobs/batch_processor.rb +71 -0
- data/lib/desiru/jobs/optimizer_job.rb +45 -0
- data/lib/desiru/models/base.rb +112 -0
- data/lib/desiru/models/raix_adapter.rb +210 -0
- data/lib/desiru/module.rb +204 -0
- data/lib/desiru/modules/chain_of_thought.rb +106 -0
- data/lib/desiru/modules/predict.rb +142 -0
- data/lib/desiru/modules/retrieve.rb +199 -0
- data/lib/desiru/optimizers/base.rb +130 -0
- data/lib/desiru/optimizers/bootstrap_few_shot.rb +212 -0
- data/lib/desiru/program.rb +106 -0
- data/lib/desiru/registry.rb +74 -0
- data/lib/desiru/signature.rb +322 -0
- data/lib/desiru/version.rb +5 -0
- data/lib/desiru.rb +67 -0
- metadata +184 -0
@@ -0,0 +1,199 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Desiru
|
4
|
+
module Modules
|
5
|
+
# Retrieve module for RAG (Retrieval Augmented Generation)
|
6
|
+
# Implements vector search capabilities with pluggable backends
|
7
|
+
class Retrieve < Module
|
8
|
+
attr_reader :backend
|
9
|
+
|
10
|
+
def initialize(signature = nil, backend: nil, **)
|
11
|
+
# Default signature for retrieval operations
|
12
|
+
signature ||= 'query: string, k: integer? -> documents: list, scores: list'
|
13
|
+
|
14
|
+
super(signature, **)
|
15
|
+
|
16
|
+
# Initialize backend
|
17
|
+
@backend = backend || InMemoryBackend.new
|
18
|
+
validate_backend!
|
19
|
+
end
|
20
|
+
|
21
|
+
def forward(**inputs)
|
22
|
+
query = inputs[:query]
|
23
|
+
# Handle k parameter - it might come as nil if optional
|
24
|
+
k = inputs.fetch(:k, 5)
|
25
|
+
k = 5 if k.nil? # Ensure we have a value even if nil was passed
|
26
|
+
|
27
|
+
# Perform retrieval using the backend
|
28
|
+
results = backend.search(query, k: k)
|
29
|
+
|
30
|
+
# Separate documents and scores
|
31
|
+
documents = results.map { |r| r[:document] }
|
32
|
+
scores = results.map { |r| r[:score] }
|
33
|
+
|
34
|
+
{ documents: documents, scores: scores }
|
35
|
+
end
|
36
|
+
|
37
|
+
# Add documents to the retrieval index
|
38
|
+
def add_documents(documents, embeddings: nil)
|
39
|
+
backend.add(documents, embeddings: embeddings)
|
40
|
+
end
|
41
|
+
|
42
|
+
# Clear the retrieval index
|
43
|
+
def clear_index
|
44
|
+
backend.clear
|
45
|
+
end
|
46
|
+
|
47
|
+
# Get the current document count
|
48
|
+
def document_count
|
49
|
+
backend.size
|
50
|
+
end
|
51
|
+
|
52
|
+
private
|
53
|
+
|
54
|
+
def validate_backend!
|
55
|
+
required_methods = %i[add search clear size]
|
56
|
+
missing_methods = required_methods.reject { |m| backend.respond_to?(m) }
|
57
|
+
|
58
|
+
return unless missing_methods.any?
|
59
|
+
|
60
|
+
raise ConfigurationError, "Backend must implement: #{missing_methods.join(', ')}"
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
# Abstract base class for retrieval backends
|
65
|
+
class Backend
|
66
|
+
def add(_documents, embeddings: nil)
|
67
|
+
raise NotImplementedError, 'Subclasses must implement #add'
|
68
|
+
end
|
69
|
+
|
70
|
+
def search(_query, k: 5)
|
71
|
+
raise NotImplementedError, 'Subclasses must implement #search'
|
72
|
+
end
|
73
|
+
|
74
|
+
def clear
|
75
|
+
raise NotImplementedError, 'Subclasses must implement #clear'
|
76
|
+
end
|
77
|
+
|
78
|
+
def size
|
79
|
+
raise NotImplementedError, 'Subclasses must implement #size'
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
# In-memory backend implementation for development and testing
|
84
|
+
class InMemoryBackend < Backend
|
85
|
+
def initialize(distance_metric: :cosine)
|
86
|
+
@documents = []
|
87
|
+
@embeddings = []
|
88
|
+
@distance_metric = distance_metric
|
89
|
+
end
|
90
|
+
|
91
|
+
def add(documents, embeddings: nil)
|
92
|
+
documents = Array(documents)
|
93
|
+
|
94
|
+
# If embeddings provided, they must match document count
|
95
|
+
if embeddings
|
96
|
+
embeddings = Array(embeddings)
|
97
|
+
if embeddings.size != documents.size
|
98
|
+
raise ArgumentError, "Embeddings count (#{embeddings.size}) must match documents count (#{documents.size})"
|
99
|
+
end
|
100
|
+
else
|
101
|
+
# Generate simple embeddings based on document content (for demo purposes)
|
102
|
+
embeddings = documents.map { |doc| generate_simple_embedding(doc) }
|
103
|
+
end
|
104
|
+
|
105
|
+
# Store documents and embeddings
|
106
|
+
@documents.concat(documents)
|
107
|
+
@embeddings.concat(embeddings)
|
108
|
+
end
|
109
|
+
|
110
|
+
def search(query, k: 5)
|
111
|
+
return [] if @documents.empty?
|
112
|
+
|
113
|
+
# Generate query embedding
|
114
|
+
query_embedding = generate_simple_embedding(query)
|
115
|
+
|
116
|
+
# Calculate distances to all documents
|
117
|
+
distances = @embeddings.map.with_index do |embedding, idx|
|
118
|
+
distance = calculate_distance(query_embedding, embedding)
|
119
|
+
{ document: @documents[idx], score: distance, index: idx }
|
120
|
+
end
|
121
|
+
|
122
|
+
# Sort by distance (ascending for distance, would be descending for similarity)
|
123
|
+
sorted = case @distance_metric
|
124
|
+
when :cosine
|
125
|
+
# For cosine similarity, higher is better, so sort descending
|
126
|
+
distances.sort_by { |d| -d[:score] }
|
127
|
+
else
|
128
|
+
# For distance metrics, lower is better
|
129
|
+
distances.sort_by { |d| d[:score] }
|
130
|
+
end
|
131
|
+
|
132
|
+
# Return top k results
|
133
|
+
sorted.first(k)
|
134
|
+
end
|
135
|
+
|
136
|
+
def clear
|
137
|
+
@documents.clear
|
138
|
+
@embeddings.clear
|
139
|
+
end
|
140
|
+
|
141
|
+
def size
|
142
|
+
@documents.size
|
143
|
+
end
|
144
|
+
|
145
|
+
private
|
146
|
+
|
147
|
+
def generate_simple_embedding(text)
|
148
|
+
# Simple embedding: character frequency vector
|
149
|
+
# In production, use proper embedding models
|
150
|
+
text = text.to_s.downcase
|
151
|
+
|
152
|
+
# Create a 26-dimensional vector for a-z frequency
|
153
|
+
embedding = Array.new(26, 0.0)
|
154
|
+
|
155
|
+
text.each_char do |char|
|
156
|
+
if char.between?('a', 'z')
|
157
|
+
idx = char.ord - 'a'.ord
|
158
|
+
embedding[idx] += 1.0
|
159
|
+
end
|
160
|
+
end
|
161
|
+
|
162
|
+
# Normalize the vector
|
163
|
+
magnitude = Math.sqrt(embedding.sum { |x| x**2 })
|
164
|
+
embedding.map! { |x| x / magnitude } if magnitude.positive?
|
165
|
+
|
166
|
+
embedding
|
167
|
+
end
|
168
|
+
|
169
|
+
def calculate_distance(vec1, vec2)
|
170
|
+
case @distance_metric
|
171
|
+
when :cosine
|
172
|
+
cosine_similarity(vec1, vec2)
|
173
|
+
when :euclidean
|
174
|
+
euclidean_distance(vec1, vec2)
|
175
|
+
else
|
176
|
+
raise ArgumentError, "Unknown distance metric: #{@distance_metric}"
|
177
|
+
end
|
178
|
+
end
|
179
|
+
|
180
|
+
def cosine_similarity(vec1, vec2)
|
181
|
+
# Cosine similarity: dot product of normalized vectors
|
182
|
+
# Since we pre-normalize embeddings, this is just dot product
|
183
|
+
vec1.zip(vec2).sum { |a, b| a * b }
|
184
|
+
|
185
|
+
# Return similarity (1.0 = identical, 0.0 = orthogonal)
|
186
|
+
end
|
187
|
+
|
188
|
+
def euclidean_distance(vec1, vec2)
|
189
|
+
# Euclidean distance
|
190
|
+
Math.sqrt(vec1.zip(vec2).sum { |a, b| (a - b)**2 })
|
191
|
+
end
|
192
|
+
end
|
193
|
+
end
|
194
|
+
end
|
195
|
+
|
196
|
+
# Register in the main module namespace for convenience
|
197
|
+
module Desiru
|
198
|
+
Retrieve = Modules::Retrieve
|
199
|
+
end
|
@@ -0,0 +1,130 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Desiru
|
4
|
+
module Optimizers
|
5
|
+
# Base class for all optimizers
|
6
|
+
class Base
|
7
|
+
attr_reader :metric, :config
|
8
|
+
|
9
|
+
def initialize(metric: :exact_match, **config)
|
10
|
+
@metric = normalize_metric(metric)
|
11
|
+
@config = default_config.merge(config)
|
12
|
+
@optimization_trace = []
|
13
|
+
end
|
14
|
+
|
15
|
+
def compile(program, trainset:, valset: nil)
|
16
|
+
raise NotImplementedError, 'Subclasses must implement #compile'
|
17
|
+
end
|
18
|
+
|
19
|
+
def optimize_module(module_instance, examples)
|
20
|
+
raise NotImplementedError, 'Subclasses must implement #optimize_module'
|
21
|
+
end
|
22
|
+
|
23
|
+
def evaluate(program, dataset)
|
24
|
+
scores = dataset.map do |example|
|
25
|
+
prediction = program.call(example.reject { |k, _| %i[answer output].include?(k) })
|
26
|
+
score_prediction(prediction, example)
|
27
|
+
end
|
28
|
+
|
29
|
+
{
|
30
|
+
average_score: scores.sum.to_f / scores.size,
|
31
|
+
scores: scores,
|
32
|
+
total: scores.size
|
33
|
+
}
|
34
|
+
end
|
35
|
+
|
36
|
+
protected
|
37
|
+
|
38
|
+
def default_config
|
39
|
+
{
|
40
|
+
max_bootstrapped_demos: 3,
|
41
|
+
max_labeled_demos: 16,
|
42
|
+
max_errors: 5,
|
43
|
+
num_candidates: 1,
|
44
|
+
stop_at_score: 1.0
|
45
|
+
}
|
46
|
+
end
|
47
|
+
|
48
|
+
def score_prediction(prediction, ground_truth)
|
49
|
+
case @metric
|
50
|
+
when Proc
|
51
|
+
@metric.call(prediction, ground_truth)
|
52
|
+
when :exact_match
|
53
|
+
exact_match_score(prediction, ground_truth)
|
54
|
+
when :f1
|
55
|
+
f1_score(prediction, ground_truth)
|
56
|
+
when :accuracy
|
57
|
+
accuracy_score(prediction, ground_truth)
|
58
|
+
else
|
59
|
+
raise OptimizerError, "Unknown metric: #{@metric}"
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
def exact_match_score(prediction, ground_truth)
|
64
|
+
pred_answer = extract_answer(prediction)
|
65
|
+
true_answer = extract_answer(ground_truth)
|
66
|
+
|
67
|
+
pred_answer.to_s.strip.downcase == true_answer.to_s.strip.downcase ? 1.0 : 0.0
|
68
|
+
end
|
69
|
+
|
70
|
+
def f1_score(prediction, ground_truth)
|
71
|
+
pred_tokens = tokenize(extract_answer(prediction))
|
72
|
+
true_tokens = tokenize(extract_answer(ground_truth))
|
73
|
+
|
74
|
+
return 0.0 if pred_tokens.empty? && true_tokens.empty?
|
75
|
+
return 0.0 if pred_tokens.empty? || true_tokens.empty?
|
76
|
+
|
77
|
+
precision = (pred_tokens & true_tokens).size.to_f / pred_tokens.size
|
78
|
+
recall = (pred_tokens & true_tokens).size.to_f / true_tokens.size
|
79
|
+
|
80
|
+
return 0.0 if (precision + recall).zero?
|
81
|
+
|
82
|
+
2 * (precision * recall) / (precision + recall)
|
83
|
+
end
|
84
|
+
|
85
|
+
def accuracy_score(prediction, ground_truth)
|
86
|
+
exact_match_score(prediction, ground_truth)
|
87
|
+
end
|
88
|
+
|
89
|
+
def extract_answer(data)
|
90
|
+
case data
|
91
|
+
when ModuleResult, ProgramResult
|
92
|
+
# Try common answer fields
|
93
|
+
data[:answer] || data[:output] || data[:result] || data.values.first
|
94
|
+
when Hash
|
95
|
+
data[:answer] || data[:output] || data[:result] || data.values.first
|
96
|
+
else
|
97
|
+
data
|
98
|
+
end
|
99
|
+
end
|
100
|
+
|
101
|
+
def tokenize(text)
|
102
|
+
text.to_s.downcase.split(/\W+/).reject(&:empty?)
|
103
|
+
end
|
104
|
+
|
105
|
+
def normalize_metric(metric)
|
106
|
+
case metric
|
107
|
+
when Symbol, String
|
108
|
+
metric.to_sym
|
109
|
+
when Proc
|
110
|
+
metric
|
111
|
+
else
|
112
|
+
raise OptimizerError, 'Metric must be a symbol or proc'
|
113
|
+
end
|
114
|
+
end
|
115
|
+
|
116
|
+
def trace_optimization(step, details)
|
117
|
+
@optimization_trace << {
|
118
|
+
step: step,
|
119
|
+
timestamp: Time.now,
|
120
|
+
details: details
|
121
|
+
}
|
122
|
+
|
123
|
+
Desiru.configuration.logger&.info("[Optimizer] #{step}: #{details}")
|
124
|
+
end
|
125
|
+
end
|
126
|
+
|
127
|
+
# Base error for optimizer-related issues
|
128
|
+
class OptimizerError < Error; end
|
129
|
+
end
|
130
|
+
end
|
@@ -0,0 +1,212 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Desiru
|
4
|
+
module Optimizers
|
5
|
+
# Bootstrap Few-Shot optimizer - automatically selects effective demonstrations
|
6
|
+
class BootstrapFewShot < Base
|
7
|
+
def compile(program, trainset:, valset: nil)
|
8
|
+
trace_optimization('Starting BootstrapFewShot optimization', {
|
9
|
+
trainset_size: trainset.size,
|
10
|
+
valset_size: valset&.size || 0
|
11
|
+
})
|
12
|
+
|
13
|
+
# Create a working copy of the program
|
14
|
+
optimized_program = deep_copy_program(program)
|
15
|
+
|
16
|
+
# Optimize each module in the program
|
17
|
+
optimize_modules(optimized_program, trainset, valset)
|
18
|
+
|
19
|
+
# Evaluate final performance
|
20
|
+
if valset
|
21
|
+
final_score = evaluate(optimized_program, valset)
|
22
|
+
trace_optimization('Final validation score', final_score)
|
23
|
+
end
|
24
|
+
|
25
|
+
optimized_program
|
26
|
+
end
|
27
|
+
|
28
|
+
def optimize_module(module_instance, examples)
|
29
|
+
trace_optimization('Optimizing module', {
|
30
|
+
module: module_instance.class.name,
|
31
|
+
examples_available: examples.size
|
32
|
+
})
|
33
|
+
|
34
|
+
# Bootstrap demonstrations
|
35
|
+
bootstrapped_demos = bootstrap_demonstrations(module_instance, examples)
|
36
|
+
|
37
|
+
# Select best demonstrations
|
38
|
+
selected_demos = select_demonstrations(
|
39
|
+
module_instance,
|
40
|
+
bootstrapped_demos,
|
41
|
+
examples
|
42
|
+
)
|
43
|
+
|
44
|
+
# Return module with selected demonstrations
|
45
|
+
module_instance.with_demos(selected_demos)
|
46
|
+
end
|
47
|
+
|
48
|
+
private
|
49
|
+
|
50
|
+
def deep_copy_program(program)
|
51
|
+
# This is a simplified version - in practice, we'd need proper deep copying
|
52
|
+
program.class.new(config: program.config, metadata: program.metadata)
|
53
|
+
end
|
54
|
+
|
55
|
+
def optimize_modules(program, trainset, _valset)
|
56
|
+
# Get all modules from the program
|
57
|
+
modules_to_optimize = extract_modules(program)
|
58
|
+
|
59
|
+
modules_to_optimize.each do |module_name, module_instance|
|
60
|
+
trace_optimization('Processing module', { name: module_name })
|
61
|
+
|
62
|
+
# Create module-specific examples
|
63
|
+
module_examples = create_module_examples(module_instance, trainset)
|
64
|
+
|
65
|
+
# Optimize the module
|
66
|
+
optimized_module = optimize_module(module_instance, module_examples)
|
67
|
+
|
68
|
+
# Replace in program
|
69
|
+
replace_module(program, module_name, optimized_module)
|
70
|
+
end
|
71
|
+
end
|
72
|
+
|
73
|
+
def bootstrap_demonstrations(module_instance, examples)
|
74
|
+
demonstrations = []
|
75
|
+
errors = 0
|
76
|
+
|
77
|
+
examples.each do |example|
|
78
|
+
break if demonstrations.size >= config[:max_bootstrapped_demos]
|
79
|
+
break if errors >= config[:max_errors]
|
80
|
+
|
81
|
+
begin
|
82
|
+
# Get module prediction
|
83
|
+
inputs = example.reject { |k, _| %i[answer output].include?(k) }
|
84
|
+
prediction = module_instance.call(inputs)
|
85
|
+
|
86
|
+
# Score the prediction
|
87
|
+
score = score_prediction(prediction, example)
|
88
|
+
|
89
|
+
if score >= 0.5 # Configurable threshold
|
90
|
+
demonstrations << {
|
91
|
+
input: format_demo_input(inputs),
|
92
|
+
output: format_demo_output(prediction),
|
93
|
+
score: score
|
94
|
+
}
|
95
|
+
else
|
96
|
+
errors += 1
|
97
|
+
end
|
98
|
+
rescue StandardError => e
|
99
|
+
trace_optimization('Error during bootstrap', { error: e.message })
|
100
|
+
errors += 1
|
101
|
+
end
|
102
|
+
end
|
103
|
+
|
104
|
+
demonstrations
|
105
|
+
end
|
106
|
+
|
107
|
+
def select_demonstrations(_module_instance, bootstrapped, examples)
|
108
|
+
all_demos = bootstrapped
|
109
|
+
|
110
|
+
# Add labeled examples if available
|
111
|
+
labeled = examples.select { |ex| ex[:answer] || ex[:output] }
|
112
|
+
labeled_demos = labeled.first(config[:max_labeled_demos]).map do |ex|
|
113
|
+
inputs = ex.reject { |k, _| %i[answer output].include?(k) }
|
114
|
+
{
|
115
|
+
input: format_demo_input(inputs),
|
116
|
+
output: format_demo_output(ex),
|
117
|
+
score: 1.0 # Perfect score for labeled examples
|
118
|
+
}
|
119
|
+
end
|
120
|
+
|
121
|
+
all_demos += labeled_demos
|
122
|
+
|
123
|
+
# Sort by score and diversity
|
124
|
+
selected = select_diverse_demos(all_demos)
|
125
|
+
|
126
|
+
# Take top K
|
127
|
+
selected.first(config[:max_bootstrapped_demos])
|
128
|
+
end
|
129
|
+
|
130
|
+
def select_diverse_demos(demos)
|
131
|
+
# Simple diversity selection - could be improved
|
132
|
+
selected = []
|
133
|
+
remaining = demos.sort_by { |d| -d[:score] }
|
134
|
+
|
135
|
+
while selected.size < config[:max_bootstrapped_demos] && remaining.any?
|
136
|
+
# Take the best remaining
|
137
|
+
best = remaining.shift
|
138
|
+
selected << best
|
139
|
+
|
140
|
+
# Remove similar demos (simple text similarity)
|
141
|
+
remaining.reject! do |demo|
|
142
|
+
similarity(demo[:input], best[:input]) > 0.8
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
selected
|
147
|
+
end
|
148
|
+
|
149
|
+
def similarity(text1, text2)
|
150
|
+
# Very simple similarity - could use better metrics
|
151
|
+
tokens1 = tokenize(text1)
|
152
|
+
tokens2 = tokenize(text2)
|
153
|
+
|
154
|
+
return 0.0 if tokens1.empty? || tokens2.empty?
|
155
|
+
|
156
|
+
intersection = (tokens1 & tokens2).size
|
157
|
+
union = (tokens1 | tokens2).size
|
158
|
+
|
159
|
+
intersection.to_f / union
|
160
|
+
end
|
161
|
+
|
162
|
+
def format_demo_input(inputs)
|
163
|
+
inputs.map { |k, v| "#{k}: #{v}" }.join("\n")
|
164
|
+
end
|
165
|
+
|
166
|
+
def format_demo_output(output)
|
167
|
+
case output
|
168
|
+
when ModuleResult
|
169
|
+
output.to_h.map { |k, v| "#{k}: #{v}" }.join("\n")
|
170
|
+
when Hash
|
171
|
+
output.map { |k, v| "#{k}: #{v}" }.join("\n")
|
172
|
+
else
|
173
|
+
output.to_s
|
174
|
+
end
|
175
|
+
end
|
176
|
+
|
177
|
+
def extract_modules(program)
|
178
|
+
# This would need to be implemented based on program structure
|
179
|
+
# For now, return modules from instance variables
|
180
|
+
modules = {}
|
181
|
+
|
182
|
+
program.instance_variables.each do |var|
|
183
|
+
value = program.instance_variable_get(var)
|
184
|
+
modules[var.to_s.delete('@').to_sym] = value if value.is_a?(Module)
|
185
|
+
end
|
186
|
+
|
187
|
+
modules
|
188
|
+
end
|
189
|
+
|
190
|
+
def create_module_examples(_module_instance, trainset)
|
191
|
+
# Transform trainset to match module's signature
|
192
|
+
trainset.map do |example|
|
193
|
+
# This is simplified - would need proper field mapping
|
194
|
+
example
|
195
|
+
end
|
196
|
+
end
|
197
|
+
|
198
|
+
def replace_module(program, module_name, new_module)
|
199
|
+
# Replace the module in the program
|
200
|
+
var_name = "@#{module_name}"
|
201
|
+
return unless program.instance_variable_defined?(var_name)
|
202
|
+
|
203
|
+
program.instance_variable_set(var_name, new_module)
|
204
|
+
end
|
205
|
+
end
|
206
|
+
end
|
207
|
+
end
|
208
|
+
|
209
|
+
# Register in the main module namespace for convenience
|
210
|
+
module Desiru
|
211
|
+
BootstrapFewShot = Optimizers::BootstrapFewShot
|
212
|
+
end
|
@@ -0,0 +1,106 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Desiru
|
4
|
+
# Base class for composing multiple modules into programs
|
5
|
+
# Implements composition patterns for complex AI workflows
|
6
|
+
class Program
|
7
|
+
attr_reader :modules, :config, :metadata
|
8
|
+
|
9
|
+
def initialize(config: {}, metadata: {})
|
10
|
+
@modules = {}
|
11
|
+
@config = default_config.merge(config)
|
12
|
+
@metadata = metadata
|
13
|
+
@execution_trace = []
|
14
|
+
|
15
|
+
setup_modules
|
16
|
+
end
|
17
|
+
|
18
|
+
def call(inputs = {})
|
19
|
+
@execution_trace.clear
|
20
|
+
start_time = Time.now
|
21
|
+
|
22
|
+
result = forward(inputs)
|
23
|
+
|
24
|
+
execution_time = Time.now - start_time
|
25
|
+
|
26
|
+
ProgramResult.new(
|
27
|
+
result,
|
28
|
+
metadata: {
|
29
|
+
execution_time: execution_time,
|
30
|
+
trace: @execution_trace.dup,
|
31
|
+
program: self.class.name
|
32
|
+
}
|
33
|
+
)
|
34
|
+
rescue StandardError => e
|
35
|
+
handle_error(e)
|
36
|
+
end
|
37
|
+
|
38
|
+
def forward(_inputs)
|
39
|
+
raise NotImplementedError, 'Subclasses must implement #forward'
|
40
|
+
end
|
41
|
+
|
42
|
+
def reset
|
43
|
+
modules.each_value(&:reset)
|
44
|
+
@execution_trace.clear
|
45
|
+
end
|
46
|
+
|
47
|
+
def optimize(optimizer, trainset, valset = nil)
|
48
|
+
optimizer.compile(self, trainset: trainset, valset: valset)
|
49
|
+
end
|
50
|
+
|
51
|
+
def to_h
|
52
|
+
{
|
53
|
+
class: self.class.name,
|
54
|
+
modules: modules.transform_values(&:to_h),
|
55
|
+
config: config,
|
56
|
+
metadata: metadata
|
57
|
+
}
|
58
|
+
end
|
59
|
+
|
60
|
+
protected
|
61
|
+
|
62
|
+
def setup_modules
|
63
|
+
# Override in subclasses to initialize modules
|
64
|
+
end
|
65
|
+
|
66
|
+
def trace_execution(module_name, inputs, outputs)
|
67
|
+
@execution_trace << {
|
68
|
+
module: module_name,
|
69
|
+
inputs: inputs,
|
70
|
+
outputs: outputs.is_a?(ModuleResult) ? outputs.to_h : outputs,
|
71
|
+
timestamp: Time.now
|
72
|
+
}
|
73
|
+
end
|
74
|
+
|
75
|
+
def default_config
|
76
|
+
{
|
77
|
+
max_iterations: 10,
|
78
|
+
early_stopping: true,
|
79
|
+
trace_execution: true
|
80
|
+
}
|
81
|
+
end
|
82
|
+
|
83
|
+
private
|
84
|
+
|
85
|
+
def handle_error(error)
|
86
|
+
Desiru.configuration.logger&.error("Program execution failed: #{error.message}")
|
87
|
+
|
88
|
+
# Programs don't retry by default - let individual modules handle retries
|
89
|
+
raise ProgramError, "Program execution failed: #{error.message}"
|
90
|
+
end
|
91
|
+
end
|
92
|
+
|
93
|
+
# Result object for program outputs
|
94
|
+
class ProgramResult < ModuleResult
|
95
|
+
def trace
|
96
|
+
metadata[:trace] || []
|
97
|
+
end
|
98
|
+
|
99
|
+
def execution_time
|
100
|
+
metadata[:execution_time]
|
101
|
+
end
|
102
|
+
end
|
103
|
+
|
104
|
+
# Base error for program-related issues
|
105
|
+
class ProgramError < Error; end
|
106
|
+
end
|