dspy 0.34.2 → 0.34.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +8 -16
- data/lib/dspy/chain_of_thought.rb +3 -2
- data/lib/dspy/context.rb +70 -21
- data/lib/dspy/evals/version.rb +1 -1
- data/lib/dspy/evals.rb +42 -31
- data/lib/dspy/events.rb +2 -3
- data/lib/dspy/example.rb +1 -1
- data/lib/dspy/lm/adapter.rb +39 -0
- data/lib/dspy/lm/json_strategy.rb +28 -67
- data/lib/dspy/lm/message.rb +1 -1
- data/lib/dspy/lm/response.rb +2 -2
- data/lib/dspy/lm/usage.rb +35 -10
- data/lib/dspy/lm.rb +22 -51
- data/lib/dspy/mixins/type_coercion.rb +256 -35
- data/lib/dspy/module.rb +203 -31
- data/lib/dspy/predict.rb +33 -6
- data/lib/dspy/prediction.rb +25 -58
- data/lib/dspy/prompt.rb +52 -76
- data/lib/dspy/propose/dataset_summary_generator.rb +1 -1
- data/lib/dspy/propose/grounded_proposer.rb +3 -3
- data/lib/dspy/re_act.rb +159 -196
- data/lib/dspy/registry/signature_registry.rb +3 -3
- data/lib/dspy/ruby_llm/lm/adapters/ruby_llm_adapter.rb +1 -27
- data/lib/dspy/schema/sorbet_json_schema.rb +7 -6
- data/lib/dspy/schema/version.rb +1 -1
- data/lib/dspy/schema_adapters.rb +1 -1
- data/lib/dspy/signature.rb +4 -5
- data/lib/dspy/storage/program_storage.rb +2 -2
- data/lib/dspy/structured_outputs_prompt.rb +4 -4
- data/lib/dspy/teleprompt/utils.rb +2 -2
- data/lib/dspy/tools/github_cli_toolset.rb +7 -7
- data/lib/dspy/tools/text_processing_toolset.rb +2 -2
- data/lib/dspy/tools/toolset.rb +1 -1
- data/lib/dspy/utils/serialization.rb +2 -6
- data/lib/dspy/version.rb +1 -1
- data/lib/dspy.rb +50 -5
- metadata +7 -26
- data/lib/dspy/events/subscriber_mixin.rb +0 -79
- data/lib/dspy/events/subscribers.rb +0 -43
- data/lib/dspy/memory/embedding_engine.rb +0 -68
- data/lib/dspy/memory/in_memory_store.rb +0 -216
- data/lib/dspy/memory/local_embedding_engine.rb +0 -244
- data/lib/dspy/memory/memory_compactor.rb +0 -298
- data/lib/dspy/memory/memory_manager.rb +0 -266
- data/lib/dspy/memory/memory_record.rb +0 -163
- data/lib/dspy/memory/memory_store.rb +0 -90
- data/lib/dspy/memory.rb +0 -30
- data/lib/dspy/tools/memory_toolset.rb +0 -117
data/lib/dspy/lm/usage.rb
CHANGED
|
@@ -45,11 +45,34 @@ module DSPy
|
|
|
45
45
|
end
|
|
46
46
|
end
|
|
47
47
|
|
|
48
|
+
# Anthropic-specific usage information with cache token fields
|
|
49
|
+
class AnthropicUsage < T::Struct
|
|
50
|
+
extend T::Sig
|
|
51
|
+
|
|
52
|
+
const :input_tokens, Integer
|
|
53
|
+
const :output_tokens, Integer
|
|
54
|
+
const :total_tokens, Integer
|
|
55
|
+
const :cache_creation_input_tokens, T.nilable(Integer), default: nil
|
|
56
|
+
const :cache_read_input_tokens, T.nilable(Integer), default: nil
|
|
57
|
+
|
|
58
|
+
sig { returns(Hash) }
|
|
59
|
+
def to_h
|
|
60
|
+
base = {
|
|
61
|
+
input_tokens: input_tokens,
|
|
62
|
+
output_tokens: output_tokens,
|
|
63
|
+
total_tokens: total_tokens
|
|
64
|
+
}
|
|
65
|
+
base[:cache_creation_input_tokens] = cache_creation_input_tokens unless cache_creation_input_tokens.nil?
|
|
66
|
+
base[:cache_read_input_tokens] = cache_read_input_tokens unless cache_read_input_tokens.nil?
|
|
67
|
+
base
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
|
|
48
71
|
# Factory for creating appropriate usage objects
|
|
49
72
|
module UsageFactory
|
|
50
73
|
extend T::Sig
|
|
51
74
|
|
|
52
|
-
sig { params(provider: String, usage_data: T.untyped).returns(T.nilable(T.any(Usage, OpenAIUsage))) }
|
|
75
|
+
sig { params(provider: String, usage_data: T.untyped).returns(T.nilable(T.any(Usage, OpenAIUsage, AnthropicUsage))) }
|
|
53
76
|
def self.create(provider, usage_data)
|
|
54
77
|
return nil if usage_data.nil?
|
|
55
78
|
|
|
@@ -99,7 +122,7 @@ module DSPy
|
|
|
99
122
|
prompt_tokens_details: prompt_details,
|
|
100
123
|
completion_tokens_details: completion_details
|
|
101
124
|
)
|
|
102
|
-
rescue => e
|
|
125
|
+
rescue StandardError => e
|
|
103
126
|
DSPy.logger.debug("Failed to create OpenAI usage: #{e.message}")
|
|
104
127
|
nil
|
|
105
128
|
end
|
|
@@ -121,19 +144,21 @@ module DSPy
|
|
|
121
144
|
nil
|
|
122
145
|
end
|
|
123
146
|
|
|
124
|
-
sig { params(data: T::Hash[Symbol, T.untyped]).returns(T.nilable(
|
|
147
|
+
sig { params(data: T::Hash[Symbol, T.untyped]).returns(T.nilable(AnthropicUsage)) }
|
|
125
148
|
def self.create_anthropic_usage(data)
|
|
126
149
|
# Anthropic uses input_tokens/output_tokens
|
|
127
150
|
input_tokens = data[:input_tokens] || 0
|
|
128
151
|
output_tokens = data[:output_tokens] || 0
|
|
129
152
|
total_tokens = data[:total_tokens] || (input_tokens + output_tokens)
|
|
130
|
-
|
|
131
|
-
|
|
153
|
+
|
|
154
|
+
AnthropicUsage.new(
|
|
132
155
|
input_tokens: input_tokens,
|
|
133
156
|
output_tokens: output_tokens,
|
|
134
|
-
total_tokens: total_tokens
|
|
157
|
+
total_tokens: total_tokens,
|
|
158
|
+
cache_creation_input_tokens: data[:cache_creation_input_tokens],
|
|
159
|
+
cache_read_input_tokens: data[:cache_read_input_tokens]
|
|
135
160
|
)
|
|
136
|
-
rescue => e
|
|
161
|
+
rescue StandardError => e
|
|
137
162
|
DSPy.logger.debug("Failed to create Anthropic usage: #{e.message}")
|
|
138
163
|
nil
|
|
139
164
|
end
|
|
@@ -150,7 +175,7 @@ module DSPy
|
|
|
150
175
|
output_tokens: output_tokens,
|
|
151
176
|
total_tokens: total_tokens
|
|
152
177
|
)
|
|
153
|
-
rescue => e
|
|
178
|
+
rescue StandardError => e
|
|
154
179
|
DSPy.logger.debug("Failed to create Gemini usage: #{e.message}")
|
|
155
180
|
nil
|
|
156
181
|
end
|
|
@@ -167,10 +192,10 @@ module DSPy
|
|
|
167
192
|
output_tokens: output_tokens,
|
|
168
193
|
total_tokens: total_tokens
|
|
169
194
|
)
|
|
170
|
-
rescue => e
|
|
195
|
+
rescue StandardError => e
|
|
171
196
|
DSPy.logger.debug("Failed to create generic usage: #{e.message}")
|
|
172
197
|
nil
|
|
173
198
|
end
|
|
174
199
|
end
|
|
175
200
|
end
|
|
176
|
-
end
|
|
201
|
+
end
|
data/lib/dspy/lm.rb
CHANGED
|
@@ -146,7 +146,7 @@ module DSPy
|
|
|
146
146
|
|
|
147
147
|
# Determine if structured outputs will be used and wrap prompt if so
|
|
148
148
|
base_prompt = inference_module.prompt
|
|
149
|
-
prompt = if will_use_structured_outputs?(inference_module.signature_class)
|
|
149
|
+
prompt = if will_use_structured_outputs?(inference_module.signature_class, data_format: base_prompt.data_format)
|
|
150
150
|
StructuredOutputsPrompt.new(**base_prompt.to_h)
|
|
151
151
|
else
|
|
152
152
|
base_prompt
|
|
@@ -171,8 +171,9 @@ module DSPy
|
|
|
171
171
|
messages
|
|
172
172
|
end
|
|
173
173
|
|
|
174
|
-
def will_use_structured_outputs?(signature_class)
|
|
174
|
+
def will_use_structured_outputs?(signature_class, data_format: nil)
|
|
175
175
|
return false unless signature_class
|
|
176
|
+
return false if data_format == :toon
|
|
176
177
|
|
|
177
178
|
adapter_class_name = adapter.class.name
|
|
178
179
|
|
|
@@ -304,6 +305,12 @@ module DSPy
|
|
|
304
305
|
span.set_attribute('gen_ai.usage.prompt_tokens', usage.input_tokens) if usage.input_tokens
|
|
305
306
|
span.set_attribute('gen_ai.usage.completion_tokens', usage.output_tokens) if usage.output_tokens
|
|
306
307
|
span.set_attribute('gen_ai.usage.total_tokens', usage.total_tokens) if usage.total_tokens
|
|
308
|
+
if usage.respond_to?(:cache_creation_input_tokens) && !usage.cache_creation_input_tokens.nil?
|
|
309
|
+
span.set_attribute('gen_ai.usage.cache_creation_input_tokens', usage.cache_creation_input_tokens)
|
|
310
|
+
end
|
|
311
|
+
if usage.respond_to?(:cache_read_input_tokens) && !usage.cache_read_input_tokens.nil?
|
|
312
|
+
span.set_attribute('gen_ai.usage.cache_read_input_tokens', usage.cache_read_input_tokens)
|
|
313
|
+
end
|
|
307
314
|
end
|
|
308
315
|
end
|
|
309
316
|
|
|
@@ -327,8 +334,9 @@ module DSPy
|
|
|
327
334
|
})
|
|
328
335
|
|
|
329
336
|
# Add timing and request correlation if available
|
|
330
|
-
|
|
331
|
-
|
|
337
|
+
context = DSPy::Context.current
|
|
338
|
+
request_id = context[:request_id]
|
|
339
|
+
start_time = context[:request_start_time]
|
|
332
340
|
|
|
333
341
|
if request_id
|
|
334
342
|
event_attributes['request_id'] = request_id
|
|
@@ -354,11 +362,16 @@ module DSPy
|
|
|
354
362
|
|
|
355
363
|
# Handle Usage struct objects
|
|
356
364
|
if response.usage.respond_to?(:input_tokens)
|
|
357
|
-
|
|
365
|
+
result = {
|
|
358
366
|
input_tokens: response.usage.input_tokens,
|
|
359
367
|
output_tokens: response.usage.output_tokens,
|
|
360
368
|
total_tokens: response.usage.total_tokens
|
|
361
|
-
}
|
|
369
|
+
}
|
|
370
|
+
if response.usage.respond_to?(:cache_creation_input_tokens)
|
|
371
|
+
result[:cache_creation_input_tokens] = response.usage.cache_creation_input_tokens
|
|
372
|
+
result[:cache_read_input_tokens] = response.usage.cache_read_input_tokens
|
|
373
|
+
end
|
|
374
|
+
return result.compact
|
|
362
375
|
end
|
|
363
376
|
|
|
364
377
|
# Handle hash-based usage (for VCR compatibility)
|
|
@@ -384,63 +397,21 @@ module DSPy
|
|
|
384
397
|
end
|
|
385
398
|
end
|
|
386
399
|
|
|
387
|
-
public
|
|
388
|
-
|
|
389
|
-
def validate_messages!(messages)
|
|
390
|
-
unless messages.is_a?(Array)
|
|
391
|
-
raise ArgumentError, "messages must be an array"
|
|
392
|
-
end
|
|
393
|
-
|
|
394
|
-
messages.each_with_index do |message, index|
|
|
395
|
-
# Accept both Message objects and hash format for backward compatibility
|
|
396
|
-
if message.is_a?(Message)
|
|
397
|
-
# Already validated by type system
|
|
398
|
-
next
|
|
399
|
-
elsif message.is_a?(Hash) || message.respond_to?(:to_h)
|
|
400
|
-
data = message.is_a?(Hash) ? message : message.to_h
|
|
401
|
-
unless data.is_a?(Hash)
|
|
402
|
-
raise ArgumentError, "Message at index #{index} must be a Message object or hash with :role and :content"
|
|
403
|
-
end
|
|
404
|
-
|
|
405
|
-
normalized = data.transform_keys(&:to_sym)
|
|
406
|
-
unless normalized.key?(:role) && normalized.key?(:content)
|
|
407
|
-
raise ArgumentError, "Message at index #{index} must have :role and :content"
|
|
408
|
-
end
|
|
409
|
-
|
|
410
|
-
role = normalized[:role].to_s
|
|
411
|
-
valid_roles = %w[system user assistant]
|
|
412
|
-
unless valid_roles.include?(role)
|
|
413
|
-
raise ArgumentError, "Invalid role at index #{index}: #{normalized[:role]}. Must be one of: #{valid_roles.join(', ')}"
|
|
414
|
-
end
|
|
415
|
-
else
|
|
416
|
-
raise ArgumentError, "Message at index #{index} must be a Message object or hash with :role and :content"
|
|
417
|
-
end
|
|
418
|
-
end
|
|
419
|
-
end
|
|
420
|
-
|
|
421
400
|
def execute_raw_chat(messages, &streaming_block)
|
|
422
401
|
# Generate unique request ID for tracking
|
|
423
402
|
request_id = SecureRandom.hex(8)
|
|
424
403
|
start_time = Time.now
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
Thread.current[:dspy_request_id] = request_id
|
|
428
|
-
Thread.current[:dspy_request_start_time] = start_time
|
|
429
|
-
|
|
430
|
-
begin
|
|
404
|
+
|
|
405
|
+
DSPy::Context.with_request(request_id, start_time) do
|
|
431
406
|
response = instrument_lm_request(messages, 'RawPrompt') do
|
|
432
407
|
# Convert messages to hash format for adapter
|
|
433
408
|
hash_messages = messages_to_hash_array(messages)
|
|
434
409
|
# Direct adapter call, no strategies or JSON parsing
|
|
435
410
|
adapter.chat(messages: hash_messages, signature: nil, &streaming_block)
|
|
436
411
|
end
|
|
437
|
-
|
|
412
|
+
|
|
438
413
|
# Return raw response content, not parsed JSON
|
|
439
414
|
response.content
|
|
440
|
-
ensure
|
|
441
|
-
# Clean up thread-local storage
|
|
442
|
-
Thread.current[:dspy_request_id] = nil
|
|
443
|
-
Thread.current[:dspy_request_start_time] = nil
|
|
444
415
|
end
|
|
445
416
|
end
|
|
446
417
|
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# frozen_string_literal: true
|
|
3
3
|
|
|
4
4
|
require 'sorbet-runtime'
|
|
5
|
+
require 'yaml'
|
|
5
6
|
|
|
6
7
|
module DSPy
|
|
7
8
|
module Mixins
|
|
@@ -9,6 +10,62 @@ module DSPy
|
|
|
9
10
|
module TypeCoercion
|
|
10
11
|
extend T::Sig
|
|
11
12
|
|
|
13
|
+
# Centralized enum deserialization with case-insensitive fallback.
|
|
14
|
+
# Uses try_deserialize for O(1) exact match, then a lazily-built
|
|
15
|
+
# case-insensitive lookup hash as fallback for LLM casing variations.
|
|
16
|
+
#
|
|
17
|
+
# Returns the enum instance on match, or nil if no match found.
|
|
18
|
+
sig { params(enum_class: T.untyped, value: T.untyped).returns(T.nilable(T::Enum)) }
|
|
19
|
+
def self.deserialize_enum(enum_class, value)
|
|
20
|
+
return value if value.is_a?(enum_class)
|
|
21
|
+
|
|
22
|
+
str = value.to_s
|
|
23
|
+
result = enum_class.try_deserialize(str)
|
|
24
|
+
return result if result
|
|
25
|
+
|
|
26
|
+
@ci_enum_cache ||= {}
|
|
27
|
+
ci_map = (@ci_enum_cache[enum_class] ||=
|
|
28
|
+
enum_class.values.each_with_object({}) { |v, h| h[v.serialize.downcase.freeze] = v }.freeze)
|
|
29
|
+
|
|
30
|
+
ci_map[str.downcase]
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# Module-level enum type detection (delegates to instance method)
|
|
34
|
+
sig { params(type: T.untyped).returns(T::Boolean) }
|
|
35
|
+
def self.enum_type?(type)
|
|
36
|
+
return false unless type
|
|
37
|
+
|
|
38
|
+
case type
|
|
39
|
+
when Class
|
|
40
|
+
!!(type < T::Enum)
|
|
41
|
+
when T::Types::Simple
|
|
42
|
+
type.raw_type.is_a?(Class) && !!(type.raw_type < T::Enum)
|
|
43
|
+
when T::Types::Union
|
|
44
|
+
non_nil = type.types.reject { |t| t.is_a?(T::Types::Simple) && t.raw_type == NilClass }
|
|
45
|
+
non_nil.size == 1 && enum_type?(non_nil.first)
|
|
46
|
+
else
|
|
47
|
+
false
|
|
48
|
+
end
|
|
49
|
+
rescue StandardError
|
|
50
|
+
false
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
# Module-level enum class extraction (delegates to instance method)
|
|
54
|
+
sig { params(prop_type: T.untyped).returns(T.class_of(T::Enum)) }
|
|
55
|
+
def self.extract_enum_class(prop_type)
|
|
56
|
+
case prop_type
|
|
57
|
+
when Class
|
|
58
|
+
return prop_type if prop_type < T::Enum
|
|
59
|
+
when T::Types::Simple
|
|
60
|
+
return prop_type.raw_type if prop_type.raw_type.is_a?(Class) && prop_type.raw_type < T::Enum
|
|
61
|
+
when T::Types::Union
|
|
62
|
+
non_nil = prop_type.types.reject { |t| t.is_a?(T::Types::Simple) && t.raw_type == NilClass }
|
|
63
|
+
return extract_enum_class(non_nil.first) if non_nil.size == 1
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
raise ArgumentError, "Not an enum type: #{prop_type.inspect}"
|
|
67
|
+
end
|
|
68
|
+
|
|
12
69
|
private
|
|
13
70
|
|
|
14
71
|
# Coerces output attributes to match their expected types
|
|
@@ -32,6 +89,15 @@ module DSPy
|
|
|
32
89
|
case prop_type
|
|
33
90
|
when ->(type) { union_type?(type) }
|
|
34
91
|
coerce_union_value(value, prop_type)
|
|
92
|
+
when ->(type) { nilable_type?(type) }
|
|
93
|
+
# Unwrap T.nilable(X) to coerce as X (nil already handled above)
|
|
94
|
+
non_nil_types = prop_type.types.reject { |t| t == T::Utils.coerce(NilClass) }
|
|
95
|
+
if non_nil_types.size == 1
|
|
96
|
+
coerce_value_to_type(value, non_nil_types.first)
|
|
97
|
+
else
|
|
98
|
+
# T.any(A, B, NilClass) — rebuild as T.any(A, B) and coerce as union
|
|
99
|
+
coerce_union_value(value, T::Types::Union.new(non_nil_types))
|
|
100
|
+
end
|
|
35
101
|
when ->(type) { array_type?(type) }
|
|
36
102
|
coerce_array_value(value, prop_type)
|
|
37
103
|
when ->(type) { hash_type?(type) }
|
|
@@ -57,32 +123,16 @@ module DSPy
|
|
|
57
123
|
end
|
|
58
124
|
end
|
|
59
125
|
|
|
60
|
-
# Checks if a type is an enum type
|
|
126
|
+
# Checks if a type is an enum type (handles Class, Simple, and nilable unions)
|
|
61
127
|
sig { params(type: T.untyped).returns(T::Boolean) }
|
|
62
128
|
def enum_type?(type)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
if type.is_a?(Class)
|
|
66
|
-
!!(type < T::Enum)
|
|
67
|
-
elsif type.is_a?(T::Types::Simple)
|
|
68
|
-
!!(type.raw_type < T::Enum)
|
|
69
|
-
else
|
|
70
|
-
false
|
|
71
|
-
end
|
|
72
|
-
rescue StandardError
|
|
73
|
-
false
|
|
129
|
+
DSPy::Mixins::TypeCoercion.enum_type?(type)
|
|
74
130
|
end
|
|
75
131
|
|
|
76
|
-
# Extracts the enum class from a type
|
|
132
|
+
# Extracts the enum class from a type (handles Class, Simple, and nilable unions)
|
|
77
133
|
sig { params(prop_type: T.untyped).returns(T.class_of(T::Enum)) }
|
|
78
134
|
def extract_enum_class(prop_type)
|
|
79
|
-
|
|
80
|
-
prop_type
|
|
81
|
-
elsif prop_type.is_a?(T::Types::Simple) && prop_type.raw_type < T::Enum
|
|
82
|
-
prop_type.raw_type
|
|
83
|
-
else
|
|
84
|
-
T.cast(prop_type, T.class_of(T::Enum))
|
|
85
|
-
end
|
|
135
|
+
DSPy::Mixins::TypeCoercion.extract_enum_class(prop_type)
|
|
86
136
|
end
|
|
87
137
|
|
|
88
138
|
# Checks if a type matches a simple type (like Float, Integer)
|
|
@@ -121,15 +171,131 @@ module DSPy
|
|
|
121
171
|
# Checks if a type is a union type (T.any)
|
|
122
172
|
sig { params(type: T.untyped).returns(T::Boolean) }
|
|
123
173
|
def union_type?(type)
|
|
124
|
-
type.is_a?(T::Types::Union) && !
|
|
174
|
+
type.is_a?(T::Types::Union) && !nilable_type?(type)
|
|
125
175
|
end
|
|
126
176
|
|
|
127
177
|
# Checks if a type is nilable (contains NilClass)
|
|
128
178
|
sig { params(type: T.untyped).returns(T::Boolean) }
|
|
129
|
-
def
|
|
179
|
+
def nilable_type?(type)
|
|
130
180
|
type.is_a?(T::Types::Union) && type.types.any? { |t| t == T::Utils.coerce(NilClass) }
|
|
131
181
|
end
|
|
132
182
|
|
|
183
|
+
# Checks if a union type is a simple nilable struct (T.nilable(SomeStruct))
|
|
184
|
+
# Returns true only if the union has exactly 2 types: NilClass and a Struct
|
|
185
|
+
sig { params(union_type: T.untyped).returns(T::Boolean) }
|
|
186
|
+
def nilable_struct_union?(union_type)
|
|
187
|
+
return false unless union_type.is_a?(T::Types::Union)
|
|
188
|
+
|
|
189
|
+
types = union_type.types
|
|
190
|
+
return false unless types.size == 2
|
|
191
|
+
|
|
192
|
+
# One type must be NilClass, the other must be a struct
|
|
193
|
+
has_nil = types.any? { |t| t == T::Utils.coerce(NilClass) }
|
|
194
|
+
struct_type = types.find { |t| t != T::Utils.coerce(NilClass) && struct_type?(t) }
|
|
195
|
+
|
|
196
|
+
has_nil && !struct_type.nil?
|
|
197
|
+
end
|
|
198
|
+
|
|
199
|
+
# Checks if a type is a scalar (primitives that don't need special serialization)
|
|
200
|
+
sig { params(type_object: T.untyped).returns(T::Boolean) }
|
|
201
|
+
def scalar_type?(type_object)
|
|
202
|
+
case type_object
|
|
203
|
+
when T::Types::Simple
|
|
204
|
+
scalar_classes = [String, Integer, Float, Numeric, TrueClass, FalseClass]
|
|
205
|
+
scalar_classes.any? { |klass| type_object.raw_type == klass || type_object.raw_type <= klass }
|
|
206
|
+
when T::Types::Union
|
|
207
|
+
# Union is scalar if all its types are scalars
|
|
208
|
+
type_object.types.all? { |t| scalar_type?(t) }
|
|
209
|
+
else
|
|
210
|
+
false
|
|
211
|
+
end
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
# Checks if a type is structured (arrays, hashes, structs that need type preservation)
|
|
215
|
+
sig { params(type_object: T.untyped).returns(T::Boolean) }
|
|
216
|
+
def structured_type?(type_object)
|
|
217
|
+
return true if type_object.is_a?(T::Types::TypedArray)
|
|
218
|
+
return true if type_object.is_a?(T::Types::TypedHash)
|
|
219
|
+
|
|
220
|
+
if type_object.is_a?(T::Types::Simple)
|
|
221
|
+
raw_type = type_object.raw_type
|
|
222
|
+
return true if raw_type.respond_to?(:<=) && raw_type <= T::Struct
|
|
223
|
+
end
|
|
224
|
+
|
|
225
|
+
# For union types (like T.nilable(T::Array[...])), check if any non-nil type is structured
|
|
226
|
+
if type_object.is_a?(T::Types::Union)
|
|
227
|
+
non_nil_types = type_object.types.reject { |t| t.is_a?(T::Types::Simple) && t.raw_type == NilClass }
|
|
228
|
+
return non_nil_types.any? { |t| structured_type?(t) }
|
|
229
|
+
end
|
|
230
|
+
|
|
231
|
+
false
|
|
232
|
+
end
|
|
233
|
+
|
|
234
|
+
# Checks if a type is String or compatible with String (e.g., T.any(String, ...) or T.nilable(String))
|
|
235
|
+
sig { params(type_object: T.untyped).returns(T::Boolean) }
|
|
236
|
+
def string_type?(type_object)
|
|
237
|
+
case type_object
|
|
238
|
+
when T::Types::Simple
|
|
239
|
+
type_object.raw_type == String
|
|
240
|
+
when T::Types::Union
|
|
241
|
+
# Check if any of the union types is String
|
|
242
|
+
type_object.types.any? { |t| t.is_a?(T::Types::Simple) && t.raw_type == String }
|
|
243
|
+
else
|
|
244
|
+
false
|
|
245
|
+
end
|
|
246
|
+
end
|
|
247
|
+
|
|
248
|
+
# Get a readable type name from a Sorbet type object
|
|
249
|
+
sig { params(type_object: T.untyped).returns(String) }
|
|
250
|
+
def type_name(type_object)
|
|
251
|
+
case type_object
|
|
252
|
+
when T::Types::TypedArray
|
|
253
|
+
element_type = type_object.type
|
|
254
|
+
"T::Array[#{type_name(element_type)}]"
|
|
255
|
+
when T::Types::TypedHash
|
|
256
|
+
"T::Hash"
|
|
257
|
+
when T::Types::Simple
|
|
258
|
+
type_object.raw_type.to_s
|
|
259
|
+
when T::Types::Union
|
|
260
|
+
types_str = type_object.types.map { |t| type_name(t) }.join(', ')
|
|
261
|
+
"T.any(#{types_str})"
|
|
262
|
+
else
|
|
263
|
+
type_object.to_s
|
|
264
|
+
end
|
|
265
|
+
end
|
|
266
|
+
|
|
267
|
+
# Returns an appropriate default value for a given Sorbet type
|
|
268
|
+
# This is used when max iterations is reached without a successful completion
|
|
269
|
+
sig { params(type_object: T.untyped).returns(T.untyped) }
|
|
270
|
+
def default_value_for_type(type_object)
|
|
271
|
+
# Handle TypedArray (T::Array[...])
|
|
272
|
+
return [] if type_object.is_a?(T::Types::TypedArray)
|
|
273
|
+
|
|
274
|
+
# Handle TypedHash (T::Hash[...])
|
|
275
|
+
return {} if type_object.is_a?(T::Types::TypedHash)
|
|
276
|
+
|
|
277
|
+
# Handle simple types
|
|
278
|
+
case type_object
|
|
279
|
+
when T::Types::Simple
|
|
280
|
+
raw_type = type_object.raw_type
|
|
281
|
+
case raw_type.to_s
|
|
282
|
+
when 'String' then ''
|
|
283
|
+
when 'Integer' then 0
|
|
284
|
+
when 'Float' then 0.0
|
|
285
|
+
when 'TrueClass', 'FalseClass' then false
|
|
286
|
+
else
|
|
287
|
+
# For T::Struct types, return nil as fallback
|
|
288
|
+
nil
|
|
289
|
+
end
|
|
290
|
+
when T::Types::Union
|
|
291
|
+
# For unions, return nil (assuming it's nilable) or first non-nil default
|
|
292
|
+
nil
|
|
293
|
+
else
|
|
294
|
+
# Default fallback for unknown types
|
|
295
|
+
nil
|
|
296
|
+
end
|
|
297
|
+
end
|
|
298
|
+
|
|
133
299
|
# Coerces an array value, converting each element as needed
|
|
134
300
|
sig { params(value: T.untyped, prop_type: T.untyped).returns(T.untyped) }
|
|
135
301
|
def coerce_array_value(value, prop_type)
|
|
@@ -143,16 +309,18 @@ module DSPy
|
|
|
143
309
|
# Coerces a hash value, converting keys and values as needed
|
|
144
310
|
sig { params(value: T.untyped, prop_type: T.untyped).returns(T.untyped) }
|
|
145
311
|
def coerce_hash_value(value, prop_type)
|
|
146
|
-
return value unless value.is_a?(Hash)
|
|
147
312
|
return value unless prop_type.is_a?(T::Types::TypedHash)
|
|
148
|
-
|
|
313
|
+
|
|
314
|
+
value = try_parse_string_to_hash(value)
|
|
315
|
+
return value unless value.is_a?(Hash)
|
|
316
|
+
|
|
149
317
|
key_type = prop_type.keys
|
|
150
318
|
value_type = prop_type.values
|
|
151
319
|
|
|
152
320
|
# Convert string keys to enum instances if key_type is an enum
|
|
153
321
|
result = if enum_type?(key_type)
|
|
154
322
|
enum_class = extract_enum_class(key_type)
|
|
155
|
-
value.transform_keys { |k|
|
|
323
|
+
value.transform_keys { |k| DSPy::Mixins::TypeCoercion.deserialize_enum(enum_class, k) || k }
|
|
156
324
|
else
|
|
157
325
|
# For non-enum keys, coerce them to the expected type
|
|
158
326
|
value.transform_keys { |k| coerce_value_to_type(k, key_type) }
|
|
@@ -162,9 +330,41 @@ module DSPy
|
|
|
162
330
|
result.transform_values { |v| coerce_value_to_type(v, value_type) }
|
|
163
331
|
end
|
|
164
332
|
|
|
333
|
+
# Attempts to parse a string into a Hash.
|
|
334
|
+
# Returns the parsed Hash on success, or the original value otherwise.
|
|
335
|
+
sig { params(value: T.untyped).returns(T.untyped) }
|
|
336
|
+
def try_parse_string_to_hash(value)
|
|
337
|
+
return value unless value.is_a?(String)
|
|
338
|
+
|
|
339
|
+
parsed = begin
|
|
340
|
+
JSON.parse(value)
|
|
341
|
+
rescue JSON::ParserError
|
|
342
|
+
YAML.safe_load(value, permitted_classes: [Symbol, Date, Time])
|
|
343
|
+
end
|
|
344
|
+
|
|
345
|
+
parsed.is_a?(Hash) ? parsed : value
|
|
346
|
+
rescue Psych::SyntaxError
|
|
347
|
+
value
|
|
348
|
+
end
|
|
349
|
+
|
|
350
|
+
# Attempts to parse a JSON string into a Hash.
|
|
351
|
+
# Returns the parsed Hash on success, or the original value otherwise.
|
|
352
|
+
sig { params(value: T.untyped).returns(T.untyped) }
|
|
353
|
+
def try_parse_json_to_hash(value)
|
|
354
|
+
return value unless value.is_a?(String)
|
|
355
|
+
|
|
356
|
+
parsed = JSON.parse(value)
|
|
357
|
+
parsed.is_a?(Hash) ? parsed : value
|
|
358
|
+
rescue JSON::ParserError
|
|
359
|
+
value
|
|
360
|
+
end
|
|
361
|
+
|
|
165
362
|
# Coerces a struct value from a hash
|
|
166
363
|
sig { params(value: T.untyped, prop_type: T.untyped).returns(T.untyped) }
|
|
167
364
|
def coerce_struct_value(value, prop_type)
|
|
365
|
+
# Anthropic tool use may return struct fields as JSON strings
|
|
366
|
+
value = try_parse_json_to_hash(value)
|
|
367
|
+
|
|
168
368
|
return value unless value.is_a?(Hash)
|
|
169
369
|
|
|
170
370
|
struct_class = if prop_type.is_a?(Class)
|
|
@@ -197,7 +397,19 @@ module DSPy
|
|
|
197
397
|
[key, val]
|
|
198
398
|
end
|
|
199
399
|
end.to_h
|
|
200
|
-
|
|
400
|
+
|
|
401
|
+
# Strip nil values for non-nilable fields that have defaults.
|
|
402
|
+
# LLMs in advisory mode may return null for unused fields.
|
|
403
|
+
# Removing the key lets Sorbet use the field's default value.
|
|
404
|
+
coerced_hash.reject! do |key, val|
|
|
405
|
+
next false unless val.nil?
|
|
406
|
+
prop_info = struct_props[key]
|
|
407
|
+
next false unless prop_info
|
|
408
|
+
prop_type = prop_info[:type_object] || prop_info[:type]
|
|
409
|
+
has_default = prop_info.key?(:default) || prop_info[:fully_optional]
|
|
410
|
+
!nilable_type?(prop_type) && has_default
|
|
411
|
+
end
|
|
412
|
+
|
|
201
413
|
# Create the struct instance
|
|
202
414
|
struct_class.new(**coerced_hash)
|
|
203
415
|
rescue ArgumentError => e
|
|
@@ -209,9 +421,22 @@ module DSPy
|
|
|
209
421
|
# Coerces a union value by using _type discriminator
|
|
210
422
|
sig { params(value: T.untyped, union_type: T.untyped).returns(T.untyped) }
|
|
211
423
|
def coerce_union_value(value, union_type)
|
|
424
|
+
# Anthropic tool use may return complex oneOf union fields as JSON strings
|
|
425
|
+
# instead of nested objects. Parse them back into Hashes for coercion.
|
|
426
|
+
value = try_parse_json_to_hash(value)
|
|
427
|
+
|
|
212
428
|
return value unless value.is_a?(Hash)
|
|
213
429
|
|
|
214
|
-
#
|
|
430
|
+
# Handle nilable struct unions (T.nilable(SomeStruct)) without _type discriminator
|
|
431
|
+
# LLMs don't provide _type for simple nilable structs, so we can directly coerce
|
|
432
|
+
if nilable_struct_union?(union_type)
|
|
433
|
+
struct_type = union_type.types.find { |t|
|
|
434
|
+
t != T::Utils.coerce(NilClass) && struct_type?(t)
|
|
435
|
+
}
|
|
436
|
+
return coerce_struct_value(value, struct_type) if struct_type
|
|
437
|
+
end
|
|
438
|
+
|
|
439
|
+
# Check for _type discriminator field (required for true multi-type unions)
|
|
215
440
|
type_name = value[:_type] || value["_type"]
|
|
216
441
|
return value unless type_name
|
|
217
442
|
|
|
@@ -311,14 +536,10 @@ module DSPy
|
|
|
311
536
|
sig { params(value: T.untyped, prop_type: T.untyped).returns(T.untyped) }
|
|
312
537
|
def coerce_enum_value(value, prop_type)
|
|
313
538
|
enum_class = extract_enum_class(prop_type)
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
# Otherwise, try to deserialize from string
|
|
319
|
-
enum_class.deserialize(value.to_s)
|
|
320
|
-
rescue ArgumentError, KeyError => e
|
|
321
|
-
DSPy.logger.debug("Failed to coerce to enum #{enum_class}: #{e.message}")
|
|
539
|
+
result = DSPy::Mixins::TypeCoercion.deserialize_enum(enum_class, value)
|
|
540
|
+
return result if result
|
|
541
|
+
|
|
542
|
+
DSPy.logger.debug("Failed to coerce to enum #{enum_class}: no match for #{value.inspect}")
|
|
322
543
|
value
|
|
323
544
|
end
|
|
324
545
|
end
|