dspy 0.4.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,119 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require 'sorbet-runtime'
5
+ require_relative '../instrumentation'
6
+
7
+ module DSPy
8
+ module Mixins
9
+ # Shared instrumentation helper methods for DSPy modules
10
+ module InstrumentationHelpers
11
+ extend T::Sig
12
+
13
+ private
14
+
15
+ # Prepares base instrumentation payload for prediction-based modules
16
+ sig { params(signature_class: T.class_of(DSPy::Signature), input_values: T::Hash[Symbol, T.untyped]).returns(T::Hash[Symbol, T.untyped]) }
17
+ def prepare_base_instrumentation_payload(signature_class, input_values)
18
+ {
19
+ signature_class: signature_class.name,
20
+ model: lm.model,
21
+ provider: lm.provider,
22
+ input_fields: input_values.keys.map(&:to_s)
23
+ }
24
+ end
25
+
26
+ # Instruments a prediction operation with base payload
27
+ sig { params(event_name: String, signature_class: T.class_of(DSPy::Signature), input_values: T::Hash[Symbol, T.untyped], additional_payload: T::Hash[Symbol, T.untyped]).returns(T.untyped) }
28
+ def instrument_prediction(event_name, signature_class, input_values, additional_payload = {})
29
+ base_payload = prepare_base_instrumentation_payload(signature_class, input_values)
30
+ full_payload = base_payload.merge(additional_payload)
31
+
32
+ # Use smart consolidation: skip nested events when higher-level events are being emitted
33
+ if should_emit_event?(event_name)
34
+ Instrumentation.instrument(event_name, full_payload) do
35
+ yield
36
+ end
37
+ else
38
+ # Skip instrumentation, just execute the block
39
+ yield
40
+ end
41
+ end
42
+
43
+ # Emits a validation error event
44
+ sig { params(signature_class: T.class_of(DSPy::Signature), validation_type: String, error_message: String).void }
45
+ def emit_validation_error(signature_class, validation_type, error_message)
46
+ Instrumentation.emit('dspy.prediction.validation_error', {
47
+ signature_class: signature_class.name,
48
+ validation_type: validation_type,
49
+ validation_errors: { validation_type.to_sym => error_message }
50
+ })
51
+ end
52
+
53
+ # Emits a prediction completion event
54
+ sig { params(signature_class: T.class_of(DSPy::Signature), success: T::Boolean, additional_data: T::Hash[Symbol, T.untyped]).void }
55
+ def emit_prediction_complete(signature_class, success, additional_data = {})
56
+ Instrumentation.emit('dspy.prediction.complete', {
57
+ signature_class: signature_class.name,
58
+ success: success
59
+ }.merge(additional_data))
60
+ end
61
+
62
+ # Determines if an event should be emitted using smart consolidation
63
+ sig { params(event_name: String).returns(T::Boolean) }
64
+ def should_emit_event?(event_name)
65
+ # Smart consolidation: skip nested events when higher-level events are being emitted
66
+ if is_nested_context?
67
+ # If we're in a nested context, only emit higher-level events
68
+ event_name.match?(/^dspy\.(chain_of_thought|react)$/)
69
+ else
70
+ # If we're not in a nested context, emit all events normally
71
+ true
72
+ end
73
+ end
74
+
75
+ # Determines if this is a top-level event (not nested)
76
+ sig { params(event_name: String).returns(T::Boolean) }
77
+ def is_top_level_event?(event_name)
78
+ # Check if we're in a nested call by looking at the call stack
79
+ caller_locations = caller_locations(1, 20)
80
+ return false if caller_locations.nil?
81
+
82
+ # Look for other instrumentation calls in the stack
83
+ instrumentation_calls = caller_locations.select do |loc|
84
+ loc.label.include?('instrument_prediction') ||
85
+ loc.label.include?('instrument') ||
86
+ loc.path.include?('instrumentation')
87
+ end
88
+
89
+ # If we have more than one instrumentation call, this is nested
90
+ instrumentation_calls.size <= 1
91
+ end
92
+
93
+ # Determines if we're in a nested call context
94
+ sig { returns(T::Boolean) }
95
+ def is_nested_call?
96
+ !is_top_level_event?('')
97
+ end
98
+
99
+ # Determines if we're in a nested context where higher-level events are being emitted
100
+ sig { returns(T::Boolean) }
101
+ def is_nested_context?
102
+ caller_locations = caller_locations(1, 30)
103
+ return false if caller_locations.nil?
104
+
105
+ # Look for higher-level DSPy modules in the call stack
106
+ # We consider ChainOfThought and ReAct as higher-level modules
107
+ higher_level_modules = caller_locations.select do |loc|
108
+ loc.path.include?('chain_of_thought') ||
109
+ loc.path.include?('re_act') ||
110
+ loc.path.include?('react')
111
+ end
112
+
113
+ # If we have higher-level modules in the call stack, we're in a nested context
114
+ higher_level_modules.any?
115
+ end
116
+
117
+ end
118
+ end
119
+ end
@@ -0,0 +1,133 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require 'sorbet-runtime'
5
+
6
+ module DSPy
7
+ module Mixins
8
+ # Shared module for building enhanced structs with input/output properties
9
+ module StructBuilder
10
+ extend T::Sig
11
+
12
+ private
13
+
14
+ # Builds a new struct class with properties from multiple sources
15
+ sig { params(property_sources: T::Hash[Symbol, T::Hash[Symbol, T.untyped]], additional_fields: T::Hash[Symbol, T.untyped]).returns(T.class_of(T::Struct)) }
16
+ def build_enhanced_struct(property_sources, additional_fields = {})
17
+ # Capture self to access methods from within the class block
18
+ builder = self
19
+
20
+ Class.new(T::Struct) do
21
+ extend T::Sig
22
+
23
+ # Add properties from each source
24
+ property_sources.each do |_source_name, props|
25
+ props.each do |name, prop|
26
+ type = builder.send(:extract_type_from_prop, prop)
27
+ options = builder.send(:extract_options_from_prop, prop)
28
+
29
+ if options[:default]
30
+ const name, type, default: options[:default]
31
+ elsif options[:factory]
32
+ const name, type, factory: options[:factory]
33
+ else
34
+ const name, type
35
+ end
36
+ end
37
+ end
38
+
39
+ # Add additional fields specific to the enhanced struct
40
+ additional_fields.each do |name, field_config|
41
+ type = builder.send(:extract_type_from_prop, field_config)
42
+ options = builder.send(:extract_options_from_prop, field_config)
43
+
44
+ if options[:default]
45
+ const name, type, default: options[:default]
46
+ elsif options[:factory]
47
+ const name, type, factory: options[:factory]
48
+ else
49
+ const name, type
50
+ end
51
+ end
52
+
53
+ include StructSerialization
54
+ end
55
+ end
56
+
57
+ # Builds properties from a props hash (from T::Struct.props)
58
+ sig { params(props: T::Hash[Symbol, T.untyped]).void }
59
+ def build_properties_from_hash(props)
60
+ props.each { |name, prop| build_single_property(name, prop) }
61
+ end
62
+
63
+ # Builds a single property with type and options
64
+ sig { params(name: Symbol, prop: T.untyped).void }
65
+ def build_single_property(name, prop)
66
+ type = extract_type_from_prop(prop)
67
+ options = extract_options_from_prop(prop)
68
+
69
+ if options[:default]
70
+ const name, type, default: options[:default]
71
+ elsif options[:factory]
72
+ const name, type, factory: options[:factory]
73
+ else
74
+ const name, type
75
+ end
76
+ end
77
+
78
+ # Extracts type from property configuration
79
+ sig { params(prop: T.untyped).returns(T.untyped) }
80
+ def extract_type_from_prop(prop)
81
+ case prop
82
+ when Hash
83
+ prop[:type]
84
+ when Array
85
+ # Handle [Type, description] format
86
+ prop.first
87
+ else
88
+ prop
89
+ end
90
+ end
91
+
92
+ # Extracts options from property configuration
93
+ sig { params(prop: T.untyped).returns(T::Hash[Symbol, T.untyped]) }
94
+ def extract_options_from_prop(prop)
95
+ case prop
96
+ when Hash
97
+ prop.except(:type, :type_object, :accessor_key, :sensitivity, :redaction)
98
+ else
99
+ {}
100
+ end
101
+ end
102
+ end
103
+
104
+ # Module for adding serialization capabilities to enhanced structs
105
+ module StructSerialization
106
+ extend T::Sig
107
+
108
+ sig { returns(T::Hash[Symbol, T.untyped]) }
109
+ def to_h
110
+ hash = input_values_hash
111
+ hash.merge(output_properties_hash)
112
+ end
113
+
114
+ private
115
+
116
+ sig { returns(T::Hash[Symbol, T.untyped]) }
117
+ def input_values_hash
118
+ if instance_variable_defined?(:@input_values)
119
+ instance_variable_get(:@input_values) || {}
120
+ else
121
+ {}
122
+ end
123
+ end
124
+
125
+ sig { returns(T::Hash[Symbol, T.untyped]) }
126
+ def output_properties_hash
127
+ self.class.props.keys.each_with_object({}) do |key, hash|
128
+ hash[key] = send(key)
129
+ end
130
+ end
131
+ end
132
+ end
133
+ end
@@ -0,0 +1,67 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require 'sorbet-runtime'
5
+
6
+ module DSPy
7
+ module Mixins
8
+ # Shared module for type coercion logic across DSPy modules
9
+ module TypeCoercion
10
+ extend T::Sig
11
+
12
+ private
13
+
14
+ # Coerces output attributes to match their expected types
15
+ sig { params(output_attributes: T::Hash[Symbol, T.untyped], output_props: T::Hash[Symbol, T.untyped]).returns(T::Hash[Symbol, T.untyped]) }
16
+ def coerce_output_attributes(output_attributes, output_props)
17
+ output_attributes.map do |key, value|
18
+ prop_type = output_props[key]&.dig(:type)
19
+ coerced_value = coerce_value_to_type(value, prop_type)
20
+ [key, coerced_value]
21
+ end.to_h
22
+ end
23
+
24
+ # Coerces a single value to match its expected type
25
+ sig { params(value: T.untyped, prop_type: T.untyped).returns(T.untyped) }
26
+ def coerce_value_to_type(value, prop_type)
27
+ return value unless prop_type
28
+
29
+ case prop_type
30
+ when ->(type) { enum_type?(type) }
31
+ extract_enum_class(prop_type).deserialize(value)
32
+ when Float, ->(type) { simple_type_match?(type, Float) }
33
+ value.to_f
34
+ when Integer, ->(type) { simple_type_match?(type, Integer) }
35
+ value.to_i
36
+ else
37
+ value
38
+ end
39
+ end
40
+
41
+ # Checks if a type is an enum type
42
+ sig { params(type: T.untyped).returns(T::Boolean) }
43
+ def enum_type?(type)
44
+ (type.is_a?(Class) && type < T::Enum) ||
45
+ (type.is_a?(T::Types::Simple) && type.raw_type < T::Enum)
46
+ end
47
+
48
+ # Extracts the enum class from a type
49
+ sig { params(prop_type: T.untyped).returns(T.class_of(T::Enum)) }
50
+ def extract_enum_class(prop_type)
51
+ if prop_type.is_a?(Class) && prop_type < T::Enum
52
+ prop_type
53
+ elsif prop_type.is_a?(T::Types::Simple) && prop_type.raw_type < T::Enum
54
+ prop_type.raw_type
55
+ else
56
+ T.cast(prop_type, T.class_of(T::Enum))
57
+ end
58
+ end
59
+
60
+ # Checks if a type matches a simple type (like Float, Integer)
61
+ sig { params(type: T.untyped, target_type: T.untyped).returns(T::Boolean) }
62
+ def simple_type_match?(type, target_type)
63
+ type.is_a?(T::Types::Simple) && type.raw_type == target_type
64
+ end
65
+ end
66
+ end
67
+ end
data/lib/dspy/predict.rb CHANGED
@@ -4,6 +4,9 @@ require 'sorbet-runtime'
4
4
  require_relative 'module'
