dspy 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/lib/dspy/lm.rb CHANGED
@@ -1,43 +1,108 @@
1
1
  # frozen_string_literal: true
2
- require 'ruby_llm'
2
+
3
+ # Load adapter infrastructure
4
+ require_relative 'lm/errors'
5
+ require_relative 'lm/response'
6
+ require_relative 'lm/adapter'
7
+ require_relative 'lm/adapter_factory'
8
+
9
+ # Load instrumentation
10
+ require_relative 'instrumentation'
11
+ require_relative 'instrumentation/token_tracker'
12
+
13
+ # Load adapters
14
+ require_relative 'lm/adapters/openai_adapter'
15
+ require_relative 'lm/adapters/anthropic_adapter'
16
+ require_relative 'lm/adapters/ruby_llm_adapter'
3
17
 
4
18
  module DSPy
5
19
  class LM
6
- attr_reader :model_id, :api_key, :model, :provider
20
+ attr_reader :model_id, :api_key, :model, :provider, :adapter
7
21
 
8
22
  def initialize(model_id, api_key: nil)
9
23
  @model_id = model_id
10
24
  @api_key = api_key
11
- # Configure RubyLLM with the API key if provided
12
- if model_id.start_with?('openai/')
13
- RubyLLM.configure do |config|
14
- config.openai_api_key = api_key
15
- end
16
- @provider = :openai
17
- @model = model_id.split('/').last
18
- elsif model_id.start_with?('anthropic/')
19
- RubyLLM.configure do |config|
20
- config.anthropic_api_key = api_key
21
- end
22
- @provider = :anthropic
23
- @model = model_id.split('/').last
24
- else
25
- raise ArgumentError, "Unsupported model provider: #{model_id}"
26
- end
25
+
26
+ # Parse provider and model from model_id
27
+ @provider, @model = parse_model_id(model_id)
28
+
29
+ # Create appropriate adapter
30
+ @adapter = AdapterFactory.create(model_id, api_key: api_key)
27
31
  end
28
32
 
29
33
  def chat(inference_module, input_values, &block)
30
34
  signature_class = inference_module.signature_class
31
- chat = RubyLLM.chat(model: model)
35
+
36
+ # Build messages from inference module
37
+ messages = build_messages(inference_module, input_values)
38
+
39
+ # Calculate input size for monitoring
40
+ input_text = messages.map { |m| m[:content] }.join(' ')
41
+ input_size = input_text.length
42
+
43
+ # Instrument LM request
44
+ response = Instrumentation.instrument('dspy.lm.request', {
45
+ gen_ai_operation_name: 'chat',
46
+ gen_ai_system: provider,
47
+ gen_ai_request_model: model,
48
+ signature_class: signature_class.name,
49
+ provider: provider,
50
+ adapter_class: adapter.class.name,
51
+ input_size: input_size
52
+ }) do
53
+ adapter.chat(messages: messages, &block)
54
+ end
55
+
56
+ # Extract actual token usage from response (more accurate than estimation)
57
+ token_usage = Instrumentation::TokenTracker.extract_token_usage(response, provider)
58
+
59
+ # Emit token usage event if available
60
+ if token_usage.any?
61
+ Instrumentation.emit('dspy.lm.tokens', token_usage.merge({
62
+ gen_ai_system: provider,
63
+ gen_ai_request_model: model,
64
+ signature_class: signature_class.name
65
+ }))
66
+ end
67
+
68
+ # Instrument response parsing
69
+ parsed_result = Instrumentation.instrument('dspy.lm.response.parsed', {
70
+ signature_class: signature_class.name,
71
+ provider: provider,
72
+ response_length: response.content&.length || 0
73
+ }) do
74
+ parse_response(response, input_values, signature_class)
75
+ end
76
+
77
+ parsed_result
78
+ end
79
+
80
+ private
81
+
82
+ def parse_model_id(model_id)
83
+ if model_id.include?('/')
84
+ provider, model = model_id.split('/', 2)
85
+ [provider, model]
86
+ else
87
+ # Legacy format: assume ruby_llm for backward compatibility
88
+ ['ruby_llm', model_id]
89
+ end
90
+ end
91
+
92
+ def build_messages(inference_module, input_values)
93
+ messages = []
94
+
95
+ # Add system message
32
96
  system_prompt = inference_module.system_signature
