ruby_llm 1.0.0 → 1.1.0rc1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +58 -19
  3. data/lib/ruby_llm/active_record/acts_as.rb +46 -7
  4. data/lib/ruby_llm/aliases.json +65 -0
  5. data/lib/ruby_llm/aliases.rb +56 -0
  6. data/lib/ruby_llm/chat.rb +11 -10
  7. data/lib/ruby_llm/configuration.rb +4 -0
  8. data/lib/ruby_llm/error.rb +15 -4
  9. data/lib/ruby_llm/models.json +1489 -283
  10. data/lib/ruby_llm/models.rb +57 -22
  11. data/lib/ruby_llm/provider.rb +44 -41
  12. data/lib/ruby_llm/providers/anthropic/capabilities.rb +8 -9
  13. data/lib/ruby_llm/providers/anthropic/chat.rb +31 -4
  14. data/lib/ruby_llm/providers/anthropic/streaming.rb +12 -6
  15. data/lib/ruby_llm/providers/anthropic.rb +4 -0
  16. data/lib/ruby_llm/providers/bedrock/capabilities.rb +168 -0
  17. data/lib/ruby_llm/providers/bedrock/chat.rb +108 -0
  18. data/lib/ruby_llm/providers/bedrock/models.rb +84 -0
  19. data/lib/ruby_llm/providers/bedrock/signing.rb +831 -0
  20. data/lib/ruby_llm/providers/bedrock/streaming/base.rb +46 -0
  21. data/lib/ruby_llm/providers/bedrock/streaming/content_extraction.rb +63 -0
  22. data/lib/ruby_llm/providers/bedrock/streaming/message_processing.rb +79 -0
  23. data/lib/ruby_llm/providers/bedrock/streaming/payload_processing.rb +90 -0
  24. data/lib/ruby_llm/providers/bedrock/streaming/prelude_handling.rb +91 -0
  25. data/lib/ruby_llm/providers/bedrock/streaming.rb +36 -0
  26. data/lib/ruby_llm/providers/bedrock.rb +83 -0
  27. data/lib/ruby_llm/providers/deepseek/chat.rb +17 -0
  28. data/lib/ruby_llm/providers/deepseek.rb +5 -0
  29. data/lib/ruby_llm/providers/gemini/capabilities.rb +50 -34
  30. data/lib/ruby_llm/providers/gemini/chat.rb +8 -15
  31. data/lib/ruby_llm/providers/gemini/images.rb +5 -10
  32. data/lib/ruby_llm/providers/gemini/models.rb +0 -8
  33. data/lib/ruby_llm/providers/gemini/streaming.rb +35 -76
  34. data/lib/ruby_llm/providers/gemini/tools.rb +12 -12
  35. data/lib/ruby_llm/providers/gemini.rb +4 -0
  36. data/lib/ruby_llm/providers/openai/capabilities.rb +154 -177
  37. data/lib/ruby_llm/providers/openai/streaming.rb +9 -13
  38. data/lib/ruby_llm/providers/openai.rb +4 -0
  39. data/lib/ruby_llm/streaming.rb +96 -0
  40. data/lib/ruby_llm/tool.rb +15 -7
  41. data/lib/ruby_llm/version.rb +1 -1
  42. data/lib/ruby_llm.rb +8 -3
  43. data/lib/tasks/browser_helper.rb +97 -0
  44. data/lib/tasks/capability_generator.rb +123 -0
  45. data/lib/tasks/capability_scraper.rb +224 -0
  46. data/lib/tasks/cli_helper.rb +22 -0
  47. data/lib/tasks/code_validator.rb +29 -0
  48. data/lib/tasks/model_updater.rb +66 -0
  49. data/lib/tasks/models.rake +28 -197
  50. data/lib/tasks/vcr.rake +97 -0
  51. metadata +42 -19
  52. data/.github/workflows/cicd.yml +0 -109
  53. data/.github/workflows/docs.yml +0 -53
  54. data/.gitignore +0 -58
  55. data/.overcommit.yml +0 -26
  56. data/.rspec +0 -3
  57. data/.rspec_status +0 -50
  58. data/.rubocop.yml +0 -10
  59. data/.yardopts +0 -12
  60. data/Gemfile +0 -32
  61. data/Rakefile +0 -9
  62. data/bin/console +0 -17
  63. data/bin/setup +0 -6
  64. data/ruby_llm.gemspec +0 -43
