dspy 0.4.0 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,133 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require 'sorbet-runtime'
5
+ require_relative '../instrumentation'
6
+
7
+ module DSPy
8
+ module Mixins
9
+ # Shared instrumentation helper methods for DSPy modules
10
+ module InstrumentationHelpers
11
+ extend T::Sig
12
+
13
+ private
14
+
15
+ # Prepares base instrumentation payload for prediction-based modules
16
+ sig { params(signature_class: T.class_of(DSPy::Signature), input_values: T::Hash[Symbol, T.untyped]).returns(T::Hash[Symbol, T.untyped]) }
17
+ def prepare_base_instrumentation_payload(signature_class, input_values)
18
+ {
19
+ signature_class: signature_class.name,
20
+ model: lm.model,
21
+ provider: lm.provider,
22
+ input_fields: input_values.keys.map(&:to_s)
23
+ }
24
+ end
25
+
26
+ # Instruments a prediction operation with base payload
27
+ sig { params(event_name: String, signature_class: T.class_of(DSPy::Signature), input_values: T::Hash[Symbol, T.untyped], additional_payload: T::Hash[Symbol, T.untyped]).returns(T.untyped) }
28
+ def instrument_prediction(event_name, signature_class, input_values, additional_payload = {})
29
+ base_payload = prepare_base_instrumentation_payload(signature_class, input_values)
30
+ full_payload = base_payload.merge(additional_payload)
31
+
32
+ # Check if we should emit this event based on trace level
33
+ trace_level = DSPy.config.instrumentation.trace_level
34
+
35
+ if should_emit_event?(event_name, trace_level)
36
+ Instrumentation.instrument(event_name, full_payload) do
37
+ yield
38
+ end
39
+ else
40
+ # Skip instrumentation, just execute the block
41
+ yield
42
+ end
43
+ end
44
+
45
+ # Emits a validation error event
46
+ sig { params(signature_class: T.class_of(DSPy::Signature), validation_type: String, error_message: String).void }
47
+ def emit_validation_error(signature_class, validation_type, error_message)
48
+ Instrumentation.emit('dspy.prediction.validation_error', {
49
+ signature_class: signature_class.name,
50
+ validation_type: validation_type,
51
+ validation_errors: { validation_type.to_sym => error_message }
52
+ })
53
+ end
54
+
55
+ # Emits a prediction completion event
56
+ sig { params(signature_class: T.class_of(DSPy::Signature), success: T::Boolean, additional_data: T::Hash[Symbol, T.untyped]).void }
57
+ def emit_prediction_complete(signature_class, success, additional_data = {})
58
+ Instrumentation.emit('dspy.prediction.complete', {
59
+ signature_class: signature_class.name,
60
+ success: success
61
+ }.merge(additional_data))
62
+ end
63
+
64
+ # Determines if an event should be emitted based on trace level
65
+ sig { params(event_name: String, trace_level: Symbol).returns(T::Boolean) }
66
+ def should_emit_event?(event_name, trace_level)
67
+ case trace_level
68
+ when :minimal
69
+ # Only emit the highest-level events (chain_of_thought, react, etc.)
70
+ event_name.match?(/^dspy\.(chain_of_thought|react)$/)
71
+ when :standard
72
+ # Emit consolidated events - skip nested events when a higher-level event is being emitted
73
+ # This is the key change: detect if we're in a nested context and skip lower-level events
74
+ if is_nested_context?
75
+ # If we're in a nested context, only emit higher-level events
76
+ event_name.match?(/^dspy\.(chain_of_thought|react)$/)
77
+ else
78
+ # If we're not in a nested context, emit all events normally
79
+ true
80
+ end
81
+ when :detailed
82
+ # Emit all events with additional correlation information
83
+ true
84
+ else
85
+ true
86
+ end
87
+ end
88
+
89
+ # Determines if this is a top-level event (not nested)
90
+ sig { params(event_name: String).returns(T::Boolean) }
91
+ def is_top_level_event?(event_name)
92
+ # Check if we're in a nested call by looking at the call stack
93
+ caller_locations = caller_locations(1, 20)
94
+ return false if caller_locations.nil?
95
+
96
+ # Look for other instrumentation calls in the stack
97
+ instrumentation_calls = caller_locations.select do |loc|
98
+ loc.label.include?('instrument_prediction') ||
99
+ loc.label.include?('instrument') ||
100
+ loc.path.include?('instrumentation')
101
+ end
102
+
103
+ # If we have more than one instrumentation call, this is nested
104
+ instrumentation_calls.size <= 1
105
+ end
106
+
107
+ # Determines if we're in a nested call context
108
+ sig { returns(T::Boolean) }
109
+ def is_nested_call?
110
+ !is_top_level_event?('')
111
+ end
112
+
113
+ # Determines if we're in a nested context where higher-level events are being emitted
114
+ sig { returns(T::Boolean) }
115
+ def is_nested_context?
116
+ caller_locations = caller_locations(1, 30)
117
+ return false if caller_locations.nil?
118
+
119
+ # Look for higher-level DSPy modules in the call stack
120
+ # We consider ChainOfThought and ReAct as higher-level modules
121
+ higher_level_modules = caller_locations.select do |loc|
122
+ loc.path.include?('chain_of_thought') ||
123
+ loc.path.include?('re_act') ||
124
+ loc.path.include?('react')
125
+ end
126
+
127
+ # If we have higher-level modules in the call stack, we're in a nested context
128
+ higher_level_modules.any?
129
+ end
130
+
131
+ end
132
+ end
133
+ end
@@ -0,0 +1,133 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require 'sorbet-runtime'
5
+
6
+ module DSPy
7
+ module Mixins
8
+ # Shared module for building enhanced structs with input/output properties
9
+ module StructBuilder
10
+ extend T::Sig
11
+
12
+ private
13
+
14
+ # Builds a new struct class with properties from multiple sources
15
+ sig { params(property_sources: T::Hash[Symbol, T::Hash[Symbol, T.untyped]], additional_fields: T::Hash[Symbol, T.untyped]).returns(T.class_of(T::Struct)) }
16
+ def build_enhanced_struct(property_sources, additional_fields = {})
17
+ # Capture self to access methods from within the class block
18
+ builder = self
19
+
20
+ Class.new(T::Struct) do
21
+ extend T::Sig
22
+
23
+ # Add properties from each source
24
+ property_sources.each do |_source_name, props|
25
+ props.each do |name, prop|
26
+ type = builder.send(:extract_type_from_prop, prop)
27
+ options = builder.send(:extract_options_from_prop, prop)
28
+
29
+ if options[:default]
30
+ const name, type, default: options[:default]
31
+ elsif options[:factory]
32
+ const name, type, factory: options[:factory]
33
+ else
34
+ const name, type
35
+ end
36
+ end
37
+ end
38
+
39
+ # Add additional fields specific to the enhanced struct
40
+ additional_fields.each do |name, field_config|
41
+ type = builder.send(:extract_type_from_prop, field_config)
42
+ options = builder.send(:extract_options_from_prop, field_config)
43
+
44
+ if options[:default]
45
+ const name, type, default: options[:default]
46
+ elsif options[:factory]
47
+ const name, type, factory: options[:factory]
48
+ else
49
+ const name, type
50
+ end
51
+ end
52
+
53
+ include StructSerialization
54
+ end
55
+ end
56
+
57
+ # Builds properties from a props hash (from T::Struct.props)
58
+ sig { params(props: T::Hash[Symbol, T.untyped]).void }
59
+ def build_properties_from_hash(props)
60
+ props.each { |name, prop| build_single_property(name, prop) }
61
+ end
62
+
63
+ # Builds a single property with type and options
64
+ sig { params(name: Symbol, prop: T.untyped).void }
65
+ def build_single_property(name, prop)
66
+ type = extract_type_from_prop(prop)
67
+ options = extract_options_from_prop(prop)
68
+
69
+ if options[:default]
70
+ const name, type, default: options[:default]
71
+ elsif options[:factory]
72
+ const name, type, factory: options[:factory]
73
+ else
74
+ const name, type
75
+ end
76
+ end
77
+
78
+ # Extracts type from property configuration
79
+ sig { params(prop: T.untyped).returns(T.untyped) }
80
+ def extract_type_from_prop(prop)
81
+ case prop
82
+ when Hash
83
+ prop[:type]
84
+ when Array
85
+ # Handle [Type, description] format
86
+ prop.first
87
+ else
88
+ prop
89
+ end
90
+ end
91
+
92
+ # Extracts options from property configuration
93
+ sig { params(prop: T.untyped).returns(T::Hash[Symbol, T.untyped]) }
94
+ def extract_options_from_prop(prop)
95
+ case prop
96
+ when Hash
97
+ prop.except(:type, :type_object, :accessor_key, :sensitivity, :redaction)
98
+ else
99
+ {}
100
+ end
101
+ end
102
+ end
103
+
104
+ # Module for adding serialization capabilities to enhanced structs
105
+ module StructSerialization
106
+ extend T::Sig
107
+
108
+ sig { returns(T::Hash[Symbol, T.untyped]) }
109
+ def to_h
110
+ hash = input_values_hash
111
+ hash.merge(output_properties_hash)
112
+ end
113
+
114
+ private
115
+
116
+ sig { returns(T::Hash[Symbol, T.untyped]) }
117
+ def input_values_hash
118
+ if instance_variable_defined?(:@input_values)
119
+ instance_variable_get(:@input_values) || {}
120
+ else
121
+ {}
122
+ end
123
+ end
124
+
125
+ sig { returns(T::Hash[Symbol, T.untyped]) }
126
+ def output_properties_hash
127
+ self.class.props.keys.each_with_object({}) do |key, hash|
128
+ hash[key] = send(key)
129
+ end
130
+ end
131
+ end
132
+ end
133
+ end
@@ -0,0 +1,67 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require 'sorbet-runtime'
5
+
6
+ module DSPy
7
+ module Mixins
8
+ # Shared module for type coercion logic across DSPy modules
9
+ module TypeCoercion
10
+ extend T::Sig
11
+
12
+ private
13
+
14
+ # Coerces output attributes to match their expected types
15
+ sig { params(output_attributes: T::Hash[Symbol, T.untyped], output_props: T::Hash[Symbol, T.untyped]).returns(T::Hash[Symbol, T.untyped]) }
16
+ def coerce_output_attributes(output_attributes, output_props)
17
+ output_attributes.map do |key, value|
18
+ prop_type = output_props[key]&.dig(:type)
19
+ coerced_value = coerce_value_to_type(value, prop_type)
20
+ [key, coerced_value]
21
+ end.to_h
22
+ end
23
+
24
+ # Coerces a single value to match its expected type
25
+ sig { params(value: T.untyped, prop_type: T.untyped).returns(T.untyped) }
26
+ def coerce_value_to_type(value, prop_type)
27
+ return value unless prop_type
28
+
29
+ case prop_type
30
+ when ->(type) { enum_type?(type) }
31
+ extract_enum_class(prop_type).deserialize(value)
32
+ when Float, ->(type) { simple_type_match?(type, Float) }
33
+ value.to_f
34
+ when Integer, ->(type) { simple_type_match?(type, Integer) }
35
+ value.to_i
36
+ else
37
+ value
38
+ end
39
+ end
40
+
41
+ # Checks if a type is an enum type
42
+ sig { params(type: T.untyped).returns(T::Boolean) }
43
+ def enum_type?(type)
44
+ (type.is_a?(Class) && type < T::Enum) ||
45
+ (type.is_a?(T::Types::Simple) && type.raw_type < T::Enum)
46
+ end
47
+
48
+ # Extracts the enum class from a type
49
+ sig { params(prop_type: T.untyped).returns(T.class_of(T::Enum)) }
50
+ def extract_enum_class(prop_type)
51
+ if prop_type.is_a?(Class) && prop_type < T::Enum
52
+ prop_type
53
+ elsif prop_type.is_a?(T::Types::Simple) && prop_type.raw_type < T::Enum
54
+ prop_type.raw_type
55
+ else
56
+ T.cast(prop_type, T.class_of(T::Enum))
57
+ end
58
+ end
59
+
60
+ # Checks if a type matches a simple type (like Float, Integer)
61
+ sig { params(type: T.untyped, target_type: T.untyped).returns(T::Boolean) }
62
+ def simple_type_match?(type, target_type)
63
+ type.is_a?(T::Types::Simple) && type.raw_type == target_type
64
+ end
65
+ end
66
+ end
67
+ end
data/lib/dspy/predict.rb CHANGED
@@ -4,6 +4,9 @@ require 'sorbet-runtime'
4
4
  require_relative 'module'
