ruby_llm 1.0.0 → 1.1.0rc1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +58 -19
- data/lib/ruby_llm/active_record/acts_as.rb +46 -7
- data/lib/ruby_llm/aliases.json +65 -0
- data/lib/ruby_llm/aliases.rb +56 -0
- data/lib/ruby_llm/chat.rb +11 -10
- data/lib/ruby_llm/configuration.rb +4 -0
- data/lib/ruby_llm/error.rb +15 -4
- data/lib/ruby_llm/models.json +1489 -283
- data/lib/ruby_llm/models.rb +57 -22
- data/lib/ruby_llm/provider.rb +44 -41
- data/lib/ruby_llm/providers/anthropic/capabilities.rb +8 -9
- data/lib/ruby_llm/providers/anthropic/chat.rb +31 -4
- data/lib/ruby_llm/providers/anthropic/streaming.rb +12 -6
- data/lib/ruby_llm/providers/anthropic.rb +4 -0
- data/lib/ruby_llm/providers/bedrock/capabilities.rb +168 -0
- data/lib/ruby_llm/providers/bedrock/chat.rb +108 -0
- data/lib/ruby_llm/providers/bedrock/models.rb +84 -0
- data/lib/ruby_llm/providers/bedrock/signing.rb +831 -0
- data/lib/ruby_llm/providers/bedrock/streaming/base.rb +46 -0
- data/lib/ruby_llm/providers/bedrock/streaming/content_extraction.rb +63 -0
- data/lib/ruby_llm/providers/bedrock/streaming/message_processing.rb +79 -0
- data/lib/ruby_llm/providers/bedrock/streaming/payload_processing.rb +90 -0
- data/lib/ruby_llm/providers/bedrock/streaming/prelude_handling.rb +91 -0
- data/lib/ruby_llm/providers/bedrock/streaming.rb +36 -0
- data/lib/ruby_llm/providers/bedrock.rb +83 -0
- data/lib/ruby_llm/providers/deepseek/chat.rb +17 -0
- data/lib/ruby_llm/providers/deepseek.rb +5 -0
- data/lib/ruby_llm/providers/gemini/capabilities.rb +50 -34
- data/lib/ruby_llm/providers/gemini/chat.rb +8 -15
- data/lib/ruby_llm/providers/gemini/images.rb +5 -10
- data/lib/ruby_llm/providers/gemini/models.rb +0 -8
- data/lib/ruby_llm/providers/gemini/streaming.rb +35 -76
- data/lib/ruby_llm/providers/gemini/tools.rb +12 -12
- data/lib/ruby_llm/providers/gemini.rb +4 -0
- data/lib/ruby_llm/providers/openai/capabilities.rb +154 -177
- data/lib/ruby_llm/providers/openai/streaming.rb +9 -13
- data/lib/ruby_llm/providers/openai.rb +4 -0
- data/lib/ruby_llm/streaming.rb +96 -0
- data/lib/ruby_llm/tool.rb +15 -7
- data/lib/ruby_llm/version.rb +1 -1
- data/lib/ruby_llm.rb +8 -3
- data/lib/tasks/browser_helper.rb +97 -0
- data/lib/tasks/capability_generator.rb +123 -0
- data/lib/tasks/capability_scraper.rb +224 -0
- data/lib/tasks/cli_helper.rb +22 -0
- data/lib/tasks/code_validator.rb +29 -0
- data/lib/tasks/model_updater.rb +66 -0
- data/lib/tasks/models.rake +28 -197
- data/lib/tasks/vcr.rake +97 -0
- metadata +42 -19
- data/.github/workflows/cicd.yml +0 -109
- data/.github/workflows/docs.yml +0 -53
- data/.gitignore +0 -58
- data/.overcommit.yml +0 -26
- data/.rspec +0 -3
- data/.rspec_status +0 -50
- data/.rubocop.yml +0 -10
- data/.yardopts +0 -12
- data/Gemfile +0 -32
- data/Rakefile +0 -9
- data/bin/console +0 -17
- data/bin/setup +0 -6
- data/ruby_llm.gemspec +0 -43
data/lib/ruby_llm/models.rb
CHANGED
@@ -12,22 +12,37 @@ module RubyLLM
|
|
12
12
|
class Models
|
13
13
|
include Enumerable
|
14
14
|
|
15
|
-
|
16
|
-
|
17
|
-
|
15
|
+
# Delegate class methods to the singleton instance
|
16
|
+
class << self
|
17
|
+
def instance
|
18
|
+
@instance ||= new
|
19
|
+
end
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
-
|
21
|
+
def provider_for(model)
|
22
|
+
Provider.for(model)
|
23
|
+
end
|
22
24
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
25
|
+
def models_file
|
26
|
+
File.expand_path('models.json', __dir__)
|
27
|
+
end
|
28
|
+
|
29
|
+
def refresh! # rubocop:disable Metrics/AbcSize,Metrics/CyclomaticComplexity,Metrics/PerceivedComplexity
|
30
|
+
configured = Provider.configured_providers
|
31
|
+
|
32
|
+
# Log provider status
|
33
|
+
skipped = Provider.providers.values - configured
|
34
|
+
RubyLLM.logger.info "Refreshing models from #{configured.map(&:slug).join(', ')}" if configured.any?
|
35
|
+
RubyLLM.logger.info "Skipping #{skipped.map(&:slug).join(', ')} - providers not configured" if skipped.any?
|
36
|
+
|
37
|
+
# Store current models except from configured providers
|
38
|
+
current = instance.load_models
|
39
|
+
preserved = current.reject { |m| configured.map(&:slug).include?(m.provider) }
|
40
|
+
|
41
|
+
all = (preserved + configured.flat_map(&:list_models)).sort_by(&:id)
|
42
|
+
@instance = new(all)
|
43
|
+
@instance
|
44
|
+
end
|
28
45
|
|
29
|
-
# Delegate class methods to the singleton instance
|
30
|
-
class << self
|
31
46
|
def method_missing(method, ...)
|
32
47
|
if instance.respond_to?(method)
|
33
48
|
instance.send(method, ...)
|
@@ -48,10 +63,14 @@ module RubyLLM
|
|
48
63
|
|
49
64
|
# Load models from the JSON file
|
50
65
|
def load_models
|
51
|
-
data =
|
52
|
-
data.map { |model| ModelInfo.new(model.transform_keys(&:to_sym)) }
|
53
|
-
rescue
|
54
|
-
[]
|
66
|
+
data = File.exist?(self.class.models_file) ? File.read(self.class.models_file) : '[]'
|
67
|
+
JSON.parse(data).map { |model| ModelInfo.new(model.transform_keys(&:to_sym)) }
|
68
|
+
rescue JSON::ParserError
|
69
|
+
[]
|
70
|
+
end
|
71
|
+
|
72
|
+
def save_models
|
73
|
+
File.write(self.class.models_file, JSON.pretty_generate(all.map(&:to_h)))
|
55
74
|
end
|
56
75
|
|
57
76
|
# Return all models in the collection
|
@@ -65,9 +84,12 @@ module RubyLLM
|
|
65
84
|
end
|
66
85
|
|
67
86
|
# Find a specific model by ID
|
68
|
-
def find(model_id)
|
69
|
-
|
70
|
-
|
87
|
+
def find(model_id, provider = nil)
|
88
|
+
if provider
|
89
|
+
find_with_provider(model_id, provider)
|
90
|
+
else
|
91
|
+
find_without_provider(model_id)
|
92
|
+
end
|
71
93
|
end
|
72
94
|
|
73
95
|
# Filter to only chat models
|
@@ -103,8 +125,21 @@ module RubyLLM
|
|
103
125
|
# Instance method to refresh models
|
104
126
|
def refresh!
|
105
127
|
self.class.refresh!
|
106
|
-
|
107
|
-
|
128
|
+
end
|
129
|
+
|
130
|
+
private
|
131
|
+
|
132
|
+
def find_with_provider(model_id, provider)
|
133
|
+
resolved_id = Aliases.resolve(model_id, provider)
|
134
|
+
all.find { |m| m.id == model_id && m.provider == provider.to_s } ||
|
135
|
+
all.find { |m| m.id == resolved_id && m.provider == provider.to_s } ||
|
136
|
+
raise(ModelNotFoundError, "Unknown model: #{model_id} for provider: #{provider}")
|
137
|
+
end
|
138
|
+
|
139
|
+
def find_without_provider(model_id)
|
140
|
+
all.find { |m| m.id == model_id } ||
|
141
|
+
all.find { |m| m.id == Aliases.resolve(model_id) } ||
|
142
|
+
raise(ModelNotFoundError, "Unknown model: #{model_id}")
|
108
143
|
end
|
109
144
|
end
|
110
145
|
end
|
data/lib/ruby_llm/provider.rb
CHANGED
@@ -7,9 +7,21 @@ module RubyLLM
|
|
7
7
|
module Provider
|
8
8
|
# Common functionality for all LLM providers. Implements the core provider
|
9
9
|
# interface so specific providers only need to implement a few key methods.
|
10
|
-
module Methods
|
11
|
-
|
12
|
-
|
10
|
+
module Methods
|
11
|
+
extend Streaming
|
12
|
+
|
13
|
+
def complete(messages, tools:, temperature:, model:, &block) # rubocop:disable Metrics/MethodLength
|
14
|
+
normalized_temperature = if capabilities.respond_to?(:normalize_temperature)
|
15
|
+
capabilities.normalize_temperature(temperature, model)
|
16
|
+
else
|
17
|
+
temperature
|
18
|
+
end
|
19
|
+
|
20
|
+
payload = render_payload(messages,
|
21
|
+
tools: tools,
|
22
|
+
temperature: normalized_temperature,
|
23
|
+
model: model,
|
24
|
+
stream: block_given?)
|
13
25
|
|
14
26
|
if block_given?
|
15
27
|
stream_response payload, &block
|
@@ -39,24 +51,35 @@ module RubyLLM
|
|
39
51
|
parse_image_response(response)
|
40
52
|
end
|
41
53
|
|
54
|
+
def configured?
|
55
|
+
missing_configs.empty?
|
56
|
+
end
|
57
|
+
|
42
58
|
private
|
43
59
|
|
44
|
-
def
|
45
|
-
|
46
|
-
|
60
|
+
def missing_configs
|
61
|
+
configuration_requirements.select do |key|
|
62
|
+
value = RubyLLM.config.send(key)
|
63
|
+
value.nil? || value.empty?
|
64
|
+
end
|
47
65
|
end
|
48
66
|
|
49
|
-
def
|
50
|
-
|
67
|
+
def ensure_configured!
|
68
|
+
return if configured?
|
51
69
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
block.call chunk
|
70
|
+
config_block = <<~RUBY
|
71
|
+
RubyLLM.configure do |config|
|
72
|
+
#{missing_configs.map { |key| "config.#{key} = ENV['#{key.to_s.upcase}']" }.join("\n ")}
|
56
73
|
end
|
57
|
-
|
74
|
+
RUBY
|
58
75
|
|
59
|
-
|
76
|
+
raise ConfigurationError,
|
77
|
+
"#{slug} provider is not configured. Add this to your initialization:\n\n#{config_block}"
|
78
|
+
end
|
79
|
+
|
80
|
+
def sync_response(payload)
|
81
|
+
response = post completion_url, payload
|
82
|
+
parse_completion_response response
|
60
83
|
end
|
61
84
|
|
62
85
|
def post(url, payload)
|
@@ -67,6 +90,8 @@ module RubyLLM
|
|
67
90
|
end
|
68
91
|
|
69
92
|
def connection # rubocop:disable Metrics/MethodLength,Metrics/AbcSize
|
93
|
+
ensure_configured!
|
94
|
+
|
70
95
|
@connection ||= Faraday.new(api_base) do |f| # rubocop:disable Metrics/BlockLength
|
71
96
|
f.options.timeout = RubyLLM.config.request_timeout
|
72
97
|
|
@@ -105,33 +130,6 @@ module RubyLLM
|
|
105
130
|
f.use :llm_errors, provider: self
|
106
131
|
end
|
107
132
|
end
|
108
|
-
|
109
|
-
def to_json_stream(&block) # rubocop:disable Metrics/MethodLength
|
110
|
-
buffer = String.new
|
111
|
-
parser = EventStreamParser::Parser.new
|
112
|
-
|
113
|
-
proc do |chunk, _bytes, env|
|
114
|
-
if env && env.status != 200
|
115
|
-
# Accumulate error chunks
|
116
|
-
buffer << chunk
|
117
|
-
begin
|
118
|
-
error_data = JSON.parse(buffer)
|
119
|
-
error_response = env.merge(body: error_data)
|
120
|
-
ErrorMiddleware.parse_error(provider: self, response: error_response)
|
121
|
-
rescue JSON::ParserError
|
122
|
-
# Keep accumulating if we don't have complete JSON yet
|
123
|
-
RubyLLM.logger.debug "Accumulating error chunk: #{chunk}"
|
124
|
-
end
|
125
|
-
else
|
126
|
-
parser.feed(chunk) do |_type, data|
|
127
|
-
unless data == '[DONE]'
|
128
|
-
parsed_data = JSON.parse(data)
|
129
|
-
block.call(parsed_data)
|
130
|
-
end
|
131
|
-
end
|
132
|
-
end
|
133
|
-
end
|
134
|
-
end
|
135
133
|
end
|
136
134
|
|
137
135
|
def try_parse_json(maybe_json)
|
@@ -171,6 +169,7 @@ module RubyLLM
|
|
171
169
|
class << self
|
172
170
|
def extended(base)
|
173
171
|
base.extend(Methods)
|
172
|
+
base.extend(Streaming)
|
174
173
|
end
|
175
174
|
|
176
175
|
def register(name, provider_module)
|
@@ -185,6 +184,10 @@ module RubyLLM
|
|
185
184
|
def providers
|
186
185
|
@providers ||= {}
|
187
186
|
end
|
187
|
+
|
188
|
+
def configured_providers
|
189
|
+
providers.select { |_name, provider| provider.configured? }.values
|
190
|
+
end
|
188
191
|
end
|
189
192
|
end
|
190
193
|
end
|
@@ -20,8 +20,8 @@ module RubyLLM
|
|
20
20
|
# @return [Integer] the maximum output tokens
|
21
21
|
def determine_max_tokens(model_id)
|
22
22
|
case model_id
|
23
|
-
when /claude-3-
|
24
|
-
else 4_096
|
23
|
+
when /claude-3-7-sonnet/, /claude-3-5/ then 8_192
|
24
|
+
else 4_096
|
25
25
|
end
|
26
26
|
end
|
27
27
|
|
@@ -92,13 +92,12 @@ module RubyLLM
|
|
92
92
|
|
93
93
|
# Pricing information for Anthropic models (per million tokens)
|
94
94
|
PRICES = {
|
95
|
-
claude37_sonnet: { input: 3.0, output: 15.0 },
|
96
|
-
claude35_sonnet: { input: 3.0, output: 15.0 },
|
97
|
-
claude35_haiku: { input: 0.80, output: 4.0 },
|
98
|
-
claude3_opus: { input: 15.0, output: 75.0 },
|
99
|
-
|
100
|
-
|
101
|
-
claude2: { input: 3.0, output: 15.0 } # Default pricing for Claude 2.x models
|
95
|
+
claude37_sonnet: { input: 3.0, output: 15.0 },
|
96
|
+
claude35_sonnet: { input: 3.0, output: 15.0 },
|
97
|
+
claude35_haiku: { input: 0.80, output: 4.0 },
|
98
|
+
claude3_opus: { input: 15.0, output: 75.0 },
|
99
|
+
claude3_haiku: { input: 0.25, output: 1.25 },
|
100
|
+
claude2: { input: 3.0, output: 15.0 }
|
102
101
|
}.freeze
|
103
102
|
|
104
103
|
# Default input price if model not found in PRICES
|
@@ -12,15 +12,42 @@ module RubyLLM
|
|
12
12
|
end
|
13
13
|
|
14
14
|
def render_payload(messages, tools:, temperature:, model:, stream: false)
|
15
|
+
system_messages, chat_messages = separate_messages(messages)
|
16
|
+
system_content = build_system_content(system_messages)
|
17
|
+
|
18
|
+
build_base_payload(chat_messages, temperature, model, stream).tap do |payload|
|
19
|
+
add_optional_fields(payload, system_content:, tools:)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
def separate_messages(messages)
|
24
|
+
messages.partition { |msg| msg.role == :system }
|
25
|
+
end
|
26
|
+
|
27
|
+
def build_system_content(system_messages)
|
28
|
+
if system_messages.length > 1
|
29
|
+
RubyLLM.logger.warn(
|
30
|
+
"Anthropic's Claude implementation only supports a single system message. " \
|
31
|
+
'Multiple system messages will be combined into one.'
|
32
|
+
)
|
33
|
+
end
|
34
|
+
|
35
|
+
system_messages.map { |msg| format_message(msg)[:content] }.join("\n\n")
|
36
|
+
end
|
37
|
+
|
38
|
+
def build_base_payload(chat_messages, temperature, model, stream)
|
15
39
|
{
|
16
40
|
model: model,
|
17
|
-
messages:
|
41
|
+
messages: chat_messages.map { |msg| format_message(msg) },
|
18
42
|
temperature: temperature,
|
19
43
|
stream: stream,
|
20
44
|
max_tokens: RubyLLM.models.find(model).max_tokens
|
21
|
-
}
|
22
|
-
|
23
|
-
|
45
|
+
}
|
46
|
+
end
|
47
|
+
|
48
|
+
def add_optional_fields(payload, system_content:, tools:)
|
49
|
+
payload[:tools] = tools.values.map { |t| function_for(t) } if tools.any?
|
50
|
+
payload[:system] = system_content unless system_content.empty?
|
24
51
|
end
|
25
52
|
|
26
53
|
def parse_completion_response(response)
|
@@ -11,12 +11,6 @@ module RubyLLM
|
|
11
11
|
completion_url
|
12
12
|
end
|
13
13
|
|
14
|
-
def handle_stream(&block)
|
15
|
-
to_json_stream do |data|
|
16
|
-
block.call(build_chunk(data))
|
17
|
-
end
|
18
|
-
end
|
19
|
-
|
20
14
|
def build_chunk(data)
|
21
15
|
Chunk.new(
|
22
16
|
role: :assistant,
|
@@ -31,6 +25,18 @@ module RubyLLM
|
|
31
25
|
def json_delta?(data)
|
32
26
|
data['type'] == 'content_block_delta' && data.dig('delta', 'type') == 'input_json_delta'
|
33
27
|
end
|
28
|
+
|
29
|
+
def parse_streaming_error(data)
|
30
|
+
error_data = JSON.parse(data)
|
31
|
+
return unless error_data['type'] == 'error'
|
32
|
+
|
33
|
+
case error_data.dig('error', 'type')
|
34
|
+
when 'overloaded_error'
|
35
|
+
[529, error_data['error']['message']]
|
36
|
+
else
|
37
|
+
[500, error_data['error']['message']]
|
38
|
+
end
|
39
|
+
end
|
34
40
|
end
|
35
41
|
end
|
36
42
|
end
|
@@ -0,0 +1,168 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module RubyLLM
|
4
|
+
module Providers
|
5
|
+
module Bedrock
|
6
|
+
# Determines capabilities and pricing for AWS Bedrock models
|
7
|
+
module Capabilities
|
8
|
+
module_function
|
9
|
+
|
10
|
+
# Returns the context window size for the given model ID
|
11
|
+
# @param model_id [String] the model identifier
|
12
|
+
# @return [Integer] the context window size in tokens
|
13
|
+
def context_window_for(model_id)
|
14
|
+
case model_id
|
15
|
+
when /anthropic\.claude-2/ then 100_000
|
16
|
+
else 200_000
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
# Returns the maximum output tokens for the given model ID
|
21
|
+
# @param model_id [String] the model identifier
|
22
|
+
# @return [Integer] the maximum output tokens
|
23
|
+
def max_tokens_for(_model_id)
|
24
|
+
4_096
|
25
|
+
end
|
26
|
+
|
27
|
+
# Returns the input price per million tokens for the given model ID
|
28
|
+
# @param model_id [String] the model identifier
|
29
|
+
# @return [Float] the price per million tokens for input
|
30
|
+
def input_price_for(model_id)
|
31
|
+
PRICES.dig(model_family(model_id), :input) || default_input_price
|
32
|
+
end
|
33
|
+
|
34
|
+
# Returns the output price per million tokens for the given model ID
|
35
|
+
# @param model_id [String] the model identifier
|
36
|
+
# @return [Float] the price per million tokens for output
|
37
|
+
def output_price_for(model_id)
|
38
|
+
PRICES.dig(model_family(model_id), :output) || default_output_price
|
39
|
+
end
|
40
|
+
|
41
|
+
# Determines if the model supports chat capabilities
|
42
|
+
# @param model_id [String] the model identifier
|
43
|
+
# @return [Boolean] true if the model supports chat
|
44
|
+
def supports_chat?(model_id)
|
45
|
+
model_id.match?(/anthropic\.claude/)
|
46
|
+
end
|
47
|
+
|
48
|
+
# Determines if the model supports streaming capabilities
|
49
|
+
# @param model_id [String] the model identifier
|
50
|
+
# @return [Boolean] true if the model supports streaming
|
51
|
+
def supports_streaming?(model_id)
|
52
|
+
model_id.match?(/anthropic\.claude/)
|
53
|
+
end
|
54
|
+
|
55
|
+
# Determines if the model supports image input/output
|
56
|
+
# @param model_id [String] the model identifier
|
57
|
+
# @return [Boolean] true if the model supports images
|
58
|
+
def supports_images?(model_id)
|
59
|
+
model_id.match?(/anthropic\.claude/)
|
60
|
+
end
|
61
|
+
|
62
|
+
# Determines if the model supports vision capabilities
|
63
|
+
# @param model_id [String] the model identifier
|
64
|
+
# @return [Boolean] true if the model supports vision
|
65
|
+
def supports_vision?(model_id)
|
66
|
+
model_id.match?(/anthropic\.claude/)
|
67
|
+
end
|
68
|
+
|
69
|
+
# Determines if the model supports function calling
|
70
|
+
# @param model_id [String] the model identifier
|
71
|
+
# @return [Boolean] true if the model supports functions
|
72
|
+
def supports_functions?(model_id)
|
73
|
+
model_id.match?(/anthropic\.claude/)
|
74
|
+
end
|
75
|
+
|
76
|
+
# Determines if the model supports audio input/output
|
77
|
+
# @param model_id [String] the model identifier
|
78
|
+
# @return [Boolean] true if the model supports audio
|
79
|
+
def supports_audio?(_model_id)
|
80
|
+
false
|
81
|
+
end
|
82
|
+
|
83
|
+
# Determines if the model supports JSON mode
|
84
|
+
# @param model_id [String] the model identifier
|
85
|
+
# @return [Boolean] true if the model supports JSON mode
|
86
|
+
def supports_json_mode?(model_id)
|
87
|
+
model_id.match?(/anthropic\.claude/)
|
88
|
+
end
|
89
|
+
|
90
|
+
# Formats the model ID into a human-readable display name
|
91
|
+
# @param model_id [String] the model identifier
|
92
|
+
# @return [String] the formatted display name
|
93
|
+
def format_display_name(model_id)
|
94
|
+
model_id.then { |id| humanize(id) }
|
95
|
+
end
|
96
|
+
|
97
|
+
# Determines the type of model
|
98
|
+
# @param model_id [String] the model identifier
|
99
|
+
# @return [String] the model type (chat, embedding, image, audio)
|
100
|
+
def model_type(_model_id)
|
101
|
+
'chat'
|
102
|
+
end
|
103
|
+
|
104
|
+
# Determines if the model supports structured output
|
105
|
+
# @param model_id [String] the model identifier
|
106
|
+
# @return [Boolean] true if the model supports structured output
|
107
|
+
def supports_structured_output?(model_id)
|
108
|
+
model_id.match?(/anthropic\.claude/)
|
109
|
+
end
|
110
|
+
|
111
|
+
# Model family patterns for capability lookup
|
112
|
+
MODEL_FAMILIES = {
|
113
|
+
/anthropic\.claude-3-opus/ => :claude3_opus,
|
114
|
+
/anthropic\.claude-3-sonnet/ => :claude3_sonnet,
|
115
|
+
/anthropic\.claude-3-5-sonnet/ => :claude3_sonnet,
|
116
|
+
/anthropic\.claude-3-7-sonnet/ => :claude3_sonnet,
|
117
|
+
/anthropic\.claude-3-haiku/ => :claude3_haiku,
|
118
|
+
/anthropic\.claude-3-5-haiku/ => :claude3_5_haiku,
|
119
|
+
/anthropic\.claude-v2/ => :claude2,
|
120
|
+
/anthropic\.claude-instant/ => :claude_instant
|
121
|
+
}.freeze
|
122
|
+
|
123
|
+
# Determines the model family for pricing and capability lookup
|
124
|
+
# @param model_id [String] the model identifier
|
125
|
+
# @return [Symbol] the model family identifier
|
126
|
+
def model_family(model_id)
|
127
|
+
MODEL_FAMILIES.find { |pattern, _family| model_id.match?(pattern) }&.last || :other
|
128
|
+
end
|
129
|
+
|
130
|
+
# Pricing information for Bedrock models (per million tokens)
|
131
|
+
PRICES = {
|
132
|
+
claude3_opus: { input: 15.0, output: 75.0 },
|
133
|
+
claude3_sonnet: { input: 3.0, output: 15.0 },
|
134
|
+
claude3_haiku: { input: 0.25, output: 1.25 },
|
135
|
+
claude3_5_haiku: { input: 0.8, output: 4.0 },
|
136
|
+
claude2: { input: 8.0, output: 24.0 },
|
137
|
+
claude_instant: { input: 0.8, output: 2.4 }
|
138
|
+
}.freeze
|
139
|
+
|
140
|
+
# Default input price when model-specific pricing is not available
|
141
|
+
# @return [Float] the default price per million tokens
|
142
|
+
def default_input_price
|
143
|
+
0.1
|
144
|
+
end
|
145
|
+
|
146
|
+
# Default output price when model-specific pricing is not available
|
147
|
+
# @return [Float] the default price per million tokens
|
148
|
+
def default_output_price
|
149
|
+
0.2
|
150
|
+
end
|
151
|
+
|
152
|
+
private
|
153
|
+
|
154
|
+
# Converts a model ID to a human-readable format
|
155
|
+
# @param id [String] the model identifier
|
156
|
+
# @return [String] the humanized model name
|
157
|
+
def humanize(id)
|
158
|
+
id.tr('-', ' ')
|
159
|
+
.split('.')
|
160
|
+
.last
|
161
|
+
.split
|
162
|
+
.map(&:capitalize)
|
163
|
+
.join(' ')
|
164
|
+
end
|
165
|
+
end
|
166
|
+
end
|
167
|
+
end
|
168
|
+
end
|
@@ -0,0 +1,108 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module RubyLLM
|
4
|
+
module Providers
|
5
|
+
module Bedrock
|
6
|
+
# Chat methods for the AWS Bedrock API implementation
|
7
|
+
module Chat
|
8
|
+
private
|
9
|
+
|
10
|
+
def completion_url
|
11
|
+
"model/#{@model_id}/invoke"
|
12
|
+
end
|
13
|
+
|
14
|
+
def render_payload(messages, tools:, temperature:, model:, stream: false) # rubocop:disable Lint/UnusedMethodArgument
|
15
|
+
# Hold model_id in instance variable for use in completion_url and stream_url
|
16
|
+
@model_id = model
|
17
|
+
|
18
|
+
system_messages, chat_messages = separate_messages(messages)
|
19
|
+
system_content = build_system_content(system_messages)
|
20
|
+
|
21
|
+
build_base_payload(chat_messages, temperature, model).tap do |payload|
|
22
|
+
add_optional_fields(payload, system_content:, tools:)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def separate_messages(messages)
|
27
|
+
messages.partition { |msg| msg.role == :system }
|
28
|
+
end
|
29
|
+
|
30
|
+
def build_system_content(system_messages)
|
31
|
+
if system_messages.length > 1
|
32
|
+
RubyLLM.logger.warn(
|
33
|
+
"Amazon Bedrock's Claude implementation only supports a single system message. " \
|
34
|
+
'Multiple system messages will be combined into one.'
|
35
|
+
)
|
36
|
+
end
|
37
|
+
|
38
|
+
system_messages.map { |msg| format_message(msg)[:content] }.join("\n\n")
|
39
|
+
end
|
40
|
+
|
41
|
+
def build_base_payload(chat_messages, temperature, model)
|
42
|
+
{
|
43
|
+
anthropic_version: 'bedrock-2023-05-31',
|
44
|
+
messages: chat_messages.map { |msg| format_message(msg) },
|
45
|
+
temperature: temperature,
|
46
|
+
max_tokens: RubyLLM.models.find(model).max_tokens
|
47
|
+
}
|
48
|
+
end
|
49
|
+
|
50
|
+
def add_optional_fields(payload, system_content:, tools:)
|
51
|
+
payload[:tools] = tools.values.map { |t| function_for(t) } if tools.any?
|
52
|
+
payload[:system] = system_content unless system_content.empty?
|
53
|
+
end
|
54
|
+
|
55
|
+
def format_message(msg)
|
56
|
+
if msg.tool_call?
|
57
|
+
format_tool_call(msg)
|
58
|
+
elsif msg.tool_result?
|
59
|
+
format_tool_result(msg)
|
60
|
+
else
|
61
|
+
format_basic_message(msg)
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
def format_basic_message(msg)
|
66
|
+
{
|
67
|
+
role: convert_role(msg.role),
|
68
|
+
content: Anthropic::Media.format_content(msg.content)
|
69
|
+
}
|
70
|
+
end
|
71
|
+
|
72
|
+
def convert_role(role)
|
73
|
+
case role
|
74
|
+
when :tool, :user then 'user'
|
75
|
+
when :system then 'system'
|
76
|
+
else 'assistant'
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def parse_completion_response(response)
|
81
|
+
data = response.body
|
82
|
+
content_blocks = data['content'] || []
|
83
|
+
|
84
|
+
text_content = extract_text_content(content_blocks)
|
85
|
+
tool_use = find_tool_use(content_blocks)
|
86
|
+
|
87
|
+
build_message(data, text_content, tool_use)
|
88
|
+
end
|
89
|
+
|
90
|
+
def extract_text_content(blocks)
|
91
|
+
text_blocks = blocks.select { |c| c['type'] == 'text' }
|
92
|
+
text_blocks.map { |c| c['text'] }.join
|
93
|
+
end
|
94
|
+
|
95
|
+
def build_message(data, content, tool_use)
|
96
|
+
Message.new(
|
97
|
+
role: :assistant,
|
98
|
+
content: content,
|
99
|
+
tool_calls: parse_tool_calls(tool_use),
|
100
|
+
input_tokens: data.dig('usage', 'input_tokens'),
|
101
|
+
output_tokens: data.dig('usage', 'output_tokens'),
|
102
|
+
model_id: data['model']
|
103
|
+
)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
107
|
+
end
|
108
|
+
end
|