@@ -12,22 +12,37 @@ module RubyLLM
12
12
  class Models
13
13
  include Enumerable
14
14
 
15
- def self.instance
16
- @instance ||= new
17
- end
15
+ # Delegate class methods to the singleton instance
16
+ class << self
17
+ def instance
18
+ @instance ||= new
19
+ end
18
20
 
19
- def self.provider_for(model)
20
- Provider.for(model)
21
- end
21
+ def provider_for(model)
22
+ Provider.for(model)
23
+ end
22
24
 
23
- # Class method to refresh model data
24
- def self.refresh!
25
- models = RubyLLM.providers.flat_map(&:list_models).sort_by(&:id)
26
- @instance = new(models)
27
- end
25
+ def models_file
26
+ File.expand_path('models.json', __dir__)
27
+ end
28
+
29
+ def refresh! # rubocop:disable Metrics/AbcSize,Metrics/CyclomaticComplexity,Metrics/PerceivedComplexity
30
+ configured = Provider.configured_providers
31
+
32
+ # Log provider status
33
+ skipped = Provider.providers.values - configured
34
+ RubyLLM.logger.info "Refreshing models from #{configured.map(&:slug).join(', ')}" if configured.any?
35
+ RubyLLM.logger.info "Skipping #{skipped.map(&:slug).join(', ')} - providers not configured" if skipped.any?
36
+
37
+ # Store current models except from configured providers
38
+ current = instance.load_models
39
+ preserved = current.reject { |m| configured.map(&:slug).include?(m.provider) }
40
+
41
+ all = (preserved + configured.flat_map(&:list_models)).sort_by(&:id)
42
+ @instance = new(all)
43
+ @instance
44
+ end
28
45
 
29
- # Delegate class methods to the singleton instance
30
- class << self
31
46
  def method_missing(method, ...)
32
47
  if instance.respond_to?(method)
33
48
  instance.send(method, ...)
@@ -48,10 +63,14 @@ module RubyLLM
48
63
 
49
64
  # Load models from the JSON file
50
65
  def load_models
51
- data = JSON.parse(File.read(File.expand_path('models.json', __dir__)))
52
- data.map { |model| ModelInfo.new(model.transform_keys(&:to_sym)) }
53
- rescue Errno::ENOENT
54
- [] # Return empty array if file doesn't exist yet
66
+ data = File.exist?(self.class.models_file) ? File.read(self.class.models_file) : '[]'
67
+ JSON.parse(data).map { |model| ModelInfo.new(model.transform_keys(&:to_sym)) }
68
+ rescue JSON::ParserError
69
+ []
70
+ end
71
+
72
+ def save_models
73
+ File.write(self.class.models_file, JSON.pretty_generate(all.map(&:to_h)))
55
74
  end
56
75
 
57
76
  # Return all models in the collection
@@ -65,9 +84,12 @@ module RubyLLM
65
84
  end
66
85
 
67
86
  # Find a specific model by ID
68
- def find(model_id)
69
- all.find { |m| m.id == model_id } or
70
- raise ModelNotFoundError, "Unknown model: #{model_id}"
87
+ def find(model_id, provider = nil)
88
+ if provider
89
+ find_with_provider(model_id, provider)
90
+ else
91
+ find_without_provider(model_id)
92
+ end
71
93
  end
72
94
 
73
95
  # Filter to only chat models
@@ -103,8 +125,21 @@ module RubyLLM
103
125
  # Instance method to refresh models
104
126
  def refresh!
105
127
  self.class.refresh!
