ask-llm-providers 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE +21 -0
- data/README.md +70 -0
- data/lib/ask/llm/config.rb +33 -0
- data/lib/ask/llm/http.rb +47 -0
- data/lib/ask/llm/version.rb +7 -0
- data/lib/ask/provider/anthropic.rb +230 -0
- data/lib/ask/provider/bedrock.rb +180 -0
- data/lib/ask/provider/cloudflare.rb +123 -0
- data/lib/ask/provider/google.rb +216 -0
- data/lib/ask/provider/mistral.rb +70 -0
- data/lib/ask/provider/ollama.rb +107 -0
- data/lib/ask/provider/openai.rb +155 -0
- data/lib/ask-llm-providers.rb +30 -0
- metadata +195 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ask
|
|
4
|
+
module Providers
|
|
5
|
+
# Cloudflare Workers AI provider. Supports both direct Workers AI and AI Gateway.
|
|
6
|
+
class Cloudflare < Ask::Provider
|
|
7
|
+
def initialize(config = {})
|
|
8
|
+
config = normalize_config(config)
|
|
9
|
+
super(config)
|
|
10
|
+
@http = build_http
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def api_base
|
|
14
|
+
if @config.gateway_id
|
|
15
|
+
"https://gateway.ai.cloudflare.com/v1/#{@config.account_id}/#{@config.gateway_id}"
|
|
16
|
+
else
|
|
17
|
+
"https://api.cloudflare.com/client/v4/accounts/#{@config.account_id}/ai/v1"
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def headers
|
|
22
|
+
{ "Content-Type" => "application/json", "Authorization" => "Bearer #{@config.api_key}" }.compact
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def chat(messages, model:, tools: nil, temperature: nil, stream: nil, schema: nil, **params, &block)
|
|
26
|
+
msgs = messages.is_a?(Ask::Conversation) ? messages.to_a : messages
|
|
27
|
+
endpoint = @config.gateway_id ? "chat/completions" : "run/#{model}"
|
|
28
|
+
payload = if @config.gateway_id
|
|
29
|
+
{ model: model, messages: msgs.map { |m| { role: (m[:role] || m["role"]).to_s, content: m[:content] || m["content"] } }, stream: stream || false }
|
|
30
|
+
else
|
|
31
|
+
{ messages: msgs.map { |m| { role: (m[:role] || m["role"]).to_s, content: m[:content] || m["content"] } } }
|
|
32
|
+
end
|
|
33
|
+
payload[:temperature] = temperature if temperature
|
|
34
|
+
payload.merge(params)
|
|
35
|
+
|
|
36
|
+
if stream && @config.gateway_id
|
|
37
|
+
chat_stream_gateway(endpoint, payload, model, &block)
|
|
38
|
+
else
|
|
39
|
+
chat_nonstream(endpoint, payload, model)
|
|
40
|
+
end
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
def list_models
|
|
44
|
+
# Workers AI lists models differently — rely on model catalog
|
|
45
|
+
[]
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def parse_error(response)
|
|
49
|
+
body = response.body rescue nil
|
|
50
|
+
body&.dig("errors", 0, "message") || body&.dig("error", "message")
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
class << self
|
|
54
|
+
def capabilities
|
|
55
|
+
{ chat: true, streaming: true, vision: true }
|
|
56
|
+
end
|
|
57
|
+
def configuration_options; %i[api_key account_id gateway_id]; end
|
|
58
|
+
def configuration_requirements; %i[api_key account_id]; end
|
|
59
|
+
def slug; "cloudflare"; end
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
private
|
|
63
|
+
|
|
64
|
+
def normalize_config(config)
|
|
65
|
+
return config unless config.is_a?(Hash)
|
|
66
|
+
Ask::LLM::Config.new(
|
|
67
|
+
api_key: config[:api_key] || config["api_key"] || config[:cloudflare_api_key],
|
|
68
|
+
account_id: config[:account_id] || config["account_id"] || config[:cf_account_id],
|
|
69
|
+
gateway_id: config[:gateway_id] || config["gateway_id"] || config[:cf_gateway_id]
|
|
70
|
+
)
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
def build_http
|
|
74
|
+
LLM::HTTP.connection(api_base, headers: headers, request: { open_timeout: 30, timeout: 120 })
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def chat_nonstream(endpoint, payload, model)
|
|
78
|
+
response = @http.post(endpoint) { |r| r.body = payload }
|
|
79
|
+
raise LLM::HTTP.map_error(response.status, response.body, provider: "Cloudflare") unless response.success?
|
|
80
|
+
|
|
81
|
+
body = response.body
|
|
82
|
+
if @config.gateway_id
|
|
83
|
+
parse_openai_response(body, model)
|
|
84
|
+
else
|
|
85
|
+
result = body["result"] || {}
|
|
86
|
+
Ask::Message.new(role: :assistant, content: result["response"], metadata: { model: model, raw: body })
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def parse_openai_response(body, model)
|
|
91
|
+
choice = body.dig("choices", 0)
|
|
92
|
+
return Ask::Message.new(role: :assistant, content: nil) unless choice
|
|
93
|
+
msg = choice["message"]
|
|
94
|
+
Ask::Message.new(role: :assistant, content: msg["content"], metadata: { model: model, finish_reason: choice["finish_reason"], raw: body })
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
def chat_stream_gateway(endpoint, payload, model, &block)
|
|
98
|
+
stream = Ask::Stream.new
|
|
99
|
+
response = @http.post(endpoint) do |req|
|
|
100
|
+
req.body = payload.merge(stream: true)
|
|
101
|
+
req.options.on_data = proc { |data, _bytes, _env| process_stream_chunk(data, stream, model, &block) }
|
|
102
|
+
end
|
|
103
|
+
raise LLM::HTTP.map_error(response.status, JSON.parse(response.body), provider: "Cloudflare") unless response.success?
|
|
104
|
+
stream.finish!
|
|
105
|
+
stream
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
def process_stream_chunk(raw, stream, model)
|
|
109
|
+
raw.each_line do |line|
|
|
110
|
+
line = line.strip
|
|
111
|
+
next unless line.start_with?("data: ")
|
|
112
|
+
data = line[6..]
|
|
113
|
+
next if data == "[DONE]"
|
|
114
|
+
parsed = JSON.parse(data) rescue next
|
|
115
|
+
delta = parsed.dig("choices", 0, "delta") || {}
|
|
116
|
+
chunk = Ask::Chunk.new(content: delta["content"])
|
|
117
|
+
stream.add(chunk)
|
|
118
|
+
yield chunk if block_given?
|
|
119
|
+
end
|
|
120
|
+
end
|
|
121
|
+
end
|
|
122
|
+
end
|
|
123
|
+
end
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ask
|
|
4
|
+
module Providers
|
|
5
|
+
# Google Gemini API provider. Also supports Vertex AI via GCP service account auth.
|
|
6
|
+
class Google < Ask::Provider
|
|
7
|
+
def initialize(config = {})
|
|
8
|
+
config = normalize_config(config)
|
|
9
|
+
super(config)
|
|
10
|
+
@http = build_http
|
|
11
|
+
@project_id = config.project_id
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def api_base
|
|
15
|
+
@config.api_base || "https://generativelanguage.googleapis.com/v1beta"
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def headers
|
|
19
|
+
h = { "Content-Type" => "application/json" }
|
|
20
|
+
if @config.api_key
|
|
21
|
+
# Gemini uses query param auth by default
|
|
22
|
+
elsif @config.access_token
|
|
23
|
+
h["Authorization"] = "Bearer #{@config.access_token}"
|
|
24
|
+
elsif @config.vertex_token
|
|
25
|
+
h["Authorization"] = "Bearer #{@config.vertex_token}"
|
|
26
|
+
end
|
|
27
|
+
h
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
def chat(messages, model:, tools: nil, temperature: nil, stream: nil, schema: nil, **params, &block)
|
|
31
|
+
msgs = messages.is_a?(Ask::Conversation) ? messages.to_a : messages
|
|
32
|
+
payload = build_chat_payload(msgs, model, tools, temperature, stream, schema, **params)
|
|
33
|
+
path = chat_path(model)
|
|
34
|
+
if stream
|
|
35
|
+
chat_stream(path, payload, model, &block)
|
|
36
|
+
else
|
|
37
|
+
chat_nonstream(path, payload, model)
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def embed(texts, model:)
|
|
42
|
+
texts = Array(texts)
|
|
43
|
+
response = @http.post("models/#{model}:batchEmbedContents") { |r| r.body = { requests: texts.map { |t| { model: "models/#{model}", content: { parts: [{ text: t }] } } } } }
|
|
44
|
+
raise LLM::HTTP.map_error(response.status, response.body, provider: "Google") unless response.success?
|
|
45
|
+
embeddings = response.body.dig("embeddings") || []
|
|
46
|
+
Ask::Result.success(embeddings.map { |e| e["values"] })
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
def list_models
|
|
50
|
+
response = @http.get("models") { |r| r.params["key"] = @config.api_key if @config.api_key }
|
|
51
|
+
return [] unless response.success?
|
|
52
|
+
(response.body["models"] || []).map { |m| Ask::ModelInfo.new(id: m["name"].sub("models/", ""), provider: slug) }
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def parse_error(response)
|
|
56
|
+
body = response.body rescue nil
|
|
57
|
+
body&.dig("error", "message")
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
class << self
|
|
61
|
+
def capabilities
|
|
62
|
+
{ chat: true, streaming: true, tool_calls: true, vision: true, structured_output: true, embed: true, file_upload: true }
|
|
63
|
+
end
|
|
64
|
+
def configuration_options; %i[api_key access_token vertex_token project_id api_base]; end
|
|
65
|
+
def configuration_requirements; %i[api_key]; end
|
|
66
|
+
def slug; "gemini"; end
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
private
|
|
70
|
+
|
|
71
|
+
def normalize_config(config)
|
|
72
|
+
return config unless config.is_a?(Hash)
|
|
73
|
+
key = config[:api_key] || config["api_key"] || config[:gemini_api_key]
|
|
74
|
+
Ask::LLM::Config.new(
|
|
75
|
+
api_key: key,
|
|
76
|
+
access_token: config[:access_token] || config["access_token"],
|
|
77
|
+
vertex_token: config[:vertex_token] || config["vertex_token"],
|
|
78
|
+
project_id: config[:project_id] || config["project_id"],
|
|
79
|
+
api_base: config[:api_base] || config["api_base"]
|
|
80
|
+
)
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
def build_http
|
|
84
|
+
LLM::HTTP.connection(api_base, headers: headers, request: { open_timeout: 30, timeout: 120 })
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
def chat_path(model)
|
|
88
|
+
model_id = model.respond_to?(:id) ? model.id : model.to_s
|
|
89
|
+
"models/#{model_id}:generateContent"
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
def build_chat_payload(messages, model, tools, temperature, stream, schema, **params)
|
|
93
|
+
contents = format_contents(messages)
|
|
94
|
+
payload = { contents: contents, systemInstruction: format_system(messages) }
|
|
95
|
+
|
|
96
|
+
if tools&.any?
|
|
97
|
+
payload[:tools] = [{ functionDeclarations: tools.map { |t| format_tool(t) } }]
|
|
98
|
+
end
|
|
99
|
+
if schema
|
|
100
|
+
payload[:generationConfig] ||= {}
|
|
101
|
+
payload[:generationConfig][:response_mime_type] = "application/json"
|
|
102
|
+
payload[:generationConfig][:response_schema] = schema
|
|
103
|
+
end
|
|
104
|
+
payload[:generationConfig] ||= {}
|
|
105
|
+
payload[:generationConfig][:temperature] = temperature if temperature
|
|
106
|
+
payload.merge(params)
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def format_contents(messages)
|
|
110
|
+
messages.reject { |m| (m[:role] || m["role"]).to_s == "system" }.map { |m| format_content(m) }
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
def format_system(messages)
|
|
114
|
+
sys = messages.select { |m| (m[:role] || m["role"]).to_s == "system" }
|
|
115
|
+
return nil if sys.empty?
|
|
116
|
+
texts = sys.map { |m| m[:content] || m["content"] }.compact
|
|
117
|
+
return nil if texts.empty?
|
|
118
|
+
{ parts: texts.map { |t| { text: t } } }
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
def format_content(msg)
|
|
122
|
+
role = (msg[:role] || msg["role"]).to_s
|
|
123
|
+
content = msg[:content] || msg["content"]
|
|
124
|
+
google_role = role == "assistant" ? "model" : role
|
|
125
|
+
|
|
126
|
+
parts = []
|
|
127
|
+
parts << { text: content } if content
|
|
128
|
+
|
|
129
|
+
# Handle tool calls
|
|
130
|
+
if msg[:tool_calls] || msg["tool_calls"]
|
|
131
|
+
(msg[:tool_calls] || msg["tool_calls"]).each do |tc|
|
|
132
|
+
parts << {
|
|
133
|
+
functionCall: {
|
|
134
|
+
name: tc.dig(:function, :name) || tc.dig("function", "name") || tc[:name],
|
|
135
|
+
args: parse_json(tc.dig(:function, :arguments) || tc.dig("function", "arguments") || tc[:arguments] || "{}")
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
end
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
# Handle tool results
|
|
142
|
+
if msg[:tool_call_id] || msg["tool_call_id"]
|
|
143
|
+
parts << {
|
|
144
|
+
functionResponse: {
|
|
145
|
+
name: msg[:name] || msg["name"] || "function",
|
|
146
|
+
response: { content: content || "" }
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
{ role: google_role, parts: parts }
|
|
152
|
+
end
|
|
153
|
+
|
|
154
|
+
def format_tool(t)
|
|
155
|
+
{ name: t.respond_to?(:name) ? t.name : t[:name], description: t.respond_to?(:description) ? t.description : t[:description], parameters: t.respond_to?(:parameters) ? t.parameters : (t[:parameters] || {}) }
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
def parse_json(str)
|
|
159
|
+
JSON.parse(str)
|
|
160
|
+
rescue JSON::ParserError
|
|
161
|
+
{}
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
def chat_nonstream(path, payload, model)
|
|
165
|
+
response = @http.post(path) do |req|
|
|
166
|
+
req.body = payload
|
|
167
|
+
req.params["key"] = @config.api_key if @config.api_key
|
|
168
|
+
end
|
|
169
|
+
raise LLM::HTTP.map_error(response.status, response.body, provider: "Google") unless response.success?
|
|
170
|
+
parse_response(response.body, model)
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
def parse_response(body, model)
|
|
174
|
+
candidate = body.dig("candidates", 0)
|
|
175
|
+
return Ask::Message.new(role: :assistant, content: nil) unless candidate
|
|
176
|
+
|
|
177
|
+
content = candidate.dig("content", "parts")&.map { |p| p["text"] }&.compact&.join
|
|
178
|
+
fc = candidate.dig("content", "parts")&.select { |p| p["functionCall"] } || []
|
|
179
|
+
tool_calls = fc.map do |p|
|
|
180
|
+
f = p["functionCall"]
|
|
181
|
+
{ id: SecureRandom.hex(8), type: "function", name: f["name"], arguments: JSON.generate(f["args"] || {}) }
|
|
182
|
+
end
|
|
183
|
+
|
|
184
|
+
usage = body["usageMetadata"] || {}
|
|
185
|
+
Ask::Message.new(role: :assistant, content: content, tool_calls: tool_calls.empty? ? nil : tool_calls, metadata: { model: model, finish_reason: candidate["finishReason"], input_tokens: usage["promptTokenCount"], output_tokens: usage["candidatesTokenCount"], raw: body })
|
|
186
|
+
end
|
|
187
|
+
|
|
188
|
+
def chat_stream(path, payload, model, &block)
|
|
189
|
+
stream = Ask::Stream.new
|
|
190
|
+
response = @http.post(path) do |req|
|
|
191
|
+
req.body = payload
|
|
192
|
+
req.params["key"] = @config.api_key if @config.api_key
|
|
193
|
+
req.options.on_data = proc { |data, _bytes, _env| process_google_chunk(data, stream, model, &block) }
|
|
194
|
+
end
|
|
195
|
+
raise LLM::HTTP.map_error(response.status, JSON.parse(response.body), provider: "Google") unless response.success?
|
|
196
|
+
stream.finish!
|
|
197
|
+
stream
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
def process_google_chunk(raw, stream, model)
|
|
201
|
+
raw.each_line do |line|
|
|
202
|
+
next unless line.start_with?("data: ")
|
|
203
|
+
data = line[6..]
|
|
204
|
+
next if data.strip == "[DONE]"
|
|
205
|
+
parsed = JSON.parse(data) rescue next
|
|
206
|
+
candidate = parsed.dig("candidates", 0) or next
|
|
207
|
+
part = candidate.dig("content", "parts", 0)
|
|
208
|
+
next unless part
|
|
209
|
+
chunk = Ask::Chunk.new(content: part["text"])
|
|
210
|
+
stream.add(chunk)
|
|
211
|
+
yield chunk if block_given?
|
|
212
|
+
end
|
|
213
|
+
end
|
|
214
|
+
end
|
|
215
|
+
end
|
|
216
|
+
end
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ask
|
|
4
|
+
module Providers
|
|
5
|
+
# Mistral AI provider. Uses OpenAI-compatible wire format.
|
|
6
|
+
class Mistral < Ask::Provider
|
|
7
|
+
def initialize(config = {})
|
|
8
|
+
config = normalize_config(config)
|
|
9
|
+
super(config)
|
|
10
|
+
@http = build_http
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def api_base
|
|
14
|
+
@config.api_base || "https://api.mistral.ai/v1"
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def headers
|
|
18
|
+
{ "Content-Type" => "application/json", "Authorization" => "Bearer #{@config.api_key}" }.compact
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def chat(messages, model:, tools: nil, temperature: nil, stream: nil, schema: nil, **params, &block)
|
|
22
|
+
# Reuse OpenAI provider's logic since Mistral is OpenAI-compatible
|
|
23
|
+
openai = Providers::OpenAI.new(api_key: @config.api_key, base_url: api_base)
|
|
24
|
+
openai.chat(messages, model: model, tools: tools, temperature: temperature, stream: stream, schema: schema, **params, &block)
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def embed(texts, model:)
|
|
28
|
+
texts = Array(texts)
|
|
29
|
+
response = @http.post("embeddings") { |r| r.body = { model: model, input: texts } }
|
|
30
|
+
raise LLM::HTTP.map_error(response.status, response.body, provider: "Mistral") unless response.success?
|
|
31
|
+
embeddings = response.body["data"].map { |d| d["embedding"] }
|
|
32
|
+
Ask::Result.success(embeddings.one? ? embeddings.first : embeddings)
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def list_models
|
|
36
|
+
response = @http.get("models")
|
|
37
|
+
return [] unless response.success?
|
|
38
|
+
response.body["data"].map { |m| Ask::ModelInfo.new(id: m["id"], provider: slug) }
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def parse_error(response)
|
|
42
|
+
body = response.body rescue nil
|
|
43
|
+
body&.dig("error", "message") || body&.dig("error", "type")
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
class << self
|
|
47
|
+
def capabilities
|
|
48
|
+
{ chat: true, streaming: true, tool_calls: true, structured_output: true, embed: true }
|
|
49
|
+
end
|
|
50
|
+
def configuration_options; %i[api_key api_base]; end
|
|
51
|
+
def configuration_requirements; %i[api_key]; end
|
|
52
|
+
def slug; "mistral"; end
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
private
|
|
56
|
+
|
|
57
|
+
def normalize_config(config)
|
|
58
|
+
return config unless config.is_a?(Hash)
|
|
59
|
+
Ask::LLM::Config.new(
|
|
60
|
+
api_key: config[:api_key] || config["api_key"] || config[:mistral_api_key],
|
|
61
|
+
api_base: config[:api_base] || config["api_base"]
|
|
62
|
+
)
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
def build_http
|
|
66
|
+
LLM::HTTP.connection(api_base, headers: headers, request: { open_timeout: 30, timeout: 120 })
|
|
67
|
+
end
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
end
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ask
|
|
4
|
+
module Providers
|
|
5
|
+
# Ollama provider for local LLM inference.
|
|
6
|
+
# Connects to a local Ollama server (default: http://localhost:11434).
|
|
7
|
+
class Ollama < Ask::Provider
|
|
8
|
+
def initialize(config = {})
|
|
9
|
+
config = normalize_config(config)
|
|
10
|
+
super(config)
|
|
11
|
+
@http = build_http
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def api_base
|
|
15
|
+
@config.api_base || "http://localhost:11434"
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def headers
|
|
19
|
+
{ "Content-Type" => "application/json" }
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def chat(messages, model:, tools: nil, temperature: nil, stream: nil, schema: nil, **params, &block)
|
|
23
|
+
msgs = messages.is_a?(Ask::Conversation) ? messages.to_a : messages
|
|
24
|
+
payload = { model: model, messages: msgs.map { |m| { role: (m[:role] || m["role"]).to_s, content: m[:content] || m["content"] } }, stream: stream || false, options: {} }
|
|
25
|
+
payload[:options][:temperature] = temperature if temperature
|
|
26
|
+
if tools&.any?
|
|
27
|
+
payload[:tools] = tools.map { |t| { type: "function", function: { name: t.respond_to?(:name) ? t.name : t[:name], description: t.respond_to?(:description) ? t.description : t[:description], parameters: t.respond_to?(:parameters) ? t.parameters : (t[:parameters] || {}) } } }
|
|
28
|
+
end
|
|
29
|
+
payload.merge(params)
|
|
30
|
+
|
|
31
|
+
if stream
|
|
32
|
+
chat_stream(payload, model, &block)
|
|
33
|
+
else
|
|
34
|
+
chat_nonstream(payload, model)
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
def embed(texts, model:)
|
|
39
|
+
texts = Array(texts)
|
|
40
|
+
response = @http.post("api/embeddings") { |r| r.body = { model: model, prompt: texts.first } }
|
|
41
|
+
raise LLM::HTTP.map_error(response.status, response.body, provider: "Ollama") unless response.success?
|
|
42
|
+
Ask::Result.success(response.body["embedding"])
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def list_models
|
|
46
|
+
response = @http.get("api/tags")
|
|
47
|
+
return [] unless response.success?
|
|
48
|
+
response.body["models"].map { |m| Ask::ModelInfo.new(id: m["name"], provider: slug) }
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
class << self
|
|
52
|
+
def capabilities
|
|
53
|
+
{ chat: true, streaming: true, tool_calls: true, embed: true, local: true }
|
|
54
|
+
end
|
|
55
|
+
def configuration_options; %i[api_base]; end
|
|
56
|
+
def configuration_requirements; %i[]; end
|
|
57
|
+
def slug; "ollama"; end
|
|
58
|
+
def local?; true; end
|
|
59
|
+
def assume_models_exist?; true; end
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
private
|
|
63
|
+
|
|
64
|
+
def normalize_config(config)
|
|
65
|
+
return config unless config.is_a?(Hash)
|
|
66
|
+
Ask::LLM::Config.new(api_base: config[:api_base] || config["api_base"])
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
def build_http
|
|
70
|
+
LLM::HTTP.connection(api_base, headers: headers, request: { open_timeout: 5, timeout: 600 })
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
def chat_nonstream(payload, model)
|
|
74
|
+
response = @http.post("api/chat") { |r| r.body = payload }
|
|
75
|
+
raise LLM::HTTP.map_error(response.status, response.body, provider: "Ollama") unless response.success?
|
|
76
|
+
msg = response.body["message"] || {}
|
|
77
|
+
Ask::Message.new(role: :assistant, content: msg["content"], metadata: { model: response.body["model"] || model, done: response.body["done"], total_duration: response.body["total_duration"], raw: response.body })
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
def chat_stream(payload, model, &block)
|
|
81
|
+
stream = Ask::Stream.new
|
|
82
|
+
response = @http.post("api/chat") do |req|
|
|
83
|
+
req.body = payload.merge(stream: true)
|
|
84
|
+
req.options.on_data = proc { |data, _bytes, _env| process_ollama_chunk(data, stream, model, &block) }
|
|
85
|
+
end
|
|
86
|
+
raise LLM::HTTP.map_error(response.status, response.body, provider: "Ollama") unless response.success?
|
|
87
|
+
stream.finish!
|
|
88
|
+
stream
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
def process_ollama_chunk(raw, stream, model)
|
|
92
|
+
raw.each_line do |line|
|
|
93
|
+
parsed = JSON.parse(line) rescue next
|
|
94
|
+
msg = parsed["message"] || {}
|
|
95
|
+
chunk = Ask::Chunk.new(content: msg["content"])
|
|
96
|
+
stream.add(chunk)
|
|
97
|
+
yield chunk if block_given?
|
|
98
|
+
if parsed["done"]
|
|
99
|
+
chunk = Ask::Chunk.new(finish_reason: "stop", usage: { total_duration: parsed["total_duration"] })
|
|
100
|
+
stream.add(chunk)
|
|
101
|
+
yield chunk if block_given?
|
|
102
|
+
end
|
|
103
|
+
end
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
end
|
|
107
|
+
end
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ask
|
|
4
|
+
module Providers
|
|
5
|
+
# OpenAI API provider. Also handles all OpenAI-compatible providers
|
|
6
|
+
# (OpenRouter, DeepSeek, Azure, XAI, Perplexity, GPUStack, etc.) via
|
|
7
|
+
# +base_url+ override.
|
|
8
|
+
class OpenAI < Ask::Provider
|
|
9
|
+
def initialize(config = {})
|
|
10
|
+
config = normalize_config(config)
|
|
11
|
+
super(config)
|
|
12
|
+
@http = build_http
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def api_base
|
|
16
|
+
@config.base_url || "https://api.openai.com/v1"
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
def headers
|
|
20
|
+
key = @config.api_key || @config.openai_api_key
|
|
21
|
+
h = { "Content-Type" => "application/json" }
|
|
22
|
+
h["Authorization"] = "Bearer #{key}" if key
|
|
23
|
+
h["OpenAI-Organization"] = @config.organization_id if @config.organization_id
|
|
24
|
+
h["OpenAI-Project"] = @config.project_id if @config.project_id
|
|
25
|
+
h
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def chat(messages, model:, tools: nil, temperature: nil, stream: nil, schema: nil, **params, &block)
|
|
29
|
+
msgs = messages.is_a?(Ask::Conversation) ? messages.to_a : messages
|
|
30
|
+
payload = build_chat_payload(msgs, model, tools, temperature, stream, schema, **params)
|
|
31
|
+
stream ? chat_stream(payload, model, &block) : chat_nonstream(payload, model)
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
def embed(texts, model:)
|
|
35
|
+
texts = Array(texts)
|
|
36
|
+
response = @http.post("embeddings") { |r| r.body = { model: model, input: texts } }
|
|
37
|
+
raise LLM::HTTP.map_error(response.status, response.body, provider: "OpenAI") unless response.success?
|
|
38
|
+
embeddings = response.body["data"].map { |d| d["embedding"] }
|
|
39
|
+
Ask::Result.success(embeddings.one? ? embeddings.first : embeddings)
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
def list_models
|
|
43
|
+
response = @http.get("models")
|
|
44
|
+
return [] unless response.success?
|
|
45
|
+
response.body["data"].map { |m| Ask::ModelInfo.new(id: m["id"], provider: slug, metadata: { owned_by: m["owned_by"] }) }
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def parse_error(response)
|
|
49
|
+
body = response.body rescue nil
|
|
50
|
+
body&.dig("error", "message") || body&.dig("error", "code")
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
class << self
|
|
54
|
+
def slug; "openai"; end
|
|
55
|
+
def capabilities
|
|
56
|
+
{ chat: true, streaming: true, tool_calls: true, vision: true, thinking: true, structured_output: true, embed: true, transcribe: true, paint: true, moderate: true }
|
|
57
|
+
end
|
|
58
|
+
def configuration_options; %i[api_key base_url organization_id project_id]; end
|
|
59
|
+
def configuration_requirements; %i[api_key]; end
|
|
60
|
+
def configured?(config)
|
|
61
|
+
(config.respond_to?(:api_key) && !config.api_key.to_s.empty?) ||
|
|
62
|
+
(config.respond_to?(:openai_api_key) && !config.openai_api_key.to_s.empty?)
|
|
63
|
+
end
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
private
|
|
67
|
+
|
|
68
|
+
def normalize_config(config)
|
|
69
|
+
return config if !config.is_a?(Hash)
|
|
70
|
+
Ask::LLM::Config.new(
|
|
71
|
+
api_key: config[:api_key] || config["api_key"] || config[:openai_api_key],
|
|
72
|
+
base_url: config[:base_url] || config["base_url"],
|
|
73
|
+
organization_id: config[:organization_id] || config["organization_id"],
|
|
74
|
+
project_id: config[:project_id] || config["project_id"]
|
|
75
|
+
)
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
def build_http
|
|
79
|
+
LLM::HTTP.connection(api_base, headers: headers, request: { open_timeout: 30, timeout: 120 })
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
def build_chat_payload(messages, model, tools, temperature, stream, schema, **params)
|
|
83
|
+
payload = { model: model, messages: format_messages(messages), stream: stream || false }
|
|
84
|
+
payload[:temperature] = temperature if temperature
|
|
85
|
+
payload[:tools] = format_tools(tools) if tools&.any?
|
|
86
|
+
payload[:response_format] = { type: "json_schema", json_schema: { name: "response", schema: schema, strict: true } } if schema
|
|
87
|
+
payload.merge(params)
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def format_messages(messages)
|
|
91
|
+
messages.map do |msg|
|
|
92
|
+
role = msg[:role] || msg["role"] || :user
|
|
93
|
+
{ role: role.to_s, content: msg[:content] || msg["content"] }.tap do |fm|
|
|
94
|
+
if (tc = msg[:tool_calls] || msg["tool_calls"])
|
|
95
|
+
fm[:tool_calls] = tc.map { |t| { id: t[:id] || t["id"], type: "function", function: { name: t.dig(:function, :name) || t.dig("function", "name") || t[:name], arguments: t.dig(:function, :arguments) || t.dig("function", "arguments") || t[:arguments] } } }
|
|
96
|
+
end
|
|
97
|
+
fm[:tool_call_id] = msg[:tool_call_id] || msg["tool_call_id"] if msg[:tool_call_id] || msg["tool_call_id"]
|
|
98
|
+
end.compact
|
|
99
|
+
end
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
def format_tools(tools)
|
|
103
|
+
tools.map { |t| { type: "function", function: { name: t.respond_to?(:name) ? t.name : t[:name], description: t.respond_to?(:description) ? t.description : t[:description], parameters: t.respond_to?(:parameters) ? t.parameters : t[:parameters] } } }
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
def chat_nonstream(payload, model)
|
|
107
|
+
response = @http.post("chat/completions") { |r| r.body = payload }
|
|
108
|
+
raise LLM::HTTP.map_error(response.status, response.body, provider: "OpenAI") unless response.success?
|
|
109
|
+
parse_response(response.body, model)
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
def parse_response(body, model)
|
|
113
|
+
choice = body.dig("choices", 0)
|
|
114
|
+
return Ask::Message.new(role: :assistant, content: nil) unless choice
|
|
115
|
+
msg = choice["message"]
|
|
116
|
+
usage = body["usage"] || {}
|
|
117
|
+
Ask::Message.new(role: :assistant, content: msg["content"], tool_calls: parse_tool_calls(msg["tool_calls"]), metadata: { model: body["model"] || model, finish_reason: choice["finish_reason"], input_tokens: usage["prompt_tokens"], output_tokens: usage["completion_tokens"], raw: body })
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
def parse_tool_calls(calls)
|
|
121
|
+
return nil unless calls&.any?
|
|
122
|
+
calls.map { |tc| { id: tc["id"], type: "function", name: tc.dig("function", "name"), arguments: tc.dig("function", "arguments") } }
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
def chat_stream(payload, model, &block)
|
|
126
|
+
stream = Ask::Stream.new
|
|
127
|
+
@http.post("chat/completions") do |req|
|
|
128
|
+
req.body = payload.merge(stream: true)
|
|
129
|
+
req.options.on_data = proc { |data, _bytes, _env| process_chunk(data, stream, model, &block) }
|
|
130
|
+
end.tap { |resp| raise LLM::HTTP.map_error(resp.status, JSON.parse(resp.body), provider: "OpenAI") unless resp.success? }
|
|
131
|
+
stream.finish!
|
|
132
|
+
stream
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
def process_chunk(raw, stream, model)
|
|
136
|
+
raw.each_line do |line|
|
|
137
|
+
line = line.strip
|
|
138
|
+
next if line.empty? || line.start_with?(":") || !line.start_with?("data: ")
|
|
139
|
+
data = line[6..]; next if data == "[DONE]"
|
|
140
|
+
parsed = JSON.parse(data) rescue next
|
|
141
|
+
choice = parsed.dig("choices", 0) or next
|
|
142
|
+
delta = choice["delta"] || {}
|
|
143
|
+
chunk = Ask::Chunk.new(content: delta["content"], tool_calls: parse_stream_tool_calls(delta["tool_calls"]), finish_reason: choice["finish_reason"], usage: parsed["usage"])
|
|
144
|
+
stream.add(chunk)
|
|
145
|
+
yield chunk if block_given?
|
|
146
|
+
end
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
def parse_stream_tool_calls(calls)
|
|
150
|
+
return nil unless calls&.any?
|
|
151
|
+
calls.map { |tc| { id: tc["id"], name: tc.dig("function", "name"), arguments: tc.dig("function", "arguments"), index: tc["index"] } }
|
|
152
|
+
end
|
|
153
|
+
end
|
|
154
|
+
end
|
|
155
|
+
end
|