dspy 0.29.1 → 0.30.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +45 -0
  3. data/README.md +159 -95
  4. data/lib/dspy/callbacks.rb +93 -19
  5. data/lib/dspy/context.rb +101 -5
  6. data/lib/dspy/errors.rb +19 -1
  7. data/lib/dspy/{datasets.rb → evals/version.rb} +2 -3
  8. data/lib/dspy/{evaluate.rb → evals.rb} +373 -110
  9. data/lib/dspy/mixins/instruction_updatable.rb +22 -0
  10. data/lib/dspy/module.rb +213 -17
  11. data/lib/dspy/observability.rb +40 -182
  12. data/lib/dspy/predict.rb +10 -2
  13. data/lib/dspy/propose/dataset_summary_generator.rb +28 -18
  14. data/lib/dspy/re_act.rb +21 -0
  15. data/lib/dspy/schema/sorbet_json_schema.rb +302 -0
  16. data/lib/dspy/schema/version.rb +7 -0
  17. data/lib/dspy/schema.rb +4 -0
  18. data/lib/dspy/structured_outputs_prompt.rb +48 -0
  19. data/lib/dspy/support/warning_filters.rb +27 -0
  20. data/lib/dspy/teleprompt/gepa.rb +9 -588
  21. data/lib/dspy/teleprompt/instruction_updates.rb +94 -0
  22. data/lib/dspy/teleprompt/teleprompter.rb +6 -6
  23. data/lib/dspy/teleprompt/utils.rb +5 -65
  24. data/lib/dspy/type_system/sorbet_json_schema.rb +2 -299
  25. data/lib/dspy/version.rb +1 -1
  26. data/lib/dspy.rb +39 -7
  27. metadata +18 -61
  28. data/lib/dspy/code_act.rb +0 -477
  29. data/lib/dspy/datasets/ade.rb +0 -90
  30. data/lib/dspy/observability/async_span_processor.rb +0 -250
  31. data/lib/dspy/observability/observation_type.rb +0 -65
  32. data/lib/dspy/optimizers/gaussian_process.rb +0 -141
  33. data/lib/dspy/teleprompt/mipro_v2.rb +0 -1672
  34. data/lib/gepa/api.rb +0 -61
  35. data/lib/gepa/core/engine.rb +0 -226
  36. data/lib/gepa/core/evaluation_batch.rb +0 -26
  37. data/lib/gepa/core/result.rb +0 -92
  38. data/lib/gepa/core/state.rb +0 -231
  39. data/lib/gepa/logging/experiment_tracker.rb +0 -54
  40. data/lib/gepa/logging/logger.rb +0 -57
  41. data/lib/gepa/logging.rb +0 -9
  42. data/lib/gepa/proposer/base.rb +0 -27
  43. data/lib/gepa/proposer/merge_proposer.rb +0 -424
  44. data/lib/gepa/proposer/reflective_mutation/base.rb +0 -48
  45. data/lib/gepa/proposer/reflective_mutation/reflective_mutation.rb +0 -188
  46. data/lib/gepa/strategies/batch_sampler.rb +0 -91
  47. data/lib/gepa/strategies/candidate_selector.rb +0 -97
  48. data/lib/gepa/strategies/component_selector.rb +0 -57
  49. data/lib/gepa/strategies/instruction_proposal.rb +0 -120
  50. data/lib/gepa/telemetry.rb +0 -122
  51. data/lib/gepa/utils/pareto.rb +0 -119
  52. data/lib/gepa.rb +0 -21
data/lib/dspy/module.rb CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  require 'sorbet-runtime'
4
4
  require 'dry-configurable'
5
+ require 'securerandom'
5
6
  require_relative 'context'
6
7
  require_relative 'callbacks'
7
8
 
@@ -12,10 +13,84 @@ module DSPy
12
13
  include Dry::Configurable
13
14
  include DSPy::Callbacks
14
15
 
