dspy 0.3.1 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/lib/dspy/re_act.rb CHANGED
@@ -7,6 +7,8 @@ require_relative 'signature'
7
7
  require_relative 'chain_of_thought'
8
8
  require 'json'
9
9
  require_relative 'instrumentation'
10
+ require_relative 'mixins/struct_builder'
11
+ require_relative 'mixins/instrumentation_helpers'
10
12
 
11
13
  module DSPy
12
14
  # Define a simple struct for history entries with proper type annotations
@@ -82,6 +84,8 @@ module DSPy
82
84
  # ReAct Agent using Sorbet signatures
83
85
  class ReAct < Predict
84
86
  extend T::Sig
87
+ include Mixins::StructBuilder
88
+ include Mixins::InstrumentationHelpers
85
89
 
86
90
  FINISH_ACTION = "finish"
87
91
  sig { returns(T.class_of(DSPy::Signature)) }
@@ -97,7 +101,7 @@ module DSPy
97
101
  attr_reader :max_iterations
98
102
 
99
103
 
100
- sig { params(signature_class: T.class_of(DSPy::Signature), tools: T::Array[T.untyped], max_iterations: Integer).void }
104
+ sig { params(signature_class: T.class_of(DSPy::Signature), tools: T::Array[DSPy::Tools::Base], max_iterations: Integer).void }
101
105
  def initialize(signature_class, tools: [], max_iterations: 5)
102
106
  @original_signature_class = signature_class
103
107
  @tools = T.let({}, T::Hash[String, T.untyped])
@@ -137,206 +141,271 @@ module DSPy
137
141
  sig { params(kwargs: T.untyped).returns(T.untyped).override }
138
142
  def forward(**kwargs)
139
143
  lm = config.lm || DSPy.config.lm
140
- # Prepare instrumentation payload
141
- input_fields = kwargs.keys.map(&:to_s)
142
144
  available_tools = @tools.keys
143
145
 
144
146
  # Instrument the entire ReAct agent lifecycle