5
5
  require_relative 'instrumentation'
6
6
  require_relative 'prompt'
7
+ require_relative 'mixins/struct_builder'
8
+ require_relative 'mixins/type_coercion'
9
+ require_relative 'mixins/instrumentation_helpers'
7
10
 
8
11
  module DSPy
9
12
  # Exception raised when prediction fails validation
@@ -22,6 +25,9 @@ module DSPy
22
25
 
23
26
  class Predict < DSPy::Module
24
27
  extend T::Sig
28
+ include Mixins::StructBuilder
29
+ include Mixins::TypeCoercion
30
+ include Mixins::InstrumentationHelpers
25
31
 
26
32
  sig { returns(T.class_of(Signature)) }
27
33
  attr_reader :signature_class
@@ -79,112 +85,63 @@ module DSPy
79
85
 
80
86
  sig { params(input_values: T.untyped).returns(T.untyped) }
81
87
  def forward_untyped(**input_values)
82
- # Prepare instrumentation payload
83
- input_fields = input_values.keys.map(&:to_s)
84
-
85
- Instrumentation.instrument('dspy.predict', {
86
- signature_class: @signature_class.name,
87
- model: lm.model,
88
- provider: lm.provider,
89
- input_fields: input_fields
90
- }) do
88
+ instrument_prediction('dspy.predict', @signature_class, input_values) do
91
89
  # Validate input