5
5
  require_relative 'instrumentation'
6
6
  require_relative 'prompt'
7
+ require_relative 'mixins/struct_builder'
8
+ require_relative 'mixins/type_coercion'
9
+ require_relative 'mixins/instrumentation_helpers'
7
10
 
8
11
  module DSPy
9
12
  # Exception raised when prediction fails validation
@@ -22,6 +25,9 @@ module DSPy
22
25
 
23
26
  class Predict < DSPy::Module
24
27
  extend T::Sig
28
+ include Mixins::StructBuilder
29
+ include Mixins::TypeCoercion
30
+ include Mixins::InstrumentationHelpers
25
31
 
26
32
  sig { returns(T.class_of(Signature)) }
27
33
  attr_reader :signature_class
@@ -79,112 +85,63 @@ module DSPy
79
85
 
80
86
  sig { params(input_values: T.untyped).returns(T.untyped) }
81
87
  def forward_untyped(**input_values)
82
- # Prepare instrumentation payload
83
- input_fields = input_values.keys.map(&:to_s)
84
-
85
- Instrumentation.instrument('dspy.predict', {
86
- signature_class: @signature_class.name,
87
- model: lm.model,
88
- provider: lm.provider,
89
- input_fields: input_fields
90
- }) do
88
+ instrument_prediction('dspy.predict', @signature_class, input_values) do
91
89
  # Validate input