16
+ class SubcriptionScope < T::Enum
17
+ enums do
18
+ Descendants = new('descendants')
19
+ SelfOnly = new('self')
20
+ end
21
+ end
22
+
23
+ DEFAULT_MODULE_SUBSCRIPTION_SCOPE = SubcriptionScope::Descendants
24
+
25
+ module ForwardOverrideHooks
26
+ def method_added(method_name)
27
+ super
28
+
29
+ return unless method_name == :forward
30
+ return if self == DSPy::Module
31
+ return if @_wrapping_forward
32
+
33
+ @_wrapping_forward = true
34
+
35
+ original = instance_method(:forward)
36
+ define_method(:forward) do |*args, **kwargs, &block|
37
+ instrument_forward_call(args, kwargs) do
38
+ original.bind(self).call(*args, **kwargs, &block)
39
+ end
40
+ end
41
+ ensure
42
+ @_wrapping_forward = false
43
+ end
44
+ end
45
+
46
+ class << self
47
+ def inherited(subclass)
48
+ super
49
+ specs_copy = module_subscription_specs.map(&:dup)
50
+ subclass.instance_variable_set(:@module_subscription_specs, specs_copy)
51
+ subclass.extend(ForwardOverrideHooks)
52
+ end
53
+
54
+ def subscribe(pattern, handler = nil, scope: DEFAULT_MODULE_SUBSCRIPTION_SCOPE, &block)
55
+ scope = normalize_scope(scope)
56
+ raise ArgumentError, 'Provide a handler method or block' if handler.nil? && block.nil?
57
+
58
+ module_subscription_specs << {
59
+ pattern: pattern,
60
+ handler: handler,
61
+ block: block,
62
+ scope: scope
63
+ }
64
+ end
65
+
66
+ def module_subscription_specs
67
+ @module_subscription_specs ||= []
68
+ end
69
+
70
+ private
71
+
72
+ def validate_subscription_scope!(scope)
73
+ T.must(scope)
74
+ end
75
+
76
+ def normalize_scope(scope)
77
+ return scope if scope.is_a?(SubcriptionScope)
78
+
79
+ case scope
80
+ when :descendants
81
+ SubcriptionScope::Descendants
82
+ when :self
83
+ SubcriptionScope::SelfOnly
84
+ else
85
+ raise ArgumentError, "Unsupported subscription scope: #{scope.inspect}"
86
+ end
87
+ end
88
+ end
89
+
15
90
  # Per-instance LM configuration
16
91
  setting :lm, default: nil
17
92
 
18
- # Define callback hooks for forward method
93
+ # Enable callback hooks for forward method
19
94
  create_before_callback :forward
20
95
  create_after_callback :forward
21
96
  create_around_callback :forward
@@ -29,23 +104,8 @@ module DSPy
29
104
  .returns(T.type_parameter(:O))
30
105
  end
31
106
  def forward(**input_values)
32
- # Create span for this module's execution
33
- observation_type = DSPy::ObservationType.for_module_class(self.class)
34
- DSPy::Context.with_span(
35
- operation: "#{self.class.name}.forward",
36
- **observation_type.langfuse_attributes,
37
- 'langfuse.observation.input' => input_values.to_json,
38
- 'dspy.module' => self.class.name
39
- ) do |span|
107
+ instrument_forward_call([], input_values) do
40
108
  result = forward_untyped(**input_values)
41
-
42
- # Add output to span
43
- if span && result
44
- output_json = result.respond_to?(:to_h) ? result.to_h.to_json : result.to_json rescue result.to_s
45
- span.set_attribute('langfuse.observation.output', output_json)
46
- end
47
-
48
- # Cast the result of forward_untyped to the expected output type
49
109
  T.cast(result, T.type_parameter(:O))
50
110
  end
51
111
  end
@@ -116,5 +176,141 @@ module DSPy
116
176
  def predictors
117
177
  named_predictors.map { |(_, predictor)| predictor }
118
178
  end