97
+ messages << { role: 'system', content: system_prompt } if system_prompt
98
+
99
+ # Add user message
33
100
  user_prompt = inference_module.user_signature(input_values)
34
- chat.add_message role: :system, content: system_prompt
35
- chat.ask(user_prompt, &block)
36
-
37
- parse_response(chat.messages.last, input_values, signature_class)
101
+ messages << { role: 'user', content: user_prompt }
102
+
103
+ messages
38
104
  end
39
105
 
40
- private
41
106
  def parse_response(response, input_values, signature_class)
42
107
  # Try to parse the response as JSON
43
108
  content = response.content
@@ -52,22 +117,9 @@ module DSPy
52
117
  begin
53
118
  json_payload = JSON.parse(content)
54
119
 
55
- # Handle different signature types
56
- if signature_class < DSPy::SorbetSignature
57
- # For Sorbet signatures, just return the parsed JSON
58
- # The SorbetPredict will handle validation
59
- json_payload
60
- else
61
- # Original dry-schema based handling
62
- output = signature_class.output_schema.call(json_payload)
63
-
64
- result_schema = Dry::Schema.JSON(parent: [signature_class.input_schema, signature_class.output_schema])
65
- result = output.to_h.merge(input_values)
66
- # create an instance with input and output schema
67
- poro_result = result_schema.call(result)
68
-
69
- poro_result.to_h
70
- end
120
+ # For Sorbet signatures, just return the parsed JSON
121
+ # The Predict will handle validation
122
+ json_payload
71
123
  rescue JSON::ParserError
72
124
  raise "Failed to parse LLM response as JSON: #{content}"
73
125
  end
data/lib/dspy/module.rb CHANGED
@@ -1,13 +1,58 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'sorbet-runtime'
4
+ require 'dry-configurable'
5
+
3
6
  module DSPy
4
7
  class Module
5
- def forward(...)
6
- raise NotImplementedError, "Subclasses must implement forward method"
8
+ extend T::Sig
9
+ extend T::Generic
10
+ include Dry::Configurable
11
+
12
+ # Per-instance LM configuration
13
+ setting :lm, default: nil
14
+
15
+ # The main forward method that users will call is generic and type parameterized
16
+ sig do
17
+ type_parameters(:I, :O)
18
+ .params(
19
+ input_values: T.type_parameter(:I)
20
+ )
21
+ .returns(T.type_parameter(:O))
22
+ end
23
+ def forward(**input_values)
24
+ # Cast the result of forward_untyped to the expected output type
25
+ T.cast(forward_untyped(**input_values), T.type_parameter(:O))
26
+ end
27
+
28
+ # The implementation method that subclasses must override
29
+ sig { params(input_values: T.untyped).returns(T.untyped) }
30
+ def forward_untyped(**input_values)
31
+ raise NotImplementedError, "Subclasses must implement forward_untyped method"
7
32
  end
8
-
9
- def call(...)
10
- forward(...)
33
+
34
+ # The main call method that users will call is generic and type parameterized
35
+ sig do
36
+ type_parameters(:I, :O)
37
+ .params(
38
+ input_values: T.type_parameter(:I)
39
+ )
40
+ .returns(T.type_parameter(:O))
41
+ end
42
+ def call(**input_values)
43
+ forward(**input_values)
44
+ end
45
+
46
+ # The implementation method for call
47
+ sig { params(input_values: T.untyped).returns(T.untyped) }
48
+ def call_untyped(**input_values)
49
+ forward_untyped(**input_values)
50
+ end
51
+
52
+ # Get the configured LM for this instance, falling back to global
53
+ sig { returns(T.untyped) }
54
+ def lm
55
+ config.lm || DSPy.config.lm
11
56
  end