106
- # Return self for method chaining
107
- self
128
+ end
129
+
130
+ private
131
+
132
+ def find_with_provider(model_id, provider)
133
+ resolved_id = Aliases.resolve(model_id, provider)
134
+ all.find { |m| m.id == model_id && m.provider == provider.to_s } ||
135
+ all.find { |m| m.id == resolved_id && m.provider == provider.to_s } ||
136
+ raise(ModelNotFoundError, "Unknown model: #{model_id} for provider: #{provider}")
137
+ end
138
+
139
+ def find_without_provider(model_id)
140
+ all.find { |m| m.id == model_id } ||
141
+ all.find { |m| m.id == Aliases.resolve(model_id) } ||
142
+ raise(ModelNotFoundError, "Unknown model: #{model_id}")
108
143
  end
109
144
  end
110
145
  end
@@ -7,9 +7,21 @@ module RubyLLM
7
7
  module Provider
8
8
  # Common functionality for all LLM providers. Implements the core provider
9
9
  # interface so specific providers only need to implement a few key methods.
10
- module Methods # rubocop:disable Metrics/ModuleLength
11
- def complete(messages, tools:, temperature:, model:, &block)
12
- payload = render_payload messages, tools: tools, temperature: temperature, model: model, stream: block_given?
10
+ module Methods
11
+ extend Streaming
12
+
13
+ def complete(messages, tools:, temperature:, model:, &block) # rubocop:disable Metrics/MethodLength
14
+ normalized_temperature = if capabilities.respond_to?(:normalize_temperature)
15
+ capabilities.normalize_temperature(temperature, model)
16
+ else
17
+ temperature
18
+ end
19
+
20
+ payload = render_payload(messages,
21
+ tools: tools,
22
+ temperature: normalized_temperature,
23
+ model: model,
24
+ stream: block_given?)
13
25
 
14
26
  if block_given?
15
27
  stream_response payload, &block
@@ -39,24 +51,35 @@ module RubyLLM
39
51
  parse_image_response(response)
40
52
  end
41
53
 
54
+ def configured?
55
+ missing_configs.empty?
56
+ end
57
+
42
58
  private
43
59
 
44
- def sync_response(payload)
45
- response = post completion_url, payload
46
- parse_completion_response response
60
+ def missing_configs
61
+ configuration_requirements.select do |key|
62
+ value = RubyLLM.config.send(key)
63
+ value.nil? || value.empty?
64
+ end
47
65
  end
48
66
 
49
- def stream_response(payload, &block)
50
- accumulator = StreamAccumulator.new
67
+ def ensure_configured!
68
+ return if configured?
51
69
 
52
- post stream_url, payload do |req|
53
- req.options.on_data = handle_stream do |chunk|
54
- accumulator.add chunk
55
- block.call chunk
70
+ config_block = <<~RUBY
71
+ RubyLLM.configure do |config|
72
+ #{missing_configs.map { |key| "config.#{key} = ENV['#{key.to_s.upcase}']" }.join("\n ")}
56
73
  end
57
- end
74
+ RUBY
58
75
 
59
- accumulator.to_message
76
+ raise ConfigurationError,
77
+ "#{slug} provider is not configured. Add this to your initialization:\n\n#{config_block}"
78
+ end
79
+
80
+ def sync_response(payload)
81
+ response = post completion_url, payload
82
+ parse_completion_response response
60
83
  end
61
84
 
62
85
  def post(url, payload)
@@ -67,6 +90,8 @@ module RubyLLM
67
90
  end
68
91
 
69
92
  def connection # rubocop:disable Metrics/MethodLength,Metrics/AbcSize
93
+ ensure_configured!
94
+
70
95
  @connection ||= Faraday.new(api_base) do |f| # rubocop:disable Metrics/BlockLength
71
96
  f.options.timeout = RubyLLM.config.request_timeout
72
97
 
@@ -105,33 +130,6 @@ module RubyLLM
105
130
  f.use :llm_errors, provider: self
106
131
  end
107
132
  end
