ruby_llm 1.0.1 → 1.1.0rc1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +28 -12
- data/lib/ruby_llm/active_record/acts_as.rb +46 -7
- data/lib/ruby_llm/aliases.json +65 -0
- data/lib/ruby_llm/aliases.rb +56 -0
- data/lib/ruby_llm/chat.rb +10 -9
- data/lib/ruby_llm/configuration.rb +4 -0
- data/lib/ruby_llm/error.rb +15 -4
- data/lib/ruby_llm/models.json +1163 -303
- data/lib/ruby_llm/models.rb +40 -11
- data/lib/ruby_llm/provider.rb +32 -39
- data/lib/ruby_llm/providers/anthropic/capabilities.rb +8 -9
- data/lib/ruby_llm/providers/anthropic/chat.rb +31 -4
- data/lib/ruby_llm/providers/anthropic/streaming.rb +12 -6
- data/lib/ruby_llm/providers/anthropic.rb +4 -0
- data/lib/ruby_llm/providers/bedrock/capabilities.rb +168 -0
- data/lib/ruby_llm/providers/bedrock/chat.rb +108 -0
- data/lib/ruby_llm/providers/bedrock/models.rb +84 -0
- data/lib/ruby_llm/providers/bedrock/signing.rb +831 -0
- data/lib/ruby_llm/providers/bedrock/streaming/base.rb +46 -0
- data/lib/ruby_llm/providers/bedrock/streaming/content_extraction.rb +63 -0
- data/lib/ruby_llm/providers/bedrock/streaming/message_processing.rb +79 -0
- data/lib/ruby_llm/providers/bedrock/streaming/payload_processing.rb +90 -0
- data/lib/ruby_llm/providers/bedrock/streaming/prelude_handling.rb +91 -0
- data/lib/ruby_llm/providers/bedrock/streaming.rb +36 -0
- data/lib/ruby_llm/providers/bedrock.rb +83 -0
- data/lib/ruby_llm/providers/deepseek/chat.rb +17 -0
- data/lib/ruby_llm/providers/deepseek.rb +5 -0
- data/lib/ruby_llm/providers/gemini/capabilities.rb +50 -34
- data/lib/ruby_llm/providers/gemini/chat.rb +8 -15
- data/lib/ruby_llm/providers/gemini/images.rb +5 -10
- data/lib/ruby_llm/providers/gemini/streaming.rb +35 -76
- data/lib/ruby_llm/providers/gemini/tools.rb +12 -12
- data/lib/ruby_llm/providers/gemini.rb +4 -0
- data/lib/ruby_llm/providers/openai/capabilities.rb +146 -206
- data/lib/ruby_llm/providers/openai/streaming.rb +9 -13
- data/lib/ruby_llm/providers/openai.rb +4 -0
- data/lib/ruby_llm/streaming.rb +96 -0
- data/lib/ruby_llm/version.rb +1 -1
- data/lib/ruby_llm.rb +6 -3
- data/lib/tasks/browser_helper.rb +97 -0
- data/lib/tasks/capability_generator.rb +123 -0
- data/lib/tasks/capability_scraper.rb +224 -0
- data/lib/tasks/cli_helper.rb +22 -0
- data/lib/tasks/code_validator.rb +29 -0
- data/lib/tasks/model_updater.rb +66 -0
- data/lib/tasks/models.rake +28 -193
- data/lib/tasks/vcr.rake +13 -30
- metadata +27 -19
- data/.github/workflows/cicd.yml +0 -158
- data/.github/workflows/docs.yml +0 -53
- data/.gitignore +0 -59
- data/.overcommit.yml +0 -26
- data/.rspec +0 -3
- data/.rubocop.yml +0 -10
- data/.yardopts +0 -12
- data/CONTRIBUTING.md +0 -207
- data/Gemfile +0 -33
- data/Rakefile +0 -9
- data/bin/console +0 -17
- data/bin/setup +0 -6
- data/ruby_llm.gemspec +0 -44
data/lib/ruby_llm/models.rb
CHANGED
@@ -26,10 +26,21 @@ module RubyLLM
|
|
26
26
|
File.expand_path('models.json', __dir__)
|
27
27
|
end
|
28
28
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
29
|
+
def refresh! # rubocop:disable Metrics/AbcSize,Metrics/CyclomaticComplexity,Metrics/PerceivedComplexity
|
30
|
+
configured = Provider.configured_providers
|
31
|
+
|
32
|
+
# Log provider status
|
33
|
+
skipped = Provider.providers.values - configured
|
34
|
+
RubyLLM.logger.info "Refreshing models from #{configured.map(&:slug).join(', ')}" if configured.any?
|
35
|
+
RubyLLM.logger.info "Skipping #{skipped.map(&:slug).join(', ')} - providers not configured" if skipped.any?
|
36
|
+
|
37
|
+
# Store current models except from configured providers
|
38
|
+
current = instance.load_models
|
39
|
+
preserved = current.reject { |m| configured.map(&:slug).include?(m.provider) }
|
40
|
+
|
41
|
+
all = (preserved + configured.flat_map(&:list_models)).sort_by(&:id)
|
42
|
+
@instance = new(all)
|
43
|
+
@instance
|
33
44
|
end
|
34
45
|
|
35
46
|
def method_missing(method, ...)
|
@@ -52,10 +63,10 @@ module RubyLLM
|
|
52
63
|
|
53
64
|
# Load models from the JSON file
|
54
65
|
def load_models
|
55
|
-
data =
|
56
|
-
data.map { |model| ModelInfo.new(model.transform_keys(&:to_sym)) }
|
57
|
-
rescue
|
58
|
-
[]
|
66
|
+
data = File.exist?(self.class.models_file) ? File.read(self.class.models_file) : '[]'
|
67
|
+
JSON.parse(data).map { |model| ModelInfo.new(model.transform_keys(&:to_sym)) }
|
68
|
+
rescue JSON::ParserError
|
69
|
+
[]
|
59
70
|
end
|
60
71
|
|
61
72
|
def save_models
|
@@ -73,9 +84,12 @@ module RubyLLM
|
|
73
84
|
end
|
74
85
|
|
75
86
|
# Find a specific model by ID
|
76
|
-
def find(model_id)
|
77
|
-
|
78
|
-
|
87
|
+
def find(model_id, provider = nil)
|
88
|
+
if provider
|
89
|
+
find_with_provider(model_id, provider)
|
90
|
+
else
|
91
|
+
find_without_provider(model_id)
|
92
|
+
end
|
79
93
|
end
|
80
94
|
|
81
95
|
# Filter to only chat models
|
@@ -112,5 +126,20 @@ module RubyLLM
|
|
112
126
|
def refresh!
|
113
127
|
self.class.refresh!
|
114
128
|
end
|
129
|
+
|
130
|
+
private
|
131
|
+
|
132
|
+
def find_with_provider(model_id, provider)
|
133
|
+
resolved_id = Aliases.resolve(model_id, provider)
|
134
|
+
all.find { |m| m.id == model_id && m.provider == provider.to_s } ||
|
135
|
+
all.find { |m| m.id == resolved_id && m.provider == provider.to_s } ||
|
136
|
+
raise(ModelNotFoundError, "Unknown model: #{model_id} for provider: #{provider}")
|
137
|
+
end
|
138
|
+
|
139
|
+
def find_without_provider(model_id)
|
140
|
+
all.find { |m| m.id == model_id } ||
|
141
|
+
all.find { |m| m.id == Aliases.resolve(model_id) } ||
|
142
|
+
raise(ModelNotFoundError, "Unknown model: #{model_id}")
|
143
|
+
end
|
115
144
|
end
|
116
145
|
end
|
data/lib/ruby_llm/provider.rb
CHANGED
@@ -7,7 +7,9 @@ module RubyLLM
|
|
7
7
|
module Provider
|
8
8
|
# Common functionality for all LLM providers. Implements the core provider
|
9
9
|
# interface so specific providers only need to implement a few key methods.
|
10
|
-
module Methods
|
10
|
+
module Methods
|
11
|
+
extend Streaming
|
12
|
+
|
11
13
|
def complete(messages, tools:, temperature:, model:, &block) # rubocop:disable Metrics/MethodLength
|
12
14
|
normalized_temperature = if capabilities.respond_to?(:normalize_temperature)
|
13
15
|
capabilities.normalize_temperature(temperature, model)
|
@@ -49,24 +51,35 @@ module RubyLLM
|
|
49
51
|
parse_image_response(response)
|
50
52
|
end
|
51
53
|
|
54
|
+
def configured?
|
55
|
+
missing_configs.empty?
|
56
|
+
end
|
57
|
+
|
52
58
|
private
|
53
59
|
|
54
|
-
def
|
55
|
-
|
56
|
-
|
60
|
+
def missing_configs
|
61
|
+
configuration_requirements.select do |key|
|
62
|
+
value = RubyLLM.config.send(key)
|
63
|
+
value.nil? || value.empty?
|
64
|
+
end
|
57
65
|
end
|
58
66
|
|
59
|
-
def
|
60
|
-
|
67
|
+
def ensure_configured!
|
68
|
+
return if configured?
|
61
69
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
block.call chunk
|
70
|
+
config_block = <<~RUBY
|
71
|
+
RubyLLM.configure do |config|
|
72
|
+
#{missing_configs.map { |key| "config.#{key} = ENV['#{key.to_s.upcase}']" }.join("\n ")}
|
66
73
|
end
|
67
|
-
|
74
|
+
RUBY
|
75
|
+
|
76
|
+
raise ConfigurationError,
|
77
|
+
"#{slug} provider is not configured. Add this to your initialization:\n\n#{config_block}"
|
78
|
+
end
|
68
79
|
|
69
|
-
|
80
|
+
def sync_response(payload)
|
81
|
+
response = post completion_url, payload
|
82
|
+
parse_completion_response response
|
70
83
|
end
|
71
84
|
|
72
85
|
def post(url, payload)
|
@@ -77,6 +90,8 @@ module RubyLLM
|
|
77
90
|
end
|
78
91
|
|
79
92
|
def connection # rubocop:disable Metrics/MethodLength,Metrics/AbcSize
|
93
|
+
ensure_configured!
|
94
|
+
|
80
95
|
@connection ||= Faraday.new(api_base) do |f| # rubocop:disable Metrics/BlockLength
|
81
96
|
f.options.timeout = RubyLLM.config.request_timeout
|
82
97
|
|
@@ -115,33 +130,6 @@ module RubyLLM
|
|
115
130
|
f.use :llm_errors, provider: self
|
116
131
|
end
|
117
132
|
end
|
118
|
-
|
119
|
-
def to_json_stream(&block) # rubocop:disable Metrics/MethodLength
|
120
|
-
buffer = String.new
|
121
|
-
parser = EventStreamParser::Parser.new
|
122
|
-
|
123
|
-
proc do |chunk, _bytes, env|
|
124
|
-
if env && env.status != 200
|
125
|
-
# Accumulate error chunks
|
126
|
-
buffer << chunk
|
127
|
-
begin
|
128
|
-
error_data = JSON.parse(buffer)
|
129
|
-
error_response = env.merge(body: error_data)
|
130
|
-
ErrorMiddleware.parse_error(provider: self, response: error_response)
|
131
|
-
rescue JSON::ParserError
|
132
|
-
# Keep accumulating if we don't have complete JSON yet
|
133
|
-
RubyLLM.logger.debug "Accumulating error chunk: #{chunk}"
|
134
|
-
end
|
135
|
-
else
|
136
|
-
parser.feed(chunk) do |_type, data|
|
137
|
-
unless data == '[DONE]'
|
138
|
-
parsed_data = JSON.parse(data)
|
139
|
-
block.call(parsed_data)
|
140
|
-
end
|
141
|
-
end
|
142
|
-
end
|
143
|
-
end
|
144
|
-
end
|
145
133
|
end
|
146
134
|
|
147
135
|
def try_parse_json(maybe_json)
|
@@ -181,6 +169,7 @@ module RubyLLM
|
|
181
169
|
class << self
|
182
170
|
def extended(base)
|
183
171
|
base.extend(Methods)
|
172
|
+
base.extend(Streaming)
|
184
173
|
end
|
185
174
|
|
186
175
|
def register(name, provider_module)
|
@@ -195,6 +184,10 @@ module RubyLLM
|
|
195
184
|
def providers
|
196
185
|
@providers ||= {}
|
197
186
|
end
|
187
|
+
|
188
|
+
def configured_providers
|
189
|
+
providers.select { |_name, provider| provider.configured? }.values
|
190
|
+
end
|
198
191
|
end
|
199
192
|
end
|
200
193
|
end
|
@@ -20,8 +20,8 @@ module RubyLLM
|
|
20
20
|
# @return [Integer] the maximum output tokens
|
21
21
|
def determine_max_tokens(model_id)
|
22
22
|
case model_id
|
23
|
-
when /claude-3-
|
24
|
-
else 4_096
|
23
|
+
when /claude-3-7-sonnet/, /claude-3-5/ then 8_192
|
24
|
+
else 4_096
|
25
25
|
end
|
26
26
|
end
|
27
27
|
|
@@ -92,13 +92,12 @@ module RubyLLM
|
|
92
92
|
|
93
93
|
# Pricing information for Anthropic models (per million tokens)
|
94
94
|
PRICES = {
|
95
|
-
claude37_sonnet: { input: 3.0, output: 15.0 },
|
96
|
-
claude35_sonnet: { input: 3.0, output: 15.0 },
|
97
|
-
claude35_haiku: { input: 0.80, output: 4.0 },
|
98
|
-
claude3_opus: { input: 15.0, output: 75.0 },
|
99
|
-
|
100
|
-
|
101
|
-
claude2: { input: 3.0, output: 15.0 } # Default pricing for Claude 2.x models
|
95
|
+
claude37_sonnet: { input: 3.0, output: 15.0 },
|
96
|
+
claude35_sonnet: { input: 3.0, output: 15.0 },
|
97
|
+
claude35_haiku: { input: 0.80, output: 4.0 },
|
98
|
+
claude3_opus: { input: 15.0, output: 75.0 },
|
99
|
+
claude3_haiku: { input: 0.25, output: 1.25 },
|
100
|
+
claude2: { input: 3.0, output: 15.0 }
|
102
101
|
}.freeze
|
103
102
|
|
104
103
|
# Default input price if model not found in PRICES
|
@@ -12,15 +12,42 @@ module RubyLLM
|
|
12
12
|
end
|
13
13
|
|
14
14
|
def render_payload(messages, tools:, temperature:, model:, stream: false)
|
15
|
+
system_messages, chat_messages = separate_messages(messages)
|
16
|
+
system_content = build_system_content(system_messages)
|
17
|
+
|
18
|
+
build_base_payload(chat_messages, temperature, model, stream).tap do |payload|
|
19
|
+
add_optional_fields(payload, system_content:, tools:)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
def separate_messages(messages)
|
24
|
+
messages.partition { |msg| msg.role == :system }
|
25
|
+
end
|
26
|
+
|
27
|
+
def build_system_content(system_messages)
|
28
|
+
if system_messages.length > 1
|
29
|
+
RubyLLM.logger.warn(
|
30
|
+
"Anthropic's Claude implementation only supports a single system message. " \
|
31
|
+
'Multiple system messages will be combined into one.'
|
32
|
+
)
|
33
|
+
end
|
34
|
+
|
35
|
+
system_messages.map { |msg| format_message(msg)[:content] }.join("\n\n")
|
36
|
+
end
|
37
|
+
|
38
|
+
def build_base_payload(chat_messages, temperature, model, stream)
|
15
39
|
{
|
16
40
|
model: model,
|
17
|
-
messages:
|
41
|
+
messages: chat_messages.map { |msg| format_message(msg) },
|
18
42
|
temperature: temperature,
|
19
43
|
stream: stream,
|
20
44
|
max_tokens: RubyLLM.models.find(model).max_tokens
|
21
|
-
}
|
22
|
-
|
23
|
-
|
45
|
+
}
|
46
|
+
end
|
47
|
+
|
48
|
+
def add_optional_fields(payload, system_content:, tools:)
|
49
|
+
payload[:tools] = tools.values.map { |t| function_for(t) } if tools.any?
|
50
|
+
payload[:system] = system_content unless system_content.empty?
|
24
51
|
end
|
25
52
|
|
26
53
|
def parse_completion_response(response)
|
@@ -11,12 +11,6 @@ module RubyLLM
|
|
11
11
|
completion_url
|
12
12
|
end
|
13
13
|
|
14
|
-
def handle_stream(&block)
|
15
|
-
to_json_stream do |data|
|
16
|
-
block.call(build_chunk(data))
|
17
|
-
end
|
18
|
-
end
|
19
|
-
|
20
14
|
def build_chunk(data)
|
21
15
|
Chunk.new(
|
22
16
|
role: :assistant,
|
@@ -31,6 +25,18 @@ module RubyLLM
|
|
31
25
|
def json_delta?(data)
|
32
26
|
data['type'] == 'content_block_delta' && data.dig('delta', 'type') == 'input_json_delta'
|
33
27
|
end
|
28
|
+
|
29
|
+
def parse_streaming_error(data)
|
30
|
+
error_data = JSON.parse(data)
|
31
|
+
return unless error_data['type'] == 'error'
|
32
|
+
|
33
|
+
case error_data.dig('error', 'type')
|
34
|
+
when 'overloaded_error'
|
35
|
+
[529, error_data['error']['message']]
|
36
|
+
else
|
37
|
+
[500, error_data['error']['message']]
|
38
|
+
end
|
39
|
+
end
|
34
40
|
end
|
35
41
|
end
|
36
42
|
end
|
@@ -0,0 +1,168 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module RubyLLM
|
4
|
+
module Providers
|
5
|
+
module Bedrock
|
6
|
+
# Determines capabilities and pricing for AWS Bedrock models
|
7
|
+
module Capabilities
|
8
|
+
module_function
|
9
|
+
|
10
|
+
# Returns the context window size for the given model ID
|
11
|
+
# @param model_id [String] the model identifier
|
12
|
+
# @return [Integer] the context window size in tokens
|
13
|
+
def context_window_for(model_id)
|
14
|
+
case model_id
|
15
|
+
when /anthropic\.claude-2/ then 100_000
|
16
|
+
else 200_000
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
# Returns the maximum output tokens for the given model ID
|
21
|
+
# @param model_id [String] the model identifier
|
22
|
+
# @return [Integer] the maximum output tokens
|
23
|
+
def max_tokens_for(_model_id)
|
24
|
+
4_096
|
25
|
+
end
|
26
|
+
|
27
|
+
# Returns the input price per million tokens for the given model ID
|
28
|
+
# @param model_id [String] the model identifier
|
29
|
+
# @return [Float] the price per million tokens for input
|
30
|
+
def input_price_for(model_id)
|
31
|
+
PRICES.dig(model_family(model_id), :input) || default_input_price
|
32
|
+
end
|
33
|
+
|
34
|
+
# Returns the output price per million tokens for the given model ID
|
35
|
+
# @param model_id [String] the model identifier
|
36
|
+
# @return [Float] the price per million tokens for output
|
37
|
+
def output_price_for(model_id)
|
38
|
+
PRICES.dig(model_family(model_id), :output) || default_output_price
|
39
|
+
end
|
40
|
+
|
41
|
+
# Determines if the model supports chat capabilities
|
42
|
+
# @param model_id [String] the model identifier
|
43
|
+
# @return [Boolean] true if the model supports chat
|
44
|
+
def supports_chat?(model_id)
|
45
|
+
model_id.match?(/anthropic\.claude/)
|
46
|
+
end
|
47
|
+
|
48
|
+
# Determines if the model supports streaming capabilities
|
49
|
+
# @param model_id [String] the model identifier
|
50
|
+
# @return [Boolean] true if the model supports streaming
|
51
|
+
def supports_streaming?(model_id)
|
52
|
+
model_id.match?(/anthropic\.claude/)
|
53
|
+
end
|
54
|
+
|
55
|
+
# Determines if the model supports image input/output
|
56
|
+
# @param model_id [String] the model identifier
|
57
|
+
# @return [Boolean] true if the model supports images
|
58
|
+
def supports_images?(model_id)
|
59
|
+
model_id.match?(/anthropic\.claude/)
|
60
|
+
end
|
61
|
+
|
62
|
+
# Determines if the model supports vision capabilities
|
63
|
+
# @param model_id [String] the model identifier
|
64
|
+
# @return [Boolean] true if the model supports vision
|
65
|
+
def supports_vision?(model_id)
|
66
|
+
model_id.match?(/anthropic\.claude/)
|
67
|
+
end
|
68
|
+
|
69
|
+
# Determines if the model supports function calling
|
70
|
+
# @param model_id [String] the model identifier
|
71
|
+
# @return [Boolean] true if the model supports functions
|
72
|
+
def supports_functions?(model_id)
|
73
|
+
model_id.match?(/anthropic\.claude/)
|
74
|
+
end
|
75
|
+
|
76
|
+
# Determines if the model supports audio input/output
|
77
|
+
# @param model_id [String] the model identifier
|
78
|
+
# @return [Boolean] true if the model supports audio
|
79
|
+
def supports_audio?(_model_id)
|
80
|
+
false
|
81
|
+
end
|
82
|
+
|
83
|
+
# Determines if the model supports JSON mode
|
84
|
+
# @param model_id [String] the model identifier
|
85
|
+
# @return [Boolean] true if the model supports JSON mode
|
86
|
+
def supports_json_mode?(model_id)
|
87
|
+
model_id.match?(/anthropic\.claude/)
|
88
|
+
end
|
89
|
+
|
90
|
+
# Formats the model ID into a human-readable display name
|
91
|
+
# @param model_id [String] the model identifier
|
92
|
+
# @return [String] the formatted display name
|
93
|
+
def format_display_name(model_id)
|
94
|
+
model_id.then { |id| humanize(id) }
|
95
|
+
end
|
96
|
+
|
97
|
+
# Determines the type of model
|
98
|
+
# @param model_id [String] the model identifier
|
99
|
+
# @return [String] the model type (chat, embedding, image, audio)
|
100
|
+
def model_type(_model_id)
|
101
|
+
'chat'
|
102
|
+
end
|
103
|
+
|
104
|
+
# Determines if the model supports structured output
|
105
|
+
# @param model_id [String] the model identifier
|
106
|
+
# @return [Boolean] true if the model supports structured output
|
107
|
+
def supports_structured_output?(model_id)
|
108
|
+
model_id.match?(/anthropic\.claude/)
|
109
|
+
end
|
110
|
+
|
111
|
+
# Model family patterns for capability lookup
|
112
|
+
MODEL_FAMILIES = {
|
113
|
+
/anthropic\.claude-3-opus/ => :claude3_opus,
|
114
|
+
/anthropic\.claude-3-sonnet/ => :claude3_sonnet,
|
115
|
+
/anthropic\.claude-3-5-sonnet/ => :claude3_sonnet,
|
116
|
+
/anthropic\.claude-3-7-sonnet/ => :claude3_sonnet,
|
117
|
+
/anthropic\.claude-3-haiku/ => :claude3_haiku,
|
118
|
+
/anthropic\.claude-3-5-haiku/ => :claude3_5_haiku,
|
119
|
+
/anthropic\.claude-v2/ => :claude2,
|
120
|
+
/anthropic\.claude-instant/ => :claude_instant
|
121
|
+
}.freeze
|
122
|
+
|
123
|
+
# Determines the model family for pricing and capability lookup
|
124
|
+
# @param model_id [String] the model identifier
|
125
|
+
# @return [Symbol] the model family identifier
|
126
|
+
def model_family(model_id)
|
127
|
+
MODEL_FAMILIES.find { |pattern, _family| model_id.match?(pattern) }&.last || :other
|
128
|
+
end
|
129
|
+
|
130
|
+
# Pricing information for Bedrock models (per million tokens)
|
131
|
+
PRICES = {
|
132
|
+
claude3_opus: { input: 15.0, output: 75.0 },
|
133
|
+
claude3_sonnet: { input: 3.0, output: 15.0 },
|
134
|
+
claude3_haiku: { input: 0.25, output: 1.25 },
|
135
|
+
claude3_5_haiku: { input: 0.8, output: 4.0 },
|
136
|
+
claude2: { input: 8.0, output: 24.0 },
|
137
|
+
claude_instant: { input: 0.8, output: 2.4 }
|
138
|
+
}.freeze
|
139
|
+
|
140
|
+
# Default input price when model-specific pricing is not available
|
141
|
+
# @return [Float] the default price per million tokens
|
142
|
+
def default_input_price
|
143
|
+
0.1
|
144
|
+
end
|
145
|
+
|
146
|
+
# Default output price when model-specific pricing is not available
|
147
|
+
# @return [Float] the default price per million tokens
|
148
|
+
def default_output_price
|
149
|
+
0.2
|
150
|
+
end
|
151
|
+
|
152
|
+
private
|
153
|
+
|
154
|
+
# Converts a model ID to a human-readable format
|
155
|
+
# @param id [String] the model identifier
|
156
|
+
# @return [String] the humanized model name
|
157
|
+
def humanize(id)
|
158
|
+
id.tr('-', ' ')
|
159
|
+
.split('.')
|
160
|
+
.last
|
161
|
+
.split
|
162
|
+
.map(&:capitalize)
|
163
|
+
.join(' ')
|
164
|
+
end
|
165
|
+
end
|
166
|
+
end
|
167
|
+
end
|
168
|
+
end
|
@@ -0,0 +1,108 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module RubyLLM
|
4
|
+
module Providers
|
5
|
+
module Bedrock
|
6
|
+
# Chat methods for the AWS Bedrock API implementation
|
7
|
+
module Chat
|
8
|
+
private
|
9
|
+
|
10
|
+
def completion_url
|
11
|
+
"model/#{@model_id}/invoke"
|
12
|
+
end
|
13
|
+
|
14
|
+
def render_payload(messages, tools:, temperature:, model:, stream: false) # rubocop:disable Lint/UnusedMethodArgument
|
15
|
+
# Hold model_id in instance variable for use in completion_url and stream_url
|
16
|
+
@model_id = model
|
17
|
+
|
18
|
+
system_messages, chat_messages = separate_messages(messages)
|
19
|
+
system_content = build_system_content(system_messages)
|
20
|
+
|
21
|
+
build_base_payload(chat_messages, temperature, model).tap do |payload|
|
22
|
+
add_optional_fields(payload, system_content:, tools:)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def separate_messages(messages)
|
27
|
+
messages.partition { |msg| msg.role == :system }
|
28
|
+
end
|
29
|
+
|
30
|
+
def build_system_content(system_messages)
|
31
|
+
if system_messages.length > 1
|
32
|
+
RubyLLM.logger.warn(
|
33
|
+
"Amazon Bedrock's Claude implementation only supports a single system message. " \
|
34
|
+
'Multiple system messages will be combined into one.'
|
35
|
+
)
|
36
|
+
end
|
37
|
+
|
38
|
+
system_messages.map { |msg| format_message(msg)[:content] }.join("\n\n")
|
39
|
+
end
|
40
|
+
|
41
|
+
def build_base_payload(chat_messages, temperature, model)
|
42
|
+
{
|
43
|
+
anthropic_version: 'bedrock-2023-05-31',
|
44
|
+
messages: chat_messages.map { |msg| format_message(msg) },
|
45
|
+
temperature: temperature,
|
46
|
+
max_tokens: RubyLLM.models.find(model).max_tokens
|
47
|
+
}
|
48
|
+
end
|
49
|
+
|
50
|
+
def add_optional_fields(payload, system_content:, tools:)
|
51
|
+
payload[:tools] = tools.values.map { |t| function_for(t) } if tools.any?
|
52
|
+
payload[:system] = system_content unless system_content.empty?
|
53
|
+
end
|
54
|
+
|
55
|
+
def format_message(msg)
|
56
|
+
if msg.tool_call?
|
57
|
+
format_tool_call(msg)
|
58
|
+
elsif msg.tool_result?
|
59
|
+
format_tool_result(msg)
|
60
|
+
else
|
61
|
+
format_basic_message(msg)
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
def format_basic_message(msg)
|
66
|
+
{
|
67
|
+
role: convert_role(msg.role),
|
68
|
+
content: Anthropic::Media.format_content(msg.content)
|
69
|
+
}
|
70
|
+
end
|
71
|
+
|
72
|
+
def convert_role(role)
|
73
|
+
case role
|
74
|
+
when :tool, :user then 'user'
|
75
|
+
when :system then 'system'
|
76
|
+
else 'assistant'
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def parse_completion_response(response)
|
81
|
+
data = response.body
|
82
|
+
content_blocks = data['content'] || []
|
83
|
+
|
84
|
+
text_content = extract_text_content(content_blocks)
|
85
|
+
tool_use = find_tool_use(content_blocks)
|
86
|
+
|
87
|
+
build_message(data, text_content, tool_use)
|
88
|
+
end
|
89
|
+
|
90
|
+
def extract_text_content(blocks)
|
91
|
+
text_blocks = blocks.select { |c| c['type'] == 'text' }
|
92
|
+
text_blocks.map { |c| c['text'] }.join
|
93
|
+
end
|
94
|
+
|
95
|
+
def build_message(data, content, tool_use)
|
96
|
+
Message.new(
|
97
|
+
role: :assistant,
|
98
|
+
content: content,
|
99
|
+
tool_calls: parse_tool_calls(tool_use),
|
100
|
+
input_tokens: data.dig('usage', 'input_tokens'),
|
101
|
+
output_tokens: data.dig('usage', 'output_tokens'),
|
102
|
+
model_id: data['model']
|
103
|
+
)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
107
|
+
end
|
108
|
+
end
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module RubyLLM
|
4
|
+
module Providers
|
5
|
+
module Bedrock
|
6
|
+
# Models methods for the AWS Bedrock API implementation
|
7
|
+
module Models
|
8
|
+
def list_models
|
9
|
+
@connection = nil # reset connection since base url is different
|
10
|
+
@api_base = "https://bedrock.#{RubyLLM.config.bedrock_region}.amazonaws.com"
|
11
|
+
full_models_url = "#{@api_base}/#{models_url}"
|
12
|
+
signature = sign_request(full_models_url, method: :get)
|
13
|
+
response = connection.get(models_url) do |req|
|
14
|
+
req.headers.merge! signature.headers
|
15
|
+
end
|
16
|
+
@connection = nil # reset connection since base url is different
|
17
|
+
|
18
|
+
parse_list_models_response(response, slug, capabilities)
|
19
|
+
end
|
20
|
+
|
21
|
+
module_function
|
22
|
+
|
23
|
+
def models_url
|
24
|
+
'foundation-models'
|
25
|
+
end
|
26
|
+
|
27
|
+
def parse_list_models_response(response, slug, capabilities)
|
28
|
+
data = response.body['modelSummaries'] || []
|
29
|
+
data.filter { |model| model['modelId'].include?('claude') }
|
30
|
+
.map { |model| create_model_info(model, slug, capabilities) }
|
31
|
+
end
|
32
|
+
|
33
|
+
def create_model_info(model, slug, capabilities)
|
34
|
+
model_id = model['modelId']
|
35
|
+
ModelInfo.new(
|
36
|
+
**base_model_attributes(model_id, model, slug),
|
37
|
+
**capability_attributes(model_id, capabilities),
|
38
|
+
**pricing_attributes(model_id, capabilities),
|
39
|
+
metadata: build_metadata(model)
|
40
|
+
)
|
41
|
+
end
|
42
|
+
|
43
|
+
def base_model_attributes(model_id, model, slug)
|
44
|
+
{
|
45
|
+
id: model_id,
|
46
|
+
created_at: nil,
|
47
|
+
display_name: model['modelName'] || capabilities.format_display_name(model_id),
|
48
|
+
provider: slug
|
49
|
+
}
|
50
|
+
end
|
51
|
+
|
52
|
+
def capability_attributes(model_id, capabilities)
|
53
|
+
{
|
54
|
+
context_window: capabilities.context_window_for(model_id),
|
55
|
+
max_tokens: capabilities.max_tokens_for(model_id),
|
56
|
+
type: capabilities.model_type(model_id),
|
57
|
+
family: capabilities.model_family(model_id).to_s,
|
58
|
+
supports_vision: capabilities.supports_vision?(model_id),
|
59
|
+
supports_functions: capabilities.supports_functions?(model_id),
|
60
|
+
supports_json_mode: capabilities.supports_json_mode?(model_id)
|
61
|
+
}
|
62
|
+
end
|
63
|
+
|
64
|
+
def pricing_attributes(model_id, capabilities)
|
65
|
+
{
|
66
|
+
input_price_per_million: capabilities.input_price_for(model_id),
|
67
|
+
output_price_per_million: capabilities.output_price_for(model_id)
|
68
|
+
}
|
69
|
+
end
|
70
|
+
|
71
|
+
def build_metadata(model)
|
72
|
+
{
|
73
|
+
provider_name: model['providerName'],
|
74
|
+
customizations_supported: model['customizationsSupported'] || [],
|
75
|
+
inference_configurations: model['inferenceTypesSupported'] || [],
|
76
|
+
response_streaming_supported: model['responseStreamingSupported'] || false,
|
77
|
+
input_modalities: model['inputModalities'] || [],
|
78
|
+
output_modalities: model['outputModalities'] || []
|
79
|
+
}
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|