12
57
  end
13
- end
58
+ end
data/lib/dspy/predict.rb CHANGED
@@ -1,35 +1,53 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'sorbet-runtime'
4
+ require_relative 'module'
5
+ require_relative 'instrumentation'
6
+
3
7
  module DSPy
4
- class PredictionInvalidError < RuntimeError
5
- attr_accessor :errors
8
+ # Exception raised when prediction fails validation
9
+ class PredictionInvalidError < StandardError
10
+ extend T::Sig
11
+
12
+ sig { params(errors: T::Hash[T.untyped, T.untyped]).void }
6
13
  def initialize(errors)
7
14
  @errors = errors
8
- super("Prediction invalid: #{errors.to_h}")
15
+ super("Prediction validation failed: #{errors}")
9
16
  end
17
+
18
+ sig { returns(T::Hash[T.untyped, T.untyped]) }
19
+ attr_reader :errors
10
20
  end
21
+
11
22
  class Predict < DSPy::Module
23
+ extend T::Sig
24
+
25
+ sig { returns(T.class_of(Signature)) }
12
26
  attr_reader :signature_class
13
27
 
28
+ sig { params(signature_class: T.class_of(Signature)).void }
14
29
  def initialize(signature_class)
30
+ super()
15
31
  @signature_class = signature_class
16
32
  end
17
33
 
34
+ sig { returns(String) }
18
35
  def system_signature
19
36
  <<-PROMPT
20
37
  Your input schema fields are:
21
38
  ```json
22
- #{JSON.generate(@signature_class.input_schema.json_schema)}
39
+ #{JSON.generate(@signature_class.input_json_schema)}
23
40
  ```
24
41
  Your output schema fields are:
25
42
  ```json
26
- #{JSON.generate(@signature_class.output_schema.json_schema)}
43
+ #{JSON.generate(@signature_class.output_json_schema)}
27
44
  ````
45
+
28
46
  All interactions will be structured in the following way, with the appropriate values filled in.
29
47
 
30
48
  ## Input values
31
49
  ```json
