dspy 0.3.1 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +69 -382
- data/lib/dspy/chain_of_thought.rb +57 -0
- data/lib/dspy/evaluate.rb +554 -0
- data/lib/dspy/example.rb +203 -0
- data/lib/dspy/few_shot_example.rb +81 -0
- data/lib/dspy/instrumentation.rb +97 -8
- data/lib/dspy/lm/adapter_factory.rb +6 -8
- data/lib/dspy/lm.rb +5 -7
- data/lib/dspy/predict.rb +32 -34
- data/lib/dspy/prompt.rb +222 -0
- data/lib/dspy/propose/grounded_proposer.rb +560 -0
- data/lib/dspy/registry/registry_manager.rb +504 -0
- data/lib/dspy/registry/signature_registry.rb +725 -0
- data/lib/dspy/storage/program_storage.rb +442 -0
- data/lib/dspy/storage/storage_manager.rb +331 -0
- data/lib/dspy/subscribers/langfuse_subscriber.rb +669 -0
- data/lib/dspy/subscribers/logger_subscriber.rb +120 -0
- data/lib/dspy/subscribers/newrelic_subscriber.rb +686 -0
- data/lib/dspy/subscribers/otel_subscriber.rb +538 -0
- data/lib/dspy/teleprompt/data_handler.rb +107 -0
- data/lib/dspy/teleprompt/mipro_v2.rb +790 -0
- data/lib/dspy/teleprompt/simple_optimizer.rb +497 -0
- data/lib/dspy/teleprompt/teleprompter.rb +336 -0
- data/lib/dspy/teleprompt/utils.rb +380 -0
- data/lib/dspy/version.rb +5 -0
- data/lib/dspy.rb +16 -0
- metadata +29 -12
- data/lib/dspy/lm/adapters/ruby_llm_adapter.rb +0 -81
@@ -0,0 +1,554 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sorbet-runtime'
|
4
|
+
require_relative 'instrumentation'
|
5
|
+
require_relative 'example'
|
6
|
+
|
7
|
+
module DSPy
|
8
|
+
# Core evaluation framework for DSPy programs
|
9
|
+
# Supports single evaluations, batch evaluations, and optimization workflows
|
10
|
+
class Evaluate
|
11
|
+
extend T::Sig
|
12
|
+
|
13
|
+
# Result of evaluating a single example
|
14
|
+
class EvaluationResult
|
15
|
+
extend T::Sig
|
16
|
+
|
17
|
+
sig { returns(T.untyped) }
|
18
|
+
attr_reader :example
|
19
|
+
|
20
|
+
sig { returns(T.untyped) }
|
21
|
+
attr_reader :prediction
|
22
|
+
|
23
|
+
sig { returns(T.untyped) }
|
24
|
+
attr_reader :trace
|
25
|
+
|
26
|
+
sig { returns(T::Hash[Symbol, T.untyped]) }
|
27
|
+
attr_reader :metrics
|
28
|
+
|
29
|
+
sig { returns(T::Boolean) }
|
30
|
+
attr_reader :passed
|
31
|
+
|
32
|
+
sig do
|
33
|
+
params(
|
34
|
+
example: T.untyped,
|
35
|
+
prediction: T.untyped,
|
36
|
+
trace: T.untyped,
|
37
|
+
metrics: T::Hash[Symbol, T.untyped],
|
38
|
+
passed: T::Boolean
|
39
|
+
).void
|
40
|
+
end
|
41
|
+
def initialize(example:, prediction:, trace:, metrics:, passed:)
|
42
|
+
@example = example
|
43
|
+
@prediction = prediction
|
44
|
+
@trace = trace
|
45
|
+
@metrics = metrics
|
46
|
+
@passed = passed
|
47
|
+
end
|
48
|
+
|
49
|
+
sig { returns(T::Hash[Symbol, T.untyped]) }
|
50
|
+
def to_h
|
51
|
+
{
|
52
|
+
example: @example,
|
53
|
+
prediction: @prediction,
|
54
|
+
trace: @trace,
|
55
|
+
metrics: @metrics,
|
56
|
+
passed: @passed
|
57
|
+
}
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
# Batch evaluation results with aggregated metrics
|
62
|
+
class BatchEvaluationResult
|
63
|
+
extend T::Sig
|
64
|
+
|
65
|
+
sig { returns(T::Array[EvaluationResult]) }
|
66
|
+
attr_reader :results
|
67
|
+
|
68
|
+
sig { returns(T::Hash[Symbol, T.untyped]) }
|
69
|
+
attr_reader :aggregated_metrics
|
70
|
+
|
71
|
+
sig { returns(Integer) }
|
72
|
+
attr_reader :total_examples
|
73
|
+
|
74
|
+
sig { returns(Integer) }
|
75
|
+
attr_reader :passed_examples
|
76
|
+
|
77
|
+
sig { returns(Float) }
|
78
|
+
attr_reader :pass_rate
|
79
|
+
|
80
|
+
sig do
|
81
|
+
params(
|
82
|
+
results: T::Array[EvaluationResult],
|
83
|
+
aggregated_metrics: T::Hash[Symbol, T.untyped]
|
84
|
+
).void
|
85
|
+
end
|
86
|
+
def initialize(results:, aggregated_metrics:)
|
87
|
+
@results = results.freeze
|
88
|
+
@aggregated_metrics = aggregated_metrics.freeze
|
89
|
+
@total_examples = results.length
|
90
|
+
@passed_examples = results.count(&:passed)
|
91
|
+
@pass_rate = @total_examples > 0 ? @passed_examples.to_f / @total_examples : 0.0
|
92
|
+
end
|
93
|
+
|
94
|
+
sig { returns(T::Hash[Symbol, T.untyped]) }
|
95
|
+
def to_h
|
96
|
+
{
|
97
|
+
total_examples: @total_examples,
|
98
|
+
passed_examples: @passed_examples,
|
99
|
+
pass_rate: @pass_rate,
|
100
|
+
aggregated_metrics: @aggregated_metrics,
|
101
|
+
results: @results.map(&:to_h)
|
102
|
+
}
|
103
|
+
end
|
104
|
+
end
|
105
|
+
|
106
|
+
sig { returns(T.untyped) }
|
107
|
+
attr_reader :program
|
108
|
+
|
109
|
+
sig { returns(T.nilable(T.proc.params(arg0: T.untyped, arg1: T.untyped).returns(T::Boolean))) }
|
110
|
+
attr_reader :metric
|
111
|
+
|
112
|
+
sig { returns(T.nilable(Integer)) }
|
113
|
+
attr_reader :num_threads
|
114
|
+
|
115
|
+
sig { returns(T.nilable(Integer)) }
|
116
|
+
attr_reader :max_errors
|
117
|
+
|
118
|
+
sig { returns(T::Boolean) }
|
119
|
+
attr_reader :provide_traceback
|
120
|
+
|
121
|
+
sig do
|
122
|
+
params(
|
123
|
+
program: T.untyped,
|
124
|
+
metric: T.nilable(T.proc.params(arg0: T.untyped, arg1: T.untyped).returns(T::Boolean)),
|
125
|
+
num_threads: T.nilable(Integer),
|
126
|
+
max_errors: T.nilable(Integer),
|
127
|
+
provide_traceback: T::Boolean
|
128
|
+
).void
|
129
|
+
end
|
130
|
+
def initialize(program, metric: nil, num_threads: 1, max_errors: 5, provide_traceback: true)
|
131
|
+
@program = program
|
132
|
+
@metric = metric
|
133
|
+
@num_threads = num_threads || 1
|
134
|
+
@max_errors = max_errors || 5
|
135
|
+
@provide_traceback = provide_traceback
|
136
|
+
end
|
137
|
+
|
138
|
+
# Evaluate program on a single example
|
139
|
+
sig { params(example: T.untyped, trace: T.nilable(T.untyped)).returns(EvaluationResult) }
|
140
|
+
def call(example, trace: nil)
|
141
|
+
Instrumentation.instrument('dspy.evaluation.example', {
|
142
|
+
program_class: @program.class.name,
|
143
|
+
has_metric: !@metric.nil?
|
144
|
+
}) do
|
145
|
+
begin
|
146
|
+
# Extract input from example - support both hash and object formats
|
147
|
+
input_values = extract_input_values(example)
|
148
|
+
|
149
|
+
# Run prediction
|
150
|
+
prediction = @program.call(**input_values)
|
151
|
+
|
152
|
+
# Calculate metrics if provided
|
153
|
+
metrics = {}
|
154
|
+
passed = true
|
155
|
+
|
156
|
+
if @metric
|
157
|
+
begin
|
158
|
+
metric_result = @metric.call(example, prediction)
|
159
|
+
if metric_result.is_a?(Hash)
|
160
|
+
metrics = metric_result
|
161
|
+
passed = metrics[:passed] || metrics['passed'] || true
|
162
|
+
else
|
163
|
+
passed = !!metric_result
|
164
|
+
metrics[:passed] = passed
|
165
|
+
end
|
166
|
+
rescue => e
|
167
|
+
passed = false
|
168
|
+
metrics[:error] = e.message
|
169
|
+
metrics[:passed] = false
|
170
|
+
end
|
171
|
+
end
|
172
|
+
|
173
|
+
EvaluationResult.new(
|
174
|
+
example: example,
|
175
|
+
prediction: prediction,
|
176
|
+
trace: trace,
|
177
|
+
metrics: metrics,
|
178
|
+
passed: passed
|
179
|
+
)
|
180
|
+
rescue => e
|
181
|
+
# Return failed evaluation result
|
182
|
+
error_metrics = {
|
183
|
+
error: e.message,
|
184
|
+
passed: false
|
185
|
+
}
|
186
|
+
|
187
|
+
if @provide_traceback
|
188
|
+
error_metrics[:traceback] = e.backtrace&.first(10) || []
|
189
|
+
end
|
190
|
+
|
191
|
+
EvaluationResult.new(
|
192
|
+
example: example,
|
193
|
+
prediction: nil,
|
194
|
+
trace: trace,
|
195
|
+
metrics: error_metrics,
|
196
|
+
passed: false
|
197
|
+
)
|
198
|
+
end
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
# Evaluate program on multiple examples
|
203
|
+
sig do
|
204
|
+
params(
|
205
|
+
devset: T::Array[T.untyped],
|
206
|
+
display_progress: T::Boolean,
|
207
|
+
display_table: T::Boolean,
|
208
|
+
return_outputs: T::Boolean
|
209
|
+
).returns(BatchEvaluationResult)
|
210
|
+
end
|
211
|
+
def evaluate(devset, display_progress: true, display_table: false, return_outputs: true)
|
212
|
+
Instrumentation.instrument('dspy.evaluation.batch', {
|
213
|
+
program_class: @program.class.name,
|
214
|
+
num_examples: devset.length,
|
215
|
+
has_metric: !@metric.nil?,
|
216
|
+
num_threads: @num_threads
|
217
|
+
}) do
|
218
|
+
results = []
|
219
|
+
errors = 0
|
220
|
+
|
221
|
+
if display_progress
|
222
|
+
puts "Evaluating #{devset.length} examples..."
|
223
|
+
end
|
224
|
+
|
225
|
+
devset.each_with_index do |example, index|
|
226
|
+
break if errors >= @max_errors
|
227
|
+
|
228
|
+
begin
|
229
|
+
result = call(example)
|
230
|
+
results << result
|
231
|
+
|
232
|
+
unless result.passed
|
233
|
+
errors += 1
|
234
|
+
end
|
235
|
+
|
236
|
+
if display_progress && (index + 1) % 10 == 0
|
237
|
+
puts "Processed #{index + 1}/#{devset.length} examples (#{results.count(&:passed)} passed)"
|
238
|
+
end
|
239
|
+
|
240
|
+
rescue => e
|
241
|
+
errors += 1
|
242
|
+
puts "Error processing example #{index}: #{e.message}" if display_progress
|
243
|
+
|
244
|
+
# Create error result
|
245
|
+
error_result = EvaluationResult.new(
|
246
|
+
example: example,
|
247
|
+
prediction: nil,
|
248
|
+
trace: nil,
|
249
|
+
metrics: { error: e.message, passed: false },
|
250
|
+
passed: false
|
251
|
+
)
|
252
|
+
results << error_result
|
253
|
+
end
|
254
|
+
end
|
255
|
+
|
256
|
+
# Aggregate metrics
|
257
|
+
aggregated_metrics = aggregate_metrics(results)
|
258
|
+
|
259
|
+
batch_result = BatchEvaluationResult.new(
|
260
|
+
results: results,
|
261
|
+
aggregated_metrics: aggregated_metrics
|
262
|
+
)
|
263
|
+
|
264
|
+
if display_table
|
265
|
+
display_results_table(batch_result)
|
266
|
+
end
|
267
|
+
|
268
|
+
# Emit batch completion event
|
269
|
+
Instrumentation.emit('dspy.evaluation.batch_complete', {
|
270
|
+
program_class: @program.class.name,
|
271
|
+
total_examples: batch_result.total_examples,
|
272
|
+
passed_examples: batch_result.passed_examples,
|
273
|
+
pass_rate: batch_result.pass_rate,
|
274
|
+
aggregated_metrics: aggregated_metrics
|
275
|
+
})
|
276
|
+
|
277
|
+
if display_progress
|
278
|
+
puts "Evaluation complete: #{batch_result.passed_examples}/#{batch_result.total_examples} passed (#{(batch_result.pass_rate * 100).round(1)}%)"
|
279
|
+
end
|
280
|
+
|
281
|
+
batch_result
|
282
|
+
end
|
283
|
+
end
|
284
|
+
|
285
|
+
private
|
286
|
+
|
287
|
+
# Extract input values from example in various formats
|
288
|
+
sig { params(example: T.untyped).returns(T::Hash[Symbol, T.untyped]) }
|
289
|
+
def extract_input_values(example)
|
290
|
+
case example
|
291
|
+
when DSPy::Example
|
292
|
+
# Preferred format: DSPy::Example object with type safety
|
293
|
+
example.input_values
|
294
|
+
when Hash
|
295
|
+
# Check if it has an :input key (structured format)
|
296
|
+
if example.key?(:input)
|
297
|
+
input_data = example[:input]
|
298
|
+
input_data.is_a?(Hash) ? input_data.transform_keys(&:to_sym) : input_data
|
299
|
+
elsif example.key?('input')
|
300
|
+
input_data = example['input']
|
301
|
+
input_data.is_a?(Hash) ? input_data.transform_keys(&:to_sym) : input_data
|
302
|
+
else
|
303
|
+
# Legacy format - assume the whole hash is input
|
304
|
+
if example.keys.first.is_a?(String)
|
305
|
+
example.transform_keys(&:to_sym)
|
306
|
+
else
|
307
|
+
example
|
308
|
+
end
|
309
|
+
end
|
310
|
+
when ->(ex) { ex.respond_to?(:input_values) }
|
311
|
+
# Object with input_values method (Example-like)
|
312
|
+
example.input_values
|
313
|
+
when ->(ex) { ex.respond_to?(:input) }
|
314
|
+
# Object with input method
|
315
|
+
input_data = example.input
|
316
|
+
input_data.is_a?(Hash) ? input_data.transform_keys(&:to_sym) : input_data
|
317
|
+
when ->(ex) { ex.respond_to?(:to_h) }
|
318
|
+
# Object that can be converted to hash
|
319
|
+
hash = example.to_h
|
320
|
+
if hash.key?(:input)
|
321
|
+
input_data = hash[:input]
|
322
|
+
input_data.is_a?(Hash) ? input_data.transform_keys(&:to_sym) : input_data
|
323
|
+
elsif hash.key?('input')
|
324
|
+
input_data = hash['input']
|
325
|
+
input_data.is_a?(Hash) ? input_data.transform_keys(&:to_sym) : input_data
|
326
|
+
else
|
327
|
+
hash.is_a?(Hash) ? hash.transform_keys(&:to_sym) : hash
|
328
|
+
end
|
329
|
+
else
|
330
|
+
# Try to extract by introspection
|
331
|
+
if example.respond_to?(:instance_variables)
|
332
|
+
vars = {}
|
333
|
+
example.instance_variables.each do |var|
|
334
|
+
key = var.to_s.delete('@').to_sym
|
335
|
+
vars[key] = example.instance_variable_get(var)
|
336
|
+
end
|
337
|
+
vars
|
338
|
+
else
|
339
|
+
raise ArgumentError, "Cannot extract input values from example: #{example.class}"
|
340
|
+
end
|
341
|
+
end
|
342
|
+
end
|
343
|
+
|
344
|
+
# Extract expected values for metric comparison (used internally)
|
345
|
+
sig { params(example: T.untyped).returns(T.nilable(T::Hash[Symbol, T.untyped])) }
|
346
|
+
def extract_expected_values(example)
|
347
|
+
case example
|
348
|
+
when DSPy::Example
|
349
|
+
example.expected_values
|
350
|
+
when Hash
|
351
|
+
if example.key?(:expected)
|
352
|
+
expected_data = example[:expected]
|
353
|
+
expected_data.is_a?(Hash) ? expected_data.transform_keys(&:to_sym) : expected_data
|
354
|
+
elsif example.key?('expected')
|
355
|
+
expected_data = example['expected']
|
356
|
+
expected_data.is_a?(Hash) ? expected_data.transform_keys(&:to_sym) : expected_data
|
357
|
+
else
|
358
|
+
# Legacy format - no separate expected values
|
359
|
+
nil
|
360
|
+
end
|
361
|
+
when ->(ex) { ex.respond_to?(:expected_values) }
|
362
|
+
example.expected_values
|
363
|
+
when ->(ex) { ex.respond_to?(:expected) }
|
364
|
+
expected_data = example.expected
|
365
|
+
expected_data.is_a?(Hash) ? expected_data.transform_keys(&:to_sym) : expected_data
|
366
|
+
else
|
367
|
+
nil
|
368
|
+
end
|
369
|
+
end
|
370
|
+
|
371
|
+
# Aggregate metrics across all results
|
372
|
+
sig { params(results: T::Array[EvaluationResult]).returns(T::Hash[Symbol, T.untyped]) }
|
373
|
+
def aggregate_metrics(results)
|
374
|
+
return {} if results.empty?
|
375
|
+
|
376
|
+
# Start with basic metrics
|
377
|
+
aggregated = {
|
378
|
+
total_examples: results.length,
|
379
|
+
passed_examples: results.count(&:passed),
|
380
|
+
failed_examples: results.count { |r| !r.passed }
|
381
|
+
}
|
382
|
+
|
383
|
+
# Aggregate numeric metrics
|
384
|
+
numeric_metrics = {}
|
385
|
+
results.each do |result|
|
386
|
+
result.metrics.each do |key, value|
|
387
|
+
next if [:error, :traceback, :passed].include?(key)
|
388
|
+
next unless value.is_a?(Numeric)
|
389
|
+
|
390
|
+
numeric_metrics[key] ||= []
|
391
|
+
numeric_metrics[key] << value
|
392
|
+
end
|
393
|
+
end
|
394
|
+
|
395
|
+
# Calculate averages for numeric metrics
|
396
|
+
numeric_metrics.each do |key, values|
|
397
|
+
aggregated[:"#{key}_avg"] = values.sum.to_f / values.length
|
398
|
+
aggregated[:"#{key}_min"] = values.min
|
399
|
+
aggregated[:"#{key}_max"] = values.max
|
400
|
+
end
|
401
|
+
|
402
|
+
# Calculate pass rate
|
403
|
+
aggregated[:pass_rate] = aggregated[:total_examples] > 0 ?
|
404
|
+
aggregated[:passed_examples].to_f / aggregated[:total_examples] : 0.0
|
405
|
+
|
406
|
+
aggregated
|
407
|
+
end
|
408
|
+
|
409
|
+
# Display results in a table format
|
410
|
+
sig { params(batch_result: BatchEvaluationResult).void }
|
411
|
+
def display_results_table(batch_result)
|
412
|
+
puts "\nEvaluation Results:"
|
413
|
+
puts "=" * 50
|
414
|
+
puts "Total Examples: #{batch_result.total_examples}"
|
415
|
+
puts "Passed: #{batch_result.passed_examples}"
|
416
|
+
puts "Failed: #{batch_result.total_examples - batch_result.passed_examples}"
|
417
|
+
puts "Pass Rate: #{(batch_result.pass_rate * 100).round(1)}%"
|
418
|
+
|
419
|
+
if batch_result.aggregated_metrics.any?
|
420
|
+
puts "\nAggregated Metrics:"
|
421
|
+
batch_result.aggregated_metrics.each do |key, value|
|
422
|
+
next if [:total_examples, :passed_examples, :failed_examples, :pass_rate].include?(key)
|
423
|
+
puts " #{key}: #{value.is_a?(Float) ? value.round(3) : value}"
|
424
|
+
end
|
425
|
+
end
|
426
|
+
|
427
|
+
puts "=" * 50
|
428
|
+
end
|
429
|
+
end
|
430
|
+
|
431
|
+
# Common metric functions for evaluation
|
432
|
+
module Metrics
|
433
|
+
extend T::Sig
|
434
|
+
|
435
|
+
# Exact match metric - checks if prediction exactly matches expected output
|
436
|
+
sig do
|
437
|
+
params(
|
438
|
+
field: Symbol,
|
439
|
+
case_sensitive: T::Boolean
|
440
|
+
).returns(T.proc.params(arg0: T.untyped, arg1: T.untyped).returns(T::Boolean))
|
441
|
+
end
|
442
|
+
def self.exact_match(field: :answer, case_sensitive: true)
|
443
|
+
proc do |example, prediction|
|
444
|
+
expected = extract_field(example, field)
|
445
|
+
actual = extract_field(prediction, field)
|
446
|
+
|
447
|
+
return false if expected.nil? || actual.nil?
|
448
|
+
|
449
|
+
if case_sensitive
|
450
|
+
expected.to_s == actual.to_s
|
451
|
+
else
|
452
|
+
expected.to_s.downcase == actual.to_s.downcase
|
453
|
+
end
|
454
|
+
end
|
455
|
+
end
|
456
|
+
|
457
|
+
# Contains metric - checks if prediction contains expected substring
|
458
|
+
sig do
|
459
|
+
params(
|
460
|
+
field: Symbol,
|
461
|
+
case_sensitive: T::Boolean
|
462
|
+
).returns(T.proc.params(arg0: T.untyped, arg1: T.untyped).returns(T::Boolean))
|
463
|
+
end
|
464
|
+
def self.contains(field: :answer, case_sensitive: false)
|
465
|
+
proc do |example, prediction|
|
466
|
+
expected = extract_field(example, field)
|
467
|
+
actual = extract_field(prediction, field)
|
468
|
+
|
469
|
+
return false if expected.nil? || actual.nil?
|
470
|
+
|
471
|
+
if case_sensitive
|
472
|
+
actual.to_s.include?(expected.to_s)
|
473
|
+
else
|
474
|
+
actual.to_s.downcase.include?(expected.to_s.downcase)
|
475
|
+
end
|
476
|
+
end
|
477
|
+
end
|
478
|
+
|
479
|
+
# Numeric difference metric - checks if prediction is within tolerance of expected value
|
480
|
+
sig do
|
481
|
+
params(
|
482
|
+
field: Symbol,
|
483
|
+
tolerance: Float
|
484
|
+
).returns(T.proc.params(arg0: T.untyped, arg1: T.untyped).returns(T::Hash[Symbol, T.untyped]))
|
485
|
+
end
|
486
|
+
def self.numeric_difference(field: :answer, tolerance: 0.01)
|
487
|
+
proc do |example, prediction|
|
488
|
+
expected = extract_field(example, field)
|
489
|
+
actual = extract_field(prediction, field)
|
490
|
+
|
491
|
+
return { passed: false, error: "Missing values" } if expected.nil? || actual.nil?
|
492
|
+
|
493
|
+
begin
|
494
|
+
expected_num = Float(expected)
|
495
|
+
actual_num = Float(actual)
|
496
|
+
difference = (expected_num - actual_num).abs
|
497
|
+
passed = difference <= tolerance
|
498
|
+
|
499
|
+
{
|
500
|
+
passed: passed,
|
501
|
+
difference: difference,
|
502
|
+
expected: expected_num,
|
503
|
+
actual: actual_num,
|
504
|
+
tolerance: tolerance
|
505
|
+
}
|
506
|
+
rescue ArgumentError
|
507
|
+
{ passed: false, error: "Non-numeric values" }
|
508
|
+
end
|
509
|
+
end
|
510
|
+
end
|
511
|
+
|
512
|
+
# Composite metric - combines multiple metrics with AND logic
|
513
|
+
def self.composite_and(*metrics)
|
514
|
+
proc do |example, prediction|
|
515
|
+
results = {}
|
516
|
+
all_passed = true
|
517
|
+
|
518
|
+
metrics.each_with_index do |metric, index|
|
519
|
+
result = metric.call(example, prediction)
|
520
|
+
|
521
|
+
if result.is_a?(Hash)
|
522
|
+
results[:"metric_#{index}"] = result
|
523
|
+
all_passed &&= result[:passed] || result['passed'] || false
|
524
|
+
else
|
525
|
+
passed = !!result
|
526
|
+
results[:"metric_#{index}"] = { passed: passed }
|
527
|
+
all_passed &&= passed
|
528
|
+
end
|
529
|
+
end
|
530
|
+
|
531
|
+
results[:passed] = all_passed
|
532
|
+
results
|
533
|
+
end
|
534
|
+
end
|
535
|
+
|
536
|
+
private
|
537
|
+
|
538
|
+
# Extract field value from example or prediction
|
539
|
+
sig { params(obj: T.untyped, field: Symbol).returns(T.untyped) }
|
540
|
+
def self.extract_field(obj, field)
|
541
|
+
case obj
|
542
|
+
when Hash
|
543
|
+
obj[field] || obj[field.to_s]
|
544
|
+
when ->(o) { o.respond_to?(field) }
|
545
|
+
obj.send(field)
|
546
|
+
when ->(o) { o.respond_to?(:to_h) }
|
547
|
+
hash = obj.to_h
|
548
|
+
hash[field] || hash[field.to_s]
|
549
|
+
else
|
550
|
+
nil
|
551
|
+
end
|
552
|
+
end
|
553
|
+
end
|
554
|
+
end
|