108
-
109
- def to_json_stream(&block) # rubocop:disable Metrics/MethodLength
110
- buffer = String.new
111
- parser = EventStreamParser::Parser.new
112
-
113
- proc do |chunk, _bytes, env|
114
- if env && env.status != 200
115
- # Accumulate error chunks
116
- buffer << chunk
117
- begin
118
- error_data = JSON.parse(buffer)
119
- error_response = env.merge(body: error_data)
120
- ErrorMiddleware.parse_error(provider: self, response: error_response)
121
- rescue JSON::ParserError
122
- # Keep accumulating if we don't have complete JSON yet
123
- RubyLLM.logger.debug "Accumulating error chunk: #{chunk}"
124
- end
125
- else
126
- parser.feed(chunk) do |_type, data|
127
- unless data == '[DONE]'
128
- parsed_data = JSON.parse(data)
129
- block.call(parsed_data)
130
- end
131
- end
132
- end
133
- end
134
- end
135
133
  end
136
134
 
137
135
  def try_parse_json(maybe_json)
@@ -171,6 +169,7 @@ module RubyLLM
171
169
  class << self
172
170
  def extended(base)
173
171
  base.extend(Methods)
172
+ base.extend(Streaming)
174
173
  end
175
174
 
176
175
  def register(name, provider_module)
@@ -185,6 +184,10 @@ module RubyLLM
185
184
  def providers
186
185
  @providers ||= {}
187
186
  end
187
+
188
+ def configured_providers
189
+ providers.select { |_name, provider| provider.configured? }.values
190
+ end
188
191
  end
189
192
  end
190
193
  end
@@ -20,8 +20,8 @@ module RubyLLM
20
20
  # @return [Integer] the maximum output tokens
21
21
  def determine_max_tokens(model_id)
22
22
  case model_id
23
- when /claude-3-(7-sonnet|5)/ then 8_192 # Can be increased to 64K with extended thinking
24
- else 4_096 # Claude 3 Opus and Haiku
23
+ when /claude-3-7-sonnet/, /claude-3-5/ then 8_192
24
+ else 4_096
25
25
  end
26
26
  end
27
27
 
@@ -92,13 +92,12 @@ module RubyLLM
92
92
 
93
93
  # Pricing information for Anthropic models (per million tokens)
94
94
  PRICES = {
95
- claude37_sonnet: { input: 3.0, output: 15.0 }, # $3.00/$15.00 per million tokens
96
- claude35_sonnet: { input: 3.0, output: 15.0 }, # $3.00/$15.00 per million tokens
97
- claude35_haiku: { input: 0.80, output: 4.0 }, # $0.80/$4.00 per million tokens
98
- claude3_opus: { input: 15.0, output: 75.0 }, # $15.00/$75.00 per million tokens
99
- claude3_sonnet: { input: 3.0, output: 15.0 }, # $3.00/$15.00 per million tokens
100
- claude3_haiku: { input: 0.25, output: 1.25 }, # $0.25/$1.25 per million tokens
101
- claude2: { input: 3.0, output: 15.0 } # Default pricing for Claude 2.x models
95
+ claude37_sonnet: { input: 3.0, output: 15.0 },
96
+ claude35_sonnet: { input: 3.0, output: 15.0 },
97
+ claude35_haiku: { input: 0.80, output: 4.0 },
98
+ claude3_opus: { input: 15.0, output: 75.0 },
99
+ claude3_haiku: { input: 0.25, output: 1.25 },
100
+ claude2: { input: 3.0, output: 15.0 }
102
101
  }.freeze
103
102
 
104
103
  # Default input price if model not found in PRICES
@@ -12,15 +12,42 @@ module RubyLLM
12
12
  end
13
13
 
14
14
  def render_payload(messages, tools:, temperature:, model:, stream: false)