32
- {input_values}
50
+ {input_values}
33
51
  ```
34
52
  ## Output values
35
53
  Respond exclusively with the output schema fields in the json block below.
@@ -42,6 +60,7 @@ module DSPy
42
60
  PROMPT
43
61
  end
44
62
 
63
+ sig { params(input_values: T::Hash[Symbol, T.untyped]).returns(String) }
45
64
  def user_signature(input_values)
46
65
  <<-PROMPT
47
66
  ## Input Values
@@ -54,19 +73,120 @@ module DSPy
54
73
  PROMPT
55
74
  end
56
75
 
57
- def lm
58
- DSPy.config.lm
76
+ sig { override.params(kwargs: T.untyped).returns(T.type_parameter(:O)) }
77
+ def forward(**kwargs)
78
+ @last_input_values = kwargs.clone
79
+ T.cast(forward_untyped(**kwargs), T.type_parameter(:O))
59
80
  end
60
81
 
61
- def forward(**input_values)
62
- DSPy.logger.info( module: self.class.to_s, **input_values)
63
- result = @signature_class.input_schema.call(input_values)
64
- if result.success?
82
+ sig { params(input_values: T.untyped).returns(T.untyped) }
83
+ def forward_untyped(**input_values)
84
+ # Prepare instrumentation payload
85
+ input_fields = input_values.keys.map(&:to_s)
86
+
87
+ Instrumentation.instrument('dspy.predict', {
88
+ signature_class: @signature_class.name,
89
+ model: lm.model,
90
+ provider: lm.provider,
91
+ input_fields: input_fields
92
+ }) do
93
+ # Validate input
94
+ begin
95
+ _input_struct = @signature_class.input_struct_class.new(**input_values)
96
+ rescue ArgumentError => e
97
+ # Emit validation error event
98
+ Instrumentation.emit('dspy.predict.validation_error', {
99
+ signature_class: @signature_class.name,
100
+ validation_type: 'input',
101
+ validation_errors: { input: e.message }
102
+ })
103
+ raise PredictionInvalidError.new({ input: e.message })
104
+ end
105
+
106
+ # Call LM
65
107
  output_attributes = lm.chat(self, input_values)
66
- poro_class = Data.define(*output_attributes.keys)
67
- return poro_class.new(*output_attributes.values)
108
+
109
+ output_attributes = output_attributes.transform_keys(&:to_sym)
110
+
111
+ output_props = @signature_class.output_struct_class.props
112
+ output_attributes = output_attributes.map do |key, value|
113
+ prop_type = output_props[key][:type] if output_props[key]
114
+ if prop_type
115
+ # Check if it's an enum (can be raw Class or T::Types::Simple)
116
+ enum_class = if prop_type.is_a?(Class) && prop_type < T::Enum
117
+ prop_type
118
+ elsif prop_type.is_a?(T::Types::Simple) && prop_type.raw_type < T::Enum
119
+ prop_type.raw_type
120
+ end
121
+
122
+ if enum_class
123
+ [key, enum_class.deserialize(value)]
124
+ elsif prop_type == Float || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Float)
125
+ [key, value.to_f]
126
+ elsif prop_type == Integer || (prop_type.is_a?(T::Types::Simple) && prop_type.raw_type == Integer)
127
+ [key, value.to_i]
128
+ else
129
+ [key, value]
130
+ end
131
+ else
132
+ [key, value]
133
+ end
134
+ end.to_h
135
+
136
+ # Create combined struct with both input and output values
137
+ begin
138
+ combined_struct = create_combined_struct_class
139
+ all_attributes = input_values.merge(output_attributes)
140
+ combined_struct.new(**all_attributes)
141
+ rescue ArgumentError => e
142
+ raise PredictionInvalidError.new({ output: e.message })
143
+ rescue TypeError => e
144
+ raise PredictionInvalidError.new({ output: e.message })
145
+ end
146
+ end
147
+ end
148
+
149
+ private
150
+
151
+ sig { returns(T.class_of(T::Struct)) }
152
+ def create_combined_struct_class
153
+ input_props = @signature_class.input_struct_class.props
154
+ output_props = @signature_class.output_struct_class.props
155
+
156
+ # Create a new struct class that combines input and output fields
157
+ Class.new(T::Struct) do
158
+ extend T::Sig
159
+
160
+ # Add input fields
161
+ input_props.each do |name, prop_info|
162
+ if prop_info[:rules]&.any? { |rule| rule.is_a?(T::Props::NilableRules) }
163
+ prop name, prop_info[:type], default: prop_info[:default]
164
+ else
165
+ const name, prop_info[:type], default: prop_info[:default]
166
+ end
167
+ end
168
+
169
+ # Add output fields
170
+ output_props.each do |name, prop_info|
171
+ if prop_info[:rules]&.any? { |rule| rule.is_a?(T::Props::NilableRules) }
172
+ prop name, prop_info[:type], default: prop_info[:default]
173
+ else
174
+ const name, prop_info[:type], default: prop_info[:default]
175
+ end
176
+ end
177
+
178
+ # Add to_h method to serialize the struct to a hash
179
+ define_method :to_h do
180
+ hash = {}
181
+
182
+ # Add all properties
183
+ self.class.props.keys.each do |key|
184
+ hash[key] = self.send(key)
185
+ end
186
+
187
+ hash
188
+ end
68
189
  end
69
- raise PredictionInvalidError.new(result.errors)
70
190
  end
71
191
  end
72
192
  end