179
+
180
+ def instrument_forward_call(call_args, call_kwargs)
181
+ ensure_module_subscriptions!
182
+
183
+ DSPy::Context.with_module(self) do
184
+ observation_type = DSPy::ObservationType.for_module_class(self.class)
185
+ span_attributes = observation_type.langfuse_attributes.merge(
186
+ 'langfuse.observation.input' => serialize_module_input(call_args, call_kwargs),
187
+ 'dspy.module' => self.class.name
188
+ )
189
+
190
+ DSPy::Context.with_span(
191
+ operation: "#{self.class.name}.forward",
192
+ **span_attributes
193
+ ) do |span|
194
+ yield.tap do |result|
195
+ if span && result
196
+ span.set_attribute('langfuse.observation.output', serialize_module_output(result))
197
+ end
198
+ end
199
+ end
200
+ end
201
+ end
202
+
203
+ def serialize_module_input(call_args, call_kwargs)
204
+ payload = if call_kwargs && !call_kwargs.empty?
205
+ call_kwargs
206
+ elsif call_args && !call_args.empty?
207
+ call_args
208
+ else
209
+ {}
210
+ end
211
+
212
+ payload.to_json
213
+ rescue StandardError
214
+ payload.to_s
215
+ end
216
+
217
+ def serialize_module_output(result)
218
+ if result.respond_to?(:to_h)
219
+ result.to_h.to_json
220
+ else
221
+ result.to_json
222
+ end
223
+ rescue StandardError
224
+ result.to_s
225
+ end
226
+
227
+ private :instrument_forward_call, :serialize_module_input, :serialize_module_output
228
+
229
+ sig { returns(String) }
230
+ def module_scope_id
231
+ @module_scope_id ||= SecureRandom.uuid
232
+ end
233
+
234
+ sig { returns(T.nilable(String)) }
235
+ def module_scope_label
236
+ @module_scope_label
237
+ end
238
+
239
+ sig { params(label: T.nilable(String)).void }
240
+ def module_scope_label=(label)
241
+ @module_scope_label = label
242
+ end
243
+
244
+ sig { returns(T::Array[String]) }
245
+ def registered_module_subscriptions
246
+ Array(@module_subscription_ids).dup
247
+ end
248
+
249
+ sig { void }
250
+ def unsubscribe_module_events
251
+ Array(@module_subscription_ids).each { |id| DSPy.events.unsubscribe(id) }
252
+ @module_subscription_ids = []
253
+ @module_subscriptions_registered = false
254
+ end
255
+
256
+ private
257
+
258
+ def ensure_module_subscriptions!
259
+ return if @module_subscriptions_registered
260
+
261
+ specs = self.class.module_subscription_specs
262
+ if specs.empty?
263
+ @module_subscriptions_registered = true
264
+ return
265
+ end
266
+
267
+ @module_subscription_ids ||= []
268
+ specs.each do |spec|
269
+ callback = build_subscription_callback(spec)
270
+ subscription_id = DSPy.events.subscribe(spec[:pattern], &callback)
271
+ @module_subscription_ids << subscription_id
272
+ end
273
+
274
+ @module_subscriptions_registered = true
275
+ end
276
+
277
+ def build_subscription_callback(spec)
278
+ scope = spec[:scope] || DEFAULT_MODULE_SUBSCRIPTION_SCOPE
279
+ handler = spec[:handler]
280
+ block = spec[:block]
281
+
282
+ proc do |event_name, attributes|
283
+ next unless module_event_within_scope?(attributes, scope)
284
+
285
+ if handler
286
+ send(handler, event_name, attributes)
287
+ else
288
+ instance_exec(event_name, attributes, &block)
289
+ end
290
+ end
291
+ end
292
+
293
+ def module_event_within_scope?(attributes, scope)
294
+ metadata = extract_module_metadata(attributes)
295
+ return false unless metadata
296
+
297
+ case scope
298
+ when SubcriptionScope::SelfOnly
299
+ metadata[:leaf_id] == module_scope_id
300
+ else
301
+ metadata[:path_ids].include?(module_scope_id)
302
+ end
303
+ end
304
+
305
+ def extract_module_metadata(attributes)
306
+ path = attributes[:module_path] || attributes['module_path']
307
+ leaf = attributes[:module_leaf] || attributes['module_leaf']
308
+ return nil unless path.is_a?(Array)
309
+
310
+ {
311
+ path_ids: path.map { |entry| entry[:id] || entry['id'] }.compact,
312
+ leaf_id: leaf&.dig(:id) || leaf&.dig('id')
313
+ }
314
+ end
119
315
  end
120
316
  end
@@ -1,196 +1,54 @@
1
+ # typed: false
1
2
  # frozen_string_literal: true
2
3
 