15
+ system_messages, chat_messages = separate_messages(messages)
16
+ system_content = build_system_content(system_messages)
17
+
18
+ build_base_payload(chat_messages, temperature, model, stream).tap do |payload|
19
+ add_optional_fields(payload, system_content:, tools:)
20
+ end
21
+ end
22
+
23
+ def separate_messages(messages)
24
+ messages.partition { |msg| msg.role == :system }
25
+ end
26
+
27
+ def build_system_content(system_messages)
28
+ if system_messages.length > 1
29
+ RubyLLM.logger.warn(
30
+ "Anthropic's Claude implementation only supports a single system message. " \
31
+ 'Multiple system messages will be combined into one.'
32
+ )
33
+ end
34
+
35
+ system_messages.map { |msg| format_message(msg)[:content] }.join("\n\n")
36
+ end
37
+
38
+ def build_base_payload(chat_messages, temperature, model, stream)
15
39
  {
16
40
  model: model,
17
- messages: messages.map { |msg| format_message(msg) },
41
+ messages: chat_messages.map { |msg| format_message(msg) },
18
42
  temperature: temperature,
19
43
  stream: stream,
20
44
  max_tokens: RubyLLM.models.find(model).max_tokens
21
- }.tap do |payload|
22
- payload[:tools] = tools.values.map { |t| function_for(t) } if tools.any?
23
- end
45
+ }
46
+ end
47
+
48
+ def add_optional_fields(payload, system_content:, tools:)
49
+ payload[:tools] = tools.values.map { |t| function_for(t) } if tools.any?
50
+ payload[:system] = system_content unless system_content.empty?
24
51
  end
25
52
 
26
53
  def parse_completion_response(response)
@@ -11,12 +11,6 @@ module RubyLLM
11
11
  completion_url
12
12
  end
13
13
 
14
- def handle_stream(&block)
15
- to_json_stream do |data|
16
- block.call(build_chunk(data))
17
- end
18
- end
19
-
20
14
  def build_chunk(data)