92
- begin
93
- _input_struct = @signature_class.input_struct_class.new(**input_values)
94
- rescue ArgumentError => e
95
- # Emit validation error event
96
- Instrumentation.emit('dspy.predict.validation_error', {
97
- signature_class: @signature_class.name,
98
- validation_type: 'input',
99
- validation_errors: { input: e.message }
100
- })
101
- raise PredictionInvalidError.new({ input: e.message })
102
- end
103
-
104
- # Call LM
90
+ validate_input_struct(input_values)
91
+
92
+ # Call LM and process response
105
93
  output_attributes = lm.chat(self, input_values)
106
-
107
- output_attributes = output_attributes.transform_keys(&:to_sym)
108
-
109
- output_props = @signature_class.output_struct_class.props
110
- output_attributes = output_attributes.map do |key, value|
111
- prop_type = output_props[key][:type] if output_props[key]
112
- if prop_type
113
- # Check if it's an enum (can be raw Class or T::Types::Simple)
114
- enum_class = if prop_type.is_a?(Class) && prop_type < T::Enum
115
- prop_type
116
- elsif prop_type.is_a?(T::Types::Simple) && prop_type.raw_type < T::Enum
117
- prop_type.raw_type
118
- end
119
-
120
- if enum_class
121
- [key, enum_class.deserialize(value)]
122
- elsif prop_type == Float || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Float)
123
- [key, value.to_f]
124
- elsif prop_type == Integer || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Integer)
125
- [key, value.to_i]
126
- else
127
- [key, value]
128
- end
129
- else
130
- [key, value]
131
- end
132
- end.to_h
133
-
134
- # Create combined struct with both input and output values
135
- begin
136
- combined_struct = create_combined_struct_class
137
- all_attributes = input_values.merge(output_attributes)
138
- combined_struct.new(**all_attributes)
139
- rescue ArgumentError => e
140
- raise PredictionInvalidError.new({ output: e.message })
141
- rescue TypeError => e
142
- raise PredictionInvalidError.new({ output: e.message })
143
- end
94
+ processed_output = process_lm_output(output_attributes)
95
+
96
+ # Create combined result struct
97
+ create_prediction_result(input_values, processed_output)
144
98
  end