92
- begin
93
- _input_struct = @signature_class.input_struct_class.new(**input_values)
94
- rescue ArgumentError => e
95
- # Emit validation error event
96
- Instrumentation.emit('dspy.predict.validation_error', {
97
- signature_class: @signature_class.name,
98
- validation_type: 'input',
99
- validation_errors: { input: e.message }
100
- })
101
- raise PredictionInvalidError.new({ input: e.message })
102
- end
103
-
104
- # Call LM
90
+ validate_input_struct(input_values)
91
+
92
+ # Call LM and process response
105
93
  output_attributes = lm.chat(self, input_values)
106
-
107
- output_attributes = output_attributes.transform_keys(&:to_sym)
108
-
109
- output_props = @signature_class.output_struct_class.props
110
- output_attributes = output_attributes.map do |key, value|
111
- prop_type = output_props[key][:type] if output_props[key]
112
- if prop_type
113
- # Check if it's an enum (can be raw Class or T::Types::Simple)
114
- enum_class = if prop_type.is_a?(Class) && prop_type < T::Enum
115
- prop_type
116
- elsif prop_type.is_a?(T::Types::Simple) && prop_type.raw_type < T::Enum
117
- prop_type.raw_type
118
- end
119
-
120
- if enum_class
121
- [key, enum_class.deserialize(value)]
122
- elsif prop_type == Float || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Float)
123
- [key, value.to_f]
124
- elsif prop_type == Integer || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Integer)
125
- [key, value.to_i]
126
- else
127
- [key, value]
128
- end
129
- else
130
- [key, value]
131
- end
132
- end.to_h
133
-
134
- # Create combined struct with both input and output values
135
- begin
136
- combined_struct = create_combined_struct_class
137
- all_attributes = input_values.merge(output_attributes)
138
- combined_struct.new(**all_attributes)
139
- rescue ArgumentError => e
140
- raise PredictionInvalidError.new({ output: e.message })
141
- rescue TypeError => e
142
- raise PredictionInvalidError.new({ output: e.message })
143
- end
94
+ processed_output = process_lm_output(output_attributes)
95
+
96
+ # Create combined result struct
97
+ create_prediction_result(input_values, processed_output)
144
98
  end