3
- require 'base64'
4
- require_relative 'observability/async_span_processor'
5
-
6
- module DSPy
7
- class Observability
8
- class << self
9
- attr_reader :enabled, :tracer, :endpoint
10
-
11
- def configure!
12
- @enabled = false
13
-
14
- # Check for explicit disable flag first
15
- if ENV['DSPY_DISABLE_OBSERVABILITY'] == 'true'
16
- DSPy.log('observability.disabled', reason: 'Explicitly disabled via DSPY_DISABLE_OBSERVABILITY')
17
- return
18
- end
19
-
20
- # Check for required Langfuse environment variables
21
- public_key = ENV['LANGFUSE_PUBLIC_KEY']
22
- secret_key = ENV['LANGFUSE_SECRET_KEY']
23
-
24
- # Skip OTLP configuration in test environment UNLESS Langfuse credentials are explicitly provided
25
- # This allows observability tests to run while protecting general tests from network calls
26
- if (ENV['RACK_ENV'] == 'test' || ENV['RAILS_ENV'] == 'test' || defined?(RSpec)) && !(public_key && secret_key)
27
- DSPy.log('observability.disabled', reason: 'Test environment detected - OTLP disabled')
28
- return
29
- end
30
-
31
- unless public_key && secret_key
32
- return
4
+ begin
5
+ require 'dspy/o11y'
6
+ rescue LoadError
7
+ require 'sorbet-runtime'
8
+
9
+ module DSPy
10
+ class Observability
11
+ class << self
12
+ def register_configurator(*); end
13
+ def configure!(*); false; end
14
+ def enabled?; false; end
15
+ def enable!(*); false; end
16
+ def disable!(*); nil; end
17
+ def start_span(*); nil; end
18
+ def finish_span(*); nil; end
19
+ def flush!; nil; end
20
+ def reset!; nil; end
21
+ def require_dependency(lib)
22
+ require lib
23
+ rescue LoadError
24
+ raise
33
25
  end
34
-
35
- # Determine endpoint based on host
36
- host = ENV['LANGFUSE_HOST'] || 'https://cloud.langfuse.com'
37
- @endpoint = "#{host}/api/public/otel/v1/traces"
38
-
39
- begin
40
- # Load OpenTelemetry gems
41
- require 'opentelemetry/sdk'
42
- require 'opentelemetry/exporter/otlp'
43
-
44
- patch_frozen_ssl_context_for_otlp!
45
-
46
- # Generate Basic Auth header
47
- auth_string = Base64.strict_encode64("#{public_key}:#{secret_key}")
48
-
49
- # Configure OpenTelemetry SDK
50
- OpenTelemetry::SDK.configure do |config|
51
- config.service_name = 'dspy-ruby'
52
- config.service_version = DSPy::VERSION
53
-
54
- # Add OTLP exporter for Langfuse using AsyncSpanProcessor
55
- exporter = OpenTelemetry::Exporter::OTLP::Exporter.new(
56
- endpoint: @endpoint,
57
- headers: {
58
- 'Authorization' => "Basic #{auth_string}",
59
- 'Content-Type' => 'application/x-protobuf'
60
- },
61
- compression: 'gzip'
62
- )
63
-
64
- # Configure AsyncSpanProcessor with environment variables
65
- async_config = {
66
- queue_size: (ENV['DSPY_TELEMETRY_QUEUE_SIZE'] || AsyncSpanProcessor::DEFAULT_QUEUE_SIZE).to_i,
67
- export_interval: (ENV['DSPY_TELEMETRY_EXPORT_INTERVAL'] || AsyncSpanProcessor::DEFAULT_EXPORT_INTERVAL).to_f,
68
- export_batch_size: (ENV['DSPY_TELEMETRY_BATCH_SIZE'] || AsyncSpanProcessor::DEFAULT_EXPORT_BATCH_SIZE).to_i,
69
- shutdown_timeout: (ENV['DSPY_TELEMETRY_SHUTDOWN_TIMEOUT'] || AsyncSpanProcessor::DEFAULT_SHUTDOWN_TIMEOUT).to_f
70
- }
71
-
72
- config.add_span_processor(
73
- AsyncSpanProcessor.new(exporter, **async_config)
74
- )
75
-
76
- # Add resource attributes
77
- config.resource = OpenTelemetry::SDK::Resources::Resource.create({
78
- 'service.name' => 'dspy-ruby',
79
- 'service.version' => DSPy::VERSION,
80
- 'telemetry.sdk.name' => 'opentelemetry',
81
- 'telemetry.sdk.language' => 'ruby'
82
- })
83
- end
84
-
85
- # Create tracer
86
- @tracer = OpenTelemetry.tracer_provider.tracer('dspy', DSPy::VERSION)
87
- @enabled = true
88
-
89
- rescue LoadError => e
90
- DSPy.log('observability.disabled', reason: 'OpenTelemetry gems not available')
91
- rescue StandardError => e
92
- DSPy.log('observability.error', error: e.message, class: e.class.name)
93
- end
94
- end
95
-
96
- def enabled?
97
- @enabled == true
98
- end
99
-
100
- def tracer
101
- @tracer
102
- end
103
-
104
- def start_span(operation_name, attributes = {})
105
- return nil unless enabled? && tracer
106
-
107
- # Convert attribute keys to strings and filter out nil values
108
- string_attributes = attributes.transform_keys(&:to_s)
109
- .reject { |k, v| v.nil? }
110
- string_attributes['operation.name'] = operation_name
111
-
112
- tracer.start_span(
113
- operation_name,
114
- kind: :internal,
115
- attributes: string_attributes
116
- )
117
- rescue StandardError => e
118
- DSPy.log('observability.span_error', error: e.message, operation: operation_name)
119
- nil
120
26
  end