145
99
  end
146
100
 
147
101
  private
148
102
 
103
+ # Validates input using signature struct
104
+ sig { params(input_values: T::Hash[Symbol, T.untyped]).void }
105
+ def validate_input_struct(input_values)
106
+ @signature_class.input_struct_class.new(**input_values)
107
+ rescue ArgumentError => e
108
+ emit_validation_error(@signature_class, 'input', e.message)
109
+ raise PredictionInvalidError.new({ input: e.message })
110
+ end
111
+
112
+ # Processes LM output with type coercion
113
+ sig { params(output_attributes: T::Hash[T.untyped, T.untyped]).returns(T::Hash[Symbol, T.untyped]) }
114
+ def process_lm_output(output_attributes)
115
+ output_attributes = output_attributes.transform_keys(&:to_sym)
116
+ output_props = @signature_class.output_struct_class.props
117
+
118
+ coerce_output_attributes(output_attributes, output_props)
119
+ end
120
+
121
+ # Creates the final prediction result struct
122
+ sig { params(input_values: T::Hash[Symbol, T.untyped], output_attributes: T::Hash[Symbol, T.untyped]).returns(T.untyped) }
123
+ def create_prediction_result(input_values, output_attributes)
124
+ begin
125
+ combined_struct = create_combined_struct_class
126
+ all_attributes = input_values.merge(output_attributes)
127
+ combined_struct.new(**all_attributes)
128
+ rescue ArgumentError => e
129
+ raise PredictionInvalidError.new({ output: e.message })
130
+ rescue TypeError => e
131
+ raise PredictionInvalidError.new({ output: e.message })
132
+ end
133
+ end
134
+
135
+ # Creates a combined struct class with input and output properties
149
136
  sig { returns(T.class_of(T::Struct)) }