145
99
  end
146
100
 
147
101
  private
148
102
 
103
+ # Validates input using signature struct
104
+ sig { params(input_values: T::Hash[Symbol, T.untyped]).void }
105
+ def validate_input_struct(input_values)
106
+ @signature_class.input_struct_class.new(**input_values)
107
+ rescue ArgumentError => e
108
+ emit_validation_error(@signature_class, 'input', e.message)
109
+ raise PredictionInvalidError.new({ input: e.message })
110
+ end
111
+
112
+ # Processes LM output with type coercion
113
+ sig { params(output_attributes: T::Hash[T.untyped, T.untyped]).returns(T::Hash[Symbol, T.untyped]) }
114
+ def process_lm_output(output_attributes)
115
+ output_attributes = output_attributes.transform_keys(&:to_sym)
116
+ output_props = @signature_class.output_struct_class.props
117
+
118
+ coerce_output_attributes(output_attributes, output_props)
119
+ end
120
+
121
+ # Creates the final prediction result struct
122
+ sig { params(input_values: T::Hash[Symbol, T.untyped], output_attributes: T::Hash[Symbol, T.untyped]).returns(T.untyped) }
123
+ def create_prediction_result(input_values, output_attributes)
124
+ begin
125
+ combined_struct = create_combined_struct_class
126
+ all_attributes = input_values.merge(output_attributes)
127
+ combined_struct.new(**all_attributes)
128
+ rescue ArgumentError => e
129
+ raise PredictionInvalidError.new({ output: e.message })
130
+ rescue TypeError => e
131
+ raise PredictionInvalidError.new({ output: e.message })
132
+ end
133
+ end
134
+
135
+ # Creates a combined struct class with input and output properties
149
136
  sig { returns(T.class_of(T::Struct)) }
150
137
  def create_combined_struct_class
151
138
  input_props = @signature_class.input_struct_class.props
152
139
  output_props = @signature_class.output_struct_class.props
153
140
 
154
- # Create a new struct class that combines input and output fields
155
- Class.new(T::Struct) do
156
- extend T::Sig
157
-
158
- # Add input fields
159
- input_props.each do |name, prop_info|
160
- if prop_info[:rules]&.any? { |rule| rule.is_a?(T::Props::NilableRules) }
161
- prop name, prop_info[:type], default: prop_info[:default]
162
- else
163
- const name, prop_info[:type], default: prop_info[:default]
164
- end
165
- end
166
-
167
- # Add output fields
168
- output_props.each do |name, prop_info|
169
- if prop_info[:rules]&.any? { |rule| rule.is_a?(T::Props::NilableRules) }
170
- prop name, prop_info[:type], default: prop_info[:default]
171
- else
172
- const name, prop_info[:type], default: prop_info[:default]
173
- end
174
- end
175
-
176
- # Add to_h method to serialize the struct to a hash
177
- define_method :to_h do
178
- hash = {}
179
-
180
- # Add all properties
181
- self.class.props.keys.each do |key|
182
- hash[key] = self.send(key)
183
- end
184
-
185
- hash
186
- end
187
- end
141
+ build_enhanced_struct({
142
+ input: input_props,
143
+ output: output_props
144
+ })
188
145
  end
189
146
  end
190
147
  end