dspy 0.3.1 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/lib/dspy/predict.rb CHANGED
@@ -3,6 +3,10 @@
3
3
  require 'sorbet-runtime'
4
4
  require_relative 'module'
5
5
  require_relative 'instrumentation'
6
+ require_relative 'prompt'
7
+ require_relative 'mixins/struct_builder'
8
+ require_relative 'mixins/type_coercion'
9
+ require_relative 'mixins/instrumentation_helpers'
6
10
 
7
11
  module DSPy
8
12
  # Exception raised when prediction fails validation
@@ -21,56 +25,56 @@ module DSPy
21
25
 
22
26
  class Predict < DSPy::Module
23
27
  extend T::Sig
28
+ include Mixins::StructBuilder
29
+ include Mixins::TypeCoercion
30
+ include Mixins::InstrumentationHelpers
24
31
 
25
32
  sig { returns(T.class_of(Signature)) }
26
33
  attr_reader :signature_class
27
34
 
35
+ sig { returns(Prompt) }
36
+ attr_reader :prompt
37
+
28
38
  sig { params(signature_class: T.class_of(Signature)).void }
29
39
  def initialize(signature_class)
30
40
  super()
31
41
  @signature_class = signature_class
42
+ @prompt = Prompt.from_signature(signature_class)
32
43
  end
33
44
 
45
+ # Backward compatibility methods - delegate to prompt object
34
46
  sig { returns(String) }
35
47
  def system_signature
36
- <<-PROMPT
37
- Your input schema fields are:
38
- ```json
39
- #{JSON.generate(@signature_class.input_json_schema)}
40
- ```
41
- Your output schema fields are:
42
- ```json
43
- #{JSON.generate(@signature_class.output_json_schema)}
44
- ````
45
-
46
- All interactions will be structured in the following way, with the appropriate values filled in.
47
-
48
- ## Input values
49
- ```json
50
- {input_values}
51
- ```
52
- ## Output values
53
- Respond exclusively with the output schema fields in the json block below.
54
- ```json
55
- {output_values}
56
- ```
57
-
58
- In adhering to this structure, your objective is: #{@signature_class.description}
59
-
60
- PROMPT
48
+ @prompt.render_system_prompt
61
49
  end
62
50
 
63
51
  sig { params(input_values: T::Hash[Symbol, T.untyped]).returns(String) }
64
52
  def user_signature(input_values)
65
- <<-PROMPT
66
- ## Input Values
67
- ```json
68
- #{JSON.generate(input_values)}
69
- ```
70
-
71
- Respond with the corresponding output schema fields wrapped in a ```json ``` block,
72
- starting with the heading `## Output values`.
73
- PROMPT
53
+ @prompt.render_user_prompt(input_values)
54
+ end
55
+
56
+ # New prompt-based interface for optimization
57
+ sig { params(new_prompt: Prompt).returns(Predict) }
58
+ def with_prompt(new_prompt)
59
+ # Create a new instance with the same signature but updated prompt
60
+ instance = self.class.new(@signature_class)
61
+ instance.instance_variable_set(:@prompt, new_prompt)
62
+ instance
63
+ end
64
+
65
+ sig { params(instruction: String).returns(Predict) }
66
+ def with_instruction(instruction)
67
+ with_prompt(@prompt.with_instruction(instruction))
68
+ end
69
+
70
+ sig { params(examples: T::Array[FewShotExample]).returns(Predict) }
71
+ def with_examples(examples)
72
+ with_prompt(@prompt.with_examples(examples))
73
+ end
74
+
75
+ sig { params(examples: T::Array[FewShotExample]).returns(Predict) }
76
+ def add_examples(examples)
77
+ with_prompt(@prompt.add_examples(examples))
74
78
  end
75
79
 
76
80
  sig { override.params(kwargs: T.untyped).returns(T.type_parameter(:O)) }
@@ -81,112 +85,63 @@ module DSPy
81
85
 
82
86
  sig { params(input_values: T.untyped).returns(T.untyped) }
83
87
  def forward_untyped(**input_values)
84
- # Prepare instrumentation payload
85
- input_fields = input_values.keys.map(&:to_s)
86
-
87
- Instrumentation.instrument('dspy.predict', {
88
- signature_class: @signature_class.name,
89
- model: lm.model,
90
- provider: lm.provider,
91
- input_fields: input_fields
92
- }) do
88
+ instrument_prediction('dspy.predict', @signature_class, input_values) do
93
89
  # Validate input
