ruby_llm 1.10.0 → 1.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +14 -2
  3. data/lib/ruby_llm/active_record/acts_as_legacy.rb +41 -7
  4. data/lib/ruby_llm/active_record/chat_methods.rb +41 -7
  5. data/lib/ruby_llm/agent.rb +323 -0
  6. data/lib/ruby_llm/aliases.json +50 -32
  7. data/lib/ruby_llm/chat.rb +27 -3
  8. data/lib/ruby_llm/configuration.rb +4 -0
  9. data/lib/ruby_llm/models.json +19806 -5991
  10. data/lib/ruby_llm/models.rb +35 -6
  11. data/lib/ruby_llm/provider.rb +13 -1
  12. data/lib/ruby_llm/providers/anthropic/media.rb +2 -2
  13. data/lib/ruby_llm/providers/azure/chat.rb +29 -0
  14. data/lib/ruby_llm/providers/azure/embeddings.rb +24 -0
  15. data/lib/ruby_llm/providers/azure/media.rb +45 -0
  16. data/lib/ruby_llm/providers/azure/models.rb +14 -0
  17. data/lib/ruby_llm/providers/azure.rb +56 -0
  18. data/lib/ruby_llm/providers/bedrock/auth.rb +122 -0
  19. data/lib/ruby_llm/providers/bedrock/chat.rb +297 -56
  20. data/lib/ruby_llm/providers/bedrock/media.rb +62 -33
  21. data/lib/ruby_llm/providers/bedrock/models.rb +88 -65
  22. data/lib/ruby_llm/providers/bedrock/streaming.rb +305 -8
  23. data/lib/ruby_llm/providers/bedrock.rb +61 -52
  24. data/lib/ruby_llm/providers/openai/media.rb +1 -1
  25. data/lib/ruby_llm/providers/xai/chat.rb +15 -0
  26. data/lib/ruby_llm/providers/xai/models.rb +75 -0
  27. data/lib/ruby_llm/providers/xai.rb +28 -0
  28. data/lib/ruby_llm/version.rb +1 -1
  29. data/lib/ruby_llm.rb +14 -8
  30. data/lib/tasks/models.rake +10 -4
  31. data/lib/tasks/vcr.rake +32 -0
  32. metadata +16 -13
  33. data/lib/ruby_llm/providers/bedrock/capabilities.rb +0 -167
  34. data/lib/ruby_llm/providers/bedrock/signing.rb +0 -831
  35. data/lib/ruby_llm/providers/bedrock/streaming/base.rb +0 -51
  36. data/lib/ruby_llm/providers/bedrock/streaming/content_extraction.rb +0 -128
  37. data/lib/ruby_llm/providers/bedrock/streaming/message_processing.rb +0 -67
  38. data/lib/ruby_llm/providers/bedrock/streaming/payload_processing.rb +0 -85
  39. data/lib/ruby_llm/providers/bedrock/streaming/prelude_handling.rb +0 -78
@@ -3,100 +3,123 @@
3
3
  module RubyLLM
4
4
  module Providers
5
5
  class Bedrock
6
- # Models methods for the AWS Bedrock API implementation
6
+ # Models methods for AWS Bedrock.
7
7
  module Models
8
- def list_models
9
- mgmt_api_base = "https://bedrock.#{@config.bedrock_region}.amazonaws.com"
10
- full_models_url = "#{mgmt_api_base}/#{models_url}"
11
- signature = sign_request(full_models_url, method: :get)
12
- response = @connection.get(full_models_url) do |req|
13
- req.headers.merge! signature.headers
14
- end
8
+ module_function
15
9
 
16
- parse_list_models_response(response, slug, capabilities)
17
- end
10
+ REGION_PREFIXES = %w[us eu ap sa ca me af il].freeze
18
11
 
19
- module_function
12
+ def models_api_base
13
+ "https://bedrock.#{bedrock_region}.amazonaws.com"
14
+ end
20
15
 
21
16
  def models_url
22
- 'foundation-models'
17
+ '/foundation-models'
23
18
  end
24
19
 
