desiru 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.claude/settings.local.json +11 -0
- data/.env.example +34 -0
- data/.rubocop.yml +7 -4
- data/.ruby-version +1 -0
- data/CHANGELOG.md +73 -0
- data/CLAUDE.local.md +3 -0
- data/CLAUDE.md +10 -1
- data/Gemfile +21 -2
- data/Gemfile.lock +88 -13
- data/README.md +301 -2
- data/Rakefile +1 -0
- data/db/migrations/001_create_initial_tables.rb +96 -0
- data/db/migrations/002_create_job_results.rb +39 -0
- data/desiru-development-swarm.yml +185 -0
- data/desiru.db +0 -0
- data/desiru.gemspec +2 -5
- data/docs/background_processing_roadmap.md +87 -0
- data/docs/job_scheduling.md +167 -0
- data/dspy-analysis-swarm.yml +60 -0
- data/dspy-feature-analysis.md +121 -0
- data/examples/README.md +69 -0
- data/examples/api_with_persistence.rb +122 -0
- data/examples/assertions_example.rb +232 -0
- data/examples/async_processing.rb +2 -0
- data/examples/few_shot_learning.rb +1 -2
- data/examples/graphql_api.rb +4 -2
- data/examples/graphql_integration.rb +3 -3
- data/examples/graphql_optimization_summary.md +143 -0
- data/examples/graphql_performance_benchmark.rb +247 -0
- data/examples/persistence_example.rb +102 -0
- data/examples/react_agent.rb +203 -0
- data/examples/rest_api.rb +173 -0
- data/examples/rest_api_advanced.rb +333 -0
- data/examples/scheduled_job_example.rb +116 -0
- data/examples/simple_qa.rb +1 -2
- data/examples/sinatra_api.rb +109 -0
- data/examples/typed_signatures.rb +1 -2
- data/graphql_optimization_summary.md +53 -0
- data/lib/desiru/api/grape_integration.rb +284 -0
- data/lib/desiru/api/persistence_middleware.rb +148 -0
- data/lib/desiru/api/sinatra_integration.rb +217 -0
- data/lib/desiru/api.rb +42 -0
- data/lib/desiru/assertions.rb +74 -0
- data/lib/desiru/async_status.rb +65 -0
- data/lib/desiru/cache.rb +1 -1
- data/lib/desiru/configuration.rb +2 -1
- data/lib/desiru/core/compiler.rb +231 -0
- data/lib/desiru/core/example.rb +96 -0
- data/lib/desiru/core/prediction.rb +108 -0
- data/lib/desiru/core/trace.rb +330 -0
- data/lib/desiru/core/traceable.rb +61 -0
- data/lib/desiru/core.rb +12 -0
- data/lib/desiru/errors.rb +160 -0
- data/lib/desiru/field.rb +17 -14
- data/lib/desiru/graphql/batch_loader.rb +85 -0
- data/lib/desiru/graphql/data_loader.rb +242 -75
- data/lib/desiru/graphql/enum_builder.rb +75 -0
- data/lib/desiru/graphql/executor.rb +37 -4
- data/lib/desiru/graphql/schema_generator.rb +62 -158
- data/lib/desiru/graphql/type_builder.rb +138 -0
- data/lib/desiru/graphql/type_cache_warmer.rb +91 -0
- data/lib/desiru/jobs/async_predict.rb +1 -1
- data/lib/desiru/jobs/base.rb +67 -0
- data/lib/desiru/jobs/batch_processor.rb +6 -6
- data/lib/desiru/jobs/retriable.rb +119 -0
- data/lib/desiru/jobs/retry_strategies.rb +169 -0
- data/lib/desiru/jobs/scheduler.rb +219 -0
- data/lib/desiru/jobs/webhook_notifier.rb +242 -0
- data/lib/desiru/models/anthropic.rb +164 -0
- data/lib/desiru/models/base.rb +37 -3
- data/lib/desiru/models/open_ai.rb +151 -0
- data/lib/desiru/models/open_router.rb +161 -0
- data/lib/desiru/module.rb +67 -9
- data/lib/desiru/modules/best_of_n.rb +306 -0
- data/lib/desiru/modules/chain_of_thought.rb +3 -3
- data/lib/desiru/modules/majority.rb +51 -0
- data/lib/desiru/modules/multi_chain_comparison.rb +256 -0
- data/lib/desiru/modules/predict.rb +15 -1
- data/lib/desiru/modules/program_of_thought.rb +338 -0
- data/lib/desiru/modules/react.rb +273 -0
- data/lib/desiru/modules/retrieve.rb +4 -2
- data/lib/desiru/optimizers/base.rb +32 -4
- data/lib/desiru/optimizers/bootstrap_few_shot.rb +2 -2
- data/lib/desiru/optimizers/copro.rb +268 -0
- data/lib/desiru/optimizers/knn_few_shot.rb +185 -0
- data/lib/desiru/optimizers/mipro_v2.rb +889 -0
- data/lib/desiru/persistence/database.rb +71 -0
- data/lib/desiru/persistence/models/api_request.rb +38 -0
- data/lib/desiru/persistence/models/job_result.rb +138 -0
- data/lib/desiru/persistence/models/module_execution.rb +37 -0
- data/lib/desiru/persistence/models/optimization_result.rb +28 -0
- data/lib/desiru/persistence/models/training_example.rb +25 -0
- data/lib/desiru/persistence/models.rb +11 -0
- data/lib/desiru/persistence/repositories/api_request_repository.rb +98 -0
- data/lib/desiru/persistence/repositories/base_repository.rb +77 -0
- data/lib/desiru/persistence/repositories/job_result_repository.rb +116 -0
- data/lib/desiru/persistence/repositories/module_execution_repository.rb +85 -0
- data/lib/desiru/persistence/repositories/optimization_result_repository.rb +67 -0
- data/lib/desiru/persistence/repositories/training_example_repository.rb +102 -0
- data/lib/desiru/persistence/repository.rb +29 -0
- data/lib/desiru/persistence/setup.rb +77 -0
- data/lib/desiru/persistence.rb +49 -0
- data/lib/desiru/registry.rb +3 -5
- data/lib/desiru/signature.rb +91 -24
- data/lib/desiru/version.rb +1 -1
- data/lib/desiru.rb +33 -8
- data/missing-features-analysis.md +192 -0
- metadata +75 -45
- data/lib/desiru/models/raix_adapter.rb +0 -210
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative '../module'
|
|
4
|
+
require_relative 'chain_of_thought'
|
|
5
|
+
|
|
6
|
+
module Desiru
|
|
7
|
+
module Modules
|
|
8
|
+
# ReAct (Reasoning and Acting) module for tool-using AI agents
|
|
9
|
+
# This module allows the language model to iteratively reason about a task
|
|
10
|
+
# and use tools to gather information before producing a final answer
|
|
11
|
+
class ReAct < Desiru::Module
|
|
12
|
+
attr_reader :max_iterations, :tools, :react_module, :extract_module
|
|
13
|
+
|
|
14
|
+
def initialize(signature, tools: [], max_iterations: 5, model: nil)
|
|
15
|
+
super(signature, model: model)
|
|
16
|
+
@tools = normalize_tools(tools)
|
|
17
|
+
@max_iterations = max_iterations
|
|
18
|
+
|
|
19
|
+
# Build the ReAct signature for reasoning and tool selection
|
|
20
|
+
react_signature = build_react_signature
|
|
21
|
+
@react_module = ChainOfThought.new(react_signature, model: @model)
|
|
22
|
+
|
|
23
|
+
# Build extraction signature for final output
|
|
24
|
+
extract_signature = build_extract_signature
|
|
25
|
+
@extract_module = ChainOfThought.new(extract_signature, model: @model)
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def forward(inputs)
|
|
29
|
+
trajectory = []
|
|
30
|
+
|
|
31
|
+
max_iterations.times do |_iteration|
|
|
32
|
+
# Get the next action from the model
|
|
33
|
+
react_inputs = prepare_react_inputs(inputs, trajectory)
|
|
34
|
+
react_output = react_module.call(react_inputs)
|
|
35
|
+
|
|
36
|
+
# Extract the tool name and arguments
|
|
37
|
+
tool_name = react_output[:next_tool_name]
|
|
38
|
+
tool_args = parse_tool_args(react_output[:next_tool_args])
|
|
39
|
+
|
|
40
|
+
# Add reasoning to trajectory
|
|
41
|
+
trajectory << {
|
|
42
|
+
thought: react_output[:next_thought],
|
|
43
|
+
tool: tool_name,
|
|
44
|
+
args: tool_args
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Check if we're done
|
|
48
|
+
break if tool_name == "finish"
|
|
49
|
+
|
|
50
|
+
# Execute the tool
|
|
51
|
+
begin
|
|
52
|
+
tool_result = execute_tool(tool_name, tool_args)
|
|
53
|
+
trajectory.last[:observation] = tool_result
|
|
54
|
+
rescue StandardError => e
|
|
55
|
+
trajectory.last[:observation] = "Error: #{e.message}"
|
|
56
|
+
end
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
# Extract final outputs from trajectory
|
|
60
|
+
extract_inputs = prepare_extract_inputs(inputs, trajectory)
|
|
61
|
+
extract_module.call(extract_inputs)
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
private
|
|
65
|
+
|
|
66
|
+
def normalize_tools(tools)
|
|
67
|
+
# Convert tools to a consistent format
|
|
68
|
+
normalized = {}
|
|
69
|
+
|
|
70
|
+
tools.each do |tool|
|
|
71
|
+
case tool
|
|
72
|
+
when Hash
|
|
73
|
+
# Assume hash has name and function keys
|
|
74
|
+
normalized[tool[:name] || tool["name"]] = tool[:function] || tool["function"]
|
|
75
|
+
when Array
|
|
76
|
+
# Assume array of [name, function] pairs
|
|
77
|
+
name, function = tool
|
|
78
|
+
normalized[name] = function
|
|
79
|
+
else
|
|
80
|
+
# Assume it's a callable with a name method
|
|
81
|
+
if tool.respond_to?(:name) && tool.respond_to?(:call)
|
|
82
|
+
normalized[tool.name] = tool
|
|
83
|
+
elsif tool.is_a?(Method) || tool.is_a?(Proc)
|
|
84
|
+
# Use the method/proc name or generate one
|
|
85
|
+
name = tool.respond_to?(:name) ? tool.name.to_s : "tool_#{normalized.size}"
|
|
86
|
+
normalized[name] = tool
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
# Always include the finish tool
|
|
92
|
+
normalized["finish"] = -> { "Task completed" }
|
|
93
|
+
|
|
94
|
+
normalized
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
def build_react_signature
|
|
98
|
+
# Build signature for reasoning and tool selection
|
|
99
|
+
input_fields = signature.input_fields.keys.join(", ")
|
|
100
|
+
|
|
101
|
+
# Create the ReAct signature
|
|
102
|
+
react_sig = "#{input_fields}, trajectory -> next_thought, next_tool_name, next_tool_args"
|
|
103
|
+
|
|
104
|
+
# Add instructions
|
|
105
|
+
instructions = <<~INST
|
|
106
|
+
You are an AI agent that can use tools to accomplish tasks.
|
|
107
|
+
|
|
108
|
+
Available tools:
|
|
109
|
+
#{format_tool_descriptions}
|
|
110
|
+
|
|
111
|
+
Based on the input and trajectory so far, reason about what to do next.
|
|
112
|
+
Then select a tool to use and provide the arguments for that tool.
|
|
113
|
+
|
|
114
|
+
When you have gathered enough information to answer the question,
|
|
115
|
+
use the "finish" tool to complete the task.
|
|
116
|
+
INST
|
|
117
|
+
|
|
118
|
+
Signature.new(react_sig, descriptions: { 'next_thought' => instructions })
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
def build_extract_signature
|
|
122
|
+
# Build signature for extracting final outputs
|
|
123
|
+
input_fields = signature.input_fields.keys.join(", ")
|
|
124
|
+
output_fields = signature.output_fields.keys.join(", ")
|
|
125
|
+
|
|
126
|
+
extract_sig = "#{input_fields}, trajectory -> #{output_fields}"
|
|
127
|
+
|
|
128
|
+
instructions = <<~INST
|
|
129
|
+
Based on the trajectory of thoughts and tool observations,
|
|
130
|
+
extract the final #{output_fields} to answer the original question.
|
|
131
|
+
INST
|
|
132
|
+
|
|
133
|
+
Signature.new(extract_sig, descriptions: { output_fields => instructions })
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
def format_tool_descriptions
|
|
137
|
+
tools.map do |name, function|
|
|
138
|
+
if name == "finish"
|
|
139
|
+
"- finish: Mark the task as complete when you have enough information"
|
|
140
|
+
else
|
|
141
|
+
# Try to extract description from function
|
|
142
|
+
desc = if function.respond_to?(:description)
|
|
143
|
+
function.description
|
|
144
|
+
elsif function.respond_to?(:to_s)
|
|
145
|
+
function.to_s
|
|
146
|
+
else
|
|
147
|
+
"Tool: #{name}"
|
|
148
|
+
end
|
|
149
|
+
"- #{name}: #{desc}"
|
|
150
|
+
end
|
|
151
|
+
end.join("\n")
|
|
152
|
+
end
|
|
153
|
+
|
|
154
|
+
def prepare_react_inputs(inputs, trajectory)
|
|
155
|
+
inputs.merge(
|
|
156
|
+
trajectory: format_trajectory(trajectory)
|
|
157
|
+
)
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
def prepare_extract_inputs(inputs, trajectory)
|
|
161
|
+
inputs.merge(
|
|
162
|
+
trajectory: format_trajectory(trajectory)
|
|
163
|
+
)
|
|
164
|
+
end
|
|
165
|
+
|
|
166
|
+
def format_trajectory(trajectory)
|
|
167
|
+
return "No actions taken yet." if trajectory.empty?
|
|
168
|
+
|
|
169
|
+
trajectory.map.with_index do |step, i|
|
|
170
|
+
parts = ["Step #{i + 1}:"]
|
|
171
|
+
parts << "Thought: #{step[:thought]}" if step[:thought]
|
|
172
|
+
parts << "Tool: #{step[:tool]}" if step[:tool]
|
|
173
|
+
parts << "Args: #{step[:args]}" if step[:args] && !step[:args].empty?
|
|
174
|
+
parts << "Observation: #{step[:observation]}" if step[:observation]
|
|
175
|
+
parts.join("\n")
|
|
176
|
+
end.join("\n\n")
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
def parse_tool_args(args_string)
|
|
180
|
+
# Parse tool arguments from string format
|
|
181
|
+
return {} if args_string.nil? || args_string.strip.empty?
|
|
182
|
+
|
|
183
|
+
# Try to parse as JSON first
|
|
184
|
+
begin
|
|
185
|
+
require 'json'
|
|
186
|
+
JSON.parse(args_string, symbolize_names: true)
|
|
187
|
+
rescue JSON::ParserError
|
|
188
|
+
# Fallback: parse simple key:value pairs
|
|
189
|
+
parse_simple_args(args_string)
|
|
190
|
+
end
|
|
191
|
+
end
|
|
192
|
+
|
|
193
|
+
def parse_simple_args(args_string)
|
|
194
|
+
# Parse simple key:value format
|
|
195
|
+
args = {}
|
|
196
|
+
|
|
197
|
+
# Match patterns like key:value or key=value
|
|
198
|
+
args_string.scan(/(\w+)[:=]\s*([^,]+)/).each do |key, value|
|
|
199
|
+
# Clean up the value
|
|
200
|
+
value = value.strip.gsub(/^["']|["']$/, '') # Remove quotes
|
|
201
|
+
|
|
202
|
+
# Try to convert to appropriate type
|
|
203
|
+
args[key.to_sym] = case value.downcase
|
|
204
|
+
when 'true' then true
|
|
205
|
+
when 'false' then false
|
|
206
|
+
when /^\d+$/ then value.to_i
|
|
207
|
+
when /^\d+\.\d+$/ then value.to_f
|
|
208
|
+
else value
|
|
209
|
+
end
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
args
|
|
213
|
+
end
|
|
214
|
+
|
|
215
|
+
def execute_tool(tool_name, args)
|
|
216
|
+
tool = tools[tool_name]
|
|
217
|
+
|
|
218
|
+
raise "Unknown tool: #{tool_name}" unless tool
|
|
219
|
+
|
|
220
|
+
# Call the tool with arguments
|
|
221
|
+
if tool.arity.zero?
|
|
222
|
+
tool.call
|
|
223
|
+
elsif tool.arity == 1 && args.is_a?(Hash)
|
|
224
|
+
# Pass args as keyword arguments if possible
|
|
225
|
+
if tool.respond_to?(:parameters)
|
|
226
|
+
param_types = tool.parameters.map(&:first)
|
|
227
|
+
if param_types.include?(:keyreq) || param_types.include?(:key)
|
|
228
|
+
tool.call(**args)
|
|
229
|
+
else
|
|
230
|
+
tool.call(args)
|
|
231
|
+
end
|
|
232
|
+
else
|
|
233
|
+
tool.call(args)
|
|
234
|
+
end
|
|
235
|
+
else
|
|
236
|
+
# Pass args as positional arguments
|
|
237
|
+
tool.call(*args.values)
|
|
238
|
+
end
|
|
239
|
+
end
|
|
240
|
+
|
|
241
|
+
# Support for truncating trajectory if it gets too long
|
|
242
|
+
def truncate_trajectory(trajectory, max_length: 3000)
|
|
243
|
+
formatted = format_trajectory(trajectory)
|
|
244
|
+
|
|
245
|
+
return trajectory if formatted.length <= max_length
|
|
246
|
+
|
|
247
|
+
# Remove oldest steps until we're under the limit
|
|
248
|
+
truncated = trajectory.dup
|
|
249
|
+
|
|
250
|
+
# Keep removing the oldest steps until we're under the limit
|
|
251
|
+
while truncated.length > 1
|
|
252
|
+
truncated_formatted = format_trajectory(truncated)
|
|
253
|
+
break if truncated_formatted.length <= max_length
|
|
254
|
+
|
|
255
|
+
truncated.shift
|
|
256
|
+
end
|
|
257
|
+
|
|
258
|
+
# If even a single step is too long, truncate its content
|
|
259
|
+
if truncated.length == 1 && format_trajectory(truncated).length > max_length
|
|
260
|
+
step = truncated[0]
|
|
261
|
+
# Truncate the observation if it exists and is long
|
|
262
|
+
if step[:observation] && step[:observation].length > 100
|
|
263
|
+
step[:observation] = "#{step[:observation][0..100]}... (truncated)"
|
|
264
|
+
end
|
|
265
|
+
# Truncate thought if it's very long
|
|
266
|
+
step[:thought] = "#{step[:thought][0..100]}... (truncated)" if step[:thought] && step[:thought].length > 100
|
|
267
|
+
end
|
|
268
|
+
|
|
269
|
+
truncated
|
|
270
|
+
end
|
|
271
|
+
end
|
|
272
|
+
end
|
|
273
|
+
end
|
|
@@ -21,6 +21,7 @@ module Desiru
|
|
|
21
21
|
def forward(**inputs)
|
|
22
22
|
query = inputs[:query]
|
|
23
23
|
# Handle k parameter - it might come as nil if optional
|
|
24
|
+
# Note: 'k' is the standard parameter name in information retrieval
|
|
24
25
|
k = inputs.fetch(:k, 5)
|
|
25
26
|
k = 5 if k.nil? # Ensure we have a value even if nil was passed
|
|
26
27
|
|
|
@@ -67,7 +68,7 @@ module Desiru
|
|
|
67
68
|
raise NotImplementedError, 'Subclasses must implement #add'
|
|
68
69
|
end
|
|
69
70
|
|
|
70
|
-
def search(_query, k: 5)
|
|
71
|
+
def search(_query, k: 5) # rubocop:disable Naming/MethodParameterName
|
|
71
72
|
raise NotImplementedError, 'Subclasses must implement #search'
|
|
72
73
|
end
|
|
73
74
|
|
|
@@ -83,6 +84,7 @@ module Desiru
|
|
|
83
84
|
# In-memory backend implementation for development and testing
|
|
84
85
|
class InMemoryBackend < Backend
|
|
85
86
|
def initialize(distance_metric: :cosine)
|
|
87
|
+
super()
|
|
86
88
|
@documents = []
|
|
87
89
|
@embeddings = []
|
|
88
90
|
@distance_metric = distance_metric
|
|
@@ -107,7 +109,7 @@ module Desiru
|
|
|
107
109
|
@embeddings.concat(embeddings)
|
|
108
110
|
end
|
|
109
111
|
|
|
110
|
-
def search(query, k: 5)
|
|
112
|
+
def search(query, k: 5) # rubocop:disable Naming/MethodParameterName
|
|
111
113
|
return [] if @documents.empty?
|
|
112
114
|
|
|
113
115
|
# Generate query embedding
|
|
@@ -22,7 +22,21 @@ module Desiru
|
|
|
22
22
|
|
|
23
23
|
def evaluate(program, dataset)
|
|
24
24
|
scores = dataset.map do |example|
|
|
25
|
-
|
|
25
|
+
# Extract inputs (exclude answer/output fields)
|
|
26
|
+
inputs = {}
|
|
27
|
+
if example.respond_to?(:to_h)
|
|
28
|
+
example.to_h.each do |k, v|
|
|
29
|
+
inputs[k] = v unless %i[answer output].include?(k)
|
|
30
|
+
end
|
|
31
|
+
elsif example.is_a?(Hash)
|
|
32
|
+
example.each do |k, v|
|
|
33
|
+
inputs[k] = v unless %i[answer output].include?(k.to_sym)
|
|
34
|
+
end
|
|
35
|
+
else
|
|
36
|
+
inputs = example
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
prediction = program.call(inputs)
|
|
26
40
|
score_prediction(prediction, example)
|
|
27
41
|
end
|
|
28
42
|
|
|
@@ -55,6 +69,10 @@ module Desiru
|
|
|
55
69
|
f1_score(prediction, ground_truth)
|
|
56
70
|
when :accuracy
|
|
57
71
|
accuracy_score(prediction, ground_truth)
|
|
72
|
+
when :confidence
|
|
73
|
+
confidence_score(prediction, ground_truth)
|
|
74
|
+
when :consistency
|
|
75
|
+
consistency_score(prediction, ground_truth)
|
|
58
76
|
else
|
|
59
77
|
raise OptimizerError, "Unknown metric: #{@metric}"
|
|
60
78
|
end
|
|
@@ -86,13 +104,23 @@ module Desiru
|
|
|
86
104
|
exact_match_score(prediction, ground_truth)
|
|
87
105
|
end
|
|
88
106
|
|
|
107
|
+
def confidence_score(prediction, ground_truth)
|
|
108
|
+
# Simple confidence score based on exact match
|
|
109
|
+
# In a real implementation, this would use model confidence scores
|
|
110
|
+
(exact_match_score(prediction, ground_truth) * 0.9) + 0.1
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
def consistency_score(prediction, ground_truth)
|
|
114
|
+
# Simple consistency score based on exact match
|
|
115
|
+
# In a real implementation, this would track consistency across examples
|
|
116
|
+
(exact_match_score(prediction, ground_truth) * 0.8) + 0.2
|
|
117
|
+
end
|
|
118
|
+
|
|
89
119
|
def extract_answer(data)
|
|
90
120
|
case data
|
|
91
|
-
when ModuleResult, ProgramResult
|
|
121
|
+
when ModuleResult, ProgramResult, Hash
|
|
92
122
|
# Try common answer fields
|
|
93
123
|
data[:answer] || data[:output] || data[:result] || data.values.first
|
|
94
|
-
when Hash
|
|
95
|
-
data[:answer] || data[:output] || data[:result] || data.values.first
|
|
96
124
|
else
|
|
97
125
|
data
|
|
98
126
|
end
|
|
@@ -80,7 +80,7 @@ module Desiru
|
|
|
80
80
|
|
|
81
81
|
begin
|
|
82
82
|
# Get module prediction
|
|
83
|
-
inputs = example.
|
|
83
|
+
inputs = example.except(:answer, :output)
|
|
84
84
|
prediction = module_instance.call(inputs)
|
|
85
85
|
|
|
86
86
|
# Score the prediction
|
|
@@ -110,7 +110,7 @@ module Desiru
|
|
|
110
110
|
# Add labeled examples if available
|
|
111
111
|
labeled = examples.select { |ex| ex[:answer] || ex[:output] }
|
|
112
112
|
labeled_demos = labeled.first(config[:max_labeled_demos]).map do |ex|
|
|
113
|
-
inputs = ex.
|
|
113
|
+
inputs = ex.except(:answer, :output)
|
|
114
114
|
{
|
|
115
115
|
input: format_demo_input(inputs),
|
|
116
116
|
output: format_demo_output(ex),
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Desiru
|
|
4
|
+
module Optimizers
|
|
5
|
+
# COPRO (Cooperative Prompt Optimization) optimizer
|
|
6
|
+
# Generates and refines instructions for each module using coordinate ascent
|
|
7
|
+
class COPRO < Base
|
|
8
|
+
def initialize(config = {})
|
|
9
|
+
super
|
|
10
|
+
@max_iterations = config[:max_iterations] || 10
|
|
11
|
+
@num_candidates = config[:num_candidates] || 5
|
|
12
|
+
@temperature = config[:temperature] || 0.7
|
|
13
|
+
@improvement_threshold = config[:improvement_threshold] || 0.01
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def compile(program, trainset, valset = nil, **kwargs)
|
|
17
|
+
valset ||= trainset # Use trainset for validation if no valset provided
|
|
18
|
+
|
|
19
|
+
# Initialize best score
|
|
20
|
+
best_score = evaluate_program(program, valset, kwargs[:metric])
|
|
21
|
+
best_program = program.dup
|
|
22
|
+
|
|
23
|
+
Desiru.logger.info("[COPRO] Initial score: #{best_score}")
|
|
24
|
+
|
|
25
|
+
# Iterate through optimization rounds
|
|
26
|
+
@max_iterations.times do |iteration|
|
|
27
|
+
Desiru.logger.info("[COPRO] Starting iteration #{iteration + 1}/#{@max_iterations}")
|
|
28
|
+
|
|
29
|
+
# Try to improve each predictor
|
|
30
|
+
improved = false
|
|
31
|
+
|
|
32
|
+
program.predictors.each do |name, predictor|
|
|
33
|
+
Desiru.logger.info("[COPRO] Optimizing predictor: #{name}")
|
|
34
|
+
|
|
35
|
+
# Generate instruction candidates
|
|
36
|
+
candidates = generate_instruction_candidates(predictor, trainset, name)
|
|
37
|
+
|
|
38
|
+
# Evaluate each candidate
|
|
39
|
+
best_candidate_score = best_score
|
|
40
|
+
best_candidate_instruction = nil
|
|
41
|
+
|
|
42
|
+
candidates.each do |instruction|
|
|
43
|
+
# Create program with new instruction
|
|
44
|
+
candidate_program = create_program_with_instruction(
|
|
45
|
+
best_program,
|
|
46
|
+
name,
|
|
47
|
+
instruction
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Evaluate
|
|
51
|
+
score = evaluate_program(candidate_program, valset, kwargs[:metric])
|
|
52
|
+
|
|
53
|
+
if score > best_candidate_score
|
|
54
|
+
best_candidate_score = score
|
|
55
|
+
best_candidate_instruction = instruction
|
|
56
|
+
end
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
# Update if improved
|
|
60
|
+
next unless best_candidate_instruction && (best_candidate_score - best_score) > @improvement_threshold
|
|
61
|
+
|
|
62
|
+
Desiru.logger.info("[COPRO] Improved #{name}: #{best_score} -> #{best_candidate_score}")
|
|
63
|
+
best_program = create_program_with_instruction(
|
|
64
|
+
best_program,
|
|
65
|
+
name,
|
|
66
|
+
best_candidate_instruction
|
|
67
|
+
)
|
|
68
|
+
best_score = best_candidate_score
|
|
69
|
+
improved = true
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
# Early stopping if no improvement
|
|
73
|
+
break unless improved
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
Desiru.logger.info("[COPRO] Final score: #{best_score}")
|
|
77
|
+
best_program
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
private
|
|
81
|
+
|
|
82
|
+
def generate_instruction_candidates(predictor, trainset, predictor_name)
|
|
83
|
+
candidates = []
|
|
84
|
+
|
|
85
|
+
# Get examples of good performance
|
|
86
|
+
good_examples = select_good_examples(predictor, trainset)
|
|
87
|
+
|
|
88
|
+
# Generate initial instruction based on signature
|
|
89
|
+
signature = predictor.signature
|
|
90
|
+
base_instruction = generate_base_instruction(signature, predictor_name)
|
|
91
|
+
candidates << base_instruction
|
|
92
|
+
|
|
93
|
+
# Generate variations
|
|
94
|
+
(@num_candidates - 1).times do |i|
|
|
95
|
+
variation_prompt = build_variation_prompt(
|
|
96
|
+
base_instruction,
|
|
97
|
+
signature,
|
|
98
|
+
good_examples,
|
|
99
|
+
i
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
response = model.complete(
|
|
103
|
+
messages: [{ role: 'user', content: variation_prompt }],
|
|
104
|
+
temperature: @temperature
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
instruction = extract_instruction(response[:content])
|
|
108
|
+
candidates << instruction if instruction
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
candidates.compact.uniq
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
def generate_base_instruction(signature, predictor_name)
|
|
115
|
+
instruction = "You are solving a #{predictor_name} task.\n\n"
|
|
116
|
+
|
|
117
|
+
# Add input description
|
|
118
|
+
if signature.input_fields.any?
|
|
119
|
+
instruction += "Given the following inputs:\n"
|
|
120
|
+
signature.input_fields.each do |name, field|
|
|
121
|
+
instruction += "- #{name}: #{field.description || field.type}\n"
|
|
122
|
+
end
|
|
123
|
+
instruction += "\n"
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
# Add output description
|
|
127
|
+
if signature.output_fields.any?
|
|
128
|
+
instruction += "Produce the following outputs:\n"
|
|
129
|
+
signature.output_fields.each do |name, field|
|
|
130
|
+
instruction += "- #{name}: #{field.description || field.type}\n"
|
|
131
|
+
end
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
instruction
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
def build_variation_prompt(base_instruction, signature, good_examples, variation_index)
|
|
138
|
+
prompt = "Improve the following instruction for better performance:\n\n"
|
|
139
|
+
prompt += "Current instruction:\n#{base_instruction}\n\n"
|
|
140
|
+
|
|
141
|
+
# Add task context
|
|
142
|
+
prompt += "Task signature: #{signature}\n\n"
|
|
143
|
+
|
|
144
|
+
# Add examples of good performance
|
|
145
|
+
if good_examples.any?
|
|
146
|
+
prompt += "Examples of successful completions:\n"
|
|
147
|
+
good_examples.take(3).each do |example|
|
|
148
|
+
prompt += format_example(example)
|
|
149
|
+
end
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
# Request specific type of improvement
|
|
153
|
+
improvement_types = [
|
|
154
|
+
"Make the instruction more specific and detailed",
|
|
155
|
+
"Add helpful constraints or guidelines",
|
|
156
|
+
"Clarify any ambiguous requirements",
|
|
157
|
+
"Add examples or patterns to follow",
|
|
158
|
+
"Emphasize important aspects of the task"
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
prompt += "\n#{improvement_types[variation_index % improvement_types.length]}.\n"
|
|
162
|
+
prompt += "Provide only the improved instruction:"
|
|
163
|
+
|
|
164
|
+
prompt
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
def select_good_examples(predictor, trainset)
|
|
168
|
+
good_examples = []
|
|
169
|
+
|
|
170
|
+
trainset.each do |example|
|
|
171
|
+
# Run predictor on example inputs
|
|
172
|
+
result = predictor.call(example[:inputs])
|
|
173
|
+
|
|
174
|
+
# Check if output matches expected
|
|
175
|
+
good_examples << example if outputs_match?(result, example[:outputs])
|
|
176
|
+
rescue StandardError
|
|
177
|
+
# Skip failed examples
|
|
178
|
+
end
|
|
179
|
+
|
|
180
|
+
good_examples
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
def outputs_match?(actual, expected)
|
|
184
|
+
return false unless actual.is_a?(Hash) && expected.is_a?(Hash)
|
|
185
|
+
|
|
186
|
+
expected.all? do |key, expected_value|
|
|
187
|
+
actual_value = actual[key]
|
|
188
|
+
|
|
189
|
+
# Flexible matching for different types
|
|
190
|
+
case expected_value
|
|
191
|
+
when String
|
|
192
|
+
actual_value.to_s.strip.downcase == expected_value.strip.downcase
|
|
193
|
+
when Numeric
|
|
194
|
+
(actual_value.to_f - expected_value.to_f).abs < 0.001
|
|
195
|
+
else
|
|
196
|
+
actual_value == expected_value
|
|
197
|
+
end
|
|
198
|
+
end
|
|
199
|
+
end
|
|
200
|
+
|
|
201
|
+
def format_example(example)
|
|
202
|
+
formatted = "\nExample:\n"
|
|
203
|
+
|
|
204
|
+
if example[:inputs]
|
|
205
|
+
formatted += "Inputs: "
|
|
206
|
+
formatted += example[:inputs].map { |k, v| "#{k}=#{v}" }.join(", ")
|
|
207
|
+
formatted += "\n"
|
|
208
|
+
end
|
|
209
|
+
|
|
210
|
+
if example[:outputs]
|
|
211
|
+
formatted += "Outputs: "
|
|
212
|
+
formatted += example[:outputs].map { |k, v| "#{k}=#{v}" }.join(", ")
|
|
213
|
+
formatted += "\n"
|
|
214
|
+
end
|
|
215
|
+
|
|
216
|
+
formatted
|
|
217
|
+
end
|
|
218
|
+
|
|
219
|
+
def extract_instruction(response)
|
|
220
|
+
# Clean up the response
|
|
221
|
+
instruction = response.strip
|
|
222
|
+
|
|
223
|
+
# Remove any meta-commentary
|
|
224
|
+
instruction = instruction.sub(/^(Here's |This is )?the improved instruction:?\s*/i, '')
|
|
225
|
+
instruction = instruction.sub(/^Improved instruction:?\s*/i, '')
|
|
226
|
+
|
|
227
|
+
# Remove quotes if wrapped
|
|
228
|
+
instruction.gsub(/^["']|["']$/, '')
|
|
229
|
+
end
|
|
230
|
+
|
|
231
|
+
def create_program_with_instruction(program, predictor_name, instruction)
|
|
232
|
+
new_program = program.dup
|
|
233
|
+
|
|
234
|
+
# Get the predictor
|
|
235
|
+
predictor = new_program.predictors[predictor_name]
|
|
236
|
+
return new_program unless predictor
|
|
237
|
+
|
|
238
|
+
# Create new predictor with updated instruction
|
|
239
|
+
new_predictor = predictor.dup
|
|
240
|
+
new_predictor.instance_variable_set(:@instruction, instruction)
|
|
241
|
+
|
|
242
|
+
# Update the program
|
|
243
|
+
new_program.instance_variable_set("@#{predictor_name}", new_predictor)
|
|
244
|
+
|
|
245
|
+
new_program
|
|
246
|
+
end
|
|
247
|
+
|
|
248
|
+
def evaluate_program(program, dataset, metric)
|
|
249
|
+
scores = []
|
|
250
|
+
|
|
251
|
+
dataset.each do |example|
|
|
252
|
+
# Run program
|
|
253
|
+
prediction = program.forward(**example[:inputs])
|
|
254
|
+
|
|
255
|
+
# Calculate score
|
|
256
|
+
score = metric.call(prediction, example[:outputs])
|
|
257
|
+
scores << score
|
|
258
|
+
rescue StandardError => e
|
|
259
|
+
Desiru.logger.debug("[COPRO] Evaluation error: #{e.message}")
|
|
260
|
+
scores << 0.0
|
|
261
|
+
end
|
|
262
|
+
|
|
263
|
+
# Return average score
|
|
264
|
+
scores.empty? ? 0.0 : scores.sum.to_f / scores.length
|
|
265
|
+
end
|
|
266
|
+
end
|
|
267
|
+
end
|
|
268
|
+
end
|