dspy 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,253 @@
1
+ module DSPy
2
+ # Define the signature for ReAct reasoning
3
+ class Thought < DSPy::Signature
4
+ description "Generate a thought about what to do next to answer the question."
5
+
6
+ input do
7
+ required(:question).value(:string).meta(description: 'The question to answer')
8
+ required(:history).value(:array).meta(description: 'Previous thoughts and actions, including observations from tools. The agent MUST use information from the history to inform its actions and final answer. Each entry is a hash representing a step in the reasoning process.')
9
+ required(:available_tools).value(:string).meta(description: 'List of available tools and their descriptions. The agent MUST choose an action from this list or use "finish".')
10
+ end
11
+
12
+ output do
13
+ required(:thought).value(:string).meta(description: 'Reasoning about what to do next, considering the history and observations.')
14
+ required(:action).value(:string).meta(description: 'The action to take. MUST be one of the tool names listed in `available_tools` input, or the literal string "finish" to provide the final answer.')
15
+ required(:action_input).value(:string).meta(description: 'Input for the chosen action. If action is "finish", this field MUST contain the final answer to the original question. This answer MUST be directly taken from the relevant Observation in the history if available. For example, if an observation showed "Observation: 100.0", and you are finishing, this field MUST be "100.0". Do not leave empty if finishing with an observed answer.')
16
+ end
17
+ end
18
+
19
+ # Define the signature for observing tool results
20
+ class ReActObservation < DSPy::Signature
21
+ description "Process the observation from a tool and decide what to do next."
22
+
23
+ input do
24
+ required(:question).value(:string).meta(description: 'The original question')
25
+ required(:history).value(:array).meta(description: 'Previous thoughts, actions, and observations. Each entry is a hash representing a step in the reasoning process.')
26
+ required(:observation).value(:string).meta(description: 'The result from the last action')
27
+ end
28
+
29
+ output do
30
+ required(:interpretation).value(:string).meta(description: 'Interpretation of the observation')
31
+ required(:next_step).value(:string).meta(description: 'What to do next: "continue" or "finish"')
32
+ end
33
+ end
34
+
35
+ # ReAct Agent Module
36
+ class ReAct < DSPy::Module
37
+ attr_reader :signature_class, :internal_output_schema, :tools, :max_iterations
38
+
39
+ # Defines the structure for each entry in the ReAct history
40
+ HistoryEntry = Struct.new(:step, :thought, :action, :action_input, :observation, keyword_init: true) do
41
+ def to_h
42
+ {
43
+ step: step,
44
+ thought: thought,
45
+ action: action,
46
+ action_input: action_input,
47
+ observation: observation
48
+ }
49
+ end
50
+ end
51
+
52
+ def initialize(signature_class, tools: [], max_iterations: 5)
53
+ super()
54
+ @signature_class = signature_class # User's original signature class
55
+ @thought_generator = DSPy::ChainOfThought.new(Thought)
56
+ @observation_processor = DSPy::Predict.new(ReActObservation)
57
+ @tools = tools.map { |tool| [tool.name.downcase, tool] }.to_h # Ensure tool names are stored lowercased for lookup
58
+ @max_iterations = max_iterations
59
+
60
+ # Define the schema for fields automatically added by ReAct
61
+ react_added_output_schema = Dry::Schema.JSON do
62
+ optional(:history).array(:hash) do
63
+ required(:step).value(:integer)
64
+ optional(:thought).value(:string)
65
+ optional(:action).value(:string)
66
+ optional(:action_input).maybe(:string)
67
+ optional(:observation).maybe(:string)
68
+ end
69
+ optional(:iterations).value(:integer).meta(description: 'Number of iterations taken by the ReAct agent.')
70
+ end
71
+
72
+ # Create the augmented internal output schema by combining user's output schema and ReAct's added fields
73
+ @internal_output_schema = Dry::Schema.JSON(parent: [signature_class.output_schema, react_added_output_schema])
74
+ end
75
+
76
+ def forward(**input_values)
77
+ # Validate input against the signature's input schema
78
+ input_validation_result = @signature_class.input_schema.call(input_values)
79
+ unless input_validation_result.success?
80
+ raise DSPy::PredictionInvalidError.new(input_validation_result.errors)
81
+ end
82
+
83
+ # Assume the first input field is the primary question for the ReAct loop
84
+ # This is a convention; a more robust solution might involve explicit mapping
85
+ # or requiring a specific field name like 'question'.
86
+ question_field_name = @signature_class.input_schema.key_map.first.name.to_sym
87
+ question = input_values[question_field_name]
88
+
89
+ history = [] # Initialize history as an array of HistoryEntry objects
90
+ available_tools_desc = @tools.map { |name, tool| "- #{name}: #{tool.description}" }.join("\n")
91
+
92
+ final_answer = nil
93
+ iterations_count = 0
94
+
95
+ @max_iterations.times do |i|
96
+ iterations_count = i + 1
97
+ current_step_history = { step: iterations_count }
98
+
99
+ # Generate thought and action
100
+ thought_result = @thought_generator.call(
101
+ question: question,
102
+ history: history.map(&:to_h),
103
+ available_tools: available_tools_desc
104
+ )
105
+
106
+ thought = thought_result.thought
107
+ action = thought_result.action
108
+ current_action_input = thought_result.action_input # What LM provided
109
+
110
+ current_step_history[:thought] = thought
111
+ current_step_history[:action] = action
112
+
113
+ if action.downcase == "finish"
114
+ # If LM says 'finish' but gives empty input, try to use last observation
115
+ if current_action_input.nil? || current_action_input.strip.empty?
116
+ # Try to find the last observation in history
117
+ last_entry_with_observation = history.reverse.find { |entry| entry.observation && !entry.observation.strip.empty? }
118
+
119
+ if last_entry_with_observation
120
+ last_observation_value = last_entry_with_observation.observation.strip
121
+ DSPy.logger.info(
122
+ module: "ReAct",
123
+ status: "Finish action had empty input. Overriding with last observation.",
124
+ original_input: current_action_input,
125
+ derived_input: last_observation_value
126
+ )
127
+ current_action_input = last_observation_value # Override
128
+ else
129
+ DSPy.logger.warn(module: "ReAct", status: "Finish action had empty input, no prior Observation found in history.", original_input: current_action_input)
130
+ end
131
+ end
132
+ final_answer = current_action_input # Set final answer from (potentially overridden) input
133
+ end
134
+
135
+ # Add thought to history using current_action_input, which might have been overridden for 'finish'
136
+ current_step_history[:action_input] = current_action_input
137
+
138
+ # Check if we should finish (using the original action from LM)
139
+ if action.downcase == "finish"
140
+ DSPy.logger.info(module: "ReAct", status: "Finishing loop after thought", action: action, final_answer: final_answer, question: question)
141
+ history << HistoryEntry.new(**current_step_history) # Add final thought/action before breaking
142
+ break
143
+ end
144
+
145
+ # Execute the action
146
+ observation_text = execute_action(action, current_action_input) # current_action_input is original for non-finish
147
+ current_step_history[:observation] = observation_text
148
+ history << HistoryEntry.new(**current_step_history) # Add completed step to history
149
+
150
+ # Process the observation
151
+ obs_result = @observation_processor.call(
152
+ question: question,
153
+ history: history.map(&:to_h),
154
+ observation: observation_text
155
+ )
156
+
157
+ if obs_result.next_step.downcase == "finish"
158
+ DSPy.logger.info(module: "ReAct", status: "Observation processor suggests finish. Generating final thought.", question: question, history_before_final_thought: history.map(&:to_h))
159
+ # Generate final thought/answer if observation processor decides to finish
160
+
161
+ # Create a new history entry for this final thought sequence
162
+ final_thought_step_history = { step: iterations_count + 1 } # This is like a sub-step or a new thought step
163
+
164
+ final_thought_result = @thought_generator.call(
165
+ question: question,
166
+ history: history.map(&:to_h), # history now includes the last observation
167
+ available_tools: available_tools_desc
168
+ )
169
+ DSPy.logger.info(module: "ReAct", status: "Finishing after observation and final thought", final_action: final_thought_result.action, final_action_input: final_thought_result.action_input, question: question)
170
+
171
+ final_thought_action = final_thought_result.action
172
+ final_thought_action_input_val = final_thought_result.action_input # LM provided
173
+
174
+ final_thought_step_history[:thought] = final_thought_result.thought
175
+ final_thought_step_history[:action] = final_thought_action
176
+
177
+ if final_thought_action.downcase == "finish"
178
+ if final_thought_action_input_val.nil? || final_thought_action_input_val.strip.empty?
179
+ # Find the last observation in the history array
180
+ last_entry_with_observation = history.reverse.find { |entry| entry.observation && !entry.observation.strip.empty? }
181
+
182
+ if last_entry_with_observation
183
+ last_observation_value_ft = last_entry_with_observation.observation.strip
184
+ DSPy.logger.info(
185
+ module: "ReAct",
186
+ status: "Final thought 'finish' action had empty input. Overriding with last observation.",
187
+ original_input: final_thought_action_input_val,
188
+ derived_input: last_observation_value_ft
189
+ )
190
+ final_thought_action_input_val = last_observation_value_ft # Override
191
+ else
192
+ DSPy.logger.warn(module: "ReAct", status: "Final thought 'finish' action had empty input, last observation also empty/not found cleanly.", original_input: final_thought_action_input_val)
193
+ end
194
+ else
195
+ # This case is if LM provides 'finish' but no observation to fall back on in history array (should be rare if history is populated correctly)
196
+ DSPy.logger.warn(module: "ReAct", status: "Final thought 'finish' action had empty input, no prior Observation found in history array.", original_input: final_thought_action_input_val) if (history.empty? || !history.any? { |entry| entry.observation && !entry.observation.strip.empty? })
197
+ end
198
+ end
199
+
200
+ final_thought_step_history[:action_input] = final_thought_action_input_val
201
+ history << HistoryEntry.new(**final_thought_step_history) # Add this final step to history
202
+
203
+ final_answer = final_thought_action_input_val # Use (potentially overridden) value
204
+ iterations_count += 1 # Account for this extra thought step in iterations
205
+ break
206
+ end
207
+ end
208
+
209
+ final_answer ||= "Unable to find answer within #{@max_iterations} iterations"
210
+ DSPy.logger.info(module: "ReAct", status: "Final answer determined", final_answer: final_answer, question: question) if final_answer.nil? || final_answer.empty? || final_answer == "Unable to find answer within #{@max_iterations} iterations"
211
+
212
+ # Prepare output data
213
+ output_data = {}
214
+
215
+ # Populate the primary answer field from the user's original signature
216
+ # This assumes the first defined output field in the user's signature is the main answer field.
217
+ user_primary_output_field = @signature_class.output_schema.key_map.first.name.to_sym
218
+ output_data[user_primary_output_field] = final_answer
219
+
220
+ # Add ReAct-specific fields
221
+ output_data[:history] = history.map(&:to_h) # Convert HistoryEntry objects to hashes for schema validation
222
+ output_data[:iterations] = iterations_count
223
+
224
+ # Validate and create PORO using the augmented internal_output_schema
225
+ output_validation_result = @internal_output_schema.call(output_data)
226
+ unless output_validation_result.success?
227
+ DSPy.logger.error(module: "ReAct", status: "Internal output validation failed", errors: output_validation_result.errors.to_h, data: output_data)
228
+ raise DSPy::PredictionInvalidError.new(output_validation_result.errors)
229
+ end
230
+
231
+ # Create PORO with all fields (user's + ReAct's)
232
+ # Sorting keys for Data.define ensures a consistent order for the PORO attributes.
233
+ poro_class = Data.define(*output_validation_result.to_h.keys.sort)
234
+ poro_class.new(**output_validation_result.to_h)
235
+ end
236
+
237
+ private
238
+
239
+ def execute_action(action, action_input)
240
+ tool = @tools[action.downcase] # Lookup with downcased action name
241
+
242
+ if tool.nil?
243
+ return "Error: Unknown tool '#{action}'. Available tools: #{@tools.keys.join(', ')}"
244
+ end
245
+
246
+ begin
247
+ tool.call(action_input)
248
+ rescue => e
249
+ "Error executing #{action}: #{e.message}"
250
+ end
251
+ end
252
+ end
253
+ end
@@ -0,0 +1,26 @@
1
+ # frozen_string_literal: true
2
+
3
+ module DSPy
4
+ class Signature
5
+ class << self
6
+ attr_reader :input_schema
7
+ attr_accessor :output_schema
8
+
9
+ def description(text = nil)
10
+ if text
11
+ @description = text
12
+ else
13
+ @description
14
+ end
15
+ end
16
+
17
+ def input(&)
18
+ @input_schema= Dry::Schema::JSON(&)
19
+ end
20
+
21
+ def output(&)
22
+ @output_schema = Dry::Schema::JSON(&)
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,91 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require 'sorbet-runtime'
5
+ require_relative 'sorbet_predict'
6
+ require_relative 'sorbet_signature'
7
+
8
+ module DSPy
9
+ # Enhances prediction by encouraging step-by-step reasoning
10
+ # before providing a final answer using Sorbet signatures.
11
+ class SorbetChainOfThought < SorbetPredict
12
+ extend T::Sig
13
+
14
+ FieldDescriptor = DSPy::SorbetSignature::FieldDescriptor
15
+
16
+ sig { params(signature_class: T.class_of(DSPy::SorbetSignature)).void }
17
+ def initialize(signature_class)
18
+ @original_signature = signature_class
19
+
20
+ # Create enhanced output struct with reasoning
21
+ enhanced_output_struct = create_enhanced_output_struct(signature_class)
22
+
23
+ # Create enhanced signature class
24
+ enhanced_signature = Class.new(DSPy::SorbetSignature) do
25
+ # Set the description
26
+ description "#{signature_class.description} Think step by step."
27
+
28
+ # Use the same input struct and copy field descriptors
29
+ @input_struct_class = signature_class.input_struct_class
30
+ @input_field_descriptors = signature_class.instance_variable_get(:@input_field_descriptors) || {}
31
+
32
+ # Use the enhanced output struct and create field descriptors for it
33
+ @output_struct_class = enhanced_output_struct
34
+
35
+ # Create field descriptors for the enhanced output struct
36
+ @output_field_descriptors = {}
37
+
38
+ # Copy original output field descriptors
39
+ original_output_descriptors = signature_class.instance_variable_get(:@output_field_descriptors) || {}
40
+ @output_field_descriptors.merge!(original_output_descriptors)
41
+
42
+ # Add reasoning field descriptor
43
+ @output_field_descriptors[:reasoning] = FieldDescriptor.new(String, nil)
44
+
45
+ class << self
46
+ attr_reader :input_struct_class, :output_struct_class
47
+ end
48
+ end
49
+
50
+ # Call parent constructor with enhanced signature
51
+ super(enhanced_signature)
52
+ end
53
+
54
+ private
55
+
56
+ sig { params(signature_class: T.class_of(DSPy::SorbetSignature)).returns(T.class_of(T::Struct)) }
57
+ def create_enhanced_output_struct(signature_class)
58
+ # Get original output props
59
+ original_props = signature_class.output_struct_class.props
60
+
61
+ # Create new struct class with reasoning added
62
+ Class.new(T::Struct) do
63
+ # Add all original fields
64
+ original_props.each do |name, prop|
65
+ # Extract the type and other options
66
+ type = prop[:type]
67
+ options = prop.except(:type, :type_object, :accessor_key, :sensitivity, :redaction)
68
+
69
+ # Handle default values
70
+ if options[:default]
71
+ const name, type, default: options[:default]
72
+ elsif options[:factory]
73
+ const name, type, factory: options[:factory]
74
+ else
75
+ const name, type
76
+ end
77
+ end
78
+
79
+ # Add reasoning field
80
+ const :reasoning, String
81
+ end
82
+ end
83
+
84
+ sig { override.returns(T::Hash[Symbol, T.untyped]) }
85
+ def generate_example_output
86
+ example = super
87
+ example[:reasoning] = "Let me think through this step by step..."
88
+ example
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,47 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'sorbet-runtime'
4
+
5
+ module DSPy
6
+ class SorbetModule
7
+ extend T::Sig
8
+ extend T::Generic
9
+
10
+ # The main forward method that users will call is generic and type parameterized
11
+ sig do
12
+ type_parameters(:I, :O)
13
+ .params(
14
+ input_values: T.type_parameter(:I)
15
+ )
16
+ .returns(T.type_parameter(:O))
17
+ end
18
+ def forward(**input_values)
19
+ # Cast the result of forward_untyped to the expected output type
20
+ T.cast(forward_untyped(**input_values), T.type_parameter(:O))
21
+ end
22
+
23
+ # The implementation method that subclasses must override
24
+ sig { params(input_values: T.untyped).returns(T.untyped) }
25
+ def forward_untyped(**input_values)
26
+ raise NotImplementedError, "Subclasses must implement forward_untyped method"
27
+ end
28
+
29
+ # The main call method that users will call is generic and type parameterized
30
+ sig do
31
+ type_parameters(:I, :O)
32
+ .params(
33
+ input_values: T.type_parameter(:I)
34
+ )
35
+ .returns(T.type_parameter(:O))
36
+ end
37
+ def call(**input_values)
38
+ forward(**input_values)
39
+ end
40
+
41
+ # The implementation method for call
42
+ sig { params(input_values: T.untyped).returns(T.untyped) }
43
+ def call_untyped(**input_values)
44
+ forward_untyped(**input_values)
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,180 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'sorbet-runtime'
4
+ require_relative 'sorbet_module'
5
+
6
+ module DSPy
7
+ class SorbetPredict < DSPy::SorbetModule
8
+ extend T::Sig
9
+
10
+ sig { returns(T.class_of(SorbetSignature)) }
11
+ attr_reader :signature_class
12
+
13
+ sig { params(signature_class: T.class_of(SorbetSignature)).void }
14
+ def initialize(signature_class)
15
+ @signature_class = signature_class
16
+ end
17
+
18
+ sig { returns(String) }
19
+ def system_signature
20
+ <<-PROMPT
21
+ Your input schema fields are:
22
+ ```json
23
+ #{JSON.generate(@signature_class.input_json_schema)}
24
+ ```
25
+ Your output schema fields are:
26
+ ```json
27
+ #{JSON.generate(@signature_class.output_json_schema)}
28
+ ````
29
+
30
+ For example, based on the schemas above, a valid interaction would be:
31
+ ## Input values
32
+ ```json
33
+ #{JSON.generate(generate_example_input)}
34
+ ```
35
+ ## Output values
36
+ ```json
37
+ #{JSON.generate(generate_example_output)}
38
+ ```
39
+
40
+ All interactions will be structured in the following way, with the appropriate values filled in.
41
+
42
+ ## Input values
43
+ ```json
44
+ {input_values}
45
+ ```
46
+ ## Output values
47
+ Respond exclusively with the output schema fields in the json block below.
48
+ ```json
49
+ {output_values}
50
+ ```
51
+
52
+ In adhering to this structure, your objective is: #{@signature_class.description}
53
+
54
+ PROMPT
55
+ end
56
+
57
+ sig { returns(T::Hash[Symbol, T.untyped]) }
58
+ def generate_example_input
59
+ example = {}
60
+ @signature_class.input_struct_class.props.each do |name, prop|
61
+ example[name] = case prop[:type]
62
+ when T::Types::Simple
63
+ case prop[:type].raw_type.to_s
64
+ when "String" then "example text"
65
+ when "Integer" then 42
66
+ when "Float" then 3.14
67
+ else "example"
68
+ end
69
+ else
70
+ "example"
71
+ end
72
+ end
73
+ example
74
+ end
75
+
76
+ sig { returns(T::Hash[Symbol, T.untyped]) }
77
+ def generate_example_output
78
+ example = {}
79
+ @signature_class.output_struct_class.props.each do |name, prop|
80
+ example[name] = case prop[:type]
81
+ when T::Types::Simple
82
+ if prop[:type].raw_type < T::Enum
83
+ # Use the first enum value as example
84
+ prop[:type].raw_type.values.first.serialize
85
+ else
86
+ case prop[:type].raw_type.to_s
87
+ when "String" then "example result"
88
+ when "Integer" then 1
89
+ when "Float" then 0.95
90
+ else "example"
91
+ end
92
+ end
93
+ else
94
+ "example"
95
+ end
96
+ end
97
+ example
98
+ end
99
+
100
+ sig { params(input_values: T::Hash[Symbol, T.untyped]).returns(String) }
101
+ def user_signature(input_values)
102
+ <<-PROMPT
103
+ ## Input Values
104
+ ```json
105
+ #{JSON.generate(input_values)}
106
+ ```
107
+
108
+ Respond with the corresponding output schema fields wrapped in a ```json ``` block,
109
+ starting with the heading `## Output values`.
110
+ PROMPT
111
+ end
112
+
113
+ sig { returns(DSPy::LM) }
114
+ def lm
115
+ DSPy.config.lm
116
+ end
117
+
118
+ sig { params(input_values: T.untyped).returns(T.untyped) }
119
+ def forward_untyped(**input_values)
120
+ DSPy.logger.info(module: self.class.to_s, **input_values)
121
+
122
+ # Validate input using T::Struct
123
+ begin
124
+ _input_struct = @signature_class.input_struct_class.new(**input_values)
125
+ rescue ArgumentError => e
126
+ raise PredictionInvalidError.new({ input: e.message })
127
+ end
128
+
129
+ # Use the original input_values since input_struct.to_h may not be available
130
+ # The input has already been validated through the struct instantiation
131
+ output_attributes = lm.chat(self, input_values)
132
+
133
+ # Debug: log what we got from LM
134
+ DSPy.logger.info("LM returned: #{output_attributes.inspect}")
135
+ DSPy.logger.info("Output attributes class: #{output_attributes.class}")
136
+
137
+ # Convert string keys to symbols
138
+ output_attributes = output_attributes.transform_keys(&:to_sym)
139
+
140
+ # Handle enum deserialization
141
+ output_props = @signature_class.output_struct_class.props
142
+ output_attributes = output_attributes.map do |key, value|
143
+ prop_type = output_props[key][:type] if output_props[key]
144
+ if prop_type
145
+ # Check if it's an enum (can be raw Class or T::Types::Simple)
146
+ enum_class = if prop_type.is_a?(Class) && prop_type < T::Enum
147
+ prop_type
148
+ elsif prop_type.is_a?(T::Types::Simple) && prop_type.raw_type < T::Enum
149
+ prop_type.raw_type
150
+ end
151
+
152
+ if enum_class
153
+ # Deserialize enum value
154
+ [key, enum_class.deserialize(value)]
155
+ elsif prop_type == Float || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Float)
156
+ # Coerce to Float
157
+ [key, value.to_f]
158
+ elsif prop_type == Integer || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Integer)
159
+ # Coerce to Integer
160
+ [key, value.to_i]
161
+ else
162
+ [key, value]
163
+ end
164
+ else
165
+ [key, value]
166
+ end
167
+ end.to_h
168
+
169
+ # Create output struct with validation
170
+ begin
171
+ output_struct = @signature_class.output_struct_class.new(**output_attributes)
172
+ return output_struct
173
+ rescue ArgumentError => e
174
+ raise PredictionInvalidError.new({ output: e.message })
175
+ rescue TypeError => e
176
+ raise PredictionInvalidError.new({ output: e.message })
177
+ end
178
+ end
179
+ end
180
+ end