ruby_llm 1.10.0 → 1.12.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +14 -2
- data/lib/ruby_llm/active_record/acts_as_legacy.rb +41 -7
- data/lib/ruby_llm/active_record/chat_methods.rb +41 -7
- data/lib/ruby_llm/agent.rb +323 -0
- data/lib/ruby_llm/aliases.json +50 -32
- data/lib/ruby_llm/chat.rb +27 -3
- data/lib/ruby_llm/configuration.rb +4 -0
- data/lib/ruby_llm/models.json +19806 -5991
- data/lib/ruby_llm/models.rb +35 -6
- data/lib/ruby_llm/provider.rb +13 -1
- data/lib/ruby_llm/providers/anthropic/media.rb +2 -2
- data/lib/ruby_llm/providers/azure/chat.rb +29 -0
- data/lib/ruby_llm/providers/azure/embeddings.rb +24 -0
- data/lib/ruby_llm/providers/azure/media.rb +45 -0
- data/lib/ruby_llm/providers/azure/models.rb +14 -0
- data/lib/ruby_llm/providers/azure.rb +56 -0
- data/lib/ruby_llm/providers/bedrock/auth.rb +122 -0
- data/lib/ruby_llm/providers/bedrock/chat.rb +297 -56
- data/lib/ruby_llm/providers/bedrock/media.rb +62 -33
- data/lib/ruby_llm/providers/bedrock/models.rb +88 -65
- data/lib/ruby_llm/providers/bedrock/streaming.rb +305 -8
- data/lib/ruby_llm/providers/bedrock.rb +61 -52
- data/lib/ruby_llm/providers/openai/media.rb +1 -1
- data/lib/ruby_llm/providers/xai/chat.rb +15 -0
- data/lib/ruby_llm/providers/xai/models.rb +75 -0
- data/lib/ruby_llm/providers/xai.rb +28 -0
- data/lib/ruby_llm/version.rb +1 -1
- data/lib/ruby_llm.rb +14 -8
- data/lib/tasks/models.rake +10 -4
- data/lib/tasks/vcr.rake +32 -0
- metadata +16 -13
- data/lib/ruby_llm/providers/bedrock/capabilities.rb +0 -167
- data/lib/ruby_llm/providers/bedrock/signing.rb +0 -831
- data/lib/ruby_llm/providers/bedrock/streaming/base.rb +0 -51
- data/lib/ruby_llm/providers/bedrock/streaming/content_extraction.rb +0 -128
- data/lib/ruby_llm/providers/bedrock/streaming/message_processing.rb +0 -67
- data/lib/ruby_llm/providers/bedrock/streaming/payload_processing.rb +0 -85
- data/lib/ruby_llm/providers/bedrock/streaming/prelude_handling.rb +0 -78
|
@@ -3,100 +3,123 @@
|
|
|
3
3
|
module RubyLLM
|
|
4
4
|
module Providers
|
|
5
5
|
class Bedrock
|
|
6
|
-
# Models methods for
|
|
6
|
+
# Models methods for AWS Bedrock.
|
|
7
7
|
module Models
|
|
8
|
-
|
|
9
|
-
mgmt_api_base = "https://bedrock.#{@config.bedrock_region}.amazonaws.com"
|
|
10
|
-
full_models_url = "#{mgmt_api_base}/#{models_url}"
|
|
11
|
-
signature = sign_request(full_models_url, method: :get)
|
|
12
|
-
response = @connection.get(full_models_url) do |req|
|
|
13
|
-
req.headers.merge! signature.headers
|
|
14
|
-
end
|
|
8
|
+
module_function
|
|
15
9
|
|
|
16
|
-
|
|
17
|
-
end
|
|
10
|
+
REGION_PREFIXES = %w[us eu ap sa ca me af il].freeze
|
|
18
11
|
|
|
19
|
-
|
|
12
|
+
def models_api_base
|
|
13
|
+
"https://bedrock.#{bedrock_region}.amazonaws.com"
|
|
14
|
+
end
|
|
20
15
|
|
|
21
16
|
def models_url
|
|
22
|
-
'foundation-models'
|
|
17
|
+
'/foundation-models'
|
|
23
18
|
end
|
|
24
19
|
|
|
25
|
-
def parse_list_models_response(response, slug,
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
models.select { |m| m['modelId'].include?('claude') }.map do |model_data|
|
|
29
|
-
model_id = model_data['modelId']
|
|
30
|
-
|
|
31
|
-
Model::Info.new(
|
|
32
|
-
id: model_id_with_region(model_id, model_data),
|
|
33
|
-
name: model_data['modelName'] || capabilities.format_display_name(model_id),
|
|
34
|
-
provider: slug,
|
|
35
|
-
family: capabilities.model_family(model_id),
|
|
36
|
-
created_at: nil,
|
|
37
|
-
context_window: capabilities.context_window_for(model_id),
|
|
38
|
-
max_output_tokens: capabilities.max_tokens_for(model_id),
|
|
39
|
-
modalities: capabilities.modalities_for(model_id),
|
|
40
|
-
capabilities: capabilities.capabilities_for(model_id),
|
|
41
|
-
pricing: capabilities.pricing_for(model_id),
|
|
42
|
-
metadata: {
|
|
43
|
-
provider_name: model_data['providerName'],
|
|
44
|
-
inference_types: model_data['inferenceTypesSupported'] || [],
|
|
45
|
-
streaming_supported: model_data['responseStreamingSupported'] || false,
|
|
46
|
-
input_modalities: model_data['inputModalities'] || [],
|
|
47
|
-
output_modalities: model_data['outputModalities'] || []
|
|
48
|
-
}
|
|
49
|
-
)
|
|
20
|
+
def parse_list_models_response(response, slug, _capabilities)
|
|
21
|
+
Array(response.body['modelSummaries']).map do |model_data|
|
|
22
|
+
create_model_info(model_data, slug)
|
|
50
23
|
end
|
|
51
24
|
end
|
|
52
25
|
|
|
53
|
-
def create_model_info(model_data, slug, _capabilities)
|
|
54
|
-
model_id = model_data['modelId']
|
|
26
|
+
def create_model_info(model_data, slug, _capabilities = nil)
|
|
27
|
+
model_id = model_id_with_region(model_data['modelId'], model_data)
|
|
28
|
+
converse_data = model_data['converse'] || {}
|
|
55
29
|
|
|
56
30
|
Model::Info.new(
|
|
57
|
-
id:
|
|
58
|
-
name: model_data['modelName']
|
|
31
|
+
id: model_id,
|
|
32
|
+
name: model_data['modelName'],
|
|
59
33
|
provider: slug,
|
|
60
|
-
family: '
|
|
34
|
+
family: model_data['modelFamily'] || model_data['providerName']&.downcase,
|
|
61
35
|
created_at: nil,
|
|
62
|
-
context_window:
|
|
63
|
-
max_output_tokens:
|
|
64
|
-
modalities: {
|
|
65
|
-
|
|
36
|
+
context_window: parse_context_window(model_data),
|
|
37
|
+
max_output_tokens: converse_data['maxTokensDefault'] || converse_data['maxTokensMaximum'],
|
|
38
|
+
modalities: {
|
|
39
|
+
input: normalize_modalities(model_data['inputModalities']),
|
|
40
|
+
output: normalize_modalities(model_data['outputModalities'])
|
|
41
|
+
},
|
|
42
|
+
capabilities: parse_capabilities(model_data),
|
|
66
43
|
pricing: {},
|
|
67
|
-
metadata: {
|
|
44
|
+
metadata: {
|
|
45
|
+
provider_name: model_data['providerName'],
|
|
46
|
+
model_arn: model_data['modelArn'],
|
|
47
|
+
inference_types: model_data['inferenceTypesSupported'],
|
|
48
|
+
converse: converse_data
|
|
49
|
+
}
|
|
68
50
|
)
|
|
69
51
|
end
|
|
70
52
|
|
|
71
53
|
def model_id_with_region(model_id, model_data)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
model_data['inferenceTypesSupported'],
|
|
75
|
-
@config.bedrock_region
|
|
76
|
-
)
|
|
54
|
+
inference_types = Array(model_data['inferenceTypesSupported'])
|
|
55
|
+
normalize_inference_profile_id(model_id, inference_types, @config.bedrock_region)
|
|
77
56
|
end
|
|
78
57
|
|
|
79
|
-
def
|
|
80
|
-
|
|
81
|
-
return
|
|
58
|
+
def normalize_inference_profile_id(model_id, inference_types, region)
|
|
59
|
+
return model_id unless inference_types.include?('INFERENCE_PROFILE')
|
|
60
|
+
return model_id if inference_types.include?('ON_DEMAND')
|
|
82
61
|
|
|
83
|
-
|
|
62
|
+
with_region_prefix(model_id, region)
|
|
84
63
|
end
|
|
85
64
|
|
|
86
65
|
def with_region_prefix(model_id, region)
|
|
87
|
-
|
|
88
|
-
return model_id if model_id.start_with?("#{desired_prefix}.")
|
|
66
|
+
prefix = region_prefix(region)
|
|
89
67
|
|
|
90
|
-
|
|
91
|
-
|
|
68
|
+
if region_prefixed?(model_id)
|
|
69
|
+
model_id.sub(/\A(?:#{REGION_PREFIXES.join('|')})\./, "#{prefix}.")
|
|
70
|
+
else
|
|
71
|
+
"#{prefix}.#{model_id}"
|
|
72
|
+
end
|
|
92
73
|
end
|
|
93
74
|
|
|
94
|
-
def
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
75
|
+
def region_prefix(region)
|
|
76
|
+
prefix = region.to_s.split('-').first
|
|
77
|
+
prefix = '' if prefix.nil?
|
|
78
|
+
prefix.empty? ? 'us' : prefix
|
|
79
|
+
end
|
|
98
80
|
|
|
99
|
-
|
|
81
|
+
def region_prefixed?(model_id)
|
|
82
|
+
model_id.match?(/\A(?:#{REGION_PREFIXES.join('|')})\./)
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
def normalize_modalities(modalities)
|
|
86
|
+
Array(modalities).map do |modality|
|
|
87
|
+
normalized = modality.to_s.downcase
|
|
88
|
+
case normalized
|
|
89
|
+
when 'embedding' then 'embeddings'
|
|
90
|
+
when 'speech' then 'audio'
|
|
91
|
+
else normalized
|
|
92
|
+
end
|
|
93
|
+
end
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
def parse_capabilities(model_data)
|
|
97
|
+
capabilities = []
|
|
98
|
+
capabilities << 'streaming' if model_data['responseStreamingSupported']
|
|
99
|
+
|
|
100
|
+
converse = model_data['converse'] || {}
|
|
101
|
+
capabilities << 'function_calling' if converse.is_a?(Hash)
|
|
102
|
+
capabilities << 'reasoning' if converse.dig('reasoningSupported', 'embedded')
|
|
103
|
+
|
|
104
|
+
capabilities
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
def reasoning_embedded?(model)
|
|
108
|
+
metadata = RubyLLM::Utils.deep_symbolize_keys(model.metadata || {})
|
|
109
|
+
converse = metadata[:converse] || {}
|
|
110
|
+
reasoning_supported = converse[:reasoningSupported] || {}
|
|
111
|
+
reasoning_supported[:embedded] || false
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
def parse_context_window(model_data)
|
|
115
|
+
value = model_data.dig('description', 'maxContextWindow')
|
|
116
|
+
return unless value.is_a?(String)
|
|
117
|
+
|
|
118
|
+
if value.match?(/\A\d+[kK]\z/)
|
|
119
|
+
value.to_i * 1000
|
|
120
|
+
elsif value.match?(/\A\d+\z/)
|
|
121
|
+
value.to_i
|
|
122
|
+
end
|
|
100
123
|
end
|
|
101
124
|
end
|
|
102
125
|
end
|
|
@@ -1,17 +1,314 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
|
-
require_relative 'streaming/base'
|
|
4
|
-
require_relative 'streaming/content_extraction'
|
|
5
|
-
require_relative 'streaming/message_processing'
|
|
6
|
-
require_relative 'streaming/payload_processing'
|
|
7
|
-
require_relative 'streaming/prelude_handling'
|
|
8
|
-
|
|
9
3
|
module RubyLLM
|
|
10
4
|
module Providers
|
|
11
5
|
class Bedrock
|
|
12
|
-
# Streaming implementation for
|
|
6
|
+
# Streaming implementation for Bedrock ConverseStream (AWS Event Stream).
|
|
13
7
|
module Streaming
|
|
14
|
-
|
|
8
|
+
private
|
|
9
|
+
|
|
10
|
+
def stream_url
|
|
11
|
+
"/model/#{@model.id}/converse-stream"
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def stream_response(connection, payload, additional_headers = {}, &block)
|
|
15
|
+
accumulator = StreamAccumulator.new
|
|
16
|
+
decoder = event_stream_decoder
|
|
17
|
+
request_payload = api_payload(payload)
|
|
18
|
+
body = JSON.generate(request_payload)
|
|
19
|
+
|
|
20
|
+
response = connection.post(stream_url, request_payload) do |req|
|
|
21
|
+
req.headers.merge!(sign_headers('POST', stream_url, body))
|
|
22
|
+
req.headers.merge!(additional_headers) unless additional_headers.empty?
|
|
23
|
+
req.headers['Accept'] = 'application/vnd.amazon.eventstream'
|
|
24
|
+
|
|
25
|
+
if Faraday::VERSION.start_with?('1')
|
|
26
|
+
req.options[:on_data] = proc do |chunk, _size|
|
|
27
|
+
parse_stream_chunk(decoder, chunk, accumulator, &block)
|
|
28
|
+
end
|
|
29
|
+
else
|
|
30
|
+
req.options.on_data = proc do |chunk, _bytes, env|
|
|
31
|
+
if env&.status == 200
|
|
32
|
+
parse_stream_chunk(decoder, chunk, accumulator, &block)
|
|
33
|
+
else
|
|
34
|
+
handle_failed_stream(chunk, env)
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
message = accumulator.to_message(response)
|
|
41
|
+
RubyLLM.logger.debug "Stream completed: #{message.content}"
|
|
42
|
+
message
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def event_stream_decoder
|
|
46
|
+
require 'aws-eventstream'
|
|
47
|
+
Aws::EventStream::Decoder.new
|
|
48
|
+
rescue LoadError
|
|
49
|
+
raise Error,
|
|
50
|
+
'The aws-eventstream gem is required for Bedrock streaming. ' \
|
|
51
|
+
'Please add it to your Gemfile: gem "aws-eventstream"'
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def handle_failed_stream(chunk, env)
|
|
55
|
+
data = JSON.parse(chunk)
|
|
56
|
+
error_response = env.merge(body: data)
|
|
57
|
+
ErrorMiddleware.parse_error(provider: self, response: error_response)
|
|
58
|
+
rescue JSON::ParserError
|
|
59
|
+
RubyLLM.logger.debug "Failed Bedrock stream error chunk: #{chunk}"
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
def parse_stream_chunk(decoder, raw_chunk, accumulator)
|
|
63
|
+
handle_non_eventstream_error_chunk(raw_chunk)
|
|
64
|
+
|
|
65
|
+
decode_events(decoder, raw_chunk).each do |event|
|
|
66
|
+
chunk = build_chunk(event)
|
|
67
|
+
next unless chunk
|
|
68
|
+
|
|
69
|
+
accumulator.add(chunk)
|
|
70
|
+
yield chunk
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
def handle_non_eventstream_error_chunk(raw_chunk)
|
|
75
|
+
text = raw_chunk.to_s
|
|
76
|
+
|
|
77
|
+
if text.start_with?('event: error')
|
|
78
|
+
payload = text.lines.find { |line| line.start_with?('data:') }&.delete_prefix('data:')&.strip
|
|
79
|
+
raise_streaming_chunk_error(payload) if payload
|
|
80
|
+
return
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
return unless text.lstrip.start_with?('{') && text.include?('"error"')
|
|
84
|
+
|
|
85
|
+
raise_streaming_chunk_error(text)
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
def raise_streaming_chunk_error(payload)
|
|
89
|
+
parsed = JSON.parse(payload)
|
|
90
|
+
message = parsed.dig('error', 'message') || parsed['message'] || 'Bedrock streaming error'
|
|
91
|
+
response = Struct.new(:body, :status).new({ 'message' => message }, 500)
|
|
92
|
+
ErrorMiddleware.parse_error(provider: self, response: response)
|
|
93
|
+
rescue JSON::ParserError
|
|
94
|
+
nil
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
def decode_events(decoder, raw_chunk)
|
|
98
|
+
events = []
|
|
99
|
+
message, eof = decoder.decode_chunk(raw_chunk)
|
|
100
|
+
|
|
101
|
+
while message
|
|
102
|
+
event = decode_event_payload(message.payload.read)
|
|
103
|
+
RubyLLM.logger.debug("Bedrock stream event keys: #{event.keys}") if event && RubyLLM.config.log_stream_debug
|
|
104
|
+
events << event if event
|
|
105
|
+
break if eof
|
|
106
|
+
|
|
107
|
+
message, eof = decoder.decode_chunk
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
events
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
def decode_event_payload(payload)
|
|
114
|
+
outer = JSON.parse(payload)
|
|
115
|
+
|
|
116
|
+
if outer['bytes'].is_a?(String)
|
|
117
|
+
JSON.parse(Base64.decode64(outer['bytes']))
|
|
118
|
+
else
|
|
119
|
+
outer
|
|
120
|
+
end
|
|
121
|
+
rescue JSON::ParserError => e
|
|
122
|
+
RubyLLM.logger.debug "Failed to decode Bedrock stream event payload: #{e.message}"
|
|
123
|
+
nil
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
def build_chunk(event)
|
|
127
|
+
raise_stream_error(event) if stream_error_event?(event)
|
|
128
|
+
|
|
129
|
+
metadata_usage, usage, message_usage = event_usage(event)
|
|
130
|
+
|
|
131
|
+
Chunk.new(
|
|
132
|
+
role: :assistant,
|
|
133
|
+
model_id: event['modelId'] || event.dig('message', 'model') || @model&.id,
|
|
134
|
+
content: extract_content_delta(event),
|
|
135
|
+
thinking: Thinking.build(
|
|
136
|
+
text: extract_thinking_delta(event),
|
|
137
|
+
signature: extract_thinking_signature(event)
|
|
138
|
+
),
|
|
139
|
+
tool_calls: extract_tool_calls(event),
|
|
140
|
+
input_tokens: extract_input_tokens(metadata_usage, usage, message_usage),
|
|
141
|
+
output_tokens: extract_output_tokens(metadata_usage, usage),
|
|
142
|
+
cached_tokens: extract_cached_tokens(metadata_usage, usage),
|
|
143
|
+
cache_creation_tokens: extract_cache_creation_tokens(metadata_usage, usage),
|
|
144
|
+
thinking_tokens: extract_reasoning_tokens(metadata_usage, usage)
|
|
145
|
+
)
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
def event_usage(event)
|
|
149
|
+
[
|
|
150
|
+
event.dig('metadata', 'usage') || {},
|
|
151
|
+
event['usage'] || {},
|
|
152
|
+
event.dig('message', 'usage') || {}
|
|
153
|
+
]
|
|
154
|
+
end
|
|
155
|
+
|
|
156
|
+
def extract_input_tokens(metadata_usage, usage, message_usage)
|
|
157
|
+
metadata_usage['inputTokens'] || usage['inputTokens'] || message_usage['input_tokens']
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
def extract_output_tokens(metadata_usage, usage)
|
|
161
|
+
metadata_usage['outputTokens'] || usage['outputTokens'] || usage['output_tokens']
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
def extract_cached_tokens(metadata_usage, usage)
|
|
165
|
+
metadata_usage['cacheReadInputTokens'] || usage['cacheReadInputTokens'] || usage['cache_read_input_tokens']
|
|
166
|
+
end
|
|
167
|
+
|
|
168
|
+
def extract_cache_creation_tokens(metadata_usage, usage)
|
|
169
|
+
metadata_usage['cacheWriteInputTokens'] || usage['cacheWriteInputTokens'] ||
|
|
170
|
+
usage['cache_creation_input_tokens']
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
def extract_reasoning_tokens(metadata_usage, usage)
|
|
174
|
+
metadata_usage['reasoningTokens'] || usage['reasoningTokens'] ||
|
|
175
|
+
usage.dig('output_tokens_details', 'thinking_tokens')
|
|
176
|
+
end
|
|
177
|
+
|
|
178
|
+
def stream_error_event?(event)
|
|
179
|
+
event.keys.any? { |key| key.end_with?('Exception') } || event['type'] == 'error'
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
def raise_stream_error(event)
|
|
183
|
+
if event['type'] == 'error'
|
|
184
|
+
message = event.dig('error', 'message') || 'Bedrock streaming error'
|
|
185
|
+
response = Struct.new(:body, :status).new({ 'message' => message }, 500)
|
|
186
|
+
ErrorMiddleware.parse_error(provider: self, response: response)
|
|
187
|
+
return
|
|
188
|
+
end
|
|
189
|
+
|
|
190
|
+
key = event.keys.find { |candidate| candidate.end_with?('Exception') }
|
|
191
|
+
payload = event[key]
|
|
192
|
+
message = payload['message'] || key
|
|
193
|
+
status = case key
|
|
194
|
+
when 'throttlingException' then 429
|
|
195
|
+
when 'validationException' then 400
|
|
196
|
+
when 'accessDeniedException', 'unrecognizedClientException' then 401
|
|
197
|
+
when 'serviceUnavailableException' then 503
|
|
198
|
+
else 500
|
|
199
|
+
end
|
|
200
|
+
|
|
201
|
+
response = Struct.new(:body, :status).new({ 'message' => message }, status)
|
|
202
|
+
ErrorMiddleware.parse_error(provider: self, response: response)
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
def extract_content_delta(event)
|
|
206
|
+
delta = normalized_delta(event)
|
|
207
|
+
return delta['text'] if delta['text']
|
|
208
|
+
|
|
209
|
+
return event.dig('delta', 'text') if event.dig('delta', 'type') == 'text_delta'
|
|
210
|
+
|
|
211
|
+
nil
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
def extract_thinking_delta(event)
|
|
215
|
+
delta = normalized_delta(event)
|
|
216
|
+
reasoning_content = delta['reasoningContent'] || {}
|
|
217
|
+
|
|
218
|
+
reasoning_text = reasoning_content['reasoningText'] || {}
|
|
219
|
+
return reasoning_text['text'] if reasoning_text['text']
|
|
220
|
+
return event.dig('delta', 'thinking') if event.dig('delta', 'type') == 'thinking_delta'
|
|
221
|
+
|
|
222
|
+
nil
|
|
223
|
+
end
|
|
224
|
+
|
|
225
|
+
def extract_thinking_signature(event)
|
|
226
|
+
signature = extract_signature_from_delta(event)
|
|
227
|
+
return signature if signature
|
|
228
|
+
|
|
229
|
+
signature = extract_signature_from_start(event)
|
|
230
|
+
return signature if signature
|
|
231
|
+
|
|
232
|
+
nil
|
|
233
|
+
end
|
|
234
|
+
|
|
235
|
+
def extract_signature_from_delta(event)
|
|
236
|
+
delta = normalized_delta(event)
|
|
237
|
+
reasoning_content = delta['reasoningContent'] || {}
|
|
238
|
+
reasoning_text = reasoning_content['reasoningText'] || {}
|
|
239
|
+
return reasoning_text['signature'] if reasoning_text['signature']
|
|
240
|
+
return event.dig('delta', 'signature') if event.dig('delta', 'type') == 'signature_delta'
|
|
241
|
+
|
|
242
|
+
nil
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
def extract_signature_from_start(event)
|
|
246
|
+
start = event.dig('contentBlockStart', 'start', 'reasoningContent')
|
|
247
|
+
return nil unless start
|
|
248
|
+
|
|
249
|
+
reasoning_text = start['reasoningText'] || {}
|
|
250
|
+
return reasoning_text['signature'] if reasoning_text['signature']
|
|
251
|
+
return start['redactedContent'] if start['redactedContent']
|
|
252
|
+
|
|
253
|
+
nil
|
|
254
|
+
end
|
|
255
|
+
|
|
256
|
+
def extract_tool_calls(event)
|
|
257
|
+
return extract_tool_call_start(event) if tool_call_start_event?(event)
|
|
258
|
+
return extract_tool_call_delta(event) if tool_call_delta_event?(event)
|
|
259
|
+
|
|
260
|
+
nil
|
|
261
|
+
end
|
|
262
|
+
|
|
263
|
+
def tool_call_start_event?(event)
|
|
264
|
+
event['contentBlockStart'] || event['start'] || event.dig('content_block', 'tool_use')
|
|
265
|
+
end
|
|
266
|
+
|
|
267
|
+
def tool_call_delta_event?(event)
|
|
268
|
+
event['contentBlockDelta'] || event.dig('delta', 'toolUse') || event.dig('delta', 'tool_use') ||
|
|
269
|
+
event.dig('delta', 'partial_json')
|
|
270
|
+
end
|
|
271
|
+
|
|
272
|
+
def extract_tool_call_start(event)
|
|
273
|
+
tool_use = event.dig('contentBlockStart', 'start', 'toolUse')
|
|
274
|
+
tool_use ||= event.dig('start', 'toolUse')
|
|
275
|
+
tool_use ||= event.dig('content_block', 'tool_use') if event['type'] == 'content_block_start'
|
|
276
|
+
return nil unless tool_use
|
|
277
|
+
|
|
278
|
+
tool_use_id = tool_use['toolUseId'] || tool_use['id']
|
|
279
|
+
tool_name = tool_use['name']
|
|
280
|
+
tool_input = tool_use['input'] || {}
|
|
281
|
+
|
|
282
|
+
{
|
|
283
|
+
tool_use_id => ToolCall.new(
|
|
284
|
+
id: tool_use_id,
|
|
285
|
+
name: tool_name,
|
|
286
|
+
arguments: tool_input
|
|
287
|
+
)
|
|
288
|
+
}
|
|
289
|
+
end
|
|
290
|
+
|
|
291
|
+
def extract_tool_call_delta(event)
|
|
292
|
+
input = normalized_delta(event).dig('toolUse', 'input')
|
|
293
|
+
input ||= normalized_delta(event).dig('tool_use', 'input')
|
|
294
|
+
input ||= event.dig('delta', 'partial_json') if event.dig('delta', 'type') == 'input_json_delta'
|
|
295
|
+
return nil unless input
|
|
296
|
+
|
|
297
|
+
{ nil => ToolCall.new(id: nil, name: nil, arguments: input) }
|
|
298
|
+
end
|
|
299
|
+
|
|
300
|
+
def normalized_delta(event)
|
|
301
|
+
delta = event.dig('contentBlockDelta', 'delta') || event['delta'] || {}
|
|
302
|
+
return delta if delta.is_a?(Hash)
|
|
303
|
+
|
|
304
|
+
if delta.is_a?(String) && !delta.empty?
|
|
305
|
+
JSON.parse(delta)
|
|
306
|
+
else
|
|
307
|
+
{}
|
|
308
|
+
end
|
|
309
|
+
rescue JSON::ParserError
|
|
310
|
+
{}
|
|
311
|
+
end
|
|
15
312
|
end
|
|
16
313
|
end
|
|
17
314
|
end
|
|
@@ -1,81 +1,90 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
|
-
require 'openssl'
|
|
4
|
-
require 'time'
|
|
5
|
-
|
|
6
3
|
module RubyLLM
|
|
7
4
|
module Providers
|
|
8
|
-
# AWS Bedrock API integration.
|
|
5
|
+
# AWS Bedrock Converse API integration.
|
|
9
6
|
class Bedrock < Provider
|
|
7
|
+
include Bedrock::Auth
|
|
10
8
|
include Bedrock::Chat
|
|
11
|
-
include Bedrock::Streaming
|
|
12
|
-
include Bedrock::Models
|
|
13
|
-
include Bedrock::Signing
|
|
14
9
|
include Bedrock::Media
|
|
15
|
-
include
|
|
10
|
+
include Bedrock::Models
|
|
11
|
+
include Bedrock::Streaming
|
|
16
12
|
|
|
17
13
|
def api_base
|
|
18
|
-
"https://bedrock-runtime.#{
|
|
14
|
+
"https://bedrock-runtime.#{bedrock_region}.amazonaws.com"
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def headers
|
|
18
|
+
{}
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, thinking: nil, &) # rubocop:disable Metrics/ParameterLists
|
|
22
|
+
normalized_params = normalize_params(params, model:)
|
|
23
|
+
|
|
24
|
+
super(
|
|
25
|
+
messages,
|
|
26
|
+
tools: tools,
|
|
27
|
+
temperature: temperature,
|
|
28
|
+
model: model,
|
|
29
|
+
params: normalized_params,
|
|
30
|
+
headers: headers,
|
|
31
|
+
schema: schema,
|
|
32
|
+
thinking: thinking,
|
|
33
|
+
&
|
|
34
|
+
)
|
|
19
35
|
end
|
|
20
36
|
|
|
21
37
|
def parse_error(response)
|
|
22
|
-
return if response.body.empty?
|
|
38
|
+
return if response.body.nil? || response.body.empty?
|
|
23
39
|
|
|
24
40
|
body = try_parse_json(response.body)
|
|
25
|
-
|
|
26
|
-
when Hash
|
|
27
|
-
body['message']
|
|
28
|
-
when Array
|
|
29
|
-
body.map do |part|
|
|
30
|
-
part['message']
|
|
31
|
-
end.join('. ')
|
|
32
|
-
else
|
|
33
|
-
body
|
|
34
|
-
end
|
|
35
|
-
end
|
|
41
|
+
return body if body.is_a?(String)
|
|
36
42
|
|
|
37
|
-
|
|
38
|
-
signer = create_signer
|
|
39
|
-
request = build_request(url, method:, payload:)
|
|
40
|
-
signer.sign_request(request)
|
|
43
|
+
body['message'] || body['Message'] || body['error'] || body['__type'] || super
|
|
41
44
|
end
|
|
42
45
|
|
|
43
|
-
def
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
secret_access_key: @config.bedrock_secret_key,
|
|
47
|
-
session_token: @config.bedrock_session_token,
|
|
48
|
-
region: @config.bedrock_region,
|
|
49
|
-
service: 'bedrock'
|
|
50
|
-
})
|
|
46
|
+
def list_models
|
|
47
|
+
response = signed_get(models_api_base, models_url)
|
|
48
|
+
parse_list_models_response(response, slug, capabilities)
|
|
51
49
|
end
|
|
52
50
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
url: url || completion_url,
|
|
58
|
-
body: payload ? JSON.generate(payload, ascii_only: false) : nil
|
|
59
|
-
}
|
|
51
|
+
class << self
|
|
52
|
+
def configuration_requirements
|
|
53
|
+
%i[bedrock_api_key bedrock_secret_key bedrock_region]
|
|
54
|
+
end
|
|
60
55
|
end
|
|
61
56
|
|
|
62
|
-
|
|
63
|
-
accept_header = streaming ? 'application/vnd.amazon.eventstream' : 'application/json'
|
|
57
|
+
private
|
|
64
58
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
'Accept' => accept_header
|
|
68
|
-
)
|
|
59
|
+
def bedrock_region
|
|
60
|
+
@config.bedrock_region
|
|
69
61
|
end
|
|
70
62
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
end
|
|
63
|
+
def sync_response(connection, payload, additional_headers = {})
|
|
64
|
+
signed_post(connection, completion_url, payload, additional_headers)
|
|
65
|
+
end
|
|
75
66
|
|
|
76
|
-
|
|
77
|
-
|
|
67
|
+
def normalize_params(params, model:)
|
|
68
|
+
normalized = RubyLLM::Utils.deep_symbolize_keys(params || {})
|
|
69
|
+
additional_fields = normalized[:additionalModelRequestFields] || {}
|
|
70
|
+
|
|
71
|
+
top_k = normalized.delete(:top_k)
|
|
72
|
+
if !top_k.nil? && model_supports_top_k?(model)
|
|
73
|
+
additional_fields = RubyLLM::Utils.deep_merge(additional_fields, { top_k: top_k })
|
|
78
74
|
end
|
|
75
|
+
|
|
76
|
+
normalized[:additionalModelRequestFields] = additional_fields unless additional_fields.empty?
|
|
77
|
+
normalized
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
def model_supports_top_k?(model)
|
|
81
|
+
Bedrock::Models.reasoning_embedded?(model)
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
def api_payload(payload)
|
|
85
|
+
cleaned = RubyLLM::Utils.deep_symbolize_keys(RubyLLM::Utils.deep_dup(payload))
|
|
86
|
+
cleaned.delete(:tools)
|
|
87
|
+
cleaned
|
|
79
88
|
end
|
|
80
89
|
end
|
|
81
90
|
end
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module RubyLLM
|
|
4
|
+
module Providers
|
|
5
|
+
class XAI
|
|
6
|
+
# Chat implementation for xAI
|
|
7
|
+
# https://docs.x.ai/docs/api-reference#chat-completions
|
|
8
|
+
module Chat
|
|
9
|
+
def format_role(role)
|
|
10
|
+
role.to_s
|
|
11
|
+
end
|
|
12
|
+
end
|
|
13
|
+
end
|
|
14
|
+
end
|
|
15
|
+
end
|