27
+ end
121
28
 
122
- def finish_span(span)
123
- return unless span
124
-
125
- span.finish
126
- rescue StandardError => e
127
- DSPy.log('observability.span_finish_error', error: e.message)
29
+ class ObservationType < T::Enum
30
+ enums do
31
+ Generation = new('generation')
32
+ Agent = new('agent')
33
+ Tool = new('tool')
34
+ Chain = new('chain')
35
+ Retriever = new('retriever')
36
+ Embedding = new('embedding')
37
+ Evaluator = new('evaluator')
38
+ Span = new('span')
39
+ Event = new('event')
128
40
  end
129
41
 
130
- def flush!
131
- return unless enabled?
132
-
133
- # Force flush any pending spans
134
- OpenTelemetry.tracer_provider.force_flush
135
- rescue StandardError => e
136
- DSPy.log('observability.flush_error', error: e.message)
42
+ def self.for_module_class(_module_class)
43
+ Span
137
44
  end
138
45
 
139
- def reset!
140
- @enabled = false
141
-
142
- # Shutdown OpenTelemetry if it's configured
143
- if defined?(OpenTelemetry) && OpenTelemetry.tracer_provider
144
- begin
145
- OpenTelemetry.tracer_provider.shutdown(timeout: 1.0)
146
- rescue => e
147
- # Ignore shutdown errors in tests - log them but don't fail
148
- DSPy.log('observability.shutdown_error', error: e.message) if respond_to?(:log)
149
- end
150
- end
151
-
152
- @tracer = nil
153
- @endpoint = nil
46
+ def langfuse_attribute
47
+ ['langfuse.observation.type', serialize]
154
48
  end
155
49
 
156
- private
157
-
158
- def patch_frozen_ssl_context_for_otlp!
159
- return unless defined?(OpenTelemetry::Exporter::OTLP::Exporter)
160
-
161
- ssl_context_frozen = begin
162
- http = Net::HTTP.new('example.com', 443)
163
- http.use_ssl = true
164
- http.ssl_context&.frozen?
165
- rescue StandardError
166
- false
167
- end
168
-
169
- return unless ssl_context_frozen
170
-
171
- exporter = OpenTelemetry::Exporter::OTLP::Exporter
172
- return if exporter.instance_variable_defined?(:@_dspy_ssl_patch_applied)
173
-
174
- exporter.class_eval do
175
- define_method(:http_connection) do |uri, ssl_verify_mode, certificate_file, client_certificate_file, client_key_file|
176
- http = Net::HTTP.new(uri.host, uri.port)
177
- use_ssl = uri.scheme == 'https'
178
- http.use_ssl = use_ssl
179
-
180
- if use_ssl && http.ssl_context&.frozen?
181
- http.instance_variable_set(:@ssl_context, OpenSSL::SSL::SSLContext.new)
182
- end
183
-
184
- http.verify_mode = ssl_verify_mode
185
- http.ca_file = certificate_file unless certificate_file.nil?
186
- http.cert = OpenSSL::X509::Certificate.new(File.read(client_certificate_file)) unless client_certificate_file.nil?
187
- http.key = OpenSSL::PKey::RSA.new(File.read(client_key_file)) unless client_key_file.nil?
188
- http.keep_alive_timeout = KEEP_ALIVE_TIMEOUT
189
- http
190
- end
191
- end
192
-
193
- exporter.instance_variable_set(:@_dspy_ssl_patch_applied, true)
50
+ def langfuse_attributes
51
+ { 'langfuse.observation.type' => serialize }
194
52
  end
