dspy 0.6.3 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +6 -1
- data/lib/dspy/lm/adapter.rb +8 -2
- data/lib/dspy/lm/adapter_factory.rb +7 -2
- data/lib/dspy/lm/adapters/anthropic_adapter.rb +2 -1
- data/lib/dspy/lm/adapters/openai/schema_converter.rb +269 -0
- data/lib/dspy/lm/adapters/openai_adapter.rb +30 -5
- data/lib/dspy/lm/cache_manager.rb +151 -0
- data/lib/dspy/lm/errors.rb +13 -0
- data/lib/dspy/lm/retry_handler.rb +119 -0
- data/lib/dspy/lm/strategies/anthropic_extraction_strategy.rb +78 -0
- data/lib/dspy/lm/strategies/base_strategy.rb +53 -0
- data/lib/dspy/lm/strategies/enhanced_prompting_strategy.rb +147 -0
- data/lib/dspy/lm/strategies/openai_structured_output_strategy.rb +60 -0
- data/lib/dspy/lm/strategy_selector.rb +79 -0
- data/lib/dspy/lm.rb +56 -18
- data/lib/dspy/predict.rb +20 -0
- data/lib/dspy/signature.rb +13 -5
- data/lib/dspy/version.rb +1 -1
- data/lib/dspy.rb +13 -0
- metadata +12 -4
@@ -0,0 +1,119 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "sorbet-runtime"
|
4
|
+
|
5
|
+
module DSPy
|
6
|
+
class LM
|
7
|
+
# Handles retry logic with progressive fallback strategies
|
8
|
+
class RetryHandler
|
9
|
+
extend T::Sig
|
10
|
+
|
11
|
+
MAX_RETRIES = 3
|
12
|
+
BACKOFF_BASE = 0.5 # seconds
|
13
|
+
|
14
|
+
sig { params(adapter: DSPy::LM::Adapter, signature_class: T.class_of(DSPy::Signature)).void }
|
15
|
+
def initialize(adapter, signature_class)
|
16
|
+
@adapter = adapter
|
17
|
+
@signature_class = signature_class
|
18
|
+
@attempt = 0
|
19
|
+
end
|
20
|
+
|
21
|
+
# Execute a block with retry logic and progressive fallback
|
22
|
+
sig do
|
23
|
+
type_parameters(:T)
|
24
|
+
.params(
|
25
|
+
initial_strategy: Strategies::BaseStrategy,
|
26
|
+
block: T.proc.params(strategy: Strategies::BaseStrategy).returns(T.type_parameter(:T))
|
27
|
+
)
|
28
|
+
.returns(T.type_parameter(:T))
|
29
|
+
end
|
30
|
+
def with_retry(initial_strategy, &block)
|
31
|
+
strategies = build_fallback_chain(initial_strategy)
|
32
|
+
last_error = nil
|
33
|
+
|
34
|
+
strategies.each do |strategy|
|
35
|
+
retry_count = 0
|
36
|
+
|
37
|
+
begin
|
38
|
+
@attempt += 1
|
39
|
+
DSPy.logger.debug("Attempting with strategy: #{strategy.name} (attempt #{@attempt})")
|
40
|
+
|
41
|
+
result = yield(strategy)
|
42
|
+
|
43
|
+
# Success! Reset attempt counter for next time
|
44
|
+
@attempt = 0
|
45
|
+
return result
|
46
|
+
|
47
|
+
rescue JSON::ParserError, StandardError => e
|
48
|
+
last_error = e
|
49
|
+
|
50
|
+
# Let strategy handle the error first
|
51
|
+
if strategy.handle_error(e)
|
52
|
+
DSPy.logger.info("Strategy #{strategy.name} handled error, will try next strategy")
|
53
|
+
next # Try next strategy
|
54
|
+
end
|
55
|
+
|
56
|
+
# Try retrying with the same strategy
|
57
|
+
if retry_count < max_retries_for_strategy(strategy)
|
58
|
+
retry_count += 1
|
59
|
+
backoff_time = calculate_backoff(retry_count)
|
60
|
+
|
61
|
+
DSPy.logger.warn(
|
62
|
+
"Retrying #{strategy.name} after error (attempt #{retry_count}/#{max_retries_for_strategy(strategy)}): #{e.message}"
|
63
|
+
)
|
64
|
+
|
65
|
+
sleep(backoff_time) if backoff_time > 0
|
66
|
+
retry
|
67
|
+
else
|
68
|
+
DSPy.logger.info("Max retries reached for #{strategy.name}, trying next strategy")
|
69
|
+
next # Try next strategy
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
73
|
+
|
74
|
+
# All strategies exhausted
|
75
|
+
DSPy.logger.error("All strategies exhausted after #{@attempt} total attempts")
|
76
|
+
raise last_error || StandardError.new("All JSON extraction strategies failed")
|
77
|
+
end
|
78
|
+
|
79
|
+
private
|
80
|
+
|
81
|
+
# Build a chain of strategies to try in order
|
82
|
+
sig { params(initial_strategy: Strategies::BaseStrategy).returns(T::Array[Strategies::BaseStrategy]) }
|
83
|
+
def build_fallback_chain(initial_strategy)
|
84
|
+
selector = StrategySelector.new(@adapter, @signature_class)
|
85
|
+
all_strategies = selector.available_strategies.sort_by(&:priority).reverse
|
86
|
+
|
87
|
+
# Start with the requested strategy, then try others
|
88
|
+
chain = [initial_strategy]
|
89
|
+
chain.concat(all_strategies.reject { |s| s.name == initial_strategy.name })
|
90
|
+
|
91
|
+
chain
|
92
|
+
end
|
93
|
+
|
94
|
+
# Different strategies get different retry counts
|
95
|
+
sig { params(strategy: Strategies::BaseStrategy).returns(Integer) }
|
96
|
+
def max_retries_for_strategy(strategy)
|
97
|
+
case strategy.name
|
98
|
+
when "openai_structured_output"
|
99
|
+
1 # Structured outputs rarely benefit from retries
|
100
|
+
when "anthropic_extraction"
|
101
|
+
2 # Anthropic can be a bit more variable
|
102
|
+
else
|
103
|
+
MAX_RETRIES # Enhanced prompting might need more attempts
|
104
|
+
end
|
105
|
+
end
|
106
|
+
|
107
|
+
# Calculate exponential backoff with jitter
|
108
|
+
sig { params(attempt: Integer).returns(Float) }
|
109
|
+
def calculate_backoff(attempt)
|
110
|
+
return 0.0 if DSPy.config.test_mode # No sleep in tests
|
111
|
+
|
112
|
+
base_delay = BACKOFF_BASE * (2 ** (attempt - 1))
|
113
|
+
jitter = rand * 0.1 * base_delay
|
114
|
+
|
115
|
+
[base_delay + jitter, 10.0].min # Cap at 10 seconds
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "base_strategy"
|
4
|
+
|
5
|
+
module DSPy
|
6
|
+
class LM
|
7
|
+
module Strategies
|
8
|
+
# Strategy for using Anthropic's enhanced JSON extraction patterns
|
9
|
+
class AnthropicExtractionStrategy < BaseStrategy
|
10
|
+
extend T::Sig
|
11
|
+
|
12
|
+
sig { override.returns(T::Boolean) }
|
13
|
+
def available?
|
14
|
+
adapter.is_a?(DSPy::LM::AnthropicAdapter)
|
15
|
+
end
|
16
|
+
|
17
|
+
sig { override.returns(Integer) }
|
18
|
+
def priority
|
19
|
+
90 # High priority - Anthropic's extraction is very reliable
|
20
|
+
end
|
21
|
+
|
22
|
+
sig { override.returns(String) }
|
23
|
+
def name
|
24
|
+
"anthropic_extraction"
|
25
|
+
end
|
26
|
+
|
27
|
+
sig { override.params(messages: T::Array[T::Hash[Symbol, String]], request_params: T::Hash[Symbol, T.untyped]).void }
|
28
|
+
def prepare_request(messages, request_params)
|
29
|
+
# Anthropic adapter already handles JSON optimization in prepare_messages_for_json
|
30
|
+
# No additional preparation needed here
|
31
|
+
end
|
32
|
+
|
33
|
+
sig { override.params(response: DSPy::LM::Response).returns(T.nilable(String)) }
|
34
|
+
def extract_json(response)
|
35
|
+
# Use Anthropic's specialized extraction method if available
|
36
|
+
if adapter.respond_to?(:extract_json_from_response)
|
37
|
+
adapter.extract_json_from_response(response.content)
|
38
|
+
else
|
39
|
+
# Fallback to basic extraction
|
40
|
+
extract_json_fallback(response.content)
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
private
|
45
|
+
|
46
|
+
sig { params(content: T.nilable(String)).returns(T.nilable(String)) }
|
47
|
+
def extract_json_fallback(content)
|
48
|
+
return nil if content.nil?
|
49
|
+
|
50
|
+
# Try the 4 patterns Anthropic adapter uses
|
51
|
+
# Pattern 1: ```json blocks
|
52
|
+
if content.include?('```json')
|
53
|
+
return content.split('```json').last.split('```').first.strip
|
54
|
+
end
|
55
|
+
|
56
|
+
# Pattern 2: ## Output values header
|
57
|
+
if content.include?('## Output values')
|
58
|
+
json_part = content.split('## Output values').last
|
59
|
+
if json_part.include?('```')
|
60
|
+
return json_part.split('```')[1].strip
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
# Pattern 3: Generic code blocks
|
65
|
+
if content.include?('```')
|
66
|
+
code_block = content.split('```')[1]
|
67
|
+
if code_block && (code_block.strip.start_with?('{') || code_block.strip.start_with?('['))
|
68
|
+
return code_block.strip
|
69
|
+
end
|
70
|
+
end
|
71
|
+
|
72
|
+
# Pattern 4: Already valid JSON
|
73
|
+
content.strip if content.strip.start_with?('{') || content.strip.start_with?('[')
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "sorbet-runtime"
|
4
|
+
|
5
|
+
module DSPy
|
6
|
+
class LM
|
7
|
+
module Strategies
|
8
|
+
# Base class for JSON extraction strategies
|
9
|
+
class BaseStrategy
|
10
|
+
extend T::Sig
|
11
|
+
extend T::Helpers
|
12
|
+
abstract!
|
13
|
+
|
14
|
+
sig { params(adapter: DSPy::LM::Adapter, signature_class: T.class_of(DSPy::Signature)).void }
|
15
|
+
def initialize(adapter, signature_class)
|
16
|
+
@adapter = adapter
|
17
|
+
@signature_class = signature_class
|
18
|
+
end
|
19
|
+
|
20
|
+
# Check if this strategy is available for the given adapter/model
|
21
|
+
sig { abstract.returns(T::Boolean) }
|
22
|
+
def available?; end
|
23
|
+
|
24
|
+
# Priority for this strategy (higher = preferred)
|
25
|
+
sig { abstract.returns(Integer) }
|
26
|
+
def priority; end
|
27
|
+
|
28
|
+
# Name of the strategy for logging/debugging
|
29
|
+
sig { abstract.returns(String) }
|
30
|
+
def name; end
|
31
|
+
|
32
|
+
# Prepare the request for JSON extraction
|
33
|
+
sig { abstract.params(messages: T::Array[T::Hash[Symbol, String]], request_params: T::Hash[Symbol, T.untyped]).void }
|
34
|
+
def prepare_request(messages, request_params); end
|
35
|
+
|
36
|
+
# Extract JSON from the response
|
37
|
+
sig { abstract.params(response: DSPy::LM::Response).returns(T.nilable(String)) }
|
38
|
+
def extract_json(response); end
|
39
|
+
|
40
|
+
# Handle errors specific to this strategy
|
41
|
+
sig { params(error: StandardError).returns(T::Boolean) }
|
42
|
+
def handle_error(error)
|
43
|
+
# By default, don't handle errors - let them propagate
|
44
|
+
false
|
45
|
+
end
|
46
|
+
|
47
|
+
protected
|
48
|
+
|
49
|
+
attr_reader :adapter, :signature_class
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
@@ -0,0 +1,147 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "base_strategy"
|
4
|
+
|
5
|
+
module DSPy
|
6
|
+
class LM
|
7
|
+
module Strategies
|
8
|
+
# Enhanced prompting strategy that works with any LLM
|
9
|
+
# Adds explicit JSON formatting instructions to improve reliability
|
10
|
+
class EnhancedPromptingStrategy < BaseStrategy
|
11
|
+
extend T::Sig
|
12
|
+
|
13
|
+
sig { override.returns(T::Boolean) }
|
14
|
+
def available?
|
15
|
+
# This strategy is always available as a fallback
|
16
|
+
true
|
17
|
+
end
|
18
|
+
|
19
|
+
sig { override.returns(Integer) }
|
20
|
+
def priority
|
21
|
+
50 # Medium priority - use when native methods aren't available
|
22
|
+
end
|
23
|
+
|
24
|
+
sig { override.returns(String) }
|
25
|
+
def name
|
26
|
+
"enhanced_prompting"
|
27
|
+
end
|
28
|
+
|
29
|
+
sig { override.params(messages: T::Array[T::Hash[Symbol, String]], request_params: T::Hash[Symbol, T.untyped]).void }
|
30
|
+
def prepare_request(messages, request_params)
|
31
|
+
# Enhance the user message with explicit JSON instructions
|
32
|
+
return if messages.empty?
|
33
|
+
|
34
|
+
# Get the output schema
|
35
|
+
output_schema = signature_class.output_json_schema
|
36
|
+
|
37
|
+
# Find the last user message
|
38
|
+
last_user_idx = messages.rindex { |msg| msg[:role] == "user" }
|
39
|
+
return unless last_user_idx
|
40
|
+
|
41
|
+
# Add JSON formatting instructions
|
42
|
+
original_content = messages[last_user_idx][:content]
|
43
|
+
enhanced_content = enhance_prompt_with_json_instructions(original_content, output_schema)
|
44
|
+
messages[last_user_idx][:content] = enhanced_content
|
45
|
+
|
46
|
+
# Add system instructions if no system message exists
|
47
|
+
if messages.none? { |msg| msg[:role] == "system" }
|
48
|
+
messages.unshift({
|
49
|
+
role: "system",
|
50
|
+
content: "You are a helpful assistant that always responds with valid JSON when requested."
|
51
|
+
})
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
sig { override.params(response: DSPy::LM::Response).returns(T.nilable(String)) }
|
56
|
+
def extract_json(response)
|
57
|
+
return nil if response.content.nil?
|
58
|
+
|
59
|
+
content = response.content.strip
|
60
|
+
|
61
|
+
# Try multiple extraction patterns
|
62
|
+
# 1. Check for markdown code blocks
|
63
|
+
if content.include?('```json')
|
64
|
+
json_content = content.split('```json').last.split('```').first.strip
|
65
|
+
return json_content if valid_json?(json_content)
|
66
|
+
elsif content.include?('```')
|
67
|
+
code_block = content.split('```')[1]
|
68
|
+
if code_block
|
69
|
+
json_content = code_block.strip
|
70
|
+
return json_content if valid_json?(json_content)
|
71
|
+
end
|
72
|
+
end
|
73
|
+
|
74
|
+
# 2. Check if the entire response is JSON
|
75
|
+
return content if valid_json?(content)
|
76
|
+
|
77
|
+
# 3. Look for JSON-like structures in the content
|
78
|
+
json_match = content.match(/\{[\s\S]*\}|\[[\s\S]*\]/)
|
79
|
+
if json_match
|
80
|
+
json_content = json_match[0]
|
81
|
+
return json_content if valid_json?(json_content)
|
82
|
+
end
|
83
|
+
|
84
|
+
nil
|
85
|
+
end
|
86
|
+
|
87
|
+
private
|
88
|
+
|
89
|
+
sig { params(prompt: String, schema: T::Hash[Symbol, T.untyped]).returns(String) }
|
90
|
+
def enhance_prompt_with_json_instructions(prompt, schema)
|
91
|
+
json_example = generate_example_from_schema(schema)
|
92
|
+
|
93
|
+
<<~ENHANCED
|
94
|
+
#{prompt}
|
95
|
+
|
96
|
+
IMPORTANT: You must respond with valid JSON that matches this structure:
|
97
|
+
```json
|
98
|
+
#{JSON.pretty_generate(json_example)}
|
99
|
+
```
|
100
|
+
|
101
|
+
Required fields: #{schema[:required]&.join(', ') || 'none'}
|
102
|
+
|
103
|
+
Ensure your response:
|
104
|
+
1. Is valid JSON (properly quoted strings, no trailing commas)
|
105
|
+
2. Includes all required fields
|
106
|
+
3. Uses the correct data types for each field
|
107
|
+
4. Is wrapped in ```json``` markdown code blocks
|
108
|
+
ENHANCED
|
109
|
+
end
|
110
|
+
|
111
|
+
sig { params(schema: T::Hash[Symbol, T.untyped]).returns(T::Hash[String, T.untyped]) }
|
112
|
+
def generate_example_from_schema(schema)
|
113
|
+
return {} unless schema[:properties]
|
114
|
+
|
115
|
+
example = {}
|
116
|
+
schema[:properties].each do |field_name, field_schema|
|
117
|
+
example[field_name.to_s] = case field_schema[:type]
|
118
|
+
when "string"
|
119
|
+
field_schema[:description] || "example string"
|
120
|
+
when "integer"
|
121
|
+
42
|
122
|
+
when "number"
|
123
|
+
3.14
|
124
|
+
when "boolean"
|
125
|
+
true
|
126
|
+
when "array"
|
127
|
+
["example item"]
|
128
|
+
when "object"
|
129
|
+
{ "nested" => "object" }
|
130
|
+
else
|
131
|
+
"example value"
|
132
|
+
end
|
133
|
+
end
|
134
|
+
example
|
135
|
+
end
|
136
|
+
|
137
|
+
sig { params(content: String).returns(T::Boolean) }
|
138
|
+
def valid_json?(content)
|
139
|
+
JSON.parse(content)
|
140
|
+
true
|
141
|
+
rescue JSON::ParserError
|
142
|
+
false
|
143
|
+
end
|
144
|
+
end
|
145
|
+
end
|
146
|
+
end
|
147
|
+
end
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "base_strategy"
|
4
|
+
|
5
|
+
module DSPy
|
6
|
+
class LM
|
7
|
+
module Strategies
|
8
|
+
# Strategy for using OpenAI's native structured output feature
|
9
|
+
class OpenAIStructuredOutputStrategy < BaseStrategy
|
10
|
+
extend T::Sig
|
11
|
+
|
12
|
+
sig { override.returns(T::Boolean) }
|
13
|
+
def available?
|
14
|
+
# Check if adapter is OpenAI and supports structured outputs
|
15
|
+
return false unless adapter.is_a?(DSPy::LM::OpenAIAdapter)
|
16
|
+
return false unless adapter.instance_variable_get(:@structured_outputs_enabled)
|
17
|
+
|
18
|
+
DSPy::LM::Adapters::OpenAI::SchemaConverter.supports_structured_outputs?(adapter.model)
|
19
|
+
end
|
20
|
+
|
21
|
+
sig { override.returns(Integer) }
|
22
|
+
def priority
|
23
|
+
100 # Highest priority - native structured outputs are most reliable
|
24
|
+
end
|
25
|
+
|
26
|
+
sig { override.returns(String) }
|
27
|
+
def name
|
28
|
+
"openai_structured_output"
|
29
|
+
end
|
30
|
+
|
31
|
+
sig { override.params(messages: T::Array[T::Hash[Symbol, String]], request_params: T::Hash[Symbol, T.untyped]).void }
|
32
|
+
def prepare_request(messages, request_params)
|
33
|
+
# Add structured output format to request
|
34
|
+
response_format = DSPy::LM::Adapters::OpenAI::SchemaConverter.to_openai_format(signature_class)
|
35
|
+
request_params[:response_format] = response_format
|
36
|
+
end
|
37
|
+
|
38
|
+
sig { override.params(response: DSPy::LM::Response).returns(T.nilable(String)) }
|
39
|
+
def extract_json(response)
|
40
|
+
# With structured outputs, the response should already be valid JSON
|
41
|
+
# Just return the content as-is
|
42
|
+
response.content
|
43
|
+
end
|
44
|
+
|
45
|
+
sig { override.params(error: StandardError).returns(T::Boolean) }
|
46
|
+
def handle_error(error)
|
47
|
+
# Handle OpenAI-specific structured output errors
|
48
|
+
if error.message.include?("response_format") || error.message.include?("Invalid schema")
|
49
|
+
# Log the error and return true to indicate we handled it
|
50
|
+
# This allows fallback to another strategy
|
51
|
+
DSPy.logger.warn("OpenAI structured output failed: #{error.message}")
|
52
|
+
true
|
53
|
+
else
|
54
|
+
false
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "sorbet-runtime"
|
4
|
+
require_relative "strategies/base_strategy"
|
5
|
+
require_relative "strategies/openai_structured_output_strategy"
|
6
|
+
require_relative "strategies/anthropic_extraction_strategy"
|
7
|
+
require_relative "strategies/enhanced_prompting_strategy"
|
8
|
+
|
9
|
+
module DSPy
|
10
|
+
class LM
|
11
|
+
# Selects the best JSON extraction strategy based on the adapter and capabilities
|
12
|
+
class StrategySelector
|
13
|
+
extend T::Sig
|
14
|
+
|
15
|
+
# Available strategies in order of registration
|
16
|
+
STRATEGIES = [
|
17
|
+
Strategies::OpenAIStructuredOutputStrategy,
|
18
|
+
Strategies::AnthropicExtractionStrategy,
|
19
|
+
Strategies::EnhancedPromptingStrategy
|
20
|
+
].freeze
|
21
|
+
|
22
|
+
sig { params(adapter: DSPy::LM::Adapter, signature_class: T.class_of(DSPy::Signature)).void }
|
23
|
+
def initialize(adapter, signature_class)
|
24
|
+
@adapter = adapter
|
25
|
+
@signature_class = signature_class
|
26
|
+
@strategies = build_strategies
|
27
|
+
end
|
28
|
+
|
29
|
+
# Select the best available strategy
|
30
|
+
sig { returns(Strategies::BaseStrategy) }
|
31
|
+
def select
|
32
|
+
# Allow manual override via configuration
|
33
|
+
if DSPy.config.structured_outputs.strategy
|
34
|
+
strategy = find_strategy_by_name(DSPy.config.structured_outputs.strategy)
|
35
|
+
return strategy if strategy&.available?
|
36
|
+
|
37
|
+
DSPy.logger.warn("Requested strategy '#{DSPy.config.structured_outputs.strategy}' is not available")
|
38
|
+
end
|
39
|
+
|
40
|
+
# Select the highest priority available strategy
|
41
|
+
available_strategies = @strategies.select(&:available?)
|
42
|
+
|
43
|
+
if available_strategies.empty?
|
44
|
+
raise "No JSON extraction strategies available for #{@adapter.class}"
|
45
|
+
end
|
46
|
+
|
47
|
+
selected = available_strategies.max_by(&:priority)
|
48
|
+
|
49
|
+
DSPy.logger.debug("Selected JSON extraction strategy: #{selected.name}")
|
50
|
+
selected
|
51
|
+
end
|
52
|
+
|
53
|
+
# Get all available strategies
|
54
|
+
sig { returns(T::Array[Strategies::BaseStrategy]) }
|
55
|
+
def available_strategies
|
56
|
+
@strategies.select(&:available?)
|
57
|
+
end
|
58
|
+
|
59
|
+
# Check if a specific strategy is available
|
60
|
+
sig { params(strategy_name: String).returns(T::Boolean) }
|
61
|
+
def strategy_available?(strategy_name)
|
62
|
+
strategy = find_strategy_by_name(strategy_name)
|
63
|
+
strategy&.available? || false
|
64
|
+
end
|
65
|
+
|
66
|
+
private
|
67
|
+
|
68
|
+
sig { returns(T::Array[Strategies::BaseStrategy]) }
|
69
|
+
def build_strategies
|
70
|
+
STRATEGIES.map { |klass| klass.new(@adapter, @signature_class) }
|
71
|
+
end
|
72
|
+
|
73
|
+
sig { params(name: String).returns(T.nilable(Strategies::BaseStrategy)) }
|
74
|
+
def find_strategy_by_name(name)
|
75
|
+
@strategies.find { |s| s.name == name }
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
79
|
+
end
|
data/lib/dspy/lm.rb
CHANGED
@@ -14,19 +14,23 @@ require_relative 'instrumentation/token_tracker'
|
|
14
14
|
require_relative 'lm/adapters/openai_adapter'
|
15
15
|
require_relative 'lm/adapters/anthropic_adapter'
|
16
16
|
|
17
|
+
# Load strategy system
|
18
|
+
require_relative 'lm/strategy_selector'
|
19
|
+
require_relative 'lm/retry_handler'
|
20
|
+
|
17
21
|
module DSPy
|
18
22
|
class LM
|
19
23
|
attr_reader :model_id, :api_key, :model, :provider, :adapter
|
20
24
|
|
21
|
-
def initialize(model_id, api_key: nil)
|
25
|
+
def initialize(model_id, api_key: nil, **options)
|
22
26
|
@model_id = model_id
|
23
27
|
@api_key = api_key
|
24
28
|
|
25
29
|
# Parse provider and model from model_id
|
26
30
|
@provider, @model = parse_model_id(model_id)
|
27
31
|
|
28
|
-
# Create appropriate adapter
|
29
|
-
@adapter = AdapterFactory.create(model_id, api_key: api_key)
|
32
|
+
# Create appropriate adapter with options
|
33
|
+
@adapter = AdapterFactory.create(model_id, api_key: api_key, **options)
|
30
34
|
end
|
31
35
|
|
32
36
|
def chat(inference_module, input_values, &block)
|
@@ -54,7 +58,7 @@ module DSPy
|
|
54
58
|
adapter_class: adapter.class.name,
|
55
59
|
input_size: input_size
|
56
60
|
}) do
|
57
|
-
|
61
|
+
chat_with_strategy(messages, signature_class, &block)
|
58
62
|
end
|
59
63
|
|
60
64
|
# Extract actual token usage from response (more accurate than estimation)
|
@@ -79,7 +83,7 @@ module DSPy
|
|
79
83
|
end
|
80
84
|
else
|
81
85
|
# Consolidated mode: execute without nested instrumentation
|
82
|
-
response =
|
86
|
+
response = chat_with_strategy(messages, signature_class, &block)
|
83
87
|
token_usage = Instrumentation::TokenTracker.extract_token_usage(response, provider)
|
84
88
|
parsed_result = parse_response(response, input_values, signature_class)
|
85
89
|
end
|
@@ -89,6 +93,53 @@ module DSPy
|
|
89
93
|
|
90
94
|
private
|
91
95
|
|
96
|
+
def chat_with_strategy(messages, signature_class, &block)
|
97
|
+
# Select the best strategy for JSON extraction
|
98
|
+
strategy_selector = StrategySelector.new(adapter, signature_class)
|
99
|
+
initial_strategy = strategy_selector.select
|
100
|
+
|
101
|
+
if DSPy.config.structured_outputs.retry_enabled && signature_class
|
102
|
+
# Use retry handler for JSON responses
|
103
|
+
retry_handler = RetryHandler.new(adapter, signature_class)
|
104
|
+
|
105
|
+
retry_handler.with_retry(initial_strategy) do |strategy|
|
106
|
+
execute_chat_with_strategy(messages, signature_class, strategy, &block)
|
107
|
+
end
|
108
|
+
else
|
109
|
+
# No retry logic, just execute once
|
110
|
+
execute_chat_with_strategy(messages, signature_class, initial_strategy, &block)
|
111
|
+
end
|
112
|
+
end
|
113
|
+
|
114
|
+
def execute_chat_with_strategy(messages, signature_class, strategy, &block)
|
115
|
+
# Prepare request with strategy-specific modifications
|
116
|
+
request_params = {}
|
117
|
+
strategy.prepare_request(messages.dup, request_params)
|
118
|
+
|
119
|
+
# Make the request
|
120
|
+
response = if request_params.any?
|
121
|
+
# Pass additional parameters if strategy added them
|
122
|
+
adapter.chat(messages: messages, signature: signature_class, **request_params, &block)
|
123
|
+
else
|
124
|
+
adapter.chat(messages: messages, signature: signature_class, &block)
|
125
|
+
end
|
126
|
+
|
127
|
+
# Let strategy handle JSON extraction if needed
|
128
|
+
if signature_class && response.content
|
129
|
+
extracted_json = strategy.extract_json(response)
|
130
|
+
if extracted_json && extracted_json != response.content
|
131
|
+
# Create a new response with extracted JSON
|
132
|
+
response = Response.new(
|
133
|
+
content: extracted_json,
|
134
|
+
usage: response.usage,
|
135
|
+
metadata: response.metadata
|
136
|
+
)
|
137
|
+
end
|
138
|
+
end
|
139
|
+
|
140
|
+
response
|
141
|
+
end
|
142
|
+
|
92
143
|
# Determines if LM-level events should be emitted using smart consolidation
|
93
144
|
def should_emit_lm_events?
|
94
145
|
# Emit LM events only if we're not in a nested context (smart consolidation)
|
@@ -139,18 +190,6 @@ module DSPy
|
|
139
190
|
# Try to parse the response as JSON
|
140
191
|
content = response.content
|
141
192
|
|
142
|
-
# Let adapters handle their own extraction logic if available
|
143
|
-
if adapter && adapter.respond_to?(:extract_json_from_response, true)
|
144
|
-
content = adapter.send(:extract_json_from_response, content)
|
145
|
-
else
|
146
|
-
# Fallback: Extract JSON if it's in a code block (legacy behavior)
|
147
|
-
if content.include?('```json')
|
148
|
-
content = content.split('```json').last.split('```').first.strip
|
149
|
-
elsif content.include?('```')
|
150
|
-
content = content.split('```').last.split('```').first.strip
|
151
|
-
end
|
152
|
-
end
|
153
|
-
|
154
193
|
begin
|
155
194
|
json_payload = JSON.parse(content)
|
156
195
|
|
@@ -161,7 +200,6 @@ module DSPy
|
|
161
200
|
# Enhanced error message with debugging information
|
162
201
|
error_details = {
|
163
202
|
original_content: response.content,
|
164
|
-
extracted_content: content,
|
165
203
|
provider: provider,
|
166
204
|
model: model
|
167
205
|
}
|
data/lib/dspy/predict.rb
CHANGED
@@ -115,6 +115,9 @@ module DSPy
|
|
115
115
|
output_attributes = output_attributes.transform_keys(&:to_sym)
|
116
116
|
output_props = @signature_class.output_struct_class.props
|
117
117
|
|
118
|
+
# Apply defaults for missing fields
|
119
|
+
output_attributes = apply_defaults_to_output(output_attributes)
|
120
|
+
|
118
121
|
coerce_output_attributes(output_attributes, output_props)
|
119
122
|
end
|
120
123
|
|
@@ -143,5 +146,22 @@ module DSPy
|
|
143
146
|
output: output_props
|
144
147
|
})
|
145
148
|
end
|
149
|
+
|
150
|
+
# Applies default values to missing output fields
|
151
|
+
sig { params(output_attributes: T::Hash[Symbol, T.untyped]).returns(T::Hash[Symbol, T.untyped]) }
|
152
|
+
def apply_defaults_to_output(output_attributes)
|
153
|
+
return output_attributes unless @signature_class.respond_to?(:output_field_descriptors)
|
154
|
+
|
155
|
+
field_descriptors = @signature_class.output_field_descriptors
|
156
|
+
|
157
|
+
field_descriptors.each do |field_name, descriptor|
|
158
|
+
# Only apply default if field is missing and has a default
|
159
|
+
if !output_attributes.key?(field_name) && descriptor.has_default
|
160
|
+
output_attributes[field_name] = descriptor.default_value
|
161
|
+
end
|
162
|
+
end
|
163
|
+
|
164
|
+
output_attributes
|
165
|
+
end
|
146
166
|
end
|
147
167
|
end
|