25
- def parse_list_models_response(response, slug, capabilities)
26
- models = Array(response.body['modelSummaries'])
27
-
28
- models.select { |m| m['modelId'].include?('claude') }.map do |model_data|
29
- model_id = model_data['modelId']
30
-
31
- Model::Info.new(
32
- id: model_id_with_region(model_id, model_data),
33
- name: model_data['modelName'] || capabilities.format_display_name(model_id),
34
- provider: slug,
35
- family: capabilities.model_family(model_id),
36
- created_at: nil,
37
- context_window: capabilities.context_window_for(model_id),
38
- max_output_tokens: capabilities.max_tokens_for(model_id),
39
- modalities: capabilities.modalities_for(model_id),
40
- capabilities: capabilities.capabilities_for(model_id),
41
- pricing: capabilities.pricing_for(model_id),
42
- metadata: {
43
- provider_name: model_data['providerName'],
44
- inference_types: model_data['inferenceTypesSupported'] || [],
45
- streaming_supported: model_data['responseStreamingSupported'] || false,
46
- input_modalities: model_data['inputModalities'] || [],
47
- output_modalities: model_data['outputModalities'] || []
48
- }
49
- )
20
+ def parse_list_models_response(response, slug, _capabilities)
21
+ Array(response.body['modelSummaries']).map do |model_data|
22
+ create_model_info(model_data, slug)
50
23
  end
51
24
  end
52
25
 
53
- def create_model_info(model_data, slug, _capabilities)
54
- model_id = model_data['modelId']
26
+ def create_model_info(model_data, slug, _capabilities = nil)
27
+ model_id = model_id_with_region(model_data['modelId'], model_data)
28
+ converse_data = model_data['converse'] || {}
55
29
 
56
30
  Model::Info.new(
57
- id: model_id_with_region(model_id, model_data),
58
- name: model_data['modelName'] || model_id,
31
+ id: model_id,
32
+ name: model_data['modelName'],
59
33
  provider: slug,
60
- family: 'claude',
34
+ family: model_data['modelFamily'] || model_data['providerName']&.downcase,
61
35
  created_at: nil,
62
- context_window: 200_000,
63
- max_output_tokens: 4096,
64
- modalities: { input: ['text'], output: ['text'] },
65
- capabilities: [],
36
+ context_window: parse_context_window(model_data),
37
+ max_output_tokens: converse_data['maxTokensDefault'] || converse_data['maxTokensMaximum'],
38
+ modalities: {
39
+ input: normalize_modalities(model_data['inputModalities']),
40
+ output: normalize_modalities(model_data['outputModalities'])
41
+ },
42
+ capabilities: parse_capabilities(model_data),
66
43
  pricing: {},
67
- metadata: {}
44
+ metadata: {
45
+ provider_name: model_data['providerName'],
46
+ model_arn: model_data['modelArn'],
47
+ inference_types: model_data['inferenceTypesSupported'],
48
+ converse: converse_data
49
+ }
68
50
  )
69
51
  end
70
52
 
71
53
  def model_id_with_region(model_id, model_data)
72
- normalize_inference_profile_id(
73
- model_id,
74
- model_data['inferenceTypesSupported'],
75
- @config.bedrock_region
76
- )
54
+ inference_types = Array(model_data['inferenceTypesSupported'])
55
+ normalize_inference_profile_id(model_id, inference_types, @config.bedrock_region)
77
56
  end
78
57
 
79
- def region_prefix(region)
80
- region = region.to_s
81
- return 'us' if region.empty?
58
+ def normalize_inference_profile_id(model_id, inference_types, region)
59
+ return model_id unless inference_types.include?('INFERENCE_PROFILE')
60
+ return model_id if inference_types.include?('ON_DEMAND')
82
61
 
83
- region[0, 2]
62
+ with_region_prefix(model_id, region)
84
63
  end
85
64
 
86
65
  def with_region_prefix(model_id, region)
87
- desired_prefix = region_prefix(region)
88
- return model_id if model_id.start_with?("#{desired_prefix}.")
66
+ prefix = region_prefix(region)
89
67
 