150
137
  def create_combined_struct_class
151
138
  input_props = @signature_class.input_struct_class.props
152
139
  output_props = @signature_class.output_struct_class.props
153
140
 
154
- # Create a new struct class that combines input and output fields
155
- Class.new(T::Struct) do
156
- extend T::Sig
157
-
158
- # Add input fields
159
- input_props.each do |name, prop_info|
160
- if prop_info[:rules]&.any? { |rule| rule.is_a?(T::Props::NilableRules) }
161
- prop name, prop_info[:type], default: prop_info[:default]
162
- else
163
- const name, prop_info[:type], default: prop_info[:default]
164
- end
165
- end
166
-
167
- # Add output fields
168
- output_props.each do |name, prop_info|
169
- if prop_info[:rules]&.any? { |rule| rule.is_a?(T::Props::NilableRules) }
170
- prop name, prop_info[:type], default: prop_info[:default]
171
- else
172
- const name, prop_info[:type], default: prop_info[:default]
173
- end
174
- end
175
-
176
- # Add to_h method to serialize the struct to a hash
177
- define_method :to_h do
178
- hash = {}
179
-
180
- # Add all properties
181
- self.class.props.keys.each do |key|
182
- hash[key] = self.send(key)
183
- end
184
-
185
- hash
186
- end
187
- end
141
+ build_enhanced_struct({
142
+ input: input_props,
143
+ output: output_props
144
+ })
188
145
  end
189
146
  end
190
147
  end