21
15
  Chunk.new(
22
16
  role: :assistant,
@@ -31,6 +25,18 @@ module RubyLLM
31
25
  def json_delta?(data)
32
26
  data['type'] == 'content_block_delta' && data.dig('delta', 'type') == 'input_json_delta'
33
27
  end
28
+
29
+ def parse_streaming_error(data)
30
+ error_data = JSON.parse(data)
31
+ return unless error_data['type'] == 'error'
32
+
33
+ case error_data.dig('error', 'type')
34
+ when 'overloaded_error'
35
+ [529, error_data['error']['message']]
36
+ else
37
+ [500, error_data['error']['message']]
38
+ end
39
+ end
34
40
  end
35
41
  end
36
42
  end
@@ -33,6 +33,10 @@ module RubyLLM
33
33
  def slug
34
34
  'anthropic'
35
35
  end
36
+
37
+ def configuration_requirements
38
+ %i[anthropic_api_key]
39
+ end
36
40
  end
37
41
  end
38
42
  end
@@ -0,0 +1,168 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module Providers
5
+ module Bedrock
6
+ # Determines capabilities and pricing for AWS Bedrock models
7
+ module Capabilities
8
+ module_function
9
+
10
+ # Returns the context window size for the given model ID
11
+ # @param model_id [String] the model identifier
12
+ # @return [Integer] the context window size in tokens
13
+ def context_window_for(model_id)
14
+ case model_id
15
+ when /anthropic\.claude-2/ then 100_000
16
+ else 200_000
17
+ end
18
+ end
19
+
20
+ # Returns the maximum output tokens for the given model ID
21
+ # @param model_id [String] the model identifier
22
+ # @return [Integer] the maximum output tokens
23
+ def max_tokens_for(_model_id)
24
+ 4_096
25
+ end
26
+
27
+ # Returns the input price per million tokens for the given model ID
28
+ # @param model_id [String] the model identifier
29
+ # @return [Float] the price per million tokens for input
30
+ def input_price_for(model_id)
31
+ PRICES.dig(model_family(model_id), :input) || default_input_price
32
+ end
33
+
34
+ # Returns the output price per million tokens for the given model ID
35
+ # @param model_id [String] the model identifier
36
+ # @return [Float] the price per million tokens for output
37
+ def output_price_for(model_id)
38
+ PRICES.dig(model_family(model_id), :output) || default_output_price
39
+ end
40
+
41
+ # Determines if the model supports chat capabilities
42
+ # @param model_id [String] the model identifier
43
+ # @return [Boolean] true if the model supports chat
44
+ def supports_chat?(model_id)
45
+ model_id.match?(/anthropic\.claude/)
46
+ end
47
+
48
+ # Determines if the model supports streaming capabilities
49
+ # @param model_id [String] the model identifier
50
+ # @return [Boolean] true if the model supports streaming
51
+ def supports_streaming?(model_id)
52
+ model_id.match?(/anthropic\.claude/)
53
+ end
54
+
55
+ # Determines if the model supports image input/output
56
+ # @param model_id [String] the model identifier
57
+ # @return [Boolean] true if the model supports images
58
+ def supports_images?(model_id)
59
+ model_id.match?(/anthropic\.claude/)
60
+ end
61
+
62
+ # Determines if the model supports vision capabilities
63
+ # @param model_id [String] the model identifier
64
+ # @return [Boolean] true if the model supports vision
65
+ def supports_vision?(model_id)
66
+ model_id.match?(/anthropic\.claude/)
67
+ end
68
+
69
+ # Determines if the model supports function calling
70
+ # @param model_id [String] the model identifier
71
+ # @return [Boolean] true if the model supports functions
72
+ def supports_functions?(model_id)
73
+ model_id.match?(/anthropic\.claude/)
74
+ end
75
+
76
+ # Determines if the model supports audio input/output
77
+ # @param model_id [String] the model identifier
78
+ # @return [Boolean] true if the model supports audio
79
+ def supports_audio?(_model_id)
80
+ false
81
+ end
82
+
83
+ # Determines if the model supports JSON mode
84
+ # @param model_id [String] the model identifier
85
+ # @return [Boolean] true if the model supports JSON mode
86
+ def supports_json_mode?(model_id)
87
+ model_id.match?(/anthropic\.claude/)
88
+ end
89
+
90
+ # Formats the model ID into a human-readable display name
91
+ # @param model_id [String] the model identifier
92
+ # @return [String] the formatted display name
93
+ def format_display_name(model_id)
94
+ model_id.then { |id| humanize(id) }
95
+ end
96
+
97
+ # Determines the type of model
98
+ # @param model_id [String] the model identifier
99
+ # @return [String] the model type (chat, embedding, image, audio)
100
+ def model_type(_model_id)
101
+ 'chat'
102
+ end
103
+
104
+ # Determines if the model supports structured output
105
+ # @param model_id [String] the model identifier
106
+ # @return [Boolean] true if the model supports structured output
107
+ def supports_structured_output?(model_id)
108
+ model_id.match?(/anthropic\.claude/)
109
+ end
110
+
111
+ # Model family patterns for capability lookup
112
+ MODEL_FAMILIES = {
113
+ /anthropic\.claude-3-opus/ => :claude3_opus,
114
+ /anthropic\.claude-3-sonnet/ => :claude3_sonnet,
115
+ /anthropic\.claude-3-5-sonnet/ => :claude3_sonnet,
116
+ /anthropic\.claude-3-7-sonnet/ => :claude3_sonnet,
117
+ /anthropic\.claude-3-haiku/ => :claude3_haiku,
118
+ /anthropic\.claude-3-5-haiku/ => :claude3_5_haiku,
119
+ /anthropic\.claude-v2/ => :claude2,
120
+ /anthropic\.claude-instant/ => :claude_instant
121
+ }.freeze
122
+
123
+ # Determines the model family for pricing and capability lookup
124
+ # @param model_id [String] the model identifier
125
+ # @return [Symbol] the model family identifier
126
+ def model_family(model_id)
127
+ MODEL_FAMILIES.find { |pattern, _family| model_id.match?(pattern) }&.last || :other
128
+ end
129
+
130
+ # Pricing information for Bedrock models (per million tokens)
131
+ PRICES = {
132
+ claude3_opus: { input: 15.0, output: 75.0 },
133
+ claude3_sonnet: { input: 3.0, output: 15.0 },
134
+ claude3_haiku: { input: 0.25, output: 1.25 },
135
+ claude3_5_haiku: { input: 0.8, output: 4.0 },
136
+ claude2: { input: 8.0, output: 24.0 },
137
+ claude_instant: { input: 0.8, output: 2.4 }
138
+ }.freeze
139
+
140
+ # Default input price when model-specific pricing is not available
141
+ # @return [Float] the default price per million tokens
142
+ def default_input_price
143
+ 0.1
144
+ end
145
+
146
+ # Default output price when model-specific pricing is not available
147
+ # @return [Float] the default price per million tokens
148
+ def default_output_price
149
+ 0.2
150
+ end
151
+
152
+ private
153
+
154
+ # Converts a model ID to a human-readable format
155
+ # @param id [String] the model identifier
156
+ # @return [String] the humanized model name
157
+ def humanize(id)
158
+ id.tr('-', ' ')
159
+ .split('.')
160
+ .last
161
+ .split
162
+ .map(&:capitalize)
163
+ .join(' ')
164
+ end
165
+ end
166
+ end
167
+ end
168
+ end
@@ -0,0 +1,108 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module Providers
5
+ module Bedrock
6
+ # Chat methods for the AWS Bedrock API implementation
7
+ module Chat
8
+ private
9
+
10
+ def completion_url
11
+ "model/#{@model_id}/invoke"
12
+ end
13
+
14
+ def render_payload(messages, tools:, temperature:, model:, stream: false) # rubocop:disable Lint/UnusedMethodArgument
15
+ # Hold model_id in instance variable for use in completion_url and stream_url
16
+ @model_id = model
17
+
18
+ system_messages, chat_messages = separate_messages(messages)
19
+ system_content = build_system_content(system_messages)
20
+
21
+ build_base_payload(chat_messages, temperature, model).tap do |payload|
22
+ add_optional_fields(payload, system_content:, tools:)
23
+ end
24
+ end
25
+
26
+ def separate_messages(messages)
27
+ messages.partition { |msg| msg.role == :system }
28
+ end
29
+
30
+ def build_system_content(system_messages)
31
+ if system_messages.length > 1
32
+ RubyLLM.logger.warn(
33
+ "Amazon Bedrock's Claude implementation only supports a single system message. " \
34
+ 'Multiple system messages will be combined into one.'
35
+ )
36
+ end
37
+
38
+ system_messages.map { |msg| format_message(msg)[:content] }.join("\n\n")
39
+ end
40
+
41
+ def build_base_payload(chat_messages, temperature, model)
42
+ {
43
+ anthropic_version: 'bedrock-2023-05-31',
44
+ messages: chat_messages.map { |msg| format_message(msg) },
45
+ temperature: temperature,
46
+ max_tokens: RubyLLM.models.find(model).max_tokens
47
+ }
48
+ end
49
+
50
+ def add_optional_fields(payload, system_content:, tools:)
51
+ payload[:tools] = tools.values.map { |t| function_for(t) } if tools.any?
52
+ payload[:system] = system_content unless system_content.empty?
53
+ end
54
+
55
+ def format_message(msg)
56
+ if msg.tool_call?
57
+ format_tool_call(msg)
58
+ elsif msg.tool_result?
59
+ format_tool_result(msg)
60
+ else
61
+ format_basic_message(msg)
62
+ end
63
+ end
64
+
65
+ def format_basic_message(msg)
66
+ {
67
+ role: convert_role(msg.role),
68
+ content: Anthropic::Media.format_content(msg.content)
69
+ }
70
+ end
71
+
72
+ def convert_role(role)
73
+ case role
74
+ when :tool, :user then 'user'
75
+ when :system then 'system'
76
+ else 'assistant'
77
+ end
78
+ end
79
+
80
+ def parse_completion_response(response)
81
+ data = response.body
82
+ content_blocks = data['content'] || []
83
+
84
+ text_content = extract_text_content(content_blocks)
85
+ tool_use = find_tool_use(content_blocks)
86
+
87
+ build_message(data, text_content, tool_use)
88
+ end
89
+
90
+ def extract_text_content(blocks)
91
+ text_blocks = blocks.select { |c| c['type'] == 'text' }
92
+ text_blocks.map { |c| c['text'] }.join
93
+ end
94
+
95
+ def build_message(data, content, tool_use)
96
+ Message.new(
97
+ role: :assistant,
98
+ content: content,
99
+ tool_calls: parse_tool_calls(tool_use),
100
+ input_tokens: data.dig('usage', 'input_tokens'),
101
+ output_tokens: data.dig('usage', 'output_tokens'),
102
+ model_id: data['model']
103
+ )
104
+ end
105
+ end
106
+ end
107
+ end
108
+ end