145
- result = Instrumentation.instrument('dspy.react', {
146
- signature_class: @original_signature_class.name,
147
- model: lm.model,
148
- provider: lm.provider,
149
- input_fields: input_fields,
147
+ result = instrument_prediction('dspy.react', @original_signature_class, kwargs, {
150
148
  max_iterations: @max_iterations,
151
149
  available_tools: available_tools
152
150
  }) do
153
- # Validate input using Sorbet struct validation
151
+ # Validate input and extract question
154
152
  input_struct = @original_signature_class.input_struct_class.new(**kwargs)
155
-
156
- # Get the question (assume first field is the question for now)
157
153
  question = T.cast(input_struct.serialize.values.first, String)
158
154
 
159
- history = T.let([], T::Array[HistoryEntry])
160
- available_tools_desc = @tools.map { |name, tool| JSON.parse(tool.schema) }
161
-
162
- final_answer = T.let(nil, T.nilable(String))
163
- iterations_count = 0
164
- last_observation = T.let(nil, T.nilable(String))
165
- tools_used = []
166
-
167
- while @max_iterations.nil? || iterations_count < @max_iterations
168
- iterations_count += 1
169
-
170
- # Instrument each iteration
171
- iteration_result = Instrumentation.instrument('dspy.react.iteration', {
172
- iteration: iterations_count,
173
- max_iterations: @max_iterations,
174
- history_length: history.length,
175
- tools_used_so_far: tools_used.uniq
176
- }) do
177
- # Get next thought from LM
178
- thought_obj = @thought_generator.forward(
179
- question: question,
180
- history: history,
181
- available_tools: available_tools_desc
182
- )
183
- step = iterations_count
184
- thought = thought_obj.thought
185
- action = thought_obj.action
186
- action_input = thought_obj.action_input
187
-
188
- # Break if finish action
189
- if action&.downcase == 'finish'
190
- final_answer = handle_finish_action(action_input, last_observation, step, thought, action, history)
191
- break
192
- end
193
-
194
- # Track tools used
195
- tools_used << action.downcase if action && @tools[action.downcase]
196
-
197
- # Execute action
198
- observation = if action && @tools[action.downcase]
199
- # Instrument tool call
200
- Instrumentation.instrument('dspy.react.tool_call', {
201
- iteration: iterations_count,
202
- tool_name: action.downcase,
203
- tool_input: action_input
204
- }) do
205
- execute_action(action, action_input)
206
- end
207
- else
208
- "Unknown action: #{action}. Available actions: #{@tools.keys.join(', ')}, finish"
209
- end
210
-
211
- last_observation = observation
212
-
213
- # Add to history
214
- history << HistoryEntry.new(
215
- step: step,
216
- thought: thought,
217
- action: action,
218
- action_input: action_input,
219
- observation: observation
220
- )
221
-
222
- # Process observation to decide next step
223
- if observation && !observation.include?("Unknown action")
224
- observation_result = @observation_processor.forward(
225
- question: question,
226
- history: history,
227
- observation: observation
228
- )
229
-
230
- # If observation processor suggests finishing, generate final thought
231
- if observation_result.next_step == NextStep::Finish
232
- final_thought = @thought_generator.forward(
233
- question: question,
234
- history: history,
235
- available_tools: available_tools_desc
236
- )
237
-
238
- # Force finish action if observation processor suggests it
239
- if final_thought.action&.downcase != 'finish'
240
- forced_answer = if observation_result.interpretation && !observation_result.interpretation.empty?
241
- observation_result.interpretation
242
- else
243
- observation
244
- end
245
- final_answer = handle_finish_action(forced_answer, last_observation, step + 1, final_thought.thought, 'finish', history)
246
- else
247
- final_answer = handle_finish_action(final_thought.action_input, last_observation, step + 1, final_thought.thought, final_thought.action, history)
248
- end
249
- break
250
- end
251
- end
252
-
253
- # Emit iteration complete event
254
- Instrumentation.emit('dspy.react.iteration_complete', {
255
- iteration: iterations_count,
256
- thought: thought,
257
- action: action,
258
- action_input: action_input,
259
- observation: observation,
260
- tools_used: tools_used.uniq
261
- })
262
- end
263
-
264
- # Check if max iterations reached
265
- if iterations_count >= @max_iterations && final_answer.nil?
266
- Instrumentation.emit('dspy.react.max_iterations', {
267
- iteration_count: iterations_count,
268
- max_iterations: @max_iterations,
269
- tools_used: tools_used.uniq,
270
- final_history_length: history.length
271
- })
272
- end
273
- end
155
+ # Execute ReAct reasoning loop
156
+ reasoning_result = execute_react_reasoning_loop(question)
274
157
 
275
158
  # Create enhanced output with all ReAct data
276
- output_field_name = @original_signature_class.output_struct_class.props.keys.first
277
- output_data = kwargs.merge({
278
- history: history.map(&:to_h),
279
- iterations: iterations_count,
280
- tools_used: tools_used.uniq
281
- })
282
- output_data[output_field_name] = final_answer || "No answer reached within #{@max_iterations} iterations"
283
- enhanced_output = @enhanced_output_struct.new(**output_data)
284
-
285
- enhanced_output
159
+ create_enhanced_result(kwargs, reasoning_result)
286
160
  end
287
-
161
+
288
162
  result
289
163
  end
290
164
 
291
165
  private
292
166
 
167
+ # Executes the main ReAct reasoning loop
168
+ sig { params(question: String).returns(T::Hash[Symbol, T.untyped]) }
169
+ def execute_react_reasoning_loop(question)
170
+ history = T.let([], T::Array[HistoryEntry])
171
+ available_tools_desc = @tools.map { |name, tool| JSON.parse(tool.schema) }
172
+ final_answer = T.let(nil, T.nilable(String))
173
+ iterations_count = 0
174
+ last_observation = T.let(nil, T.nilable(String))
175
+ tools_used = []
176
+
177
+ while should_continue_iteration?(iterations_count, final_answer)
178
+ iterations_count += 1
179
+
180
+ iteration_result = execute_single_iteration(
181
+ question, history, available_tools_desc, iterations_count, tools_used, last_observation
182
+ )
183
+
184
+ if iteration_result[:should_finish]
185
+ final_answer = iteration_result[:final_answer]
186
+ break
187
+ end
188
+
189
+ history = iteration_result[:history]
190
+ tools_used = iteration_result[:tools_used]
191
+ last_observation = iteration_result[:last_observation]
192
+ end
193
+
194
+ handle_max_iterations_if_needed(iterations_count, final_answer, tools_used, history)
195
+
196
+ {
197
+ history: history,
198
+ iterations: iterations_count,
199
+ tools_used: tools_used.uniq,
200
+ final_answer: final_answer || default_no_answer_message
201
+ }
202
+ end
203
+
204
+ # Executes a single iteration of the ReAct loop
205
+ sig { params(question: String, history: T::Array[HistoryEntry], available_tools_desc: T::Array[T::Hash[String, T.untyped]], iteration: Integer, tools_used: T::Array[String], last_observation: T.nilable(String)).returns(T::Hash[Symbol, T.untyped]) }
206
+ def execute_single_iteration(question, history, available_tools_desc, iteration, tools_used, last_observation)
207
+ # Instrument each iteration
208
+ Instrumentation.instrument('dspy.react.iteration', {
209
+ iteration: iteration,
210
+ max_iterations: @max_iterations,
211
+ history_length: history.length,
212
+ tools_used_so_far: tools_used.uniq
213
+ }) do
214
+ # Generate thought and action
215
+ thought_obj = @thought_generator.forward(
216
+ question: question,
217
+ history: history,
218
+ available_tools: available_tools_desc
219
+ )
220
+
221
+ # Process thought result
222
+ if finish_action?(thought_obj.action)
223
+ final_answer = handle_finish_action(
224
+ thought_obj.action_input, last_observation, iteration,
225
+ thought_obj.thought, thought_obj.action, history
226
+ )
227
+ return { should_finish: true, final_answer: final_answer }
228
+ end
229
+
230
+ # Execute tool action
231
+ observation = execute_tool_with_instrumentation(
232
+ thought_obj.action, thought_obj.action_input, iteration
233
+ )
234
+
235
+ # Track tools used
236
+ tools_used << thought_obj.action.downcase if valid_tool?(thought_obj.action)
237
+
238
+ # Add to history
239
+ history << create_history_entry(
240
+ iteration, thought_obj.thought, thought_obj.action,
241
+ thought_obj.action_input, observation
242
+ )
243
+
244
+ # Process observation and decide next step
245
+ observation_decision = process_observation_and_decide_next_step(
246
+ question, history, observation, available_tools_desc, iteration
247
+ )
248
+
249
+ if observation_decision[:should_finish]
250
+ return { should_finish: true, final_answer: observation_decision[:final_answer] }
251
+ end
252
+
253
+ emit_iteration_complete_event(
254
+ iteration, thought_obj.thought, thought_obj.action,
255
+ thought_obj.action_input, observation, tools_used
256
+ )
257
+
258
+ {
259
+ should_finish: false,
260
+ history: history,
261
+ tools_used: tools_used,
262
+ last_observation: observation
263
+ }
264
+ end
265
+ end
266
+
267
+ # Creates enhanced output struct with ReAct-specific fields
293
268
  sig { params(signature_class: T.class_of(DSPy::Signature)).returns(T.class_of(T::Struct)) }
294
269
  def create_enhanced_output_struct(signature_class)
295
- # Get original input and output props
296
270
  input_props = signature_class.input_struct_class.props
297
271
  output_props = signature_class.output_struct_class.props
298
272
 
299
- # Create new struct class with input, output, and ReAct fields
300
- Class.new(T::Struct) do
301
- # Add all input fields
302
- input_props.each do |name, prop|
303
- # Extract the type and other options
304
- type = prop[:type]
305
- options = prop.except(:type, :type_object, :accessor_key, :sensitivity, :redaction)
306
-
307
- # Handle default values
308
- if options[:default]
309
- const name, type, default: options[:default]
310
- elsif options[:factory]
311
- const name, type, factory: options[:factory]
312
- else
313
- const name, type
314
- end
315
- end
273
+ build_enhanced_struct(
274
+ { input: input_props, output: output_props },
275
+ {
276
+ history: [T::Array[T::Hash[Symbol, T.untyped]], "ReAct execution history"],
277
+ iterations: [Integer, "Number of iterations executed"],
278
+ tools_used: [T::Array[String], "List of tools used during execution"]
279
+ }
280
+ )
281
+ end
282
+
283
+ # Creates enhanced result struct
284
+ sig { params(input_kwargs: T::Hash[Symbol, T.untyped], reasoning_result: T::Hash[Symbol, T.untyped]).returns(T.untyped) }
285
+ def create_enhanced_result(input_kwargs, reasoning_result)
286
+ output_field_name = @original_signature_class.output_struct_class.props.keys.first
287
+
288
+ output_data = input_kwargs.merge({
289
+ history: reasoning_result[:history].map(&:to_h),
290
+ iterations: reasoning_result[:iterations],
291
+ tools_used: reasoning_result[:tools_used]
292
+ })
293
+ output_data[output_field_name] = reasoning_result[:final_answer]
294
+
295
+ @enhanced_output_struct.new(**output_data)
296
+ end
316
297
 
317
- # Add all output fields
318
- output_props.each do |name, prop|
319
- # Extract the type and other options
320
- type = prop[:type]
321
- options = prop.except(:type, :type_object, :accessor_key, :sensitivity, :redaction)
322
-
323
- # Handle default values
324
- if options[:default]
325
- const name, type, default: options[:default]
326
- elsif options[:factory]
327
- const name, type, factory: options[:factory]
328
- else
329
- const name, type
330
- end
298
+ # Helper methods for ReAct logic
299
+ sig { params(iterations_count: Integer, final_answer: T.nilable(String)).returns(T::Boolean) }
300
+ def should_continue_iteration?(iterations_count, final_answer)
301
+ final_answer.nil? && (@max_iterations.nil? || iterations_count < @max_iterations)
302
+ end
303
+
304
+ sig { params(action: T.nilable(String)).returns(T::Boolean) }
305
+ def finish_action?(action)
306
+ action&.downcase == FINISH_ACTION
307
+ end
308
+
309
+ sig { params(action: T.nilable(String)).returns(T::Boolean) }
310
+ def valid_tool?(action)
311
+ !!(action && @tools[action.downcase])
312
+ end
313
+
314
+ sig { params(action: T.nilable(String), action_input: T.untyped, iteration: Integer).returns(String) }
315
+ def execute_tool_with_instrumentation(action, action_input, iteration)
316
+ if action && @tools[action.downcase]
317
+ Instrumentation.instrument('dspy.react.tool_call', {
318
+ iteration: iteration,
319
+ tool_name: action.downcase,
320
+ tool_input: action_input
321
+ }) do
322
+ execute_action(action, action_input)
331
323
  end
324
+ else
325
+ "Unknown action: #{action}. Available actions: #{@tools.keys.join(', ')}, finish"
326
+ end
327
+ end
328
+
329
+ sig { params(step: Integer, thought: String, action: String, action_input: T.untyped, observation: String).returns(HistoryEntry) }
330
+ def create_history_entry(step, thought, action, action_input, observation)
331
+ HistoryEntry.new(
332
+ step: step,
333
+ thought: thought,
334
+ action: action,
335
+ action_input: action_input,
336
+ observation: observation
337
+ )
338
+ end
339
+
340
+ sig { params(question: String, history: T::Array[HistoryEntry], observation: String, available_tools_desc: T::Array[T::Hash[String, T.untyped]], iteration: Integer).returns(T::Hash[Symbol, T.untyped]) }
341
+ def process_observation_and_decide_next_step(question, history, observation, available_tools_desc, iteration)
342
+ return { should_finish: false } if observation.include?("Unknown action")
343
+
344
+ observation_result = @observation_processor.forward(
345
+ question: question,
346
+ history: history,
347
+ observation: observation
348
+ )
349
+
350
+ return { should_finish: false } unless observation_result.next_step == NextStep::Finish
351
+
352
+ final_answer = generate_forced_final_answer(
353
+ question, history, available_tools_desc, observation_result, iteration
354
+ )
355
+
356
+ { should_finish: true, final_answer: final_answer }
357
+ end
358
+
359
+ sig { params(question: String, history: T::Array[HistoryEntry], available_tools_desc: T::Array[T::Hash[String, T.untyped]], observation_result: T.untyped, iteration: Integer).returns(String) }
360
+ def generate_forced_final_answer(question, history, available_tools_desc, observation_result, iteration)
361
+ final_thought = @thought_generator.forward(
362
+ question: question,
363
+ history: history,
364
+ available_tools: available_tools_desc
365
+ )
332
366
 
333
- # Add ReAct-specific fields
334
- prop :history, T::Array[T::Hash[Symbol, T.untyped]]
335
- prop :iterations, Integer
336
- prop :tools_used, T::Array[String]
367
+ if final_thought.action&.downcase != FINISH_ACTION
368
+ forced_answer = if observation_result.interpretation && !observation_result.interpretation.empty?
369
+ observation_result.interpretation
370
+ else
371
+ history.last&.observation || "No answer available"
372
+ end
373
+ handle_finish_action(forced_answer, history.last&.observation, iteration + 1, final_thought.thought, FINISH_ACTION, history)
374
+ else
375
+ handle_finish_action(final_thought.action_input, history.last&.observation, iteration + 1, final_thought.thought, final_thought.action, history)
337
376
  end
338
377
  end
339
378
 
379
+ sig { params(iteration: Integer, thought: String, action: String, action_input: T.untyped, observation: String, tools_used: T::Array[String]).void }
380
+ def emit_iteration_complete_event(iteration, thought, action, action_input, observation, tools_used)
381
+ Instrumentation.emit('dspy.react.iteration_complete', {
382
+ iteration: iteration,
383
+ thought: thought,
384
+ action: action,
385
+ action_input: action_input,
386
+ observation: observation,
387
+ tools_used: tools_used.uniq
388
+ })
389
+ end
390
+
391
+ sig { params(iterations_count: Integer, final_answer: T.nilable(String), tools_used: T::Array[String], history: T::Array[HistoryEntry]).void }
392
+ def handle_max_iterations_if_needed(iterations_count, final_answer, tools_used, history)
393
+ if iterations_count >= @max_iterations && final_answer.nil?
394
+ Instrumentation.emit('dspy.react.max_iterations', {
395
+ iteration_count: iterations_count,
396
+ max_iterations: @max_iterations,
397
+ tools_used: tools_used.uniq,
398
+ final_history_length: history.length
399
+ })
400
+ end
401
+ end
402
+
403
+ sig { returns(String) }
404
+ def default_no_answer_message
405
+ "No answer reached within #{@max_iterations} iterations"
406
+ end
407
+
408
+ # Tool execution method
340
409
  sig { params(action: String, action_input: T.untyped).returns(String) }
341
410
  def execute_action(action, action_input)
342
411
  tool_name = action.downcase