dspy 0.3.1 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,554 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'sorbet-runtime'
4
+ require_relative 'instrumentation'
5
+ require_relative 'example'
6
+
7
+ module DSPy
8
+ # Core evaluation framework for DSPy programs
9
+ # Supports single evaluations, batch evaluations, and optimization workflows
10
+ class Evaluate
11
+ extend T::Sig
12
+
13
+ # Result of evaluating a single example
14
+ class EvaluationResult
15
+ extend T::Sig
16
+
17
+ sig { returns(T.untyped) }
18
+ attr_reader :example
19
+
20
+ sig { returns(T.untyped) }
21
+ attr_reader :prediction
22
+
23
+ sig { returns(T.untyped) }
24
+ attr_reader :trace
25
+
26
+ sig { returns(T::Hash[Symbol, T.untyped]) }
27
+ attr_reader :metrics
28
+
29
+ sig { returns(T::Boolean) }
30
+ attr_reader :passed
31
+
32
+ sig do
33
+ params(
34
+ example: T.untyped,
35
+ prediction: T.untyped,
36
+ trace: T.untyped,
37
+ metrics: T::Hash[Symbol, T.untyped],
38
+ passed: T::Boolean
39
+ ).void
40
+ end
41
+ def initialize(example:, prediction:, trace:, metrics:, passed:)
42
+ @example = example
43
+ @prediction = prediction
44
+ @trace = trace
45
+ @metrics = metrics
46
+ @passed = passed
47
+ end
48
+
49
+ sig { returns(T::Hash[Symbol, T.untyped]) }
50
+ def to_h
51
+ {
52
+ example: @example,
53
+ prediction: @prediction,
54
+ trace: @trace,
55
+ metrics: @metrics,
56
+ passed: @passed
57
+ }
58
+ end
59
+ end
60
+
61
+ # Batch evaluation results with aggregated metrics
62
+ class BatchEvaluationResult
63
+ extend T::Sig
64
+
65
+ sig { returns(T::Array[EvaluationResult]) }
66
+ attr_reader :results
67
+
68
+ sig { returns(T::Hash[Symbol, T.untyped]) }
69
+ attr_reader :aggregated_metrics
70
+
71
+ sig { returns(Integer) }
72
+ attr_reader :total_examples
73
+
74
+ sig { returns(Integer) }
75
+ attr_reader :passed_examples
76
+
77
+ sig { returns(Float) }
78
+ attr_reader :pass_rate
79
+
80
+ sig do
81
+ params(
82
+ results: T::Array[EvaluationResult],
83
+ aggregated_metrics: T::Hash[Symbol, T.untyped]
84
+ ).void
85
+ end
86
+ def initialize(results:, aggregated_metrics:)
87
+ @results = results.freeze
88
+ @aggregated_metrics = aggregated_metrics.freeze
89
+ @total_examples = results.length
90
+ @passed_examples = results.count(&:passed)
91
+ @pass_rate = @total_examples > 0 ? @passed_examples.to_f / @total_examples : 0.0
92
+ end
93
+
94
+ sig { returns(T::Hash[Symbol, T.untyped]) }
95
+ def to_h
96
+ {
97
+ total_examples: @total_examples,
98
+ passed_examples: @passed_examples,
99
+ pass_rate: @pass_rate,
100
+ aggregated_metrics: @aggregated_metrics,
101
+ results: @results.map(&:to_h)
102
+ }
103
+ end
104
+ end
105
+
106
+ sig { returns(T.untyped) }
107
+ attr_reader :program
108
+
109
+ sig { returns(T.nilable(T.proc.params(arg0: T.untyped, arg1: T.untyped).returns(T::Boolean))) }
110
+ attr_reader :metric
111
+
112
+ sig { returns(T.nilable(Integer)) }
113
+ attr_reader :num_threads
114
+
115
+ sig { returns(T.nilable(Integer)) }
116
+ attr_reader :max_errors
117
+
118
+ sig { returns(T::Boolean) }
119
+ attr_reader :provide_traceback
120
+
121
+ sig do
122
+ params(
123
+ program: T.untyped,
124
+ metric: T.nilable(T.proc.params(arg0: T.untyped, arg1: T.untyped).returns(T::Boolean)),
125
+ num_threads: T.nilable(Integer),
126
+ max_errors: T.nilable(Integer),
127
+ provide_traceback: T::Boolean
128
+ ).void
129
+ end
130
+ def initialize(program, metric: nil, num_threads: 1, max_errors: 5, provide_traceback: true)
131
+ @program = program
132
+ @metric = metric
133
+ @num_threads = num_threads || 1
134
+ @max_errors = max_errors || 5
135
+ @provide_traceback = provide_traceback
136
+ end
137
+
138
+ # Evaluate program on a single example
139
+ sig { params(example: T.untyped, trace: T.nilable(T.untyped)).returns(EvaluationResult) }
140
+ def call(example, trace: nil)
141
+ Instrumentation.instrument('dspy.evaluation.example', {
142
+ program_class: @program.class.name,
143
+ has_metric: !@metric.nil?
144
+ }) do
145
+ begin
146
+ # Extract input from example - support both hash and object formats
147
+ input_values = extract_input_values(example)
148
+
149
+ # Run prediction
150
+ prediction = @program.call(**input_values)
151
+
152
+ # Calculate metrics if provided
153
+ metrics = {}
154
+ passed = true
155
+
156
+ if @metric
157
+ begin
158
+ metric_result = @metric.call(example, prediction)
159
+ if metric_result.is_a?(Hash)
160
+ metrics = metric_result
161
+ passed = metrics[:passed] || metrics['passed'] || true
162
+ else
163
+ passed = !!metric_result
164
+ metrics[:passed] = passed
165
+ end
166
+ rescue => e
167
+ passed = false
168
+ metrics[:error] = e.message
169
+ metrics[:passed] = false
170
+ end
171
+ end
172
+
173
+ EvaluationResult.new(
174
+ example: example,
175
+ prediction: prediction,
176
+ trace: trace,
177
+ metrics: metrics,
178
+ passed: passed
179
+ )
180
+ rescue => e
181
+ # Return failed evaluation result
182
+ error_metrics = {
183
+ error: e.message,
184
+ passed: false
185
+ }
186
+
187
+ if @provide_traceback
188
+ error_metrics[:traceback] = e.backtrace&.first(10) || []
189
+ end
190
+
191
+ EvaluationResult.new(
192
+ example: example,
193
+ prediction: nil,
194
+ trace: trace,
195
+ metrics: error_metrics,
196
+ passed: false
197
+ )
198
+ end
199
+ end
200
+ end
201
+
202
+ # Evaluate program on multiple examples
203
+ sig do
204
+ params(
205
+ devset: T::Array[T.untyped],
206
+ display_progress: T::Boolean,
207
+ display_table: T::Boolean,
208
+ return_outputs: T::Boolean
209
+ ).returns(BatchEvaluationResult)
210
+ end
211
+ def evaluate(devset, display_progress: true, display_table: false, return_outputs: true)
212
+ Instrumentation.instrument('dspy.evaluation.batch', {
213
+ program_class: @program.class.name,
214
+ num_examples: devset.length,
215
+ has_metric: !@metric.nil?,
216
+ num_threads: @num_threads
217
+ }) do
218
+ results = []
219
+ errors = 0
220
+
221
+ if display_progress
222
+ puts "Evaluating #{devset.length} examples..."
223
+ end
224
+
225
+ devset.each_with_index do |example, index|
226
+ break if errors >= @max_errors
227
+
228
+ begin
229
+ result = call(example)
230
+ results << result
231
+
232
+ unless result.passed
233
+ errors += 1
234
+ end
235
+
236
+ if display_progress && (index + 1) % 10 == 0
237
+ puts "Processed #{index + 1}/#{devset.length} examples (#{results.count(&:passed)} passed)"
238
+ end
239
+
240
+ rescue => e
241
+ errors += 1
242
+ puts "Error processing example #{index}: #{e.message}" if display_progress
243
+
244
+ # Create error result
245
+ error_result = EvaluationResult.new(
246
+ example: example,
247
+ prediction: nil,
248
+ trace: nil,
249
+ metrics: { error: e.message, passed: false },
250
+ passed: false
251
+ )
252
+ results << error_result
253
+ end
254
+ end
255
+
256
+ # Aggregate metrics
257
+ aggregated_metrics = aggregate_metrics(results)
258
+
259
+ batch_result = BatchEvaluationResult.new(
260
+ results: results,
261
+ aggregated_metrics: aggregated_metrics
262
+ )
263
+
264
+ if display_table
265
+ display_results_table(batch_result)
266
+ end
267
+
268
+ # Emit batch completion event
269
+ Instrumentation.emit('dspy.evaluation.batch_complete', {
270
+ program_class: @program.class.name,
271
+ total_examples: batch_result.total_examples,
272
+ passed_examples: batch_result.passed_examples,
273
+ pass_rate: batch_result.pass_rate,
274
+ aggregated_metrics: aggregated_metrics
275
+ })
276
+
277
+ if display_progress
278
+ puts "Evaluation complete: #{batch_result.passed_examples}/#{batch_result.total_examples} passed (#{(batch_result.pass_rate * 100).round(1)}%)"
279
+ end
280
+
281
+ batch_result
282
+ end
283
+ end
284
+
285
+ private
286
+
287
+ # Extract input values from example in various formats
288
+ sig { params(example: T.untyped).returns(T::Hash[Symbol, T.untyped]) }
289
+ def extract_input_values(example)
290
+ case example
291
+ when DSPy::Example
292
+ # Preferred format: DSPy::Example object with type safety
293
+ example.input_values
294
+ when Hash
295
+ # Check if it has an :input key (structured format)
296
+ if example.key?(:input)
297
+ input_data = example[:input]
298
+ input_data.is_a?(Hash) ? input_data.transform_keys(&:to_sym) : input_data
299
+ elsif example.key?('input')
300
+ input_data = example['input']
301
+ input_data.is_a?(Hash) ? input_data.transform_keys(&:to_sym) : input_data
302
+ else
303
+ # Legacy format - assume the whole hash is input
304
+ if example.keys.first.is_a?(String)
305
+ example.transform_keys(&:to_sym)
306
+ else
307
+ example
308
+ end
309
+ end
310
+ when ->(ex) { ex.respond_to?(:input_values) }
311
+ # Object with input_values method (Example-like)
312
+ example.input_values
313
+ when ->(ex) { ex.respond_to?(:input) }
314
+ # Object with input method
315
+ input_data = example.input
316
+ input_data.is_a?(Hash) ? input_data.transform_keys(&:to_sym) : input_data
317
+ when ->(ex) { ex.respond_to?(:to_h) }
318
+ # Object that can be converted to hash
319
+ hash = example.to_h
320
+ if hash.key?(:input)
321
+ input_data = hash[:input]
322
+ input_data.is_a?(Hash) ? input_data.transform_keys(&:to_sym) : input_data
323
+ elsif hash.key?('input')
324
+ input_data = hash['input']
325
+ input_data.is_a?(Hash) ? input_data.transform_keys(&:to_sym) : input_data
326
+ else
327
+ hash.is_a?(Hash) ? hash.transform_keys(&:to_sym) : hash
328
+ end
329
+ else
330
+ # Try to extract by introspection
331
+ if example.respond_to?(:instance_variables)
332
+ vars = {}
333
+ example.instance_variables.each do |var|
334
+ key = var.to_s.delete('@').to_sym
335
+ vars[key] = example.instance_variable_get(var)
336
+ end
337
+ vars
338
+ else
339
+ raise ArgumentError, "Cannot extract input values from example: #{example.class}"
340
+ end
341
+ end
342
+ end
343
+
344
+ # Extract expected values for metric comparison (used internally)
345
+ sig { params(example: T.untyped).returns(T.nilable(T::Hash[Symbol, T.untyped])) }
346
+ def extract_expected_values(example)
347
+ case example
348
+ when DSPy::Example
349
+ example.expected_values
350
+ when Hash
351
+ if example.key?(:expected)
352
+ expected_data = example[:expected]
353
+ expected_data.is_a?(Hash) ? expected_data.transform_keys(&:to_sym) : expected_data
354
+ elsif example.key?('expected')
355
+ expected_data = example['expected']
356
+ expected_data.is_a?(Hash) ? expected_data.transform_keys(&:to_sym) : expected_data
357
+ else
358
+ # Legacy format - no separate expected values
359
+ nil
360
+ end
361
+ when ->(ex) { ex.respond_to?(:expected_values) }
362
+ example.expected_values
363
+ when ->(ex) { ex.respond_to?(:expected) }
364
+ expected_data = example.expected
365
+ expected_data.is_a?(Hash) ? expected_data.transform_keys(&:to_sym) : expected_data
366
+ else
367
+ nil
368
+ end
369
+ end
370
+
371
+ # Aggregate metrics across all results
372
+ sig { params(results: T::Array[EvaluationResult]).returns(T::Hash[Symbol, T.untyped]) }
373
+ def aggregate_metrics(results)
374
+ return {} if results.empty?
375
+
376
+ # Start with basic metrics
377
+ aggregated = {
378
+ total_examples: results.length,
379
+ passed_examples: results.count(&:passed),
380
+ failed_examples: results.count { |r| !r.passed }
381
+ }
382
+
383
+ # Aggregate numeric metrics
384
+ numeric_metrics = {}
385
+ results.each do |result|
386
+ result.metrics.each do |key, value|
387
+ next if [:error, :traceback, :passed].include?(key)
388
+ next unless value.is_a?(Numeric)
389
+
390
+ numeric_metrics[key] ||= []
391
+ numeric_metrics[key] << value
392
+ end
393
+ end
394
+
395
+ # Calculate averages for numeric metrics
396
+ numeric_metrics.each do |key, values|
397
+ aggregated[:"#{key}_avg"] = values.sum.to_f / values.length
398
+ aggregated[:"#{key}_min"] = values.min
399
+ aggregated[:"#{key}_max"] = values.max
400
+ end
401
+
402
+ # Calculate pass rate
403
+ aggregated[:pass_rate] = aggregated[:total_examples] > 0 ?
404
+ aggregated[:passed_examples].to_f / aggregated[:total_examples] : 0.0
405
+
406
+ aggregated
407
+ end
408
+
409
+ # Display results in a table format
410
+ sig { params(batch_result: BatchEvaluationResult).void }
411
+ def display_results_table(batch_result)
412
+ puts "\nEvaluation Results:"
413
+ puts "=" * 50
414
+ puts "Total Examples: #{batch_result.total_examples}"
415
+ puts "Passed: #{batch_result.passed_examples}"
416
+ puts "Failed: #{batch_result.total_examples - batch_result.passed_examples}"
417
+ puts "Pass Rate: #{(batch_result.pass_rate * 100).round(1)}%"
418
+
419
+ if batch_result.aggregated_metrics.any?
420
+ puts "\nAggregated Metrics:"
421
+ batch_result.aggregated_metrics.each do |key, value|
422
+ next if [:total_examples, :passed_examples, :failed_examples, :pass_rate].include?(key)
423
+ puts " #{key}: #{value.is_a?(Float) ? value.round(3) : value}"
424
+ end
425
+ end
426
+
427
+ puts "=" * 50
428
+ end
429
+ end
430
+
431
+ # Common metric functions for evaluation
432
+ module Metrics
433
+ extend T::Sig
434
+
435
+ # Exact match metric - checks if prediction exactly matches expected output
436
+ sig do
437
+ params(
438
+ field: Symbol,
439
+ case_sensitive: T::Boolean
440
+ ).returns(T.proc.params(arg0: T.untyped, arg1: T.untyped).returns(T::Boolean))
441
+ end
442
+ def self.exact_match(field: :answer, case_sensitive: true)
443
+ proc do |example, prediction|
444
+ expected = extract_field(example, field)
445
+ actual = extract_field(prediction, field)
446
+
447
+ return false if expected.nil? || actual.nil?
448
+
449
+ if case_sensitive
450
+ expected.to_s == actual.to_s
451
+ else
452
+ expected.to_s.downcase == actual.to_s.downcase
453
+ end
454
+ end
455
+ end
456
+
457
+ # Contains metric - checks if prediction contains expected substring
458
+ sig do
459
+ params(
460
+ field: Symbol,
461
+ case_sensitive: T::Boolean
462
+ ).returns(T.proc.params(arg0: T.untyped, arg1: T.untyped).returns(T::Boolean))
463
+ end
464
+ def self.contains(field: :answer, case_sensitive: false)
465
+ proc do |example, prediction|
466
+ expected = extract_field(example, field)
467
+ actual = extract_field(prediction, field)
468
+
469
+ return false if expected.nil? || actual.nil?
470
+
471
+ if case_sensitive
472
+ actual.to_s.include?(expected.to_s)
473
+ else
474
+ actual.to_s.downcase.include?(expected.to_s.downcase)
475
+ end
476
+ end
477
+ end
478
+
479
+ # Numeric difference metric - checks if prediction is within tolerance of expected value
480
+ sig do
481
+ params(
482
+ field: Symbol,
483
+ tolerance: Float
484
+ ).returns(T.proc.params(arg0: T.untyped, arg1: T.untyped).returns(T::Hash[Symbol, T.untyped]))
485
+ end
486
+ def self.numeric_difference(field: :answer, tolerance: 0.01)
487
+ proc do |example, prediction|
488
+ expected = extract_field(example, field)
489
+ actual = extract_field(prediction, field)
490
+
491
+ return { passed: false, error: "Missing values" } if expected.nil? || actual.nil?
492
+
493
+ begin
494
+ expected_num = Float(expected)
495
+ actual_num = Float(actual)
496
+ difference = (expected_num - actual_num).abs
497
+ passed = difference <= tolerance
498
+
499
+ {
500
+ passed: passed,
501
+ difference: difference,
502
+ expected: expected_num,
503
+ actual: actual_num,
504
+ tolerance: tolerance
505
+ }
506
+ rescue ArgumentError
507
+ { passed: false, error: "Non-numeric values" }
508
+ end
509
+ end
510
+ end
511
+
512
+ # Composite metric - combines multiple metrics with AND logic
513
+ def self.composite_and(*metrics)
514
+ proc do |example, prediction|
515
+ results = {}
516
+ all_passed = true
517
+
518
+ metrics.each_with_index do |metric, index|
519
+ result = metric.call(example, prediction)
520
+
521
+ if result.is_a?(Hash)
522
+ results[:"metric_#{index}"] = result
523
+ all_passed &&= result[:passed] || result['passed'] || false
524
+ else
525
+ passed = !!result
526
+ results[:"metric_#{index}"] = { passed: passed }
527
+ all_passed &&= passed
528
+ end
529
+ end
530
+
531
+ results[:passed] = all_passed
532
+ results
533
+ end
534
+ end
535
+
536
+ private
537
+
538
+ # Extract field value from example or prediction
539
+ sig { params(obj: T.untyped, field: Symbol).returns(T.untyped) }
540
+ def self.extract_field(obj, field)
541
+ case obj
542
+ when Hash
543
+ obj[field] || obj[field.to_s]
544
+ when ->(o) { o.respond_to?(field) }
545
+ obj.send(field)
546
+ when ->(o) { o.respond_to?(:to_h) }
547
+ hash = obj.to_h
548
+ hash[field] || hash[field.to_s]
549
+ else
550
+ nil
551
+ end
552
+ end
553
+ end
554
+ end