94
- begin
95
- _input_struct = @signature_class.input_struct_class.new(**input_values)
96
- rescue ArgumentError => e
97
- # Emit validation error event
98
- Instrumentation.emit('dspy.predict.validation_error', {
99
- signature_class: @signature_class.name,
100
- validation_type: 'input',
101
- validation_errors: { input: e.message }
102
- })
103
- raise PredictionInvalidError.new({ input: e.message })
104
- end
105
-
106
- # Call LM
90
+ validate_input_struct(input_values)
91
+
92
+ # Call LM and process response
107
93
  output_attributes = lm.chat(self, input_values)
108
-
109
- output_attributes = output_attributes.transform_keys(&:to_sym)
110
-
111
- output_props = @signature_class.output_struct_class.props
112
- output_attributes = output_attributes.map do |key, value|
113
- prop_type = output_props[key][:type] if output_props[key]
114
- if prop_type
115
- # Check if it's an enum (can be raw Class or T::Types::Simple)
116
- enum_class = if prop_type.is_a?(Class) && prop_type < T::Enum
117
- prop_type
118
- elsif prop_type.is_a?(T::Types::Simple) && prop_type.raw_type < T::Enum
119
- prop_type.raw_type
120
- end
121
-
122
- if enum_class
123
- [key, enum_class.deserialize(value)]
124
- elsif prop_type == Float || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Float)
125
- [key, value.to_f]
126
- elsif prop_type == Integer || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Integer)
127
- [key, value.to_i]
128
- else
129
- [key, value]
130
- end
131
- else
132
- [key, value]
133
- end
134
- end.to_h
135
-
136
- # Create combined struct with both input and output values
137
- begin
138
- combined_struct = create_combined_struct_class
139
- all_attributes = input_values.merge(output_attributes)
140
- combined_struct.new(**all_attributes)
141
- rescue ArgumentError => e
142
- raise PredictionInvalidError.new({ output: e.message })
143
- rescue TypeError => e
144
- raise PredictionInvalidError.new({ output: e.message })
145
- end
94
+ processed_output = process_lm_output(output_attributes)
95
+
96
+ # Create combined result struct
97
+ create_prediction_result(input_values, processed_output)
146
98
  end
147
99
  end
148
100
 
149
101
  private
150
102
 
103
+ # Validates input using signature struct
104
+ sig { params(input_values: T::Hash[Symbol, T.untyped]).void }
105
+ def validate_input_struct(input_values)
106
+ @signature_class.input_struct_class.new(**input_values)
107
+ rescue ArgumentError => e
108
+ emit_validation_error(@signature_class, 'input', e.message)
109
+ raise PredictionInvalidError.new({ input: e.message })
110
+ end
111
+
112
+ # Processes LM output with type coercion
113
+ sig { params(output_attributes: T::Hash[T.untyped, T.untyped]).returns(T::Hash[Symbol, T.untyped]) }
114
+ def process_lm_output(output_attributes)
115
+ output_attributes = output_attributes.transform_keys(&:to_sym)
116
+ output_props = @signature_class.output_struct_class.props
117
+
118
+ coerce_output_attributes(output_attributes, output_props)
119
+ end
120
+
121
+ # Creates the final prediction result struct
122
+ sig { params(input_values: T::Hash[Symbol, T.untyped], output_attributes: T::Hash[Symbol, T.untyped]).returns(T.untyped) }
123
+ def create_prediction_result(input_values, output_attributes)
124
+ begin
125
+ combined_struct = create_combined_struct_class
126
+ all_attributes = input_values.merge(output_attributes)
127
+ combined_struct.new(**all_attributes)
128
+ rescue ArgumentError => e
129
+ raise PredictionInvalidError.new({ output: e.message })
130
+ rescue TypeError => e
131
+ raise PredictionInvalidError.new({ output: e.message })
132
+ end
133
+ end
134
+
135
+ # Creates a combined struct class with input and output properties
151
136
  sig { returns(T.class_of(T::Struct)) }
152
137
  def create_combined_struct_class
153
138
  input_props = @signature_class.input_struct_class.props
154
139
  output_props = @signature_class.output_struct_class.props
155
140
 