90
- clean_model_id = model_id.sub(/^[a-z]{2}\./, '')
91
- "#{desired_prefix}.#{clean_model_id}"
68
+ if region_prefixed?(model_id)
69
+ model_id.sub(/\A(?:#{REGION_PREFIXES.join('|')})\./, "#{prefix}.")
70
+ else
71
+ "#{prefix}.#{model_id}"
72
+ end
92
73
  end
93
74
 
94
- def normalize_inference_profile_id(model_id, inference_types, region)
95
- types = Array(inference_types)
96
- return model_id unless types.include?('INFERENCE_PROFILE')
97
- return model_id if types.include?('ON_DEMAND')
75
+ def region_prefix(region)
76
+ prefix = region.to_s.split('-').first
77
+ prefix = '' if prefix.nil?
78
+ prefix.empty? ? 'us' : prefix
79
+ end
98
80
 
99
- with_region_prefix(model_id, region)
81
+ def region_prefixed?(model_id)
82
+ model_id.match?(/\A(?:#{REGION_PREFIXES.join('|')})\./)
83
+ end
84
+
85
+ def normalize_modalities(modalities)
86
+ Array(modalities).map do |modality|
87
+ normalized = modality.to_s.downcase
88
+ case normalized
89
+ when 'embedding' then 'embeddings'
90
+ when 'speech' then 'audio'
91
+ else normalized
92
+ end
93
+ end
94
+ end
95
+
96
+ def parse_capabilities(model_data)
97
+ capabilities = []
98
+ capabilities << 'streaming' if model_data['responseStreamingSupported']
99
+
100
+ converse = model_data['converse'] || {}
101
+ capabilities << 'function_calling' if converse.is_a?(Hash)
102
+ capabilities << 'reasoning' if converse.dig('reasoningSupported', 'embedded')
103
+
104
+ capabilities
105
+ end
106
+
107
+ def reasoning_embedded?(model)
108
+ metadata = RubyLLM::Utils.deep_symbolize_keys(model.metadata || {})
109
+ converse = metadata[:converse] || {}
110
+ reasoning_supported = converse[:reasoningSupported] || {}
111
+ reasoning_supported[:embedded] || false
112
+ end
113
+
114
+ def parse_context_window(model_data)
115
+ value = model_data.dig('description', 'maxContextWindow')
116
+ return unless value.is_a?(String)
117
+
118
+ if value.match?(/\A\d+[kK]\z/)
119
+ value.to_i * 1000
120
+ elsif value.match?(/\A\d+\z/)
121
+ value.to_i
122
+ end
100
123
  end
101
124
  end
102
125
  end
@@ -1,17 +1,314 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require_relative 'streaming/base'
4
- require_relative 'streaming/content_extraction'
5
- require_relative 'streaming/message_processing'
6
- require_relative 'streaming/payload_processing'
7
- require_relative 'streaming/prelude_handling'
8
-
9
3
  module RubyLLM
10
4
  module Providers
11
5
  class Bedrock
12
- # Streaming implementation for the AWS Bedrock API.
6
+ # Streaming implementation for Bedrock ConverseStream (AWS Event Stream).
13
7
  module Streaming
14
- include Base
8
+ private
9
+
10
+ def stream_url
11
+ "/model/#{@model.id}/converse-stream"
12
+ end
13
+
14
+ def stream_response(connection, payload, additional_headers = {}, &block)
15
+ accumulator = StreamAccumulator.new
16
+ decoder = event_stream_decoder
17
+ request_payload = api_payload(payload)
18
+ body = JSON.generate(request_payload)
19
+
20
+ response = connection.post(stream_url, request_payload) do |req|
21
+ req.headers.merge!(sign_headers('POST', stream_url, body))
22
+ req.headers.merge!(additional_headers) unless additional_headers.empty?
23
+ req.headers['Accept'] = 'application/vnd.amazon.eventstream'
24
+
25
+ if Faraday::VERSION.start_with?('1')
26
+ req.options[:on_data] = proc do |chunk, _size|
27
+ parse_stream_chunk(decoder, chunk, accumulator, &block)
28
+ end
29
+ else
30
+ req.options.on_data = proc do |chunk, _bytes, env|
31
+ if env&.status == 200
32
+ parse_stream_chunk(decoder, chunk, accumulator, &block)
33
+ else
34
+ handle_failed_stream(chunk, env)
35
+ end
36
+ end
37
+ end
38
+ end
39
+
40
+ message = accumulator.to_message(response)
41
+ RubyLLM.logger.debug "Stream completed: #{message.content}"
42
+ message
43
+ end
44
+
45
+ def event_stream_decoder
46
+ require 'aws-eventstream'
47
+ Aws::EventStream::Decoder.new
48
+ rescue LoadError
49
+ raise Error,
50
+ 'The aws-eventstream gem is required for Bedrock streaming. ' \
51
+ 'Please add it to your Gemfile: gem "aws-eventstream"'
52
+ end
53
+
54
+ def handle_failed_stream(chunk, env)
55
+ data = JSON.parse(chunk)
56
+ error_response = env.merge(body: data)
57
+ ErrorMiddleware.parse_error(provider: self, response: error_response)
58
+ rescue JSON::ParserError
59
+ RubyLLM.logger.debug "Failed Bedrock stream error chunk: #{chunk}"
60
+ end
61
+
62
+ def parse_stream_chunk(decoder, raw_chunk, accumulator)
63
+ handle_non_eventstream_error_chunk(raw_chunk)
64
+
65
+ decode_events(decoder, raw_chunk).each do |event|
66
+ chunk = build_chunk(event)
67
+ next unless chunk
68
+
69
+ accumulator.add(chunk)
70
+ yield chunk
71
+ end
72
+ end
73
+
74
+ def handle_non_eventstream_error_chunk(raw_chunk)
75
+ text = raw_chunk.to_s
76
+
77
+ if text.start_with?('event: error')
78
+ payload = text.lines.find { |line| line.start_with?('data:') }&.delete_prefix('data:')&.strip
79
+ raise_streaming_chunk_error(payload) if payload
80
+ return
81
+ end
82
+
83
+ return unless text.lstrip.start_with?('{') && text.include?('"error"')
84
+
85
+ raise_streaming_chunk_error(text)
86
+ end
87
+
88
+ def raise_streaming_chunk_error(payload)
89
+ parsed = JSON.parse(payload)
90
+ message = parsed.dig('error', 'message') || parsed['message'] || 'Bedrock streaming error'
91
+ response = Struct.new(:body, :status).new({ 'message' => message }, 500)
92
+ ErrorMiddleware.parse_error(provider: self, response: response)
93
+ rescue JSON::ParserError
94
+ nil
95
+ end
96
+
97
+ def decode_events(decoder, raw_chunk)
98
+ events = []
99
+ message, eof = decoder.decode_chunk(raw_chunk)
100
+
101
+ while message
102
+ event = decode_event_payload(message.payload.read)
103
+ RubyLLM.logger.debug("Bedrock stream event keys: #{event.keys}") if event && RubyLLM.config.log_stream_debug
104
+ events << event if event
105
+ break if eof
106
+
107
+ message, eof = decoder.decode_chunk
108
+ end
109
+
110
+ events
111
+ end
112
+
113
+ def decode_event_payload(payload)
114
+ outer = JSON.parse(payload)
115
+
116
+ if outer['bytes'].is_a?(String)
117
+ JSON.parse(Base64.decode64(outer['bytes']))
118
+ else
119
+ outer
120
+ end
121
+ rescue JSON::ParserError => e
122
+ RubyLLM.logger.debug "Failed to decode Bedrock stream event payload: #{e.message}"
123
+ nil
124
+ end
125
+
126
+ def build_chunk(event)
127
+ raise_stream_error(event) if stream_error_event?(event)
128
+
129
+ metadata_usage, usage, message_usage = event_usage(event)
130
+
131
+ Chunk.new(
132
+ role: :assistant,
133
+ model_id: event['modelId'] || event.dig('message', 'model') || @model&.id,
134
+ content: extract_content_delta(event),
135
+ thinking: Thinking.build(
136
+ text: extract_thinking_delta(event),
137
+ signature: extract_thinking_signature(event)
138
+ ),
139
+ tool_calls: extract_tool_calls(event),
140
+ input_tokens: extract_input_tokens(metadata_usage, usage, message_usage),
141
+ output_tokens: extract_output_tokens(metadata_usage, usage),
142
+ cached_tokens: extract_cached_tokens(metadata_usage, usage),
143
+ cache_creation_tokens: extract_cache_creation_tokens(metadata_usage, usage),
144
+ thinking_tokens: extract_reasoning_tokens(metadata_usage, usage)
145
+ )
146
+ end
147
+
148
+ def event_usage(event)
149
+ [
150
+ event.dig('metadata', 'usage') || {},
151
+ event['usage'] || {},
152
+ event.dig('message', 'usage') || {}
153
+ ]
154
+ end
155
+
156
+ def extract_input_tokens(metadata_usage, usage, message_usage)
157
+ metadata_usage['inputTokens'] || usage['inputTokens'] || message_usage['input_tokens']
158
+ end
159
+
160
+ def extract_output_tokens(metadata_usage, usage)
161
+ metadata_usage['outputTokens'] || usage['outputTokens'] || usage['output_tokens']
162
+ end
163
+
164
+ def extract_cached_tokens(metadata_usage, usage)
165
+ metadata_usage['cacheReadInputTokens'] || usage['cacheReadInputTokens'] || usage['cache_read_input_tokens']
166
+ end
167
+
168
+ def extract_cache_creation_tokens(metadata_usage, usage)
169
+ metadata_usage['cacheWriteInputTokens'] || usage['cacheWriteInputTokens'] ||
170
+ usage['cache_creation_input_tokens']
171
+ end
172
+
173
+ def extract_reasoning_tokens(metadata_usage, usage)
174
+ metadata_usage['reasoningTokens'] || usage['reasoningTokens'] ||
175
+ usage.dig('output_tokens_details', 'thinking_tokens')
176
+ end
177
+
178
+ def stream_error_event?(event)
179
+ event.keys.any? { |key| key.end_with?('Exception') } || event['type'] == 'error'
180
+ end
181
+
182
+ def raise_stream_error(event)
183
+ if event['type'] == 'error'
184
+ message = event.dig('error', 'message') || 'Bedrock streaming error'
185
+ response = Struct.new(:body, :status).new({ 'message' => message }, 500)
186
+ ErrorMiddleware.parse_error(provider: self, response: response)
187
+ return
188
+ end
189
+
190
+ key = event.keys.find { |candidate| candidate.end_with?('Exception') }
191
+ payload = event[key]
192
+ message = payload['message'] || key
193
+ status = case key
194
+ when 'throttlingException' then 429
195
+ when 'validationException' then 400
196
+ when 'accessDeniedException', 'unrecognizedClientException' then 401
197
+ when 'serviceUnavailableException' then 503
198
+ else 500
199
+ end
200
+
201
+ response = Struct.new(:body, :status).new({ 'message' => message }, status)
202
+ ErrorMiddleware.parse_error(provider: self, response: response)
203
+ end
204
+
205
+ def extract_content_delta(event)
206
+ delta = normalized_delta(event)
207
+ return delta['text'] if delta['text']
208
+
209
+ return event.dig('delta', 'text') if event.dig('delta', 'type') == 'text_delta'
210
+
211
+ nil
212
+ end
213
+
214
+ def extract_thinking_delta(event)
215
+ delta = normalized_delta(event)
216
+ reasoning_content = delta['reasoningContent'] || {}
217
+
218
+ reasoning_text = reasoning_content['reasoningText'] || {}
219
+ return reasoning_text['text'] if reasoning_text['text']
220
+ return event.dig('delta', 'thinking') if event.dig('delta', 'type') == 'thinking_delta'
221
+
222
+ nil
223
+ end
224
+
225
+ def extract_thinking_signature(event)
226
+ signature = extract_signature_from_delta(event)
227
+ return signature if signature
228
+
229
+ signature = extract_signature_from_start(event)
230
+ return signature if signature
231
+
232
+ nil
233
+ end
234
+
235
+ def extract_signature_from_delta(event)
236
+ delta = normalized_delta(event)
237
+ reasoning_content = delta['reasoningContent'] || {}
238
+ reasoning_text = reasoning_content['reasoningText'] || {}
239
+ return reasoning_text['signature'] if reasoning_text['signature']
240
+ return event.dig('delta', 'signature') if event.dig('delta', 'type') == 'signature_delta'
241
+
242
+ nil
243
+ end
244
+
245
+ def extract_signature_from_start(event)
246
+ start = event.dig('contentBlockStart', 'start', 'reasoningContent')
247
+ return nil unless start
248
+
249
+ reasoning_text = start['reasoningText'] || {}
250
+ return reasoning_text['signature'] if reasoning_text['signature']
251
+ return start['redactedContent'] if start['redactedContent']
252
+
253
+ nil
254
+ end
255
+
256
+ def extract_tool_calls(event)
257
+ return extract_tool_call_start(event) if tool_call_start_event?(event)
258
+ return extract_tool_call_delta(event) if tool_call_delta_event?(event)
259
+
260
+ nil
261
+ end
262
+
263
+ def tool_call_start_event?(event)
264
+ event['contentBlockStart'] || event['start'] || event.dig('content_block', 'tool_use')
265
+ end
266
+
267
+ def tool_call_delta_event?(event)
268
+ event['contentBlockDelta'] || event.dig('delta', 'toolUse') || event.dig('delta', 'tool_use') ||
269
+ event.dig('delta', 'partial_json')
270
+ end
271
+
272
+ def extract_tool_call_start(event)
273
+ tool_use = event.dig('contentBlockStart', 'start', 'toolUse')
274
+ tool_use ||= event.dig('start', 'toolUse')
275
+ tool_use ||= event.dig('content_block', 'tool_use') if event['type'] == 'content_block_start'
276
+ return nil unless tool_use
277
+
278
+ tool_use_id = tool_use['toolUseId'] || tool_use['id']
279
+ tool_name = tool_use['name']
280
+ tool_input = tool_use['input'] || {}
281
+
282
+ {
283
+ tool_use_id => ToolCall.new(
284
+ id: tool_use_id,
285
+ name: tool_name,
286
+ arguments: tool_input
287
+ )
288
+ }
289
+ end
290
+
291
+ def extract_tool_call_delta(event)
292
+ input = normalized_delta(event).dig('toolUse', 'input')
293
+ input ||= normalized_delta(event).dig('tool_use', 'input')
294
+ input ||= event.dig('delta', 'partial_json') if event.dig('delta', 'type') == 'input_json_delta'
295
+ return nil unless input
296
+
297
+ { nil => ToolCall.new(id: nil, name: nil, arguments: input) }
298
+ end
299
+
300
+ def normalized_delta(event)
301
+ delta = event.dig('contentBlockDelta', 'delta') || event['delta'] || {}
302
+ return delta if delta.is_a?(Hash)
303
+
304
+ if delta.is_a?(String) && !delta.empty?
305
+ JSON.parse(delta)
306
+ else
307
+ {}
308
+ end
309
+ rescue JSON::ParserError
310
+ {}
311
+ end
15
312
  end
16
313
  end
17
314
  end
@@ -1,81 +1,90 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'openssl'
4
- require 'time'
5
-
6
3
  module RubyLLM
7
4
  module Providers
8
- # AWS Bedrock API integration.
5
+ # AWS Bedrock Converse API integration.
9
6
  class Bedrock < Provider
7
+ include Bedrock::Auth
10
8
  include Bedrock::Chat
11
- include Bedrock::Streaming
12
- include Bedrock::Models
13
- include Bedrock::Signing
14
9
  include Bedrock::Media
15
- include Anthropic::Tools
10
+ include Bedrock::Models
11
+ include Bedrock::Streaming
16
12
 
17
13
  def api_base
18
- "https://bedrock-runtime.#{@config.bedrock_region}.amazonaws.com"
14
+ "https://bedrock-runtime.#{bedrock_region}.amazonaws.com"
15
+ end
16
+
17
+ def headers
18
+ {}
19
+ end
20
+
21
+ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, thinking: nil, &) # rubocop:disable Metrics/ParameterLists
22
+ normalized_params = normalize_params(params, model:)
23
+
24
+ super(
25
+ messages,
26
+ tools: tools,
27
+ temperature: temperature,
28
+ model: model,
29
+ params: normalized_params,
30
+ headers: headers,
31
+ schema: schema,
32
+ thinking: thinking,
33
+ &
34
+ )
19
35
  end
20
36
 
21
37
  def parse_error(response)
22
- return if response.body.empty?
38
+ return if response.body.nil? || response.body.empty?
23
39
 
24
40
  body = try_parse_json(response.body)
25
- case body
26
- when Hash
27
- body['message']
28
- when Array
29
- body.map do |part|
30
- part['message']
31
- end.join('. ')
32
- else
33
- body
34
- end
35
- end
41
+ return body if body.is_a?(String)
36
42
 
37
- def sign_request(url, method: :post, payload: nil)
38
- signer = create_signer
39
- request = build_request(url, method:, payload:)
40
- signer.sign_request(request)
43
+ body['message'] || body['Message'] || body['error'] || body['__type'] || super
41
44
  end
42
45
 
43
- def create_signer
44
- Signing::Signer.new({
45
- access_key_id: @config.bedrock_api_key,
46
- secret_access_key: @config.bedrock_secret_key,
47
- session_token: @config.bedrock_session_token,
48
- region: @config.bedrock_region,
49
- service: 'bedrock'
50
- })
46
+ def list_models
47
+ response = signed_get(models_api_base, models_url)
48
+ parse_list_models_response(response, slug, capabilities)
51
49
  end
52
50
 
53
- def build_request(url, method: :post, payload: nil)
54
- {
55
- connection: @connection,
56
- http_method: method,
57
- url: url || completion_url,
58
- body: payload ? JSON.generate(payload, ascii_only: false) : nil
59
- }
51
+ class << self
52
+ def configuration_requirements
53
+ %i[bedrock_api_key bedrock_secret_key bedrock_region]
54
+ end
60
55
  end
61
56
 
62
- def build_headers(signature_headers, streaming: false)
63
- accept_header = streaming ? 'application/vnd.amazon.eventstream' : 'application/json'
57
+ private
64
58
 
65
- signature_headers.merge(
66
- 'Content-Type' => 'application/json',
67
- 'Accept' => accept_header
68
- )
59
+ def bedrock_region
60
+ @config.bedrock_region
69
61
  end
70
62
 
71
- class << self
72
- def capabilities
73
- Bedrock::Capabilities
74
- end
63
+ def sync_response(connection, payload, additional_headers = {})
64
+ signed_post(connection, completion_url, payload, additional_headers)
65
+ end
75
66
 
76
- def configuration_requirements
77
- %i[bedrock_api_key bedrock_secret_key bedrock_region]
67
+ def normalize_params(params, model:)
68
+ normalized = RubyLLM::Utils.deep_symbolize_keys(params || {})
69
+ additional_fields = normalized[:additionalModelRequestFields] || {}
70
+
71
+ top_k = normalized.delete(:top_k)
72
+ if !top_k.nil? && model_supports_top_k?(model)
73
+ additional_fields = RubyLLM::Utils.deep_merge(additional_fields, { top_k: top_k })
78
74
  end
75
+
76
+ normalized[:additionalModelRequestFields] = additional_fields unless additional_fields.empty?
77
+ normalized
78
+ end
79
+
80
+ def model_supports_top_k?(model)
81
+ Bedrock::Models.reasoning_embedded?(model)
82
+ end
83
+
84
+ def api_payload(payload)
85
+ cleaned = RubyLLM::Utils.deep_symbolize_keys(RubyLLM::Utils.deep_dup(payload))
86
+ cleaned.delete(:tools)
87
+ cleaned
79
88
  end
80
89
  end
81
90
  end
@@ -37,7 +37,7 @@ module RubyLLM
37
37
  {
38
38
  type: 'image_url',
39
39
  image_url: {
40
- url: image.url? ? image.source : image.for_llm
40
+ url: image.url? ? image.source.to_s : image.for_llm
41
41
  }
42
42
  }
43
43
  end
@@ -0,0 +1,15 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module Providers
5
+ class XAI
6
+ # Chat implementation for xAI
7
+ # https://docs.x.ai/docs/api-reference#chat-completions
8
+ module Chat
9
+ def format_role(role)
10
+ role.to_s
11
+ end
12
+ end
13
+ end
14
+ end
15
+ end