desiru 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.env.example +34 -0
- data/.rubocop.yml +7 -4
- data/.ruby-version +1 -0
- data/CLAUDE.md +4 -0
- data/Gemfile +21 -2
- data/Gemfile.lock +87 -12
- data/README.md +295 -2
- data/Rakefile +1 -0
- data/db/migrations/001_create_initial_tables.rb +96 -0
- data/db/migrations/002_create_job_results.rb +39 -0
- data/desiru.db +0 -0
- data/desiru.gemspec +2 -5
- data/docs/background_processing_roadmap.md +87 -0
- data/docs/job_scheduling.md +167 -0
- data/dspy-analysis-swarm.yml +60 -0
- data/dspy-feature-analysis.md +121 -0
- data/examples/README.md +69 -0
- data/examples/api_with_persistence.rb +122 -0
- data/examples/assertions_example.rb +232 -0
- data/examples/async_processing.rb +2 -0
- data/examples/few_shot_learning.rb +1 -2
- data/examples/graphql_api.rb +4 -2
- data/examples/graphql_integration.rb +3 -3
- data/examples/graphql_optimization_summary.md +143 -0
- data/examples/graphql_performance_benchmark.rb +247 -0
- data/examples/persistence_example.rb +102 -0
- data/examples/react_agent.rb +203 -0
- data/examples/rest_api.rb +173 -0
- data/examples/rest_api_advanced.rb +333 -0
- data/examples/scheduled_job_example.rb +116 -0
- data/examples/simple_qa.rb +1 -2
- data/examples/sinatra_api.rb +109 -0
- data/examples/typed_signatures.rb +1 -2
- data/graphql_optimization_summary.md +53 -0
- data/lib/desiru/api/grape_integration.rb +284 -0
- data/lib/desiru/api/persistence_middleware.rb +148 -0
- data/lib/desiru/api/sinatra_integration.rb +217 -0
- data/lib/desiru/api.rb +42 -0
- data/lib/desiru/assertions.rb +74 -0
- data/lib/desiru/async_status.rb +65 -0
- data/lib/desiru/cache.rb +1 -1
- data/lib/desiru/configuration.rb +2 -1
- data/lib/desiru/errors.rb +160 -0
- data/lib/desiru/field.rb +17 -14
- data/lib/desiru/graphql/batch_loader.rb +85 -0
- data/lib/desiru/graphql/data_loader.rb +242 -75
- data/lib/desiru/graphql/enum_builder.rb +75 -0
- data/lib/desiru/graphql/executor.rb +37 -4
- data/lib/desiru/graphql/schema_generator.rb +62 -158
- data/lib/desiru/graphql/type_builder.rb +138 -0
- data/lib/desiru/graphql/type_cache_warmer.rb +91 -0
- data/lib/desiru/jobs/async_predict.rb +1 -1
- data/lib/desiru/jobs/base.rb +67 -0
- data/lib/desiru/jobs/batch_processor.rb +6 -6
- data/lib/desiru/jobs/retriable.rb +119 -0
- data/lib/desiru/jobs/retry_strategies.rb +169 -0
- data/lib/desiru/jobs/scheduler.rb +219 -0
- data/lib/desiru/jobs/webhook_notifier.rb +242 -0
- data/lib/desiru/models/anthropic.rb +164 -0
- data/lib/desiru/models/base.rb +37 -3
- data/lib/desiru/models/open_ai.rb +151 -0
- data/lib/desiru/models/open_router.rb +161 -0
- data/lib/desiru/module.rb +59 -9
- data/lib/desiru/modules/chain_of_thought.rb +3 -3
- data/lib/desiru/modules/majority.rb +51 -0
- data/lib/desiru/modules/multi_chain_comparison.rb +204 -0
- data/lib/desiru/modules/predict.rb +8 -1
- data/lib/desiru/modules/program_of_thought.rb +139 -0
- data/lib/desiru/modules/react.rb +273 -0
- data/lib/desiru/modules/retrieve.rb +4 -2
- data/lib/desiru/optimizers/base.rb +2 -4
- data/lib/desiru/optimizers/bootstrap_few_shot.rb +2 -2
- data/lib/desiru/optimizers/copro.rb +268 -0
- data/lib/desiru/optimizers/knn_few_shot.rb +185 -0
- data/lib/desiru/persistence/database.rb +71 -0
- data/lib/desiru/persistence/models/api_request.rb +38 -0
- data/lib/desiru/persistence/models/job_result.rb +138 -0
- data/lib/desiru/persistence/models/module_execution.rb +37 -0
- data/lib/desiru/persistence/models/optimization_result.rb +28 -0
- data/lib/desiru/persistence/models/training_example.rb +25 -0
- data/lib/desiru/persistence/models.rb +11 -0
- data/lib/desiru/persistence/repositories/api_request_repository.rb +98 -0
- data/lib/desiru/persistence/repositories/base_repository.rb +77 -0
- data/lib/desiru/persistence/repositories/job_result_repository.rb +116 -0
- data/lib/desiru/persistence/repositories/module_execution_repository.rb +85 -0
- data/lib/desiru/persistence/repositories/optimization_result_repository.rb +67 -0
- data/lib/desiru/persistence/repositories/training_example_repository.rb +102 -0
- data/lib/desiru/persistence/repository.rb +29 -0
- data/lib/desiru/persistence/setup.rb +77 -0
- data/lib/desiru/persistence.rb +49 -0
- data/lib/desiru/registry.rb +3 -5
- data/lib/desiru/signature.rb +91 -24
- data/lib/desiru/version.rb +1 -1
- data/lib/desiru.rb +23 -8
- data/missing-features-analysis.md +192 -0
- metadata +63 -45
- data/lib/desiru/models/raix_adapter.rb +0 -210
@@ -0,0 +1,139 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Desiru
|
4
|
+
module Modules
|
5
|
+
# ProgramOfThought module that generates executable code to solve problems
|
6
|
+
# Similar to ChainOfThought but produces code instead of reasoning steps
|
7
|
+
class ProgramOfThought < Desiru::Module
|
8
|
+
def initialize(signature = nil, model: nil, **kwargs)
|
9
|
+
super
|
10
|
+
@max_iterations = kwargs[:max_iterations] || 1
|
11
|
+
@code_language = kwargs[:code_language] || 'ruby'
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(**inputs)
|
15
|
+
# Enhance the prompt to request code generation
|
16
|
+
code_prompt = build_code_prompt(inputs)
|
17
|
+
|
18
|
+
# Get the model to generate code
|
19
|
+
response = model.complete(
|
20
|
+
messages: [{ role: 'user', content: code_prompt }],
|
21
|
+
temperature: 0.3 # Lower temperature for more deterministic code
|
22
|
+
)
|
23
|
+
|
24
|
+
generated_code = extract_code(response[:content])
|
25
|
+
|
26
|
+
# Execute the generated code if safe
|
27
|
+
result = if safe_to_execute?(generated_code)
|
28
|
+
execute_code(generated_code, inputs)
|
29
|
+
else
|
30
|
+
{ error: "Generated code deemed unsafe to execute", code: generated_code }
|
31
|
+
end
|
32
|
+
|
33
|
+
# Format outputs according to signature
|
34
|
+
format_outputs(result, generated_code)
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def build_code_prompt(inputs)
|
40
|
+
prompt = "You are a programming assistant. Generate #{@code_language} code to solve this problem.\n\n"
|
41
|
+
|
42
|
+
# Add input context
|
43
|
+
prompt += "Given inputs:\n"
|
44
|
+
inputs.each do |key, value|
|
45
|
+
prompt += "#{key}: #{value}\n"
|
46
|
+
end
|
47
|
+
|
48
|
+
# Add expected output format
|
49
|
+
prompt += "\nExpected outputs:\n"
|
50
|
+
signature.output_fields.each do |name, field|
|
51
|
+
prompt += "- #{name} (#{field.type}): #{field.description || 'No description'}\n"
|
52
|
+
end
|
53
|
+
|
54
|
+
prompt += "\nGenerate executable #{@code_language} code that processes the inputs "
|
55
|
+
prompt += "and returns the expected outputs. "
|
56
|
+
prompt += "Wrap your code in triple backticks with the language identifier.\n"
|
57
|
+
prompt += "The code should define a method called 'solve' that takes the inputs "
|
58
|
+
prompt += "as keyword arguments and returns a hash with the output values."
|
59
|
+
|
60
|
+
prompt
|
61
|
+
end
|
62
|
+
|
63
|
+
def extract_code(response)
|
64
|
+
# Extract code from markdown code blocks
|
65
|
+
code_match = response.match(/```#{@code_language}?\n(.*?)```/m)
|
66
|
+
return code_match[1].strip if code_match
|
67
|
+
|
68
|
+
# Fallback: try to extract any code block
|
69
|
+
code_match = response.match(/```\n(.*?)```/m)
|
70
|
+
return code_match[1].strip if code_match
|
71
|
+
|
72
|
+
# Last resort: assume the entire response is code
|
73
|
+
response.strip
|
74
|
+
end
|
75
|
+
|
76
|
+
def safe_to_execute?(code)
|
77
|
+
# Basic safety checks - in production, use proper sandboxing
|
78
|
+
dangerous_patterns = [
|
79
|
+
/system\s*\(/,
|
80
|
+
/exec\s*\(/,
|
81
|
+
/eval\s*\(/,
|
82
|
+
/%x\{/,
|
83
|
+
/`.*`/,
|
84
|
+
/File\s*\.\s*delete/,
|
85
|
+
/FileUtils\s*\.\s*rm/,
|
86
|
+
/Dir\s*\.\s*delete/,
|
87
|
+
/require\s+['"]net/,
|
88
|
+
/Socket/,
|
89
|
+
/Process\s*\.\s*kill/
|
90
|
+
]
|
91
|
+
|
92
|
+
dangerous_patterns.none? { |pattern| code.match?(pattern) }
|
93
|
+
end
|
94
|
+
|
95
|
+
def execute_code(code, inputs)
|
96
|
+
# Create a safe execution context
|
97
|
+
context = Object.new
|
98
|
+
|
99
|
+
# Define the code in the context
|
100
|
+
context.instance_eval(code)
|
101
|
+
|
102
|
+
# Call the solve method if it exists
|
103
|
+
if context.respond_to?(:solve)
|
104
|
+
context.solve(**inputs.transform_keys(&:to_sym))
|
105
|
+
else
|
106
|
+
{ error: "Generated code does not define a 'solve' method" }
|
107
|
+
end
|
108
|
+
rescue StandardError => e
|
109
|
+
{ error: "Code execution failed: #{e.message}" }
|
110
|
+
end
|
111
|
+
|
112
|
+
def format_outputs(result, generated_code)
|
113
|
+
outputs = {}
|
114
|
+
|
115
|
+
# Always include the generated code
|
116
|
+
outputs[:code] = generated_code if signature.output_fields.key?(:code)
|
117
|
+
|
118
|
+
if result[:error]
|
119
|
+
# Handle error case
|
120
|
+
outputs[:error] = result[:error]
|
121
|
+
signature.output_fields.each do |name, field|
|
122
|
+
next if %i[code error].include?(name)
|
123
|
+
|
124
|
+
outputs[name] = field.default || nil
|
125
|
+
end
|
126
|
+
else
|
127
|
+
# Map result to expected outputs
|
128
|
+
signature.output_fields.each do |name, field|
|
129
|
+
next if name == :code
|
130
|
+
|
131
|
+
outputs[name] = result[name] || field.default || nil
|
132
|
+
end
|
133
|
+
end
|
134
|
+
|
135
|
+
outputs
|
136
|
+
end
|
137
|
+
end
|
138
|
+
end
|
139
|
+
end
|
@@ -0,0 +1,273 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative '../module'
|
4
|
+
require_relative 'chain_of_thought'
|
5
|
+
|
6
|
+
module Desiru
|
7
|
+
module Modules
|
8
|
+
# ReAct (Reasoning and Acting) module for tool-using AI agents
|
9
|
+
# This module allows the language model to iteratively reason about a task
|
10
|
+
# and use tools to gather information before producing a final answer
|
11
|
+
class ReAct < Desiru::Module
|
12
|
+
attr_reader :max_iterations, :tools, :react_module, :extract_module
|
13
|
+
|
14
|
+
def initialize(signature, tools: [], max_iterations: 5, model: nil)
|
15
|
+
super(signature, model: model)
|
16
|
+
@tools = normalize_tools(tools)
|
17
|
+
@max_iterations = max_iterations
|
18
|
+
|
19
|
+
# Build the ReAct signature for reasoning and tool selection
|
20
|
+
react_signature = build_react_signature
|
21
|
+
@react_module = ChainOfThought.new(react_signature, model: @model)
|
22
|
+
|
23
|
+
# Build extraction signature for final output
|
24
|
+
extract_signature = build_extract_signature
|
25
|
+
@extract_module = ChainOfThought.new(extract_signature, model: @model)
|
26
|
+
end
|
27
|
+
|
28
|
+
def forward(inputs)
|
29
|
+
trajectory = []
|
30
|
+
|
31
|
+
max_iterations.times do |_iteration|
|
32
|
+
# Get the next action from the model
|
33
|
+
react_inputs = prepare_react_inputs(inputs, trajectory)
|
34
|
+
react_output = react_module.call(react_inputs)
|
35
|
+
|
36
|
+
# Extract the tool name and arguments
|
37
|
+
tool_name = react_output[:next_tool_name]
|
38
|
+
tool_args = parse_tool_args(react_output[:next_tool_args])
|
39
|
+
|
40
|
+
# Add reasoning to trajectory
|
41
|
+
trajectory << {
|
42
|
+
thought: react_output[:next_thought],
|
43
|
+
tool: tool_name,
|
44
|
+
args: tool_args
|
45
|
+
}
|
46
|
+
|
47
|
+
# Check if we're done
|
48
|
+
break if tool_name == "finish"
|
49
|
+
|
50
|
+
# Execute the tool
|
51
|
+
begin
|
52
|
+
tool_result = execute_tool(tool_name, tool_args)
|
53
|
+
trajectory.last[:observation] = tool_result
|
54
|
+
rescue StandardError => e
|
55
|
+
trajectory.last[:observation] = "Error: #{e.message}"
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
# Extract final outputs from trajectory
|
60
|
+
extract_inputs = prepare_extract_inputs(inputs, trajectory)
|
61
|
+
extract_module.call(extract_inputs)
|
62
|
+
end
|
63
|
+
|
64
|
+
private
|
65
|
+
|
66
|
+
def normalize_tools(tools)
|
67
|
+
# Convert tools to a consistent format
|
68
|
+
normalized = {}
|
69
|
+
|
70
|
+
tools.each do |tool|
|
71
|
+
case tool
|
72
|
+
when Hash
|
73
|
+
# Assume hash has name and function keys
|
74
|
+
normalized[tool[:name] || tool["name"]] = tool[:function] || tool["function"]
|
75
|
+
when Array
|
76
|
+
# Assume array of [name, function] pairs
|
77
|
+
name, function = tool
|
78
|
+
normalized[name] = function
|
79
|
+
else
|
80
|
+
# Assume it's a callable with a name method
|
81
|
+
if tool.respond_to?(:name) && tool.respond_to?(:call)
|
82
|
+
normalized[tool.name] = tool
|
83
|
+
elsif tool.is_a?(Method) || tool.is_a?(Proc)
|
84
|
+
# Use the method/proc name or generate one
|
85
|
+
name = tool.respond_to?(:name) ? tool.name.to_s : "tool_#{normalized.size}"
|
86
|
+
normalized[name] = tool
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
|
91
|
+
# Always include the finish tool
|
92
|
+
normalized["finish"] = -> { "Task completed" }
|
93
|
+
|
94
|
+
normalized
|
95
|
+
end
|
96
|
+
|
97
|
+
def build_react_signature
|
98
|
+
# Build signature for reasoning and tool selection
|
99
|
+
input_fields = signature.input_fields.keys.join(", ")
|
100
|
+
|
101
|
+
# Create the ReAct signature
|
102
|
+
react_sig = "#{input_fields}, trajectory -> next_thought, next_tool_name, next_tool_args"
|
103
|
+
|
104
|
+
# Add instructions
|
105
|
+
instructions = <<~INST
|
106
|
+
You are an AI agent that can use tools to accomplish tasks.
|
107
|
+
|
108
|
+
Available tools:
|
109
|
+
#{format_tool_descriptions}
|
110
|
+
|
111
|
+
Based on the input and trajectory so far, reason about what to do next.
|
112
|
+
Then select a tool to use and provide the arguments for that tool.
|
113
|
+
|
114
|
+
When you have gathered enough information to answer the question,
|
115
|
+
use the "finish" tool to complete the task.
|
116
|
+
INST
|
117
|
+
|
118
|
+
Signature.new(react_sig, descriptions: { 'next_thought' => instructions })
|
119
|
+
end
|
120
|
+
|
121
|
+
def build_extract_signature
|
122
|
+
# Build signature for extracting final outputs
|
123
|
+
input_fields = signature.input_fields.keys.join(", ")
|
124
|
+
output_fields = signature.output_fields.keys.join(", ")
|
125
|
+
|
126
|
+
extract_sig = "#{input_fields}, trajectory -> #{output_fields}"
|
127
|
+
|
128
|
+
instructions = <<~INST
|
129
|
+
Based on the trajectory of thoughts and tool observations,
|
130
|
+
extract the final #{output_fields} to answer the original question.
|
131
|
+
INST
|
132
|
+
|
133
|
+
Signature.new(extract_sig, descriptions: { output_fields => instructions })
|
134
|
+
end
|
135
|
+
|
136
|
+
def format_tool_descriptions
|
137
|
+
tools.map do |name, function|
|
138
|
+
if name == "finish"
|
139
|
+
"- finish: Mark the task as complete when you have enough information"
|
140
|
+
else
|
141
|
+
# Try to extract description from function
|
142
|
+
desc = if function.respond_to?(:description)
|
143
|
+
function.description
|
144
|
+
elsif function.respond_to?(:to_s)
|
145
|
+
function.to_s
|
146
|
+
else
|
147
|
+
"Tool: #{name}"
|
148
|
+
end
|
149
|
+
"- #{name}: #{desc}"
|
150
|
+
end
|
151
|
+
end.join("\n")
|
152
|
+
end
|
153
|
+
|
154
|
+
def prepare_react_inputs(inputs, trajectory)
|
155
|
+
inputs.merge(
|
156
|
+
trajectory: format_trajectory(trajectory)
|
157
|
+
)
|
158
|
+
end
|
159
|
+
|
160
|
+
def prepare_extract_inputs(inputs, trajectory)
|
161
|
+
inputs.merge(
|
162
|
+
trajectory: format_trajectory(trajectory)
|
163
|
+
)
|
164
|
+
end
|
165
|
+
|
166
|
+
def format_trajectory(trajectory)
|
167
|
+
return "No actions taken yet." if trajectory.empty?
|
168
|
+
|
169
|
+
trajectory.map.with_index do |step, i|
|
170
|
+
parts = ["Step #{i + 1}:"]
|
171
|
+
parts << "Thought: #{step[:thought]}" if step[:thought]
|
172
|
+
parts << "Tool: #{step[:tool]}" if step[:tool]
|
173
|
+
parts << "Args: #{step[:args]}" if step[:args] && !step[:args].empty?
|
174
|
+
parts << "Observation: #{step[:observation]}" if step[:observation]
|
175
|
+
parts.join("\n")
|
176
|
+
end.join("\n\n")
|
177
|
+
end
|
178
|
+
|
179
|
+
def parse_tool_args(args_string)
|
180
|
+
# Parse tool arguments from string format
|
181
|
+
return {} if args_string.nil? || args_string.strip.empty?
|
182
|
+
|
183
|
+
# Try to parse as JSON first
|
184
|
+
begin
|
185
|
+
require 'json'
|
186
|
+
JSON.parse(args_string, symbolize_names: true)
|
187
|
+
rescue JSON::ParserError
|
188
|
+
# Fallback: parse simple key:value pairs
|
189
|
+
parse_simple_args(args_string)
|
190
|
+
end
|
191
|
+
end
|
192
|
+
|
193
|
+
def parse_simple_args(args_string)
|
194
|
+
# Parse simple key:value format
|
195
|
+
args = {}
|
196
|
+
|
197
|
+
# Match patterns like key:value or key=value
|
198
|
+
args_string.scan(/(\w+)[:=]\s*([^,]+)/).each do |key, value|
|
199
|
+
# Clean up the value
|
200
|
+
value = value.strip.gsub(/^["']|["']$/, '') # Remove quotes
|
201
|
+
|
202
|
+
# Try to convert to appropriate type
|
203
|
+
args[key.to_sym] = case value.downcase
|
204
|
+
when 'true' then true
|
205
|
+
when 'false' then false
|
206
|
+
when /^\d+$/ then value.to_i
|
207
|
+
when /^\d+\.\d+$/ then value.to_f
|
208
|
+
else value
|
209
|
+
end
|
210
|
+
end
|
211
|
+
|
212
|
+
args
|
213
|
+
end
|
214
|
+
|
215
|
+
def execute_tool(tool_name, args)
|
216
|
+
tool = tools[tool_name]
|
217
|
+
|
218
|
+
raise "Unknown tool: #{tool_name}" unless tool
|
219
|
+
|
220
|
+
# Call the tool with arguments
|
221
|
+
if tool.arity.zero?
|
222
|
+
tool.call
|
223
|
+
elsif tool.arity == 1 && args.is_a?(Hash)
|
224
|
+
# Pass args as keyword arguments if possible
|
225
|
+
if tool.respond_to?(:parameters)
|
226
|
+
param_types = tool.parameters.map(&:first)
|
227
|
+
if param_types.include?(:keyreq) || param_types.include?(:key)
|
228
|
+
tool.call(**args)
|
229
|
+
else
|
230
|
+
tool.call(args)
|
231
|
+
end
|
232
|
+
else
|
233
|
+
tool.call(args)
|
234
|
+
end
|
235
|
+
else
|
236
|
+
# Pass args as positional arguments
|
237
|
+
tool.call(*args.values)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
|
241
|
+
# Support for truncating trajectory if it gets too long
|
242
|
+
def truncate_trajectory(trajectory, max_length: 3000)
|
243
|
+
formatted = format_trajectory(trajectory)
|
244
|
+
|
245
|
+
return trajectory if formatted.length <= max_length
|
246
|
+
|
247
|
+
# Remove oldest steps until we're under the limit
|
248
|
+
truncated = trajectory.dup
|
249
|
+
|
250
|
+
# Keep removing the oldest steps until we're under the limit
|
251
|
+
while truncated.length > 1
|
252
|
+
truncated_formatted = format_trajectory(truncated)
|
253
|
+
break if truncated_formatted.length <= max_length
|
254
|
+
|
255
|
+
truncated.shift
|
256
|
+
end
|
257
|
+
|
258
|
+
# If even a single step is too long, truncate its content
|
259
|
+
if truncated.length == 1 && format_trajectory(truncated).length > max_length
|
260
|
+
step = truncated[0]
|
261
|
+
# Truncate the observation if it exists and is long
|
262
|
+
if step[:observation] && step[:observation].length > 100
|
263
|
+
step[:observation] = "#{step[:observation][0..100]}... (truncated)"
|
264
|
+
end
|
265
|
+
# Truncate thought if it's very long
|
266
|
+
step[:thought] = "#{step[:thought][0..100]}... (truncated)" if step[:thought] && step[:thought].length > 100
|
267
|
+
end
|
268
|
+
|
269
|
+
truncated
|
270
|
+
end
|
271
|
+
end
|
272
|
+
end
|
273
|
+
end
|
@@ -21,6 +21,7 @@ module Desiru
|
|
21
21
|
def forward(**inputs)
|
22
22
|
query = inputs[:query]
|
23
23
|
# Handle k parameter - it might come as nil if optional
|
24
|
+
# Note: 'k' is the standard parameter name in information retrieval
|
24
25
|
k = inputs.fetch(:k, 5)
|
25
26
|
k = 5 if k.nil? # Ensure we have a value even if nil was passed
|
26
27
|
|
@@ -67,7 +68,7 @@ module Desiru
|
|
67
68
|
raise NotImplementedError, 'Subclasses must implement #add'
|
68
69
|
end
|
69
70
|
|
70
|
-
def search(_query, k: 5)
|
71
|
+
def search(_query, k: 5) # rubocop:disable Naming/MethodParameterName
|
71
72
|
raise NotImplementedError, 'Subclasses must implement #search'
|
72
73
|
end
|
73
74
|
|
@@ -83,6 +84,7 @@ module Desiru
|
|
83
84
|
# In-memory backend implementation for development and testing
|
84
85
|
class InMemoryBackend < Backend
|
85
86
|
def initialize(distance_metric: :cosine)
|
87
|
+
super()
|
86
88
|
@documents = []
|
87
89
|
@embeddings = []
|
88
90
|
@distance_metric = distance_metric
|
@@ -107,7 +109,7 @@ module Desiru
|
|
107
109
|
@embeddings.concat(embeddings)
|
108
110
|
end
|
109
111
|
|
110
|
-
def search(query, k: 5)
|
112
|
+
def search(query, k: 5) # rubocop:disable Naming/MethodParameterName
|
111
113
|
return [] if @documents.empty?
|
112
114
|
|
113
115
|
# Generate query embedding
|
@@ -22,7 +22,7 @@ module Desiru
|
|
22
22
|
|
23
23
|
def evaluate(program, dataset)
|
24
24
|
scores = dataset.map do |example|
|
25
|
-
prediction = program.call(example.
|
25
|
+
prediction = program.call(example.except(:answer, :output))
|
26
26
|
score_prediction(prediction, example)
|
27
27
|
end
|
28
28
|
|
@@ -88,11 +88,9 @@ module Desiru
|
|
88
88
|
|
89
89
|
def extract_answer(data)
|
90
90
|
case data
|
91
|
-
when ModuleResult, ProgramResult
|
91
|
+
when ModuleResult, ProgramResult, Hash
|
92
92
|
# Try common answer fields
|
93
93
|
data[:answer] || data[:output] || data[:result] || data.values.first
|
94
|
-
when Hash
|
95
|
-
data[:answer] || data[:output] || data[:result] || data.values.first
|
96
94
|
else
|
97
95
|
data
|
98
96
|
end
|
@@ -80,7 +80,7 @@ module Desiru
|
|
80
80
|
|
81
81
|
begin
|
82
82
|
# Get module prediction
|
83
|
-
inputs = example.
|
83
|
+
inputs = example.except(:answer, :output)
|
84
84
|
prediction = module_instance.call(inputs)
|
85
85
|
|
86
86
|
# Score the prediction
|
@@ -110,7 +110,7 @@ module Desiru
|
|
110
110
|
# Add labeled examples if available
|
111
111
|
labeled = examples.select { |ex| ex[:answer] || ex[:output] }
|
112
112
|
labeled_demos = labeled.first(config[:max_labeled_demos]).map do |ex|
|
113
|
-
inputs = ex.
|
113
|
+
inputs = ex.except(:answer, :output)
|
114
114
|
{
|
115
115
|
input: format_demo_input(inputs),
|
116
116
|
output: format_demo_output(ex),
|