dspy 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +374 -3
- data/lib/dspy/chain_of_thought.rb +22 -0
- data/lib/dspy/ext/dry_schema.rb +94 -0
- data/lib/dspy/field.rb +23 -0
- data/lib/dspy/lm.rb +76 -0
- data/lib/dspy/module.rb +13 -0
- data/lib/dspy/predict.rb +72 -0
- data/lib/dspy/re_act.rb +253 -0
- data/lib/dspy/signature.rb +26 -0
- data/lib/dspy/sorbet_chain_of_thought.rb +91 -0
- data/lib/dspy/sorbet_module.rb +47 -0
- data/lib/dspy/sorbet_predict.rb +180 -0
- data/lib/dspy/sorbet_re_act.rb +332 -0
- data/lib/dspy/sorbet_signature.rb +218 -0
- data/lib/dspy/tools/sorbet_tool.rb +226 -0
- data/lib/dspy/tools.rb +21 -0
- data/lib/dspy/types.rb +3 -0
- data/lib/dspy.rb +29 -2
- metadata +117 -3
data/lib/dspy/re_act.rb
ADDED
@@ -0,0 +1,253 @@
|
|
1
|
+
module DSPy
|
2
|
+
# Define the signature for ReAct reasoning
|
3
|
+
class Thought < DSPy::Signature
|
4
|
+
description "Generate a thought about what to do next to answer the question."
|
5
|
+
|
6
|
+
input do
|
7
|
+
required(:question).value(:string).meta(description: 'The question to answer')
|
8
|
+
required(:history).value(:array).meta(description: 'Previous thoughts and actions, including observations from tools. The agent MUST use information from the history to inform its actions and final answer. Each entry is a hash representing a step in the reasoning process.')
|
9
|
+
required(:available_tools).value(:string).meta(description: 'List of available tools and their descriptions. The agent MUST choose an action from this list or use "finish".')
|
10
|
+
end
|
11
|
+
|
12
|
+
output do
|
13
|
+
required(:thought).value(:string).meta(description: 'Reasoning about what to do next, considering the history and observations.')
|
14
|
+
required(:action).value(:string).meta(description: 'The action to take. MUST be one of the tool names listed in `available_tools` input, or the literal string "finish" to provide the final answer.')
|
15
|
+
required(:action_input).value(:string).meta(description: 'Input for the chosen action. If action is "finish", this field MUST contain the final answer to the original question. This answer MUST be directly taken from the relevant Observation in the history if available. For example, if an observation showed "Observation: 100.0", and you are finishing, this field MUST be "100.0". Do not leave empty if finishing with an observed answer.')
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
# Define the signature for observing tool results
|
20
|
+
class ReActObservation < DSPy::Signature
|
21
|
+
description "Process the observation from a tool and decide what to do next."
|
22
|
+
|
23
|
+
input do
|
24
|
+
required(:question).value(:string).meta(description: 'The original question')
|
25
|
+
required(:history).value(:array).meta(description: 'Previous thoughts, actions, and observations. Each entry is a hash representing a step in the reasoning process.')
|
26
|
+
required(:observation).value(:string).meta(description: 'The result from the last action')
|
27
|
+
end
|
28
|
+
|
29
|
+
output do
|
30
|
+
required(:interpretation).value(:string).meta(description: 'Interpretation of the observation')
|
31
|
+
required(:next_step).value(:string).meta(description: 'What to do next: "continue" or "finish"')
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
# ReAct Agent Module
|
36
|
+
class ReAct < DSPy::Module
|
37
|
+
attr_reader :signature_class, :internal_output_schema, :tools, :max_iterations
|
38
|
+
|
39
|
+
# Defines the structure for each entry in the ReAct history
|
40
|
+
HistoryEntry = Struct.new(:step, :thought, :action, :action_input, :observation, keyword_init: true) do
|
41
|
+
def to_h
|
42
|
+
{
|
43
|
+
step: step,
|
44
|
+
thought: thought,
|
45
|
+
action: action,
|
46
|
+
action_input: action_input,
|
47
|
+
observation: observation
|
48
|
+
}
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
def initialize(signature_class, tools: [], max_iterations: 5)
|
53
|
+
super()
|
54
|
+
@signature_class = signature_class # User's original signature class
|
55
|
+
@thought_generator = DSPy::ChainOfThought.new(Thought)
|
56
|
+
@observation_processor = DSPy::Predict.new(ReActObservation)
|
57
|
+
@tools = tools.map { |tool| [tool.name.downcase, tool] }.to_h # Ensure tool names are stored lowercased for lookup
|
58
|
+
@max_iterations = max_iterations
|
59
|
+
|
60
|
+
# Define the schema for fields automatically added by ReAct
|
61
|
+
react_added_output_schema = Dry::Schema.JSON do
|
62
|
+
optional(:history).array(:hash) do
|
63
|
+
required(:step).value(:integer)
|
64
|
+
optional(:thought).value(:string)
|
65
|
+
optional(:action).value(:string)
|
66
|
+
optional(:action_input).maybe(:string)
|
67
|
+
optional(:observation).maybe(:string)
|
68
|
+
end
|
69
|
+
optional(:iterations).value(:integer).meta(description: 'Number of iterations taken by the ReAct agent.')
|
70
|
+
end
|
71
|
+
|
72
|
+
# Create the augmented internal output schema by combining user's output schema and ReAct's added fields
|
73
|
+
@internal_output_schema = Dry::Schema.JSON(parent: [signature_class.output_schema, react_added_output_schema])
|
74
|
+
end
|
75
|
+
|
76
|
+
def forward(**input_values)
|
77
|
+
# Validate input against the signature's input schema
|
78
|
+
input_validation_result = @signature_class.input_schema.call(input_values)
|
79
|
+
unless input_validation_result.success?
|
80
|
+
raise DSPy::PredictionInvalidError.new(input_validation_result.errors)
|
81
|
+
end
|
82
|
+
|
83
|
+
# Assume the first input field is the primary question for the ReAct loop
|
84
|
+
# This is a convention; a more robust solution might involve explicit mapping
|
85
|
+
# or requiring a specific field name like 'question'.
|
86
|
+
question_field_name = @signature_class.input_schema.key_map.first.name.to_sym
|
87
|
+
question = input_values[question_field_name]
|
88
|
+
|
89
|
+
history = [] # Initialize history as an array of HistoryEntry objects
|
90
|
+
available_tools_desc = @tools.map { |name, tool| "- #{name}: #{tool.description}" }.join("\n")
|
91
|
+
|
92
|
+
final_answer = nil
|
93
|
+
iterations_count = 0
|
94
|
+
|
95
|
+
@max_iterations.times do |i|
|
96
|
+
iterations_count = i + 1
|
97
|
+
current_step_history = { step: iterations_count }
|
98
|
+
|
99
|
+
# Generate thought and action
|
100
|
+
thought_result = @thought_generator.call(
|
101
|
+
question: question,
|
102
|
+
history: history.map(&:to_h),
|
103
|
+
available_tools: available_tools_desc
|
104
|
+
)
|
105
|
+
|
106
|
+
thought = thought_result.thought
|
107
|
+
action = thought_result.action
|
108
|
+
current_action_input = thought_result.action_input # What LM provided
|
109
|
+
|
110
|
+
current_step_history[:thought] = thought
|
111
|
+
current_step_history[:action] = action
|
112
|
+
|
113
|
+
if action.downcase == "finish"
|
114
|
+
# If LM says 'finish' but gives empty input, try to use last observation
|
115
|
+
if current_action_input.nil? || current_action_input.strip.empty?
|
116
|
+
# Try to find the last observation in history
|
117
|
+
last_entry_with_observation = history.reverse.find { |entry| entry.observation && !entry.observation.strip.empty? }
|
118
|
+
|
119
|
+
if last_entry_with_observation
|
120
|
+
last_observation_value = last_entry_with_observation.observation.strip
|
121
|
+
DSPy.logger.info(
|
122
|
+
module: "ReAct",
|
123
|
+
status: "Finish action had empty input. Overriding with last observation.",
|
124
|
+
original_input: current_action_input,
|
125
|
+
derived_input: last_observation_value
|
126
|
+
)
|
127
|
+
current_action_input = last_observation_value # Override
|
128
|
+
else
|
129
|
+
DSPy.logger.warn(module: "ReAct", status: "Finish action had empty input, no prior Observation found in history.", original_input: current_action_input)
|
130
|
+
end
|
131
|
+
end
|
132
|
+
final_answer = current_action_input # Set final answer from (potentially overridden) input
|
133
|
+
end
|
134
|
+
|
135
|
+
# Add thought to history using current_action_input, which might have been overridden for 'finish'
|
136
|
+
current_step_history[:action_input] = current_action_input
|
137
|
+
|
138
|
+
# Check if we should finish (using the original action from LM)
|
139
|
+
if action.downcase == "finish"
|
140
|
+
DSPy.logger.info(module: "ReAct", status: "Finishing loop after thought", action: action, final_answer: final_answer, question: question)
|
141
|
+
history << HistoryEntry.new(**current_step_history) # Add final thought/action before breaking
|
142
|
+
break
|
143
|
+
end
|
144
|
+
|
145
|
+
# Execute the action
|
146
|
+
observation_text = execute_action(action, current_action_input) # current_action_input is original for non-finish
|
147
|
+
current_step_history[:observation] = observation_text
|
148
|
+
history << HistoryEntry.new(**current_step_history) # Add completed step to history
|
149
|
+
|
150
|
+
# Process the observation
|
151
|
+
obs_result = @observation_processor.call(
|
152
|
+
question: question,
|
153
|
+
history: history.map(&:to_h),
|
154
|
+
observation: observation_text
|
155
|
+
)
|
156
|
+
|
157
|
+
if obs_result.next_step.downcase == "finish"
|
158
|
+
DSPy.logger.info(module: "ReAct", status: "Observation processor suggests finish. Generating final thought.", question: question, history_before_final_thought: history.map(&:to_h))
|
159
|
+
# Generate final thought/answer if observation processor decides to finish
|
160
|
+
|
161
|
+
# Create a new history entry for this final thought sequence
|
162
|
+
final_thought_step_history = { step: iterations_count + 1 } # This is like a sub-step or a new thought step
|
163
|
+
|
164
|
+
final_thought_result = @thought_generator.call(
|
165
|
+
question: question,
|
166
|
+
history: history.map(&:to_h), # history now includes the last observation
|
167
|
+
available_tools: available_tools_desc
|
168
|
+
)
|
169
|
+
DSPy.logger.info(module: "ReAct", status: "Finishing after observation and final thought", final_action: final_thought_result.action, final_action_input: final_thought_result.action_input, question: question)
|
170
|
+
|
171
|
+
final_thought_action = final_thought_result.action
|
172
|
+
final_thought_action_input_val = final_thought_result.action_input # LM provided
|
173
|
+
|
174
|
+
final_thought_step_history[:thought] = final_thought_result.thought
|
175
|
+
final_thought_step_history[:action] = final_thought_action
|
176
|
+
|
177
|
+
if final_thought_action.downcase == "finish"
|
178
|
+
if final_thought_action_input_val.nil? || final_thought_action_input_val.strip.empty?
|
179
|
+
# Find the last observation in the history array
|
180
|
+
last_entry_with_observation = history.reverse.find { |entry| entry.observation && !entry.observation.strip.empty? }
|
181
|
+
|
182
|
+
if last_entry_with_observation
|
183
|
+
last_observation_value_ft = last_entry_with_observation.observation.strip
|
184
|
+
DSPy.logger.info(
|
185
|
+
module: "ReAct",
|
186
|
+
status: "Final thought 'finish' action had empty input. Overriding with last observation.",
|
187
|
+
original_input: final_thought_action_input_val,
|
188
|
+
derived_input: last_observation_value_ft
|
189
|
+
)
|
190
|
+
final_thought_action_input_val = last_observation_value_ft # Override
|
191
|
+
else
|
192
|
+
DSPy.logger.warn(module: "ReAct", status: "Final thought 'finish' action had empty input, last observation also empty/not found cleanly.", original_input: final_thought_action_input_val)
|
193
|
+
end
|
194
|
+
else
|
195
|
+
# This case is if LM provides 'finish' but no observation to fall back on in history array (should be rare if history is populated correctly)
|
196
|
+
DSPy.logger.warn(module: "ReAct", status: "Final thought 'finish' action had empty input, no prior Observation found in history array.", original_input: final_thought_action_input_val) if (history.empty? || !history.any? { |entry| entry.observation && !entry.observation.strip.empty? })
|
197
|
+
end
|
198
|
+
end
|
199
|
+
|
200
|
+
final_thought_step_history[:action_input] = final_thought_action_input_val
|
201
|
+
history << HistoryEntry.new(**final_thought_step_history) # Add this final step to history
|
202
|
+
|
203
|
+
final_answer = final_thought_action_input_val # Use (potentially overridden) value
|
204
|
+
iterations_count += 1 # Account for this extra thought step in iterations
|
205
|
+
break
|
206
|
+
end
|
207
|
+
end
|
208
|
+
|
209
|
+
final_answer ||= "Unable to find answer within #{@max_iterations} iterations"
|
210
|
+
DSPy.logger.info(module: "ReAct", status: "Final answer determined", final_answer: final_answer, question: question) if final_answer.nil? || final_answer.empty? || final_answer == "Unable to find answer within #{@max_iterations} iterations"
|
211
|
+
|
212
|
+
# Prepare output data
|
213
|
+
output_data = {}
|
214
|
+
|
215
|
+
# Populate the primary answer field from the user's original signature
|
216
|
+
# This assumes the first defined output field in the user's signature is the main answer field.
|
217
|
+
user_primary_output_field = @signature_class.output_schema.key_map.first.name.to_sym
|
218
|
+
output_data[user_primary_output_field] = final_answer
|
219
|
+
|
220
|
+
# Add ReAct-specific fields
|
221
|
+
output_data[:history] = history.map(&:to_h) # Convert HistoryEntry objects to hashes for schema validation
|
222
|
+
output_data[:iterations] = iterations_count
|
223
|
+
|
224
|
+
# Validate and create PORO using the augmented internal_output_schema
|
225
|
+
output_validation_result = @internal_output_schema.call(output_data)
|
226
|
+
unless output_validation_result.success?
|
227
|
+
DSPy.logger.error(module: "ReAct", status: "Internal output validation failed", errors: output_validation_result.errors.to_h, data: output_data)
|
228
|
+
raise DSPy::PredictionInvalidError.new(output_validation_result.errors)
|
229
|
+
end
|
230
|
+
|
231
|
+
# Create PORO with all fields (user's + ReAct's)
|
232
|
+
# Sorting keys for Data.define ensures a consistent order for the PORO attributes.
|
233
|
+
poro_class = Data.define(*output_validation_result.to_h.keys.sort)
|
234
|
+
poro_class.new(**output_validation_result.to_h)
|
235
|
+
end
|
236
|
+
|
237
|
+
private
|
238
|
+
|
239
|
+
def execute_action(action, action_input)
|
240
|
+
tool = @tools[action.downcase] # Lookup with downcased action name
|
241
|
+
|
242
|
+
if tool.nil?
|
243
|
+
return "Error: Unknown tool '#{action}'. Available tools: #{@tools.keys.join(', ')}"
|
244
|
+
end
|
245
|
+
|
246
|
+
begin
|
247
|
+
tool.call(action_input)
|
248
|
+
rescue => e
|
249
|
+
"Error executing #{action}: #{e.message}"
|
250
|
+
end
|
251
|
+
end
|
252
|
+
end
|
253
|
+
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module DSPy
|
4
|
+
class Signature
|
5
|
+
class << self
|
6
|
+
attr_reader :input_schema
|
7
|
+
attr_accessor :output_schema
|
8
|
+
|
9
|
+
def description(text = nil)
|
10
|
+
if text
|
11
|
+
@description = text
|
12
|
+
else
|
13
|
+
@description
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
def input(&)
|
18
|
+
@input_schema= Dry::Schema::JSON(&)
|
19
|
+
end
|
20
|
+
|
21
|
+
def output(&)
|
22
|
+
@output_schema = Dry::Schema::JSON(&)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require 'sorbet-runtime'
|
5
|
+
require_relative 'sorbet_predict'
|
6
|
+
require_relative 'sorbet_signature'
|
7
|
+
|
8
|
+
module DSPy
|
9
|
+
# Enhances prediction by encouraging step-by-step reasoning
|
10
|
+
# before providing a final answer using Sorbet signatures.
|
11
|
+
class SorbetChainOfThought < SorbetPredict
|
12
|
+
extend T::Sig
|
13
|
+
|
14
|
+
FieldDescriptor = DSPy::SorbetSignature::FieldDescriptor
|
15
|
+
|
16
|
+
sig { params(signature_class: T.class_of(DSPy::SorbetSignature)).void }
|
17
|
+
def initialize(signature_class)
|
18
|
+
@original_signature = signature_class
|
19
|
+
|
20
|
+
# Create enhanced output struct with reasoning
|
21
|
+
enhanced_output_struct = create_enhanced_output_struct(signature_class)
|
22
|
+
|
23
|
+
# Create enhanced signature class
|
24
|
+
enhanced_signature = Class.new(DSPy::SorbetSignature) do
|
25
|
+
# Set the description
|
26
|
+
description "#{signature_class.description} Think step by step."
|
27
|
+
|
28
|
+
# Use the same input struct and copy field descriptors
|
29
|
+
@input_struct_class = signature_class.input_struct_class
|
30
|
+
@input_field_descriptors = signature_class.instance_variable_get(:@input_field_descriptors) || {}
|
31
|
+
|
32
|
+
# Use the enhanced output struct and create field descriptors for it
|
33
|
+
@output_struct_class = enhanced_output_struct
|
34
|
+
|
35
|
+
# Create field descriptors for the enhanced output struct
|
36
|
+
@output_field_descriptors = {}
|
37
|
+
|
38
|
+
# Copy original output field descriptors
|
39
|
+
original_output_descriptors = signature_class.instance_variable_get(:@output_field_descriptors) || {}
|
40
|
+
@output_field_descriptors.merge!(original_output_descriptors)
|
41
|
+
|
42
|
+
# Add reasoning field descriptor
|
43
|
+
@output_field_descriptors[:reasoning] = FieldDescriptor.new(String, nil)
|
44
|
+
|
45
|
+
class << self
|
46
|
+
attr_reader :input_struct_class, :output_struct_class
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
# Call parent constructor with enhanced signature
|
51
|
+
super(enhanced_signature)
|
52
|
+
end
|
53
|
+
|
54
|
+
private
|
55
|
+
|
56
|
+
sig { params(signature_class: T.class_of(DSPy::SorbetSignature)).returns(T.class_of(T::Struct)) }
|
57
|
+
def create_enhanced_output_struct(signature_class)
|
58
|
+
# Get original output props
|
59
|
+
original_props = signature_class.output_struct_class.props
|
60
|
+
|
61
|
+
# Create new struct class with reasoning added
|
62
|
+
Class.new(T::Struct) do
|
63
|
+
# Add all original fields
|
64
|
+
original_props.each do |name, prop|
|
65
|
+
# Extract the type and other options
|
66
|
+
type = prop[:type]
|
67
|
+
options = prop.except(:type, :type_object, :accessor_key, :sensitivity, :redaction)
|
68
|
+
|
69
|
+
# Handle default values
|
70
|
+
if options[:default]
|
71
|
+
const name, type, default: options[:default]
|
72
|
+
elsif options[:factory]
|
73
|
+
const name, type, factory: options[:factory]
|
74
|
+
else
|
75
|
+
const name, type
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
# Add reasoning field
|
80
|
+
const :reasoning, String
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
84
|
+
sig { override.returns(T::Hash[Symbol, T.untyped]) }
|
85
|
+
def generate_example_output
|
86
|
+
example = super
|
87
|
+
example[:reasoning] = "Let me think through this step by step..."
|
88
|
+
example
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
@@ -0,0 +1,47 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
|
5
|
+
module DSPy
|
6
|
+
class SorbetModule
|
7
|
+
extend T::Sig
|
8
|
+
extend T::Generic
|
9
|
+
|
10
|
+
# The main forward method that users will call is generic and type parameterized
|
11
|
+
sig do
|
12
|
+
type_parameters(:I, :O)
|
13
|
+
.params(
|
14
|
+
input_values: T.type_parameter(:I)
|
15
|
+
)
|
16
|
+
.returns(T.type_parameter(:O))
|
17
|
+
end
|
18
|
+
def forward(**input_values)
|
19
|
+
# Cast the result of forward_untyped to the expected output type
|
20
|
+
T.cast(forward_untyped(**input_values), T.type_parameter(:O))
|
21
|
+
end
|
22
|
+
|
23
|
+
# The implementation method that subclasses must override
|
24
|
+
sig { params(input_values: T.untyped).returns(T.untyped) }
|
25
|
+
def forward_untyped(**input_values)
|
26
|
+
raise NotImplementedError, "Subclasses must implement forward_untyped method"
|
27
|
+
end
|
28
|
+
|
29
|
+
# The main call method that users will call is generic and type parameterized
|
30
|
+
sig do
|
31
|
+
type_parameters(:I, :O)
|
32
|
+
.params(
|
33
|
+
input_values: T.type_parameter(:I)
|
34
|
+
)
|
35
|
+
.returns(T.type_parameter(:O))
|
36
|
+
end
|
37
|
+
def call(**input_values)
|
38
|
+
forward(**input_values)
|
39
|
+
end
|
40
|
+
|
41
|
+
# The implementation method for call
|
42
|
+
sig { params(input_values: T.untyped).returns(T.untyped) }
|
43
|
+
def call_untyped(**input_values)
|
44
|
+
forward_untyped(**input_values)
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,180 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
require_relative 'sorbet_module'
|
5
|
+
|
6
|
+
module DSPy
|
7
|
+
class SorbetPredict < DSPy::SorbetModule
|
8
|
+
extend T::Sig
|
9
|
+
|
10
|
+
sig { returns(T.class_of(SorbetSignature)) }
|
11
|
+
attr_reader :signature_class
|
12
|
+
|
13
|
+
sig { params(signature_class: T.class_of(SorbetSignature)).void }
|
14
|
+
def initialize(signature_class)
|
15
|
+
@signature_class = signature_class
|
16
|
+
end
|
17
|
+
|
18
|
+
sig { returns(String) }
|
19
|
+
def system_signature
|
20
|
+
<<-PROMPT
|
21
|
+
Your input schema fields are:
|
22
|
+
```json
|
23
|
+
#{JSON.generate(@signature_class.input_json_schema)}
|
24
|
+
```
|
25
|
+
Your output schema fields are:
|
26
|
+
```json
|
27
|
+
#{JSON.generate(@signature_class.output_json_schema)}
|
28
|
+
````
|
29
|
+
|
30
|
+
For example, based on the schemas above, a valid interaction would be:
|
31
|
+
## Input values
|
32
|
+
```json
|
33
|
+
#{JSON.generate(generate_example_input)}
|
34
|
+
```
|
35
|
+
## Output values
|
36
|
+
```json
|
37
|
+
#{JSON.generate(generate_example_output)}
|
38
|
+
```
|
39
|
+
|
40
|
+
All interactions will be structured in the following way, with the appropriate values filled in.
|
41
|
+
|
42
|
+
## Input values
|
43
|
+
```json
|
44
|
+
{input_values}
|
45
|
+
```
|
46
|
+
## Output values
|
47
|
+
Respond exclusively with the output schema fields in the json block below.
|
48
|
+
```json
|
49
|
+
{output_values}
|
50
|
+
```
|
51
|
+
|
52
|
+
In adhering to this structure, your objective is: #{@signature_class.description}
|
53
|
+
|
54
|
+
PROMPT
|
55
|
+
end
|
56
|
+
|
57
|
+
sig { returns(T::Hash[Symbol, T.untyped]) }
|
58
|
+
def generate_example_input
|
59
|
+
example = {}
|
60
|
+
@signature_class.input_struct_class.props.each do |name, prop|
|
61
|
+
example[name] = case prop[:type]
|
62
|
+
when T::Types::Simple
|
63
|
+
case prop[:type].raw_type.to_s
|
64
|
+
when "String" then "example text"
|
65
|
+
when "Integer" then 42
|
66
|
+
when "Float" then 3.14
|
67
|
+
else "example"
|
68
|
+
end
|
69
|
+
else
|
70
|
+
"example"
|
71
|
+
end
|
72
|
+
end
|
73
|
+
example
|
74
|
+
end
|
75
|
+
|
76
|
+
sig { returns(T::Hash[Symbol, T.untyped]) }
|
77
|
+
def generate_example_output
|
78
|
+
example = {}
|
79
|
+
@signature_class.output_struct_class.props.each do |name, prop|
|
80
|
+
example[name] = case prop[:type]
|
81
|
+
when T::Types::Simple
|
82
|
+
if prop[:type].raw_type < T::Enum
|
83
|
+
# Use the first enum value as example
|
84
|
+
prop[:type].raw_type.values.first.serialize
|
85
|
+
else
|
86
|
+
case prop[:type].raw_type.to_s
|
87
|
+
when "String" then "example result"
|
88
|
+
when "Integer" then 1
|
89
|
+
when "Float" then 0.95
|
90
|
+
else "example"
|
91
|
+
end
|
92
|
+
end
|
93
|
+
else
|
94
|
+
"example"
|
95
|
+
end
|
96
|
+
end
|
97
|
+
example
|
98
|
+
end
|
99
|
+
|
100
|
+
sig { params(input_values: T::Hash[Symbol, T.untyped]).returns(String) }
|
101
|
+
def user_signature(input_values)
|
102
|
+
<<-PROMPT
|
103
|
+
## Input Values
|
104
|
+
```json
|
105
|
+
#{JSON.generate(input_values)}
|
106
|
+
```
|
107
|
+
|
108
|
+
Respond with the corresponding output schema fields wrapped in a ```json ``` block,
|
109
|
+
starting with the heading `## Output values`.
|
110
|
+
PROMPT
|
111
|
+
end
|
112
|
+
|
113
|
+
sig { returns(DSPy::LM) }
|
114
|
+
def lm
|
115
|
+
DSPy.config.lm
|
116
|
+
end
|
117
|
+
|
118
|
+
sig { params(input_values: T.untyped).returns(T.untyped) }
|
119
|
+
def forward_untyped(**input_values)
|
120
|
+
DSPy.logger.info(module: self.class.to_s, **input_values)
|
121
|
+
|
122
|
+
# Validate input using T::Struct
|
123
|
+
begin
|
124
|
+
_input_struct = @signature_class.input_struct_class.new(**input_values)
|
125
|
+
rescue ArgumentError => e
|
126
|
+
raise PredictionInvalidError.new({ input: e.message })
|
127
|
+
end
|
128
|
+
|
129
|
+
# Use the original input_values since input_struct.to_h may not be available
|
130
|
+
# The input has already been validated through the struct instantiation
|
131
|
+
output_attributes = lm.chat(self, input_values)
|
132
|
+
|
133
|
+
# Debug: log what we got from LM
|
134
|
+
DSPy.logger.info("LM returned: #{output_attributes.inspect}")
|
135
|
+
DSPy.logger.info("Output attributes class: #{output_attributes.class}")
|
136
|
+
|
137
|
+
# Convert string keys to symbols
|
138
|
+
output_attributes = output_attributes.transform_keys(&:to_sym)
|
139
|
+
|
140
|
+
# Handle enum deserialization
|
141
|
+
output_props = @signature_class.output_struct_class.props
|
142
|
+
output_attributes = output_attributes.map do |key, value|
|
143
|
+
prop_type = output_props[key][:type] if output_props[key]
|
144
|
+
if prop_type
|
145
|
+
# Check if it's an enum (can be raw Class or T::Types::Simple)
|
146
|
+
enum_class = if prop_type.is_a?(Class) && prop_type < T::Enum
|
147
|
+
prop_type
|
148
|
+
elsif prop_type.is_a?(T::Types::Simple) && prop_type.raw_type < T::Enum
|
149
|
+
prop_type.raw_type
|
150
|
+
end
|
151
|
+
|
152
|
+
if enum_class
|
153
|
+
# Deserialize enum value
|
154
|
+
[key, enum_class.deserialize(value)]
|
155
|
+
elsif prop_type == Float || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Float)
|
156
|
+
# Coerce to Float
|
157
|
+
[key, value.to_f]
|
158
|
+
elsif prop_type == Integer || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Integer)
|
159
|
+
# Coerce to Integer
|
160
|
+
[key, value.to_i]
|
161
|
+
else
|
162
|
+
[key, value]
|
163
|
+
end
|
164
|
+
else
|
165
|
+
[key, value]
|
166
|
+
end
|
167
|
+
end.to_h
|
168
|
+
|
169
|
+
# Create output struct with validation
|
170
|
+
begin
|
171
|
+
output_struct = @signature_class.output_struct_class.new(**output_attributes)
|
172
|
+
return output_struct
|
173
|
+
rescue ArgumentError => e
|
174
|
+
raise PredictionInvalidError.new({ output: e.message })
|
175
|
+
rescue TypeError => e
|
176
|
+
raise PredictionInvalidError.new({ output: e.message })
|
177
|
+
end
|
178
|
+
end
|
179
|
+
end
|
180
|
+
end
|