195
53
  end
196
54
  end
data/lib/dspy/predict.rb CHANGED
@@ -6,6 +6,7 @@ require_relative 'prompt'
6
6
  require_relative 'utils/serialization'
7
7
  require_relative 'mixins/struct_builder'
8
8
  require_relative 'mixins/type_coercion'
9
+ require_relative 'mixins/instruction_updatable'
9
10
  require_relative 'error_formatter'
10
11
 
11
12
  module DSPy
@@ -46,6 +47,7 @@ module DSPy
46
47
  extend T::Sig
47
48
  include Mixins::StructBuilder
48
49
  include Mixins::TypeCoercion
50
+ include Mixins::InstructionUpdatable
49
51
 
50
52
  sig { returns(T.class_of(Signature)) }
51
53
  attr_reader :signature_class
@@ -120,6 +122,7 @@ module DSPy
120
122
  # Create a new instance with the same signature but updated prompt
121
123
  instance = self.class.new(@signature_class)
122
124
  instance.instance_variable_set(:@prompt, new_prompt)
125
+ instance.instance_variable_set(:@demos, @demos&.map { |demo| demo })
123
126
  instance
124
127
  end
125
128
 
@@ -130,12 +133,17 @@ module DSPy
130
133
 
131
134
  sig { params(examples: T::Array[FewShotExample]).returns(Predict) }
132
135
  def with_examples(examples)
133
- with_prompt(@prompt.with_examples(examples))
136
+ instance = with_prompt(@prompt.with_examples(examples))
137
+ instance.demos = examples.map { |example| example }
138
+ instance
134
139
  end
135
140
 
136
141
  sig { params(examples: T::Array[FewShotExample]).returns(Predict) }
137
142
  def add_examples(examples)
138
- with_prompt(@prompt.add_examples(examples))
143
+ instance = with_prompt(@prompt.add_examples(examples))
144
+ combined = instance.prompt.few_shot_examples
145
+ instance.demos = combined.map { |example| example }
146
+ instance
139
147
  end
140
148
 
141
149
  sig { override.returns(T::Array[[String, DSPy::Module]]) }
@@ -34,7 +34,7 @@ module DSPy
34
34
  "It will be useful to make an educated guess as to the nature of the task this dataset will enable. Don't be afraid to be creative"
35
35
 
36
36
  input do
37
- const :examples, String, description: "Sample data points from the dataset"
37
+ const :examples, T::Array[T::Hash[String, T.untyped]], description: "Sample data points from the dataset"
38
38
  end
39
39
 
40
40
  output do
@@ -50,7 +50,7 @@ module DSPy
50
50
  "It will be useful to make an educated guess as to the nature of the task this dataset will enable. Don't be afraid to be creative"
51
51
 
52
52
  input do
53
- const :examples, String, description: "Sample data points from the dataset"
53
+ const :examples, T::Array[T::Hash[String, T.untyped]], description: "Sample data points from the dataset"
54
54
  const :prior_observations, String, description: "Some prior observations I made about the data"
55
55
  end
56
56
 
@@ -124,9 +124,7 @@ module DSPy
124
124
  upper_lim = [trainset.length, view_data_batch_size].min
125
125
  batch_examples = trainset[0...upper_lim]
126
126
  predictor = DSPy::Predict.new(DatasetDescriptor)
127
- examples_repr = format_examples_for_prompt(batch_examples)
128
-
129
- observation = predictor.call(examples: examples_repr)
127
+ observation = predictor.call(examples: format_examples_for_prompt(batch_examples))
130
128
  observations = observation.observations
131
129
 
