dspy 0.34.2 → 0.34.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +8 -16
- data/lib/dspy/chain_of_thought.rb +3 -2
- data/lib/dspy/context.rb +70 -21
- data/lib/dspy/evals/version.rb +1 -1
- data/lib/dspy/evals.rb +42 -31
- data/lib/dspy/events.rb +2 -3
- data/lib/dspy/example.rb +1 -1
- data/lib/dspy/lm/adapter.rb +39 -0
- data/lib/dspy/lm/json_strategy.rb +28 -67
- data/lib/dspy/lm/message.rb +1 -1
- data/lib/dspy/lm/response.rb +2 -2
- data/lib/dspy/lm/usage.rb +35 -10
- data/lib/dspy/lm.rb +22 -51
- data/lib/dspy/mixins/type_coercion.rb +256 -35
- data/lib/dspy/module.rb +203 -31
- data/lib/dspy/predict.rb +33 -6
- data/lib/dspy/prediction.rb +25 -58
- data/lib/dspy/prompt.rb +52 -76
- data/lib/dspy/propose/dataset_summary_generator.rb +1 -1
- data/lib/dspy/propose/grounded_proposer.rb +3 -3
- data/lib/dspy/re_act.rb +159 -196
- data/lib/dspy/registry/signature_registry.rb +3 -3
- data/lib/dspy/ruby_llm/lm/adapters/ruby_llm_adapter.rb +1 -27
- data/lib/dspy/schema/sorbet_json_schema.rb +7 -6
- data/lib/dspy/schema/version.rb +1 -1
- data/lib/dspy/schema_adapters.rb +1 -1
- data/lib/dspy/signature.rb +4 -5
- data/lib/dspy/storage/program_storage.rb +2 -2
- data/lib/dspy/structured_outputs_prompt.rb +4 -4
- data/lib/dspy/teleprompt/utils.rb +2 -2
- data/lib/dspy/tools/github_cli_toolset.rb +7 -7
- data/lib/dspy/tools/text_processing_toolset.rb +2 -2
- data/lib/dspy/tools/toolset.rb +1 -1
- data/lib/dspy/utils/serialization.rb +2 -6
- data/lib/dspy/version.rb +1 -1
- data/lib/dspy.rb +50 -5
- metadata +7 -26
- data/lib/dspy/events/subscriber_mixin.rb +0 -79
- data/lib/dspy/events/subscribers.rb +0 -43
- data/lib/dspy/memory/embedding_engine.rb +0 -68
- data/lib/dspy/memory/in_memory_store.rb +0 -216
- data/lib/dspy/memory/local_embedding_engine.rb +0 -244
- data/lib/dspy/memory/memory_compactor.rb +0 -298
- data/lib/dspy/memory/memory_manager.rb +0 -266
- data/lib/dspy/memory/memory_record.rb +0 -163
- data/lib/dspy/memory/memory_store.rb +0 -90
- data/lib/dspy/memory.rb +0 -30
- data/lib/dspy/tools/memory_toolset.rb +0 -117
data/lib/dspy/module.rb
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
require 'sorbet-runtime'
|
|
4
4
|
require 'dry-configurable'
|
|
5
5
|
require 'securerandom'
|
|
6
|
+
require 'weakref'
|
|
6
7
|
require_relative 'context'
|
|
7
8
|
require_relative 'callbacks'
|
|
8
9
|
require_relative 'type_serializer'
|
|
@@ -24,15 +25,20 @@ module DSPy
|
|
|
24
25
|
|
|
25
26
|
DEFAULT_MODULE_SUBSCRIPTION_SCOPE = SubcriptionScope::Descendants
|
|
26
27
|
|
|
28
|
+
# Hook to wrap forward methods with instrumentation.
|
|
29
|
+
# Uses a Set-based guard (not boolean) to prevent re-wrapping when
|
|
30
|
+
# other hooks (like Callbacks) also use define_method.
|
|
27
31
|
module ForwardOverrideHooks
|
|
28
32
|
def method_added(method_name)
|
|
29
33
|
super
|
|
30
34
|
|
|
31
35
|
return unless method_name == :forward
|
|
32
36
|
return if self == DSPy::Module
|
|
33
|
-
return if @_wrapping_forward
|
|
34
37
|
|
|
35
|
-
|
|
38
|
+
# Use Set-based guard - persists across hook invocations
|
|
39
|
+
@_forward_instrumented ||= Set.new
|
|
40
|
+
return if @_forward_instrumented.include?(object_id)
|
|
41
|
+
@_forward_instrumented << object_id
|
|
36
42
|
|
|
37
43
|
original = instance_method(:forward)
|
|
38
44
|
define_method(:forward) do |*args, **kwargs, &block|
|
|
@@ -40,8 +46,6 @@ module DSPy
|
|
|
40
46
|
original.bind(self).call(*args, **kwargs, &block)
|
|
41
47
|
end
|
|
42
48
|
end
|
|
43
|
-
ensure
|
|
44
|
-
@_wrapping_forward = false
|
|
45
49
|
end
|
|
46
50
|
end
|
|
47
51
|
|
|
@@ -71,6 +75,35 @@ module DSPy
|
|
|
71
75
|
|
|
72
76
|
private
|
|
73
77
|
|
|
78
|
+
def build_subscription_callback(weakref, subscription_id_ref, spec)
|
|
79
|
+
scope = spec[:scope] || DEFAULT_MODULE_SUBSCRIPTION_SCOPE
|
|
80
|
+
handler = spec[:handler]
|
|
81
|
+
block = spec[:block]
|
|
82
|
+
|
|
83
|
+
->(event_name, attributes) do
|
|
84
|
+
target = begin
|
|
85
|
+
weakref.__getobj__
|
|
86
|
+
rescue WeakRef::RefError
|
|
87
|
+
nil
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
unless target
|
|
91
|
+
subscription_id = subscription_id_ref[:id]
|
|
92
|
+
DSPy.events.unsubscribe(subscription_id) if subscription_id
|
|
93
|
+
DSPy.logger&.debug(event: 'module.subscription.auto_unsubscribe', subscription_id: subscription_id)
|
|
94
|
+
return
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
return unless target.send(:module_event_within_scope?, attributes, scope)
|
|
98
|
+
|
|
99
|
+
if handler
|
|
100
|
+
target.send(handler, event_name, attributes)
|
|
101
|
+
else
|
|
102
|
+
target.instance_exec(event_name, attributes, &block)
|
|
103
|
+
end
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
|
|
74
107
|
def validate_subscription_scope!(scope)
|
|
75
108
|
T.must(scope)
|
|
76
109
|
end
|
|
@@ -97,7 +130,8 @@ module DSPy
|
|
|
97
130
|
create_after_callback :forward
|
|
98
131
|
create_around_callback :forward
|
|
99
132
|
|
|
100
|
-
# The main forward method that users will call is generic and type parameterized
|
|
133
|
+
# The main forward method that users will call is generic and type parameterized.
|
|
134
|
+
# Instrument here only when subclasses don't override forward.
|
|
101
135
|
sig do
|
|
102
136
|
type_parameters(:I, :O)
|
|
103
137
|
.params(
|
|
@@ -106,10 +140,14 @@ module DSPy
|
|
|
106
140
|
.returns(T.type_parameter(:O))
|
|
107
141
|
end
|
|
108
142
|
def forward(**input_values)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
143
|
+
result = if self.class.instance_method(:forward).owner == DSPy::Module
|
|
144
|
+
instrument_forward_call([], input_values) do
|
|
145
|
+
forward_untyped(**input_values)
|
|
146
|
+
end
|
|
147
|
+
else
|
|
148
|
+
forward_untyped(**input_values)
|
|
112
149
|
end
|
|
150
|
+
T.cast(result, T.type_parameter(:O))
|
|
113
151
|
end
|
|
114
152
|
|
|
115
153
|
# The implementation method that subclasses must override
|
|
@@ -223,21 +261,64 @@ module DSPy
|
|
|
223
261
|
def instrument_forward_call(call_args, call_kwargs)
|
|
224
262
|
ensure_module_subscriptions!
|
|
225
263
|
|
|
264
|
+
input_json = serialize_module_input(call_args, call_kwargs)
|
|
265
|
+
root_call = DSPy::Context.current[:span_stack].empty?
|
|
266
|
+
|
|
226
267
|
DSPy::Context.with_module(self) do
|
|
227
268
|
observation_type = DSPy::ObservationType.for_module_class(self.class)
|
|
228
269
|
span_attributes = observation_type.langfuse_attributes.merge(
|
|
229
|
-
'langfuse.observation.input' =>
|
|
270
|
+
'langfuse.observation.input' => input_json,
|
|
230
271
|
'dspy.module' => self.class.name
|
|
231
272
|
)
|
|
273
|
+
operation_name = "#{self.class.name}.forward"
|
|
274
|
+
span_attributes.merge!(root_trace_attributes(call_args, call_kwargs, input_json)) if root_call
|
|
275
|
+
|
|
276
|
+
if self.class.name == 'DSPy::Predict' && respond_to?(:signature_class)
|
|
277
|
+
signature_name = signature_class&.name
|
|
278
|
+
span_attributes['dspy.signature'] = signature_name || 'anonymous'
|
|
279
|
+
span_attributes['dspy.signature_kind'] = infer_signature_kind(signature_name)
|
|
280
|
+
span_attributes['dspy.predictor_label'] = module_scope_label if module_scope_label
|
|
281
|
+
operation_name = "DSPy::Predict(#{signature_name}).forward" if signature_name
|
|
282
|
+
end
|
|
232
283
|
|
|
233
284
|
DSPy::Context.with_span(
|
|
234
|
-
operation:
|
|
285
|
+
operation: operation_name,
|
|
235
286
|
**span_attributes
|
|
236
287
|
) do |span|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
span
|
|
288
|
+
begin
|
|
289
|
+
yield.tap do |result|
|
|
290
|
+
if span && !result.nil?
|
|
291
|
+
span.set_attribute('langfuse.observation.output', serialize_module_output(result))
|
|
292
|
+
span.set_attribute('langfuse.observation.status', 'completed')
|
|
293
|
+
span.set_attribute('dspy.status', 'completed')
|
|
294
|
+
if root_call
|
|
295
|
+
span.set_attribute('langfuse.trace.output', serialize_module_output(result))
|
|
296
|
+
span.set_attribute('langfuse.trace.status', 'completed')
|
|
297
|
+
end
|
|
298
|
+
end
|
|
299
|
+
end
|
|
300
|
+
rescue StandardError => e
|
|
301
|
+
if span
|
|
302
|
+
span.set_attribute('langfuse.observation.output', serialize_module_error_output(e))
|
|
303
|
+
span.set_attribute('langfuse.observation.status', 'error')
|
|
304
|
+
span.set_attribute('dspy.error.class', e.class.name)
|
|
305
|
+
span.set_attribute('dspy.error.message', e.message.to_s[0, 2000]) if e.message
|
|
306
|
+
span.set_attribute('dspy.status', 'error')
|
|
307
|
+
if root_call
|
|
308
|
+
span.set_attribute('langfuse.trace.output', serialize_module_error_output(e))
|
|
309
|
+
span.set_attribute('langfuse.trace.status', 'error')
|
|
310
|
+
end
|
|
311
|
+
if e.respond_to?(:iterations)
|
|
312
|
+
span.set_attribute('dspy.error.iterations', e.iterations.to_i) unless e.iterations.nil?
|
|
313
|
+
end
|
|
314
|
+
if e.respond_to?(:max_iterations)
|
|
315
|
+
span.set_attribute('dspy.error.max_iterations', e.max_iterations.to_i) unless e.max_iterations.nil?
|
|
316
|
+
end
|
|
317
|
+
if e.respond_to?(:tools_used)
|
|
318
|
+
span.set_attribute('dspy.error.tools_used', Array(e.tools_used).map(&:to_s))
|
|
319
|
+
end
|
|
240
320
|
end
|
|
321
|
+
raise
|
|
241
322
|
end
|
|
242
323
|
end
|
|
243
324
|
end
|
|
@@ -265,7 +346,91 @@ module DSPy
|
|
|
265
346
|
result.to_s
|
|
266
347
|
end
|
|
267
348
|
|
|
268
|
-
|
|
349
|
+
def serialize_module_error_output(error)
|
|
350
|
+
payload = {
|
|
351
|
+
error: {
|
|
352
|
+
class: error.class.name,
|
|
353
|
+
message: error.message.to_s
|
|
354
|
+
}
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
if error.respond_to?(:iterations) || error.respond_to?(:max_iterations) || error.respond_to?(:tools_used)
|
|
358
|
+
payload[:react] = {}
|
|
359
|
+
payload[:react][:iterations] = error.iterations if error.respond_to?(:iterations)
|
|
360
|
+
payload[:react][:max_iterations] = error.max_iterations if error.respond_to?(:max_iterations)
|
|
361
|
+
payload[:react][:tools_used] = Array(error.tools_used) if error.respond_to?(:tools_used)
|
|
362
|
+
end
|
|
363
|
+
|
|
364
|
+
serialized = DSPy::TypeSerializer.serialize(payload)
|
|
365
|
+
JSON.generate(serialized)
|
|
366
|
+
rescue StandardError
|
|
367
|
+
"#{error.class}: #{error.message}"
|
|
368
|
+
end
|
|
369
|
+
|
|
370
|
+
def root_trace_attributes(call_args, call_kwargs, input_json)
|
|
371
|
+
metadata = {
|
|
372
|
+
module: self.class.name,
|
|
373
|
+
signature: (respond_to?(:signature_class) ? signature_class&.name : nil),
|
|
374
|
+
signature_kind: (respond_to?(:signature_class) ? infer_signature_kind(signature_class&.name) : nil),
|
|
375
|
+
predictor_label: module_scope_label
|
|
376
|
+
}.compact
|
|
377
|
+
conversation_id, conversation_id_source = resolve_conversation_id(call_args, call_kwargs)
|
|
378
|
+
metadata[:conversation_id_source] = conversation_id_source if conversation_id_source
|
|
379
|
+
|
|
380
|
+
{
|
|
381
|
+
'langfuse.trace.name' => "#{self.class.name}.forward",
|
|
382
|
+
'langfuse.trace.input' => input_json,
|
|
383
|
+
'langfuse.trace.metadata' => JSON.generate(metadata),
|
|
384
|
+
'langfuse.trace.output' => '{"status":"in_progress"}',
|
|
385
|
+
'conversation_id' => conversation_id,
|
|
386
|
+
'dspy.conversation_id' => conversation_id
|
|
387
|
+
}
|
|
388
|
+
rescue StandardError
|
|
389
|
+
{}
|
|
390
|
+
end
|
|
391
|
+
|
|
392
|
+
# Conversation ID precedence is deterministic:
|
|
393
|
+
# 1. top-level kwargs[:conversation_id]
|
|
394
|
+
# 2. first positional hash arg[:conversation_id]
|
|
395
|
+
# 3. kwargs[:input_context][:conversation_id]
|
|
396
|
+
# 4. DSPy::Context.current[:conversation_id]
|
|
397
|
+
def resolve_conversation_id(call_args, call_kwargs)
|
|
398
|
+
direct = fetch_hash_value(call_kwargs, :conversation_id)
|
|
399
|
+
return [direct.to_s, 'kwargs.conversation_id'] if present_value?(direct)
|
|
400
|
+
|
|
401
|
+
first_arg = call_args.first if call_args.is_a?(Array) && call_args.first.is_a?(Hash)
|
|
402
|
+
arg_value = fetch_hash_value(first_arg, :conversation_id)
|
|
403
|
+
return [arg_value.to_s, 'args[0].conversation_id'] if present_value?(arg_value)
|
|
404
|
+
|
|
405
|
+
input_context = fetch_hash_value(call_kwargs, :input_context)
|
|
406
|
+
nested = fetch_hash_value(input_context, :conversation_id)
|
|
407
|
+
return [nested.to_s, 'kwargs.input_context.conversation_id'] if present_value?(nested)
|
|
408
|
+
|
|
409
|
+
context_value = fetch_hash_value(DSPy::Context.current, :conversation_id)
|
|
410
|
+
return [context_value.to_s, 'context.conversation_id'] if present_value?(context_value)
|
|
411
|
+
|
|
412
|
+
[nil, nil]
|
|
413
|
+
end
|
|
414
|
+
|
|
415
|
+
def fetch_hash_value(hash, key)
|
|
416
|
+
return nil unless hash.is_a?(Hash)
|
|
417
|
+
|
|
418
|
+
hash[key] || hash[key.to_s]
|
|
419
|
+
end
|
|
420
|
+
|
|
421
|
+
def present_value?(value)
|
|
422
|
+
!value.nil? && !(value.respond_to?(:empty?) && value.empty?)
|
|
423
|
+
end
|
|
424
|
+
|
|
425
|
+
def infer_signature_kind(signature_name)
|
|
426
|
+
return 'custom' unless signature_name
|
|
427
|
+
return 'thought' if signature_name.match?(/thought/i)
|
|
428
|
+
return 'observation' if signature_name.match?(/observation/i)
|
|
429
|
+
|
|
430
|
+
'custom'
|
|
431
|
+
end
|
|
432
|
+
|
|
433
|
+
private :instrument_forward_call, :serialize_module_input, :serialize_module_output, :serialize_module_error_output, :root_trace_attributes, :resolve_conversation_id, :fetch_hash_value, :present_value?, :infer_signature_kind
|
|
269
434
|
|
|
270
435
|
sig { returns(String) }
|
|
271
436
|
def module_scope_id
|
|
@@ -294,8 +459,28 @@ module DSPy
|
|
|
294
459
|
@module_subscriptions_registered = false
|
|
295
460
|
end
|
|
296
461
|
|
|
462
|
+
sig { returns(T.self_type) }
|
|
463
|
+
def dup_for_thread
|
|
464
|
+
cloned = dup
|
|
465
|
+
cloned.instance_variable_set(:@module_subscription_ids, [])
|
|
466
|
+
cloned.instance_variable_set(:@module_subscriptions_registered, false)
|
|
467
|
+
cloned.instance_variable_set(:@module_scope_id, SecureRandom.uuid)
|
|
468
|
+
cloned.send(:reset_thread_state)
|
|
469
|
+
cloned
|
|
470
|
+
end
|
|
471
|
+
|
|
297
472
|
private
|
|
298
473
|
|
|
474
|
+
def reset_thread_state
|
|
475
|
+
instance_variables.each do |ivar|
|
|
476
|
+
value = instance_variable_get(ivar)
|
|
477
|
+
case value
|
|
478
|
+
when Array, Hash, Set
|
|
479
|
+
instance_variable_set(ivar, value.dup)
|
|
480
|
+
end
|
|
481
|
+
end
|
|
482
|
+
end
|
|
483
|
+
|
|
299
484
|
# Propagate LM configuration to child predictors recursively
|
|
300
485
|
# Skips children that already have an explicit LM configured
|
|
301
486
|
sig { params(lm: T.untyped).void }
|
|
@@ -322,30 +507,17 @@ module DSPy
|
|
|
322
507
|
|
|
323
508
|
@module_subscription_ids ||= []
|
|
324
509
|
specs.each do |spec|
|
|
325
|
-
|
|
510
|
+
weakref = WeakRef.new(self)
|
|
511
|
+
subscription_id_ref = { id: nil }
|
|
512
|
+
callback = self.class.send(:build_subscription_callback, weakref, subscription_id_ref, spec)
|
|
326
513
|
subscription_id = DSPy.events.subscribe(spec[:pattern], &callback)
|
|
514
|
+
subscription_id_ref[:id] = subscription_id
|
|
327
515
|
@module_subscription_ids << subscription_id
|
|
328
516
|
end
|
|
329
517
|
|
|
330
518
|
@module_subscriptions_registered = true
|
|
331
519
|
end
|
|
332
520
|
|
|
333
|
-
def build_subscription_callback(spec)
|
|
334
|
-
scope = spec[:scope] || DEFAULT_MODULE_SUBSCRIPTION_SCOPE
|
|
335
|
-
handler = spec[:handler]
|
|
336
|
-
block = spec[:block]
|
|
337
|
-
|
|
338
|
-
proc do |event_name, attributes|
|
|
339
|
-
next unless module_event_within_scope?(attributes, scope)
|
|
340
|
-
|
|
341
|
-
if handler
|
|
342
|
-
send(handler, event_name, attributes)
|
|
343
|
-
else
|
|
344
|
-
instance_exec(event_name, attributes, &block)
|
|
345
|
-
end
|
|
346
|
-
end
|
|
347
|
-
end
|
|
348
|
-
|
|
349
521
|
def module_event_within_scope?(attributes, scope)
|
|
350
522
|
metadata = extract_module_metadata(attributes)
|
|
351
523
|
return false unless metadata
|
data/lib/dspy/predict.rb
CHANGED
|
@@ -64,8 +64,7 @@ module DSPy
|
|
|
64
64
|
super()
|
|
65
65
|
@signature_class = signature_class
|
|
66
66
|
|
|
67
|
-
|
|
68
|
-
@prompt = Prompt.from_signature(signature_class)
|
|
67
|
+
@prompt = build_prompt_from_signature
|
|
69
68
|
@demos = nil
|
|
70
69
|
end
|
|
71
70
|
|
|
@@ -146,6 +145,13 @@ module DSPy
|
|
|
146
145
|
instance
|
|
147
146
|
end
|
|
148
147
|
|
|
148
|
+
sig { override.params(block: T.proc.params(config: T.untyped).void).returns(T.self_type) }
|
|
149
|
+
def configure(&block)
|
|
150
|
+
super(&block)
|
|
151
|
+
sync_prompt_formats_from_lm(config.lm) if config.lm
|
|
152
|
+
self
|
|
153
|
+
end
|
|
154
|
+
|
|
149
155
|
sig { override.returns(T::Array[[String, DSPy::Module]]) }
|
|
150
156
|
def named_predictors
|
|
151
157
|
[["self", self]]
|
|
@@ -166,9 +172,6 @@ module DSPy
|
|
|
166
172
|
input_props = @signature_class.input_struct_class.props
|
|
167
173
|
coerced_input_values = coerce_output_attributes(input_values, input_props)
|
|
168
174
|
|
|
169
|
-
# Store coerced input values for optimization
|
|
170
|
-
@last_input_values = coerced_input_values.clone
|
|
171
|
-
|
|
172
175
|
# Validate input with coerced values
|
|
173
176
|
validate_input_struct(coerced_input_values)
|
|
174
177
|
|
|
@@ -190,6 +193,30 @@ module DSPy
|
|
|
190
193
|
|
|
191
194
|
private
|
|
192
195
|
|
|
196
|
+
def reset_thread_state
|
|
197
|
+
super
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
def build_prompt_from_signature
|
|
201
|
+
lm_source = lm
|
|
202
|
+
schema_format = lm_source&.schema_format
|
|
203
|
+
data_format = lm_source&.respond_to?(:data_format) ? lm_source.data_format : nil
|
|
204
|
+
|
|
205
|
+
Prompt.from_signature(@signature_class, schema_format: schema_format, data_format: data_format)
|
|
206
|
+
end
|
|
207
|
+
|
|
208
|
+
def sync_prompt_formats_from_lm(lm_source)
|
|
209
|
+
return unless lm_source
|
|
210
|
+
|
|
211
|
+
schema_format = lm_source&.schema_format
|
|
212
|
+
data_format = lm_source&.respond_to?(:data_format) ? lm_source.data_format : nil
|
|
213
|
+
|
|
214
|
+
prompt = @prompt
|
|
215
|
+
prompt = prompt.with_schema_format(schema_format) if schema_format
|
|
216
|
+
prompt = prompt.with_data_format(data_format) if data_format
|
|
217
|
+
@prompt = prompt
|
|
218
|
+
end
|
|
219
|
+
|
|
193
220
|
# Validates input using signature struct (assumes input is already coerced)
|
|
194
221
|
sig { params(input_values: T::Hash[Symbol, T.untyped]).void }
|
|
195
222
|
def validate_input_struct(input_values)
|
|
@@ -277,7 +304,7 @@ module DSPy
|
|
|
277
304
|
next unless prop_type
|
|
278
305
|
|
|
279
306
|
# For nilable fields with nil values, ensure proper handling
|
|
280
|
-
if value.nil? &&
|
|
307
|
+
if value.nil? && nilable_type?(prop_type)
|
|
281
308
|
# For nilable fields, nil is valid - keep it as is
|
|
282
309
|
next
|
|
283
310
|
elsif value.nil? && prop_info[:fully_optional]
|
data/lib/dspy/prediction.rb
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# typed: strict
|
|
2
2
|
# frozen_string_literal: true
|
|
3
3
|
|
|
4
|
+
require_relative 'utils/serialization'
|
|
5
|
+
|
|
4
6
|
module DSPy
|
|
5
7
|
class Prediction
|
|
6
8
|
extend T::Sig
|
|
@@ -54,7 +56,14 @@ module DSPy
|
|
|
54
56
|
|
|
55
57
|
sig { returns(T::Hash[Symbol, T.untyped]) }
|
|
56
58
|
def to_h
|
|
57
|
-
@_struct.serialize
|
|
59
|
+
hash = DSPy::Utils::Serialization.deep_serialize(@_struct.serialize)
|
|
60
|
+
hash.delete('_prediction_marker')
|
|
61
|
+
hash
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
sig { params(args: T.untyped).returns(String) }
|
|
65
|
+
def to_json(*args)
|
|
66
|
+
to_h.to_json(*args)
|
|
58
67
|
end
|
|
59
68
|
|
|
60
69
|
private
|
|
@@ -122,9 +131,10 @@ module DSPy
|
|
|
122
131
|
converted[key] = nil
|
|
123
132
|
end
|
|
124
133
|
elsif is_enum_type?(prop_type) && value.is_a?(String)
|
|
125
|
-
# Convert string to enum
|
|
134
|
+
# Convert string to enum (case-insensitive for structured_outputs: false)
|
|
126
135
|
enum_class = extract_enum_class(prop_type)
|
|
127
|
-
|
|
136
|
+
result = DSPy::Mixins::TypeCoercion.deserialize_enum(enum_class, value)
|
|
137
|
+
converted[key] = result || value
|
|
128
138
|
elsif value.is_a?(Hash) && needs_struct_conversion?(prop_type)
|
|
129
139
|
# Regular struct field that needs conversion
|
|
130
140
|
converted[key] = convert_to_struct(value, prop_type)
|
|
@@ -188,60 +198,12 @@ module DSPy
|
|
|
188
198
|
|
|
189
199
|
sig { params(type: T.untyped).returns(T::Boolean) }
|
|
190
200
|
def is_enum_type?(type)
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
case type
|
|
194
|
-
when T::Types::Simple
|
|
195
|
-
# Handle regular enum types
|
|
196
|
-
begin
|
|
197
|
-
raw_type = type.raw_type
|
|
198
|
-
return false unless raw_type.is_a?(Class)
|
|
199
|
-
result = raw_type < T::Enum
|
|
200
|
-
return result == true # Force conversion to boolean
|
|
201
|
-
rescue StandardError
|
|
202
|
-
return false
|
|
203
|
-
end
|
|
204
|
-
when T::Private::Types::SimplePairUnion, T::Types::Union
|
|
205
|
-
# Handle T.nilable enum types
|
|
206
|
-
# Find the non-nil type and check if it's an enum
|
|
207
|
-
non_nil_types = if type.respond_to?(:types)
|
|
208
|
-
type.types.reject { |t| t.respond_to?(:raw_type) && t.raw_type == NilClass }
|
|
209
|
-
else
|
|
210
|
-
[]
|
|
211
|
-
end
|
|
212
|
-
|
|
213
|
-
# For nilable types, we expect exactly one non-nil type
|
|
214
|
-
return false unless non_nil_types.size == 1
|
|
215
|
-
|
|
216
|
-
non_nil_type = non_nil_types.first
|
|
217
|
-
return is_enum_type?(non_nil_type) # Recursively check
|
|
218
|
-
else
|
|
219
|
-
return false
|
|
220
|
-
end
|
|
201
|
+
DSPy::Mixins::TypeCoercion.enum_type?(type)
|
|
221
202
|
end
|
|
222
203
|
|
|
223
204
|
sig { params(type: T.untyped).returns(T.untyped) }
|
|
224
205
|
def extract_enum_class(type)
|
|
225
|
-
|
|
226
|
-
when T::Types::Simple
|
|
227
|
-
# Regular enum type
|
|
228
|
-
type.raw_type
|
|
229
|
-
when T::Private::Types::SimplePairUnion, T::Types::Union
|
|
230
|
-
# Nilable enum type - find the non-nil type
|
|
231
|
-
non_nil_types = if type.respond_to?(:types)
|
|
232
|
-
type.types.reject { |t| t.respond_to?(:raw_type) && t.raw_type == NilClass }
|
|
233
|
-
else
|
|
234
|
-
[]
|
|
235
|
-
end
|
|
236
|
-
|
|
237
|
-
if non_nil_types.size == 1
|
|
238
|
-
extract_enum_class(non_nil_types.first)
|
|
239
|
-
else
|
|
240
|
-
raise ArgumentError, "Unable to extract enum class from complex union type: #{type.inspect}"
|
|
241
|
-
end
|
|
242
|
-
else
|
|
243
|
-
raise ArgumentError, "Not an enum type: #{type.inspect}"
|
|
244
|
-
end
|
|
206
|
+
DSPy::Mixins::TypeCoercion.extract_enum_class(type)
|
|
245
207
|
end
|
|
246
208
|
|
|
247
209
|
sig { params(union_type: T::Types::Union, discriminator_type: T.untyped).returns(T::Hash[String, T.untyped]) }
|
|
@@ -387,8 +349,10 @@ module DSPy
|
|
|
387
349
|
if prop_info
|
|
388
350
|
prop_type = prop_info[:type_object] || prop_info[:type]
|
|
389
351
|
if v.is_a?(String) && is_enum_type?(prop_type)
|
|
390
|
-
# Convert string to enum
|
|
391
|
-
|
|
352
|
+
# Convert string to enum (case-insensitive for structured_outputs: false)
|
|
353
|
+
enum_class = extract_enum_class(prop_type)
|
|
354
|
+
result = DSPy::Mixins::TypeCoercion.deserialize_enum(enum_class, v)
|
|
355
|
+
converted_hash[k] = result || v
|
|
392
356
|
elsif v.is_a?(Hash) && needs_struct_conversion?(prop_type)
|
|
393
357
|
converted_hash[k] = convert_to_struct(v, prop_type)
|
|
394
358
|
elsif v.is_a?(Array) && needs_array_conversion?(prop_type)
|
|
@@ -488,8 +452,9 @@ module DSPy
|
|
|
488
452
|
convert_to_struct(element, element_type)
|
|
489
453
|
end
|
|
490
454
|
elsif element.is_a?(String) && is_enum_type?(element_type)
|
|
491
|
-
# Convert string to enum
|
|
492
|
-
element_type
|
|
455
|
+
# Convert string to enum (case-insensitive for structured_outputs: false)
|
|
456
|
+
enum_class = extract_enum_class(element_type)
|
|
457
|
+
DSPy::Mixins::TypeCoercion.deserialize_enum(enum_class, element) || element
|
|
493
458
|
else
|
|
494
459
|
element
|
|
495
460
|
end
|
|
@@ -539,7 +504,9 @@ module DSPy
|
|
|
539
504
|
if prop_info
|
|
540
505
|
prop_type = prop_info[:type_object] || prop_info[:type]
|
|
541
506
|
if v.is_a?(String) && is_enum_type?(prop_type)
|
|
542
|
-
|
|
507
|
+
enum_class = extract_enum_class(prop_type)
|
|
508
|
+
result = DSPy::Mixins::TypeCoercion.deserialize_enum(enum_class, v)
|
|
509
|
+
converted_hash[k] = result || v
|
|
543
510
|
elsif v.is_a?(Hash) && needs_struct_conversion?(prop_type)
|
|
544
511
|
converted_hash[k] = convert_to_struct(v, prop_type)
|
|
545
512
|
elsif v.is_a?(Array) && needs_array_conversion?(prop_type)
|