156
- # Create a new struct class that combines input and output fields
157
- Class.new(T::Struct) do
158
- extend T::Sig
159
-
160
- # Add input fields
161
- input_props.each do |name, prop_info|
162
- if prop_info[:rules]&.any? { |rule| rule.is_a?(T::Props::NilableRules) }
163
- prop name, prop_info[:type], default: prop_info[:default]
164
- else
165
- const name, prop_info[:type], default: prop_info[:default]
166
- end
167
- end
168
-
169
- # Add output fields
170
- output_props.each do |name, prop_info|
171
- if prop_info[:rules]&.any? { |rule| rule.is_a?(T::Props::NilableRules) }
172
- prop name, prop_info[:type], default: prop_info[:default]
173
- else
174
- const name, prop_info[:type], default: prop_info[:default]
175
- end
176
- end
177
-
178
- # Add to_h method to serialize the struct to a hash
179
- define_method :to_h do
180
- hash = {}
181
-
182
- # Add all properties
183
- self.class.props.keys.each do |key|
184
- hash[key] = self.send(key)
185
- end
186
-
187
- hash
188
- end
189
- end
141
+ build_enhanced_struct({
142
+ input: input_props,
143
+ output: output_props
144
+ })
190
145
  end
191
146
  end
192
147
  end
