ask-llm-providers 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE +21 -0
- data/README.md +70 -0
- data/lib/ask/llm/config.rb +33 -0
- data/lib/ask/llm/http.rb +47 -0
- data/lib/ask/llm/version.rb +7 -0
- data/lib/ask/provider/anthropic.rb +230 -0
- data/lib/ask/provider/bedrock.rb +180 -0
- data/lib/ask/provider/cloudflare.rb +123 -0
- data/lib/ask/provider/google.rb +216 -0
- data/lib/ask/provider/mistral.rb +70 -0
- data/lib/ask/provider/ollama.rb +107 -0
- data/lib/ask/provider/openai.rb +155 -0
- data/lib/ask-llm-providers.rb +30 -0
- metadata +195 -0
checksums.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
---
|
|
2
|
+
SHA256:
|
|
3
|
+
metadata.gz: 677ec905a0f11d7072c4574d03193b85720065778678e940b38252d2adc2f1a0
|
|
4
|
+
data.tar.gz: 9cb65bb51e2ea18e6b7c1b92e0d7fcce64aab4b4d8e9d5493215200928e35eb9
|
|
5
|
+
SHA512:
|
|
6
|
+
metadata.gz: cf49fac238b8ce8a9a8df31dcab9a3854d35401eb45699bbb964b01bf50417e544ba5095a9c9b1e40c68b11095f1d3e35bad386b5549a6a6a8793e28ebb0b85b
|
|
7
|
+
data.tar.gz: 99e31531be1bbc2b0930f957630118d77de620e3e668750f04ea966a1bcd0c627339623ba59fae69464c7aa07bfe39be354f5934f145d2b6dcb3d8fb24c73c81
|
data/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Kaka Ruto
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
data/README.md
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# ask-llm-providers
|
|
2
|
+
|
|
3
|
+
All LLM providers for the ask-rb ecosystem in one gem. Implements `Ask::Provider`
|
|
4
|
+
from `ask-core` with a capabilities-based interface.
|
|
5
|
+
|
|
6
|
+
## Supported Providers
|
|
7
|
+
|
|
8
|
+
| Provider | Auth | Implementation |
|
|
9
|
+
|---|---|---|
|
|
10
|
+
| **OpenAI** + all OpenAI-compatible | `Ask::Auth.resolve(:openai_api_key)` | `Ask::Provider::OpenAI` |
|
|
11
|
+
| **Anthropic** (Claude) | `Ask::Auth.resolve(:anthropic_api_key)` | `Ask::Provider::Anthropic` |
|
|
12
|
+
| **Google Gemini** | `Ask::Auth.resolve(:gemini_api_key)` | `Ask::Provider::Google` |
|
|
13
|
+
| **Vertex AI** | GCP service account | `Ask::Provider::VertexAI` |
|
|
14
|
+
| **Amazon Bedrock** | AWS credentials chain | `Ask::Provider::Bedrock` |
|
|
15
|
+
| **Ollama** (local) | None needed | `Ask::Provider::Ollama` |
|
|
16
|
+
| **Mistral AI** | `Ask::Auth.resolve(:mistral_api_key)` | `Ask::Provider::Mistral` |
|
|
17
|
+
| **Cloudflare Workers AI** | `Ask::Auth.resolve(:cloudflare_api_key)` | `Ask::Provider::Cloudflare` |
|
|
18
|
+
|
|
19
|
+
## Installation
|
|
20
|
+
|
|
21
|
+
```ruby
|
|
22
|
+
gem "ask-llm-providers"
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
## Usage
|
|
26
|
+
|
|
27
|
+
```ruby
|
|
28
|
+
require "ask-llm-providers"
|
|
29
|
+
|
|
30
|
+
# All providers are auto-registered with Ask::Models
|
|
31
|
+
models = Ask::Models.find("gpt-4o")
|
|
32
|
+
# => { provider: :openai, capabilities: [...] }
|
|
33
|
+
|
|
34
|
+
# Use a provider directly
|
|
35
|
+
provider = Ask::Provider::OpenAI.new
|
|
36
|
+
provider.chat(conversation, tools: [], model: "gpt-4o") do |chunk|
|
|
37
|
+
print chunk.content
|
|
38
|
+
end
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
## Capabilities
|
|
42
|
+
|
|
43
|
+
Each provider and model exposes its capabilities:
|
|
44
|
+
|
|
45
|
+
```ruby
|
|
46
|
+
provider = Ask::Provider::OpenAI.new
|
|
47
|
+
provider.capabilities
|
|
48
|
+
# => [:chat, :streaming, :tool_calls, :vision, :thinking,
|
|
49
|
+
# :structured_output, :embed, :transcribe, :paint, :moderate]
|
|
50
|
+
|
|
51
|
+
model = Ask::Models.find("claude-sonnet-4-5")
|
|
52
|
+
model[:capabilities]
|
|
53
|
+
# => [:chat, :streaming, :tool_calls, :vision, :thinking, :prompt_caching]
|
|
54
|
+
|
|
55
|
+
# Unsupported capabilities raise a helpful error
|
|
56
|
+
provider = Ask::Provider::Anthropic.new
|
|
57
|
+
provider.embed(["text"], model: "claude-sonnet-4-5")
|
|
58
|
+
# => Ask::CapabilityNotSupported: Anthropic (claude-sonnet-4-5) does not support embeddings.
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Development
|
|
62
|
+
|
|
63
|
+
```bash
|
|
64
|
+
bin/setup
|
|
65
|
+
bundle exec rake test
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
## License
|
|
69
|
+
|
|
70
|
+
MIT
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ask
|
|
4
|
+
module LLM
|
|
5
|
+
# Simple config wrapper without requiring ostruct.
|
|
6
|
+
# Wraps a hash and provides method-based access.
|
|
7
|
+
class Config
|
|
8
|
+
def initialize(hash = {})
|
|
9
|
+
@hash = (hash || {}).transform_keys(&:to_sym)
|
|
10
|
+
# Also accept string keys
|
|
11
|
+
hash.each { |k, v| @hash[k.to_sym] = v if k.is_a?(String) }
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def method_missing(name, *args, &block)
|
|
15
|
+
if name.to_s.end_with?("=")
|
|
16
|
+
@hash[name.to_s.chomp("=").to_sym] = args.first
|
|
17
|
+
elsif args.empty?
|
|
18
|
+
@hash.key?(name) ? @hash[name] : nil
|
|
19
|
+
else
|
|
20
|
+
super
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def respond_to_missing?(name, include_private = false)
|
|
25
|
+
@hash.key?(name.to_s.sub(/=$/, "").to_sym) || super
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def to_h
|
|
29
|
+
@hash.dup
|
|
30
|
+
end
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
end
|
data/lib/ask/llm/http.rb
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ask
|
|
4
|
+
module LLM
|
|
5
|
+
# Shared HTTP infrastructure for all providers.
|
|
6
|
+
module HTTP
|
|
7
|
+
# Build a Faraday connection for a provider.
|
|
8
|
+
def self.connection(base_url, headers: {}, request: {})
|
|
9
|
+
Faraday.new(url: base_url, headers: headers, request: request) do |f|
|
|
10
|
+
f.request :json
|
|
11
|
+
f.response :json, content_type: /\bjson$/
|
|
12
|
+
f.adapter Faraday.default_adapter
|
|
13
|
+
end
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
# Map an HTTP exception or error response to the appropriate Ask::Error.
|
|
17
|
+
def self.map_error(status, body, provider:)
|
|
18
|
+
message = extract_error_message(body, status) || "HTTP #{status} from #{provider}"
|
|
19
|
+
|
|
20
|
+
# Check for context length exceeded regardless of status code
|
|
21
|
+
if body&.dig("error", "code") == "context_length_exceeded"
|
|
22
|
+
return Ask::ContextLengthExceeded.new("#{provider}: #{message}")
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
case status
|
|
26
|
+
when 400 then Ask::ProviderError.new(message, status_code: status, response_body: body&.to_json)
|
|
27
|
+
when 401, 403 then Ask::Unauthorized.new("#{provider}: #{message}")
|
|
28
|
+
when 429 then Ask::RateLimitError.new("#{provider}: #{message}")
|
|
29
|
+
when 500 then Ask::ServerError.new("#{provider}: #{message}")
|
|
30
|
+
when 503 then Ask::ServiceUnavailable.new("#{provider}: #{message}")
|
|
31
|
+
else Ask::ProviderError.new("#{provider}: #{message}", status_code: status, response_body: body&.to_json)
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Extract a human-readable error message from various provider error formats.
|
|
36
|
+
def self.extract_error_message(body, status)
|
|
37
|
+
return nil unless body
|
|
38
|
+
|
|
39
|
+
body.dig("error", "message") ||
|
|
40
|
+
body.dig("error", "msg") ||
|
|
41
|
+
body.dig("error", "error") ||
|
|
42
|
+
body["message"] ||
|
|
43
|
+
body.to_s
|
|
44
|
+
end
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
end
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ask
|
|
4
|
+
module Providers
|
|
5
|
+
# Anthropic Claude API provider.
|
|
6
|
+
class Anthropic < Ask::Provider
|
|
7
|
+
def initialize(config = {})
|
|
8
|
+
config = normalize_config(config)
|
|
9
|
+
super(config)
|
|
10
|
+
@http = build_http
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def api_base
|
|
14
|
+
@config.api_base || "https://api.anthropic.com"
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def headers
|
|
18
|
+
{
|
|
19
|
+
"x-api-key" => @config.api_key,
|
|
20
|
+
"anthropic-version" => "2023-06-01",
|
|
21
|
+
"Content-Type" => "application/json"
|
|
22
|
+
}
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def chat(messages, model:, tools: nil, temperature: nil, stream: nil, schema: nil, **params, &block)
|
|
26
|
+
msgs = messages.is_a?(Ask::Conversation) ? messages.to_a : messages
|
|
27
|
+
payload = build_chat_payload(msgs, model, tools, temperature, stream, schema, **params)
|
|
28
|
+
if stream
|
|
29
|
+
chat_stream(payload, model, &block)
|
|
30
|
+
else
|
|
31
|
+
chat_nonstream(payload, model)
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def embed(_texts, model: nil)
|
|
36
|
+
raise Ask::UnsupportedFeature, "Anthropic does not support embeddings"
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
def list_models
|
|
40
|
+
response = @http.get("v1/models")
|
|
41
|
+
return [] unless response.success?
|
|
42
|
+
response.body["data"].map { |m| Ask::ModelInfo.new(id: m["id"], provider: slug) }
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def parse_error(response)
|
|
46
|
+
body = response.body rescue nil
|
|
47
|
+
body&.dig("error", "message") || body&.dig("error", "type")
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
class << self
|
|
51
|
+
def capabilities
|
|
52
|
+
{ chat: true, streaming: true, tool_calls: true, vision: true, thinking: true, prompt_caching: true, structured_output: true }
|
|
53
|
+
end
|
|
54
|
+
def configuration_options; %i[api_key api_base]; end
|
|
55
|
+
def configuration_requirements; %i[api_key]; end
|
|
56
|
+
def slug; "anthropic"; end
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
private
|
|
60
|
+
|
|
61
|
+
def normalize_config(config)
|
|
62
|
+
return config unless config.is_a?(Hash)
|
|
63
|
+
Ask::LLM::Config.new(
|
|
64
|
+
api_key: config[:api_key] || config["api_key"] || config[:anthropic_api_key],
|
|
65
|
+
api_base: config[:api_base] || config["api_base"]
|
|
66
|
+
)
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
def build_http
|
|
70
|
+
LLM::HTTP.connection(api_base, headers: headers, request: { open_timeout: 30, timeout: 120 })
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
def build_chat_payload(messages, model, tools, temperature, stream, schema, **params)
|
|
74
|
+
system_msgs, chat_msgs = messages.partition { |m| (m[:role] || m["role"]).to_s == "system" }
|
|
75
|
+
system_content = format_system_content(system_msgs)
|
|
76
|
+
tools_array = format_tools(tools) if tools&.any?
|
|
77
|
+
|
|
78
|
+
payload = {
|
|
79
|
+
model: model,
|
|
80
|
+
messages: chat_msgs.map { |m| format_message(m) },
|
|
81
|
+
max_tokens: params.delete(:max_tokens) || 4096,
|
|
82
|
+
stream: stream || false
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
payload[:system] = system_content if system_content
|
|
86
|
+
payload[:tools] = tools_array if tools_array
|
|
87
|
+
payload[:temperature] = temperature if temperature
|
|
88
|
+
payload.merge(params)
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
def format_system_content(messages)
|
|
92
|
+
return nil if messages.empty?
|
|
93
|
+
texts = messages.map { |m| m[:content] || m["content"] }.compact
|
|
94
|
+
return nil if texts.empty?
|
|
95
|
+
texts.join("\n")
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
def format_message(msg)
|
|
99
|
+
role = (msg[:role] || msg["role"]).to_s
|
|
100
|
+
content = msg[:content] || msg["content"]
|
|
101
|
+
|
|
102
|
+
# Handle tool calls
|
|
103
|
+
if msg[:tool_calls] || msg["tool_calls"]
|
|
104
|
+
tc = msg[:tool_calls] || msg["tool_calls"]
|
|
105
|
+
return {
|
|
106
|
+
role: role,
|
|
107
|
+
content: content,
|
|
108
|
+
tool_calls: tc.map { |t|
|
|
109
|
+
{
|
|
110
|
+
type: "tool_use",
|
|
111
|
+
id: t[:id] || t["id"],
|
|
112
|
+
name: t.dig(:function, :name) || t.dig("function", "name") || t[:name],
|
|
113
|
+
input: parse_json(t.dig(:function, :arguments) || t.dig("function", "arguments") || t[:arguments] || "{}")
|
|
114
|
+
}
|
|
115
|
+
}.compact
|
|
116
|
+
}.compact
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
# Handle tool results
|
|
120
|
+
if msg[:tool_call_id] || msg["tool_call_id"]
|
|
121
|
+
return {
|
|
122
|
+
role: "user",
|
|
123
|
+
content: [{
|
|
124
|
+
type: "tool_result",
|
|
125
|
+
tool_use_id: msg[:tool_call_id] || msg["tool_call_id"],
|
|
126
|
+
content: content || ""
|
|
127
|
+
}]
|
|
128
|
+
}
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
{ role: role, content: content }.compact
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
def parse_json(str)
|
|
135
|
+
JSON.parse(str)
|
|
136
|
+
rescue JSON::ParserError
|
|
137
|
+
{}
|
|
138
|
+
end
|
|
139
|
+
|
|
140
|
+
def format_tools(tools)
|
|
141
|
+
tools.map do |t|
|
|
142
|
+
{
|
|
143
|
+
name: t.respond_to?(:name) ? t.name : t[:name],
|
|
144
|
+
description: t.respond_to?(:description) ? t.description : t[:description],
|
|
145
|
+
input_schema: t.respond_to?(:parameters) ? t.parameters : (t[:parameters] || { type: "object", properties: {} })
|
|
146
|
+
}
|
|
147
|
+
end
|
|
148
|
+
end
|
|
149
|
+
|
|
150
|
+
def chat_nonstream(payload, model)
|
|
151
|
+
response = @http.post("v1/messages") { |r| r.body = payload }
|
|
152
|
+
raise LLM::HTTP.map_error(response.status, response.body, provider: "Anthropic") unless response.success?
|
|
153
|
+
parse_response(response.body, model)
|
|
154
|
+
end
|
|
155
|
+
|
|
156
|
+
def parse_response(body, model)
|
|
157
|
+
content_blocks = body["content"] || []
|
|
158
|
+
text_content = content_blocks.select { |c| c["type"] == "text" }.map { |c| c["text"] }.join
|
|
159
|
+
tool_blocks = content_blocks.select { |c| c["type"] == "tool_use" }
|
|
160
|
+
thinking_blocks = content_blocks.select { |c| %w[thinking redacted_thinking].include?(c["type"]) }
|
|
161
|
+
usage = body["usage"] || {}
|
|
162
|
+
|
|
163
|
+
tool_calls = tool_blocks.map do |tb|
|
|
164
|
+
{ id: tb["id"], type: "function", name: tb["name"], arguments: JSON.generate(tb["input"]) }
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
metadata = {
|
|
168
|
+
model: body["model"] || model,
|
|
169
|
+
stop_reason: body["stop_reason"],
|
|
170
|
+
stop_sequence: body["stop_sequence"],
|
|
171
|
+
input_tokens: usage["input_tokens"],
|
|
172
|
+
output_tokens: usage["output_tokens"],
|
|
173
|
+
thinking: thinking_blocks.map { |b| b["thinking"] || b["text"] }.compact.join("\n"),
|
|
174
|
+
raw: body
|
|
175
|
+
}.compact
|
|
176
|
+
|
|
177
|
+
Ask::Message.new(role: :assistant, content: text_content.empty? ? nil : text_content, tool_calls: tool_calls.empty? ? nil : tool_calls, metadata: metadata)
|
|
178
|
+
end
|
|
179
|
+
|
|
180
|
+
def chat_stream(payload, model, &block)
|
|
181
|
+
stream = Ask::Stream.new
|
|
182
|
+
response = @http.post("v1/messages") do |req|
|
|
183
|
+
req.body = payload.merge(stream: true)
|
|
184
|
+
req.options.on_data = proc { |data, _bytes, _env| process_anthropic_chunk(data, stream, model, &block) }
|
|
185
|
+
end
|
|
186
|
+
raise LLM::HTTP.map_error(response.status, JSON.parse(response.body), provider: "Anthropic") unless response.success?
|
|
187
|
+
stream.finish!
|
|
188
|
+
stream
|
|
189
|
+
end
|
|
190
|
+
|
|
191
|
+
def process_anthropic_chunk(raw, stream, model)
|
|
192
|
+
raw.each_line do |line|
|
|
193
|
+
line = line.strip
|
|
194
|
+
next if line.empty? || line.start_with?(":")
|
|
195
|
+
next unless line.start_with?("event:") || line.start_with?("data:")
|
|
196
|
+
|
|
197
|
+
if line.start_with?("data: ")
|
|
198
|
+
data = line[6..]
|
|
199
|
+
begin
|
|
200
|
+
parsed = JSON.parse(data)
|
|
201
|
+
rescue JSON::ParserError
|
|
202
|
+
next
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
case parsed["type"]
|
|
206
|
+
when "content_block_delta"
|
|
207
|
+
delta = parsed.dig("delta")
|
|
208
|
+
next unless delta
|
|
209
|
+
chunk = Ask::Chunk.new(
|
|
210
|
+
content: delta["text"],
|
|
211
|
+
finish_reason: delta["type"] == "thinking_delta" ? nil : nil
|
|
212
|
+
)
|
|
213
|
+
stream.add(chunk)
|
|
214
|
+
yield chunk if block_given?
|
|
215
|
+
when "message_stop"
|
|
216
|
+
usage = parsed["usage"] || parsed["message"]&.dig("usage")
|
|
217
|
+
if usage
|
|
218
|
+
chunk = Ask::Chunk.new(finish_reason: "stop", usage: usage)
|
|
219
|
+
stream.add(chunk)
|
|
220
|
+
yield chunk if block_given?
|
|
221
|
+
end
|
|
222
|
+
when "message_start"
|
|
223
|
+
# Message started — no content yet
|
|
224
|
+
end
|
|
225
|
+
end
|
|
226
|
+
end
|
|
227
|
+
end
|
|
228
|
+
end
|
|
229
|
+
end
|
|
230
|
+
end
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ask
|
|
4
|
+
module Providers
|
|
5
|
+
# Amazon Bedrock provider using the Converse API.
|
|
6
|
+
# Uses the AWS SDK for authentication (credentials chain: env, ~/.aws, instance profile).
|
|
7
|
+
class Bedrock < Ask::Provider
|
|
8
|
+
def initialize(config = {})
|
|
9
|
+
config = normalize_config(config)
|
|
10
|
+
super(config)
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def api_base
|
|
14
|
+
@config.region || "us-east-1"
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def chat(messages, model:, tools: nil, temperature: nil, stream: nil, schema: nil, **params, &block)
|
|
18
|
+
msgs = messages.is_a?(Ask::Conversation) ? messages.to_a : messages
|
|
19
|
+
payload = build_converse_payload(msgs, model, tools, temperature, schema, **params)
|
|
20
|
+
if stream
|
|
21
|
+
chat_stream(payload, model, &block)
|
|
22
|
+
else
|
|
23
|
+
chat_nonstream(payload, model)
|
|
24
|
+
end
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def embed(_texts, model: nil)
|
|
28
|
+
raise Ask::UnsupportedFeature, "Bedrock does not support embeddings via Converse API"
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def list_models
|
|
32
|
+
# Bedrock doesn't have a list models endpoint — rely on model catalog
|
|
33
|
+
[]
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def parse_error(response)
|
|
37
|
+
response.body["message"] rescue nil
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
class << self
|
|
41
|
+
def capabilities
|
|
42
|
+
{ chat: true, streaming: true, tool_calls: true, vision: true }
|
|
43
|
+
end
|
|
44
|
+
def configuration_options; %i[region access_key_id secret_access_key session_token]; end
|
|
45
|
+
def configuration_requirements; %i[]; end
|
|
46
|
+
def slug; "bedrock"; end
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
private
|
|
50
|
+
|
|
51
|
+
def normalize_config(config)
|
|
52
|
+
return config unless config.is_a?(Hash)
|
|
53
|
+
Ask::LLM::Config.new(
|
|
54
|
+
region: config[:region] || config["region"] || ENV["AWS_REGION"] || "us-east-1",
|
|
55
|
+
access_key_id: config[:access_key_id] || config["access_key_id"] || ENV["AWS_ACCESS_KEY_ID"],
|
|
56
|
+
secret_access_key: config[:secret_access_key] || config["secret_access_key"] || ENV["AWS_SECRET_ACCESS_KEY"],
|
|
57
|
+
session_token: config[:session_token] || config["session_token"] || ENV["AWS_SESSION_TOKEN"]
|
|
58
|
+
)
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
def build_converse_payload(messages, model, tools, temperature, schema, **params)
|
|
62
|
+
system_msgs, chat_msgs = messages.partition { |m| (m[:role] || m["role"]).to_s == "system" }
|
|
63
|
+
payload = {
|
|
64
|
+
modelId: model,
|
|
65
|
+
messages: chat_msgs.map { |m| format_bedrock_msg(m) },
|
|
66
|
+
inferenceConfig: { temperature: temperature || 1.0 }.compact
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
sys = system_msgs.map { |m| m[:content] || m["content"] }.compact
|
|
70
|
+
payload[:system] = sys.map { |s| { text: s } } if sys.any?
|
|
71
|
+
if tools&.any?
|
|
72
|
+
payload[:toolConfig] = { tools: format_bedrock_tools(tools) }
|
|
73
|
+
end
|
|
74
|
+
if schema
|
|
75
|
+
payload[:inferenceConfig][:response_type] = "json_object"
|
|
76
|
+
end
|
|
77
|
+
payload.merge(params)
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
def format_bedrock_msg(msg)
|
|
81
|
+
role = (msg[:role] || msg["role"]).to_s
|
|
82
|
+
content = msg[:content] || msg["content"]
|
|
83
|
+
bedrock_role = role == "assistant" ? "assistant" : "user"
|
|
84
|
+
parts = []
|
|
85
|
+
|
|
86
|
+
parts << { text: content } if content
|
|
87
|
+
|
|
88
|
+
if msg[:tool_calls] || msg["tool_calls"]
|
|
89
|
+
(msg[:tool_calls] || msg["tool_calls"]).each do |tc|
|
|
90
|
+
parts << {
|
|
91
|
+
toolUse: {
|
|
92
|
+
toolUseId: tc[:id] || tc["id"],
|
|
93
|
+
name: tc.dig(:function, :name) || tc.dig("function", "name") || tc[:name],
|
|
94
|
+
input: parse_json(tc.dig(:function, :arguments) || tc.dig("function", "arguments") || tc[:arguments] || "{}")
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
end
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
if msg[:tool_call_id] || msg["tool_call_id"]
|
|
101
|
+
parts << {
|
|
102
|
+
toolResult: {
|
|
103
|
+
toolUseId: msg[:tool_call_id] || msg["tool_call_id"],
|
|
104
|
+
content: [{ text: content || "" }]
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
bedrock_role = "user"
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
{ role: bedrock_role, content: parts }
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
def format_bedrock_tools(tools)
|
|
114
|
+
tools.map do |t|
|
|
115
|
+
{ toolSpec: { name: t.respond_to?(:name) ? t.name : t[:name], description: t.respond_to?(:description) ? t.description : t[:description], inputSchema: { json: t.respond_to?(:parameters) ? t.parameters : (t[:parameters] || { type: "object", properties: {} }) } } }
|
|
116
|
+
end
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
def parse_json(str)
|
|
120
|
+
JSON.parse(str)
|
|
121
|
+
rescue JSON::ParserError
|
|
122
|
+
{}
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
def bedrock_client
|
|
126
|
+
require "aws-sdk-bedrockruntime"
|
|
127
|
+
Aws::BedrockRuntime::Client.new(region: @config.region)
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
def chat_nonstream(payload, model)
|
|
131
|
+
client = bedrock_client
|
|
132
|
+
resp = client.converse(payload)
|
|
133
|
+
parse_bedrock_response(resp, model)
|
|
134
|
+
rescue Aws::Errors::ServiceError => e
|
|
135
|
+
raise LLM::HTTP.map_error(e.status_code&.to_i || 500, { message: e.message }, provider: "Bedrock")
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
def parse_bedrock_response(resp, model)
|
|
139
|
+
output = resp.output
|
|
140
|
+
return Ask::Message.new(role: :assistant, content: nil) unless output
|
|
141
|
+
|
|
142
|
+
msg = output.message
|
|
143
|
+
text = msg.content&.map { |c| c.text }&.compact&.join
|
|
144
|
+
tool_uses = msg.content&.select { |c| c.tool_use } || []
|
|
145
|
+
tool_calls = tool_uses.map do |tu|
|
|
146
|
+
{ id: tu.tool_use.tool_use_id, type: "function", name: tu.tool_use.name, arguments: JSON.generate(tu.tool_use.input.to_h) }
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
usage = resp.usage || {}
|
|
150
|
+
Ask::Message.new(role: :assistant, content: text, tool_calls: tool_calls.empty? ? nil : tool_calls, metadata: { model: model, stop_reason: resp.stop_reason, input_tokens: usage.input_tokens, output_tokens: usage.output_tokens, raw: resp.to_h })
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
def chat_stream(payload, model, &block)
|
|
154
|
+
client = bedrock_client
|
|
155
|
+
stream = Ask::Stream.new
|
|
156
|
+
resp = client.converse_stream(payload)
|
|
157
|
+
resp.stream.each do |event|
|
|
158
|
+
if event.content_block_delta
|
|
159
|
+
delta = event.content_block_delta.delta
|
|
160
|
+
chunk = Ask::Chunk.new(content: delta.text) if delta.respond_to?(:text)
|
|
161
|
+
if chunk
|
|
162
|
+
stream.add(chunk)
|
|
163
|
+
yield chunk if block_given?
|
|
164
|
+
end
|
|
165
|
+
end
|
|
166
|
+
if event.message_stop
|
|
167
|
+
usage = event.message_stop.usage || {}
|
|
168
|
+
chunk = Ask::Chunk.new(finish_reason: "stop", usage: { input_tokens: usage.input_tokens, output_tokens: usage.output_tokens })
|
|
169
|
+
stream.add(chunk)
|
|
170
|
+
yield chunk if block_given?
|
|
171
|
+
end
|
|
172
|
+
end
|
|
173
|
+
stream.finish!
|
|
174
|
+
stream
|
|
175
|
+
rescue Aws::Errors::ServiceError => e
|
|
176
|
+
raise LLM::HTTP.map_error(e.status_code&.to_i || 500, { message: e.message }, provider: "Bedrock")
|
|
177
|
+
end
|
|
178
|
+
end
|
|
179
|
+
end
|
|
180
|
+
end
|