132
130
  # Iteratively refine observations with additional batches
@@ -145,11 +143,9 @@ module DSPy
145
143
 
146
144
  predictor = DSPy::Predict.new(DatasetDescriptorWithPriorObservations)
147
145
  batch_examples = trainset[b...upper_lim]
148
- examples_repr = format_examples_for_prompt(batch_examples)
149
-
150
146
  output = predictor.call(
151
147
  prior_observations: observations,
152
- examples: examples_repr
148
+ examples: format_examples_for_prompt(batch_examples)
153
149
  )
154
150
 
155
151
  # Check if LLM indicates observations are complete
@@ -179,31 +175,45 @@ module DSPy
179
175
  end
180
176
  end
181
177
 
182
- sig { params(examples: T::Array[T.untyped]).returns(String) }
178
+ sig { params(examples: T::Array[T.untyped]).returns(T::Array[T::Hash[String, T.untyped]]) }
183
179
  def self.format_examples_for_prompt(examples)
184
180
  serialized_examples = examples.map do |example|
185
181
  case example
186
182
  when DSPy::Example
187
183
  {
188
- signature: example.signature_class.name,
189
- input: DSPy::TypeSerializer.serialize(example.input),
190
- expected: DSPy::TypeSerializer.serialize(example.expected)
184
+ 'signature' => example.signature_class.name || example.signature_class.to_s,
185
+ 'input' => stringify_keys(DSPy::TypeSerializer.serialize(example.input)),
186
+ 'expected' => stringify_keys(DSPy::TypeSerializer.serialize(example.expected))
191
187
  }
192
188
  when DSPy::FewShotExample
193
189
  base = {
194
- input: example.input,
195
- output: example.output
190
+ 'input' => stringify_keys(example.input),
191
+ 'output' => stringify_keys(example.output)
196
192
  }
197
- base[:reasoning] = example.reasoning if example.reasoning
193
+ base['reasoning'] = example.reasoning if example.reasoning
198
194
  base
199
195
  when Hash
200
- example
196
+ stringify_keys(example)
201
197
  else
202
- example.respond_to?(:to_h) ? example.to_h : { value: example }
198
+ stringify_keys(example.respond_to?(:to_h) ? example.to_h : { value: example })
203
199
  end
204
200
  end
205
201
 
206
- JSON.pretty_generate(serialized_examples)
202
+ serialized_examples
203
+ end
204
+
205
+ sig { params(value: T.untyped).returns(T.untyped) }
206
+ def self.stringify_keys(value)
207
+ case value
208
+ when Hash
209
+ value.each_with_object({}) do |(k, v), result|
210
+ result[k.to_s] = stringify_keys(v)
211
+ end
212
+ when Array
213
+ value.map { |item| stringify_keys(item) }
214
+ else
215
+ value
216
+ end
207
217
  end
208
218
  end
209
219
  end
data/lib/dspy/re_act.rb CHANGED
@@ -157,6 +157,27 @@ module DSPy
157
157
  named_predictors.map { |(_, predictor)| predictor }
158
158
  end
159
159
 
160
+ sig { returns(DSPy::Prompt) }
161
+ def prompt
162
+ @thought_generator.prompt
163
+ end
164
+
165
+ sig { params(instruction: String).returns(ReAct).override }
166
+ def with_instruction(instruction)
167
+ clone = self.class.new(@original_signature_class, tools: @tools.values, max_iterations: @max_iterations)
168
+ thought_generator = clone.instance_variable_get(:@thought_generator)
169
+ clone.instance_variable_set(:@thought_generator, thought_generator.with_instruction(instruction))
170
+ clone
171
+ end
172
+
173
+ sig { params(examples: T::Array[DSPy::FewShotExample]).returns(ReAct).override }
174
+ def with_examples(examples)
175
+ clone = self.class.new(@original_signature_class, tools: @tools.values, max_iterations: @max_iterations)
176
+ thought_generator = clone.instance_variable_get(:@thought_generator)
177
+ clone.instance_variable_set(:@thought_generator, thought_generator.with_examples(examples))
178
+ clone
179
+ end
180
+
160
181
  sig { params(kwargs: T.untyped).returns(T.untyped).override }
161
182
  def forward(**kwargs)
162
183
  # Validate input