@@ -0,0 +1,222 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'sorbet-runtime'
4
+ require_relative 'few_shot_example'
5
+
6
+ module DSPy
7
+ class Prompt
8
+ extend T::Sig
9
+
10
+ sig { returns(String) }
11
+ attr_reader :instruction
12
+
13
+ sig { returns(T::Array[FewShotExample]) }
14
+ attr_reader :few_shot_examples
15
+
16
+ sig { returns(T::Hash[Symbol, T.untyped]) }
17
+ attr_reader :input_schema
18
+
19
+ sig { returns(T::Hash[Symbol, T.untyped]) }
20
+ attr_reader :output_schema
21
+
22
+ sig { returns(T.nilable(String)) }
23
+ attr_reader :signature_class_name
24
+
25
+ sig do
26
+ params(
27
+ instruction: String,
28
+ input_schema: T::Hash[Symbol, T.untyped],
29
+ output_schema: T::Hash[Symbol, T.untyped],
30
+ few_shot_examples: T::Array[FewShotExample],
31
+ signature_class_name: T.nilable(String)
32
+ ).void
33
+ end
34
+ def initialize(instruction:, input_schema:, output_schema:, few_shot_examples: [], signature_class_name: nil)
35
+ @instruction = instruction
36
+ @few_shot_examples = few_shot_examples.freeze
37
+ @input_schema = input_schema.freeze
38
+ @output_schema = output_schema.freeze
39
+ @signature_class_name = signature_class_name
40
+ end
41
+
42
+ # Immutable update methods for optimization
43
+ sig { params(new_instruction: String).returns(Prompt) }
44
+ def with_instruction(new_instruction)
45
+ self.class.new(
46
+ instruction: new_instruction,
47
+ input_schema: @input_schema,
48
+ output_schema: @output_schema,
49
+ few_shot_examples: @few_shot_examples,
50
+ signature_class_name: @signature_class_name
51
+ )
52
+ end
53
+
54
+ sig { params(new_examples: T::Array[FewShotExample]).returns(Prompt) }
55
+ def with_examples(new_examples)
56
+ self.class.new(
57
+ instruction: @instruction,
58
+ input_schema: @input_schema,
59
+ output_schema: @output_schema,
60
+ few_shot_examples: new_examples,
61
+ signature_class_name: @signature_class_name
62
+ )
63
+ end
64
+
65
+ sig { params(new_examples: T::Array[FewShotExample]).returns(Prompt) }
66
+ def add_examples(new_examples)
67
+ combined_examples = @few_shot_examples + new_examples
68
+ with_examples(combined_examples)
69
+ end
70
+
71
+ # Core prompt rendering methods
72
+ sig { returns(String) }
73
+ def render_system_prompt
74
+ sections = []
75
+
76
+ sections << "Your input schema fields are:"
77
+ sections << "```json"
78
+ sections << JSON.pretty_generate(@input_schema)
79
+ sections << "```"
80
+
81
+ sections << "Your output schema fields are:"
82
+ sections << "```json"
83
+ sections << JSON.pretty_generate(@output_schema)
84
+ sections << "```"
85
+
86
+ sections << ""
87
+ sections << "All interactions will be structured in the following way, with the appropriate values filled in."
88
+
89
+ # Add few-shot examples if present
90
+ if @few_shot_examples.any?
91
+ sections << ""
92
+ sections << "Here are some examples:"
93
+ sections << ""
94
+ @few_shot_examples.each_with_index do |example, index|
95
+ sections << "### Example #{index + 1}"
96
+ sections << example.to_prompt_section
97
+ sections << ""
98
+ end
99
+ end
100
+
101
+ sections << "## Input values"
102
+ sections << "```json"
103
+ sections << "{input_values}"
104
+ sections << "```"
105
+
106
+ sections << "## Output values"
107
+ sections << "Respond exclusively with the output schema fields in the json block below."
108
+ sections << "```json"
109
+ sections << "{output_values}"
110
+ sections << "```"
111
+
112
+ sections << ""
113
+ sections << "In adhering to this structure, your objective is: #{@instruction}"
114
+
115
+ sections.join("\n")
116
+ end
117
+
118
+ sig { params(input_values: T::Hash[Symbol, T.untyped]).returns(String) }
119
+ def render_user_prompt(input_values)
120
+ sections = []
121
+
122
+ sections << "## Input Values"
123
+ sections << "```json"
124
+ sections << JSON.pretty_generate(input_values)
125
+ sections << "```"
126
+
127
+ sections << ""
128
+ sections << "Respond with the corresponding output schema fields wrapped in a ```json ``` block,"
129
+ sections << "starting with the heading `## Output values`."
130
+
131
+ sections.join("\n")
132
+ end
133
+
134
+ # Generate messages for LM adapter
135
+ sig { params(input_values: T::Hash[Symbol, T.untyped]).returns(T::Array[T::Hash[Symbol, String]]) }
136
+ def to_messages(input_values)
137
+ [
138
+ { role: 'system', content: render_system_prompt },
139
+ { role: 'user', content: render_user_prompt(input_values) }
140
+ ]
141
+ end
142
+
143
+ # Serialization for persistence and optimization
144
+ sig { returns(T::Hash[Symbol, T.untyped]) }
145
+ def to_h
146
+ {
147
+ instruction: @instruction,
148
+ few_shot_examples: @few_shot_examples.map(&:to_h),
149
+ input_schema: @input_schema,
150
+ output_schema: @output_schema,
151
+ signature_class_name: @signature_class_name
152
+ }
153
+ end
154
+
155
+ sig { params(hash: T::Hash[Symbol, T.untyped]).returns(Prompt) }
156
+ def self.from_h(hash)
157
+ examples = (hash[:few_shot_examples] || []).map { |ex| FewShotExample.from_h(ex) }
158
+
159
+ new(
160
+ instruction: hash[:instruction] || "",
161
+ input_schema: hash[:input_schema] || {},
162
+ output_schema: hash[:output_schema] || {},
163
+ few_shot_examples: examples,
164
+ signature_class_name: hash[:signature_class_name]
165
+ )
166
+ end
167
+
168
+ # Create prompt from signature class
169
+ sig { params(signature_class: T.class_of(Signature)).returns(Prompt) }
170
+ def self.from_signature(signature_class)
171
+ new(
172
+ instruction: signature_class.description || "Complete this task.",
173
+ input_schema: signature_class.input_json_schema,
174
+ output_schema: signature_class.output_json_schema,
175
+ few_shot_examples: [],
176
+ signature_class_name: signature_class.name
177
+ )
178
+ end
179
+
180
+ # Comparison and diff methods for optimization
181
+ sig { params(other: T.untyped).returns(T::Boolean) }
182
+ def ==(other)
183
+ return false unless other.is_a?(Prompt)
184
+
185
+ @instruction == other.instruction &&
186
+ @few_shot_examples == other.few_shot_examples &&
187
+ @input_schema == other.input_schema &&
188
+ @output_schema == other.output_schema
189
+ end
190
+
191
+ sig { params(other: Prompt).returns(T::Hash[Symbol, T.untyped]) }
192
+ def diff(other)
193
+ changes = {}
194
+
195
+ changes[:instruction] = {
196
+ from: @instruction,
197
+ to: other.instruction
198
+ } if @instruction != other.instruction
199
+
200
+ changes[:few_shot_examples] = {
201
+ from: @few_shot_examples.length,
202
+ to: other.few_shot_examples.length,
203
+ added: other.few_shot_examples - @few_shot_examples,
204
+ removed: @few_shot_examples - other.few_shot_examples
205
+ } if @few_shot_examples != other.few_shot_examples
206
+
207
+ changes
208
+ end
209
+
210
+ # Statistics for optimization tracking
211
+ sig { returns(T::Hash[Symbol, T.untyped]) }
212
+ def stats
213
+ {
214
+ character_count: @instruction.length,
215
+ example_count: @few_shot_examples.length,
216
+ total_example_chars: @few_shot_examples.sum { |ex| ex.to_prompt_section.length },
217
+ input_fields: @input_schema.dig(:properties)&.keys&.length || 0,
218
+ output_fields: @output_schema.dig(:properties)&.keys&.length || 0
219
+ }
220
+ end
221
+ end
222
+ end