ask-core 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE +21 -0
- data/README.md +284 -0
- data/lib/ask/conversation.rb +235 -0
- data/lib/ask/errors.rb +66 -0
- data/lib/ask/models.rb +438 -0
- data/lib/ask/provider.rb +232 -0
- data/lib/ask/result.rb +109 -0
- data/lib/ask/stream.rb +123 -0
- data/lib/ask/tool_def.rb +114 -0
- data/lib/ask/version.rb +5 -0
- data/lib/ask.rb +23 -0
- metadata +95 -0
data/lib/ask/models.rb
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
require "net/http"
|
|
5
|
+
require "time"
|
|
6
|
+
|
|
7
|
+
module Ask
|
|
8
|
+
# Model metadata: capabilities, pricing, context window, modalities.
|
|
9
|
+
# Immutable value object representing a single model entry.
|
|
10
|
+
class ModelInfo
|
|
11
|
+
# @return [String] model identifier (e.g. "gpt-4o")
|
|
12
|
+
attr_reader :id
|
|
13
|
+
|
|
14
|
+
# @return [String] human-readable model name
|
|
15
|
+
attr_reader :name
|
|
16
|
+
|
|
17
|
+
# @return [String] provider slug (e.g. "openai")
|
|
18
|
+
attr_reader :provider
|
|
19
|
+
|
|
20
|
+
# @return [String, nil] model family (e.g. "gpt", "claude")
|
|
21
|
+
attr_reader :family
|
|
22
|
+
|
|
23
|
+
# @return [Array<String>] capability strings
|
|
24
|
+
attr_reader :capabilities
|
|
25
|
+
|
|
26
|
+
# @return [Integer, nil] maximum context window in tokens
|
|
27
|
+
attr_reader :context_window
|
|
28
|
+
|
|
29
|
+
# @return [Integer, nil] maximum output tokens
|
|
30
|
+
attr_reader :max_output_tokens
|
|
31
|
+
|
|
32
|
+
# @return [Hash] input/output modalities
|
|
33
|
+
attr_reader :modalities
|
|
34
|
+
|
|
35
|
+
# @return [Hash] pricing information
|
|
36
|
+
attr_reader :pricing
|
|
37
|
+
|
|
38
|
+
# @return [Date, nil] knowledge cutoff date
|
|
39
|
+
attr_reader :knowledge_cutoff
|
|
40
|
+
|
|
41
|
+
# @return [String, nil] creation/publication date
|
|
42
|
+
attr_reader :created_at
|
|
43
|
+
|
|
44
|
+
# @return [Hash] additional metadata
|
|
45
|
+
attr_reader :metadata
|
|
46
|
+
|
|
47
|
+
def initialize(id:, name: nil, provider:, family: nil, capabilities: [],
|
|
48
|
+
context_window: nil, max_output_tokens: nil,
|
|
49
|
+
modalities: {}, pricing: {}, knowledge_cutoff: nil,
|
|
50
|
+
created_at: nil, metadata: {})
|
|
51
|
+
@id = id
|
|
52
|
+
@name = name || id
|
|
53
|
+
@provider = provider.to_s
|
|
54
|
+
@family = family
|
|
55
|
+
@capabilities = Array(capabilities).map(&:to_s)
|
|
56
|
+
@context_window = context_window
|
|
57
|
+
@max_output_tokens = max_output_tokens
|
|
58
|
+
@modalities = modalities
|
|
59
|
+
@pricing = pricing
|
|
60
|
+
@knowledge_cutoff = knowledge_cutoff
|
|
61
|
+
@created_at = created_at
|
|
62
|
+
@metadata = metadata
|
|
63
|
+
freeze
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# @return [Boolean] true if this is a chat model
|
|
67
|
+
def chat? = type == "chat"
|
|
68
|
+
|
|
69
|
+
# @return [Boolean] true if this is an embedding model
|
|
70
|
+
def embedding? = type == "embedding" || modalities.dig(:output)&.include?("embeddings")
|
|
71
|
+
|
|
72
|
+
# @return [Boolean] true if this model supports audio output
|
|
73
|
+
def audio? = modalities.dig(:output)&.include?("audio")
|
|
74
|
+
|
|
75
|
+
# @return [Boolean] true if this model supports image output
|
|
76
|
+
def image? = modalities.dig(:output)&.include?("image")
|
|
77
|
+
|
|
78
|
+
# Check if this model supports a given capability.
|
|
79
|
+
# @param capability [String, Symbol] capability name
|
|
80
|
+
# @return [Boolean]
|
|
81
|
+
def supports?(capability)
|
|
82
|
+
capabilities.include?(capability.to_s)
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
# @return [String] model type ("chat", "embedding", "audio", "image")
|
|
86
|
+
def type
|
|
87
|
+
@metadata[:type] || infer_type
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
# @return [Hash] serialized model info
|
|
91
|
+
def to_h
|
|
92
|
+
{
|
|
93
|
+
id: @id,
|
|
94
|
+
name: @name,
|
|
95
|
+
provider: @provider,
|
|
96
|
+
family: @family,
|
|
97
|
+
capabilities: @capabilities,
|
|
98
|
+
context_window: @context_window,
|
|
99
|
+
max_output_tokens: @max_output_tokens,
|
|
100
|
+
modalities: @modalities,
|
|
101
|
+
pricing: @pricing,
|
|
102
|
+
knowledge_cutoff: @knowledge_cutoff,
|
|
103
|
+
created_at: @created_at,
|
|
104
|
+
metadata: @metadata
|
|
105
|
+
}.compact
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
# @return [String]
|
|
109
|
+
def inspect
|
|
110
|
+
"#<Ask::ModelInfo id=#{@id.inspect} provider=#{@provider.inspect}>"
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
private
|
|
114
|
+
|
|
115
|
+
def infer_type
|
|
116
|
+
return "chat" if supports?(:function_calling) || supports?(:structured_output)
|
|
117
|
+
return "embedding" if @id.to_s.include?("embedding")
|
|
118
|
+
return "audio" if audio?
|
|
119
|
+
return "image" if image?
|
|
120
|
+
"chat"
|
|
121
|
+
end
|
|
122
|
+
end
|
|
123
|
+
|
|
124
|
+
# Parses raw models.dev API response JSON into {ModelInfo} objects.
|
|
125
|
+
# Extracted into a module for independent unit testing.
|
|
126
|
+
module ModelsDevParser
|
|
127
|
+
# Maps models.dev provider keys to ask-rb provider slugs.
|
|
128
|
+
PROVIDER_MAP = {
|
|
129
|
+
"openai" => "openai",
|
|
130
|
+
"anthropic" => "anthropic",
|
|
131
|
+
"google" => "gemini",
|
|
132
|
+
"google-vertex" => "vertexai",
|
|
133
|
+
"amazon-bedrock" => "bedrock",
|
|
134
|
+
"deepseek" => "deepseek",
|
|
135
|
+
"mistral" => "mistral",
|
|
136
|
+
"openrouter" => "openrouter",
|
|
137
|
+
"perplexity" => "perplexity",
|
|
138
|
+
"xai" => "xai",
|
|
139
|
+
"github" => "github"
|
|
140
|
+
}.freeze
|
|
141
|
+
|
|
142
|
+
INPUT_MODALITIES = %w[text image audio pdf video file].freeze
|
|
143
|
+
OUTPUT_MODALITIES = %w[text image audio video embeddings moderation].freeze
|
|
144
|
+
|
|
145
|
+
module_function
|
|
146
|
+
|
|
147
|
+
# Parse a raw models.dev API response into {ModelInfo} objects.
|
|
148
|
+
# @param api_response [Hash] the parsed JSON from models.dev/api.json
|
|
149
|
+
# @return [Array<Ask::ModelInfo>]
|
|
150
|
+
def parse(api_response)
|
|
151
|
+
api_response.flat_map do |provider_key, provider_data|
|
|
152
|
+
provider_slug = PROVIDER_MAP[provider_key.to_s]
|
|
153
|
+
next [] unless provider_slug
|
|
154
|
+
|
|
155
|
+
models_data = provider_data.dig("models") || {}
|
|
156
|
+
models_data.values.map do |model_data|
|
|
157
|
+
build_model(model_data, provider_slug, provider_key.to_s)
|
|
158
|
+
end
|
|
159
|
+
end.compact
|
|
160
|
+
end
|
|
161
|
+
|
|
162
|
+
# Build a {ModelInfo} from a single model entry in the models.dev response.
|
|
163
|
+
# @param model_data [Hash] the model data from the API
|
|
164
|
+
# @param provider_slug [String] normalized provider slug
|
|
165
|
+
# @param provider_key [String] original provider key from the API
|
|
166
|
+
# @return [Ask::ModelInfo]
|
|
167
|
+
def build_model(model_data, provider_slug, provider_key)
|
|
168
|
+
modalities = normalize_modalities(model_data["modalities"])
|
|
169
|
+
capabilities = extract_capabilities(model_data, modalities)
|
|
170
|
+
pricing = build_pricing(model_data["cost"])
|
|
171
|
+
created_date = [model_data["release_date"], model_data["last_updated"]]
|
|
172
|
+
.find { |v| v && !v.to_s.strip.empty? }
|
|
173
|
+
|
|
174
|
+
ModelInfo.new(
|
|
175
|
+
id: model_data["id"],
|
|
176
|
+
name: model_data["name"] || model_data["id"],
|
|
177
|
+
provider: provider_slug,
|
|
178
|
+
family: model_data["family"],
|
|
179
|
+
capabilities: capabilities,
|
|
180
|
+
context_window: model_data.dig("limit", "context"),
|
|
181
|
+
max_output_tokens: model_data.dig("limit", "output"),
|
|
182
|
+
modalities: modalities,
|
|
183
|
+
pricing: pricing,
|
|
184
|
+
knowledge_cutoff: parse_date(model_data["knowledge"]),
|
|
185
|
+
created_at: created_date,
|
|
186
|
+
metadata: {
|
|
187
|
+
source: "models.dev",
|
|
188
|
+
provider_id: provider_key,
|
|
189
|
+
open_weights: model_data["open_weights"],
|
|
190
|
+
status: model_data["status"],
|
|
191
|
+
reasoning_options: model_data["reasoning_options"]
|
|
192
|
+
}.compact
|
|
193
|
+
)
|
|
194
|
+
end
|
|
195
|
+
|
|
196
|
+
# @param modalities [Hash, nil] raw modalities hash
|
|
197
|
+
# @return [Hash{Symbol => Array<String>}] normalized with known modality filters
|
|
198
|
+
def normalize_modalities(modalities)
|
|
199
|
+
return { input: [], output: [] } unless modalities
|
|
200
|
+
|
|
201
|
+
{
|
|
202
|
+
input: Array(modalities["input"]).compact & INPUT_MODALITIES,
|
|
203
|
+
output: Array(modalities["output"]).compact & OUTPUT_MODALITIES
|
|
204
|
+
}
|
|
205
|
+
end
|
|
206
|
+
|
|
207
|
+
# Extract capability strings from model data and modalities.
|
|
208
|
+
# @param model_data [Hash] raw model data
|
|
209
|
+
# @param modalities [Hash] normalized modalities
|
|
210
|
+
# @return [Array<String>]
|
|
211
|
+
def extract_capabilities(model_data, modalities)
|
|
212
|
+
caps = []
|
|
213
|
+
caps << "function_calling" if model_data["tool_call"]
|
|
214
|
+
caps << "structured_output" if model_data["structured_output"]
|
|
215
|
+
caps << "reasoning" if model_data["reasoning"] || model_data["reasoning_options"]
|
|
216
|
+
caps << "vision" if modalities[:input].intersect?(%w[image video pdf])
|
|
217
|
+
caps.uniq
|
|
218
|
+
end
|
|
219
|
+
|
|
220
|
+
# Build a pricing hash from raw cost data.
|
|
221
|
+
# @param cost [Hash, nil] cost object from the API
|
|
222
|
+
# @return [Hash]
|
|
223
|
+
def build_pricing(cost)
|
|
224
|
+
return {} unless cost
|
|
225
|
+
|
|
226
|
+
text_standard = {
|
|
227
|
+
input_per_million: cost["input"],
|
|
228
|
+
output_per_million: cost["output"],
|
|
229
|
+
cache_read_input_per_million: cost["cache_read"],
|
|
230
|
+
cache_write_input_per_million: cost["cache_write"],
|
|
231
|
+
reasoning_output_per_million: cost["reasoning"]
|
|
232
|
+
}.compact
|
|
233
|
+
|
|
234
|
+
audio_standard = {
|
|
235
|
+
input_per_million: cost["input_audio"],
|
|
236
|
+
output_per_million: cost["output_audio"]
|
|
237
|
+
}.compact
|
|
238
|
+
|
|
239
|
+
pricing = {}
|
|
240
|
+
pricing[:text_tokens] = { standard: text_standard } if text_standard.any?
|
|
241
|
+
pricing[:audio_tokens] = { standard: audio_standard } if audio_standard.any?
|
|
242
|
+
pricing
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
# Parse a date from a string, returning nil on failure.
|
|
246
|
+
# @param value [String, Date, nil]
|
|
247
|
+
# @return [Date, nil]
|
|
248
|
+
def parse_date(value)
|
|
249
|
+
return nil if value.nil?
|
|
250
|
+
return value if value.is_a?(Date)
|
|
251
|
+
|
|
252
|
+
Date.parse(value.to_s)
|
|
253
|
+
rescue ArgumentError
|
|
254
|
+
nil
|
|
255
|
+
end
|
|
256
|
+
end
|
|
257
|
+
|
|
258
|
+
# Catalog of available AI models. Provides model resolution by name/ID,
|
|
259
|
+
# filtering by capability, and refresh from the models.dev API.
|
|
260
|
+
#
|
|
261
|
+
# Ask::ModelCatalog.find("gpt-4o")
|
|
262
|
+
# Ask::ModelCatalog.chat_models
|
|
263
|
+
# Ask::ModelCatalog.refresh!
|
|
264
|
+
#
|
|
265
|
+
class ModelCatalog
|
|
266
|
+
include Enumerable
|
|
267
|
+
|
|
268
|
+
# @return [String] URL for the models.dev API
|
|
269
|
+
MODELS_DEV_URL = "https://models.dev/api.json".freeze
|
|
270
|
+
|
|
271
|
+
# Ordered provider preference for disambiguation.
|
|
272
|
+
PROVIDER_PREFERENCE = %w[
|
|
273
|
+
openai anthropic gemini vertexai bedrock
|
|
274
|
+
openrouter deepseek mistral perplexity xai
|
|
275
|
+
azure ollama gpustack github
|
|
276
|
+
].freeze
|
|
277
|
+
|
|
278
|
+
# Methods delegated to the singleton instance.
|
|
279
|
+
DELEGATES = %i[all each find chat_models embedding_models
|
|
280
|
+
audio_models image_models by_family by_provider
|
|
281
|
+
refresh!].freeze
|
|
282
|
+
|
|
283
|
+
class << self
|
|
284
|
+
DELEGATES.each do |method|
|
|
285
|
+
define_method(method) do |*args, **kwargs, &block|
|
|
286
|
+
instance.public_send(method, *args, **kwargs, &block)
|
|
287
|
+
end
|
|
288
|
+
end
|
|
289
|
+
|
|
290
|
+
# @return [Ask::ModelCatalog] the process-wide singleton instance
|
|
291
|
+
def instance
|
|
292
|
+
@instance ||= new
|
|
293
|
+
end
|
|
294
|
+
|
|
295
|
+
# Reset the singleton instance (useful for testing).
|
|
296
|
+
def reset_instance!
|
|
297
|
+
@instance = nil
|
|
298
|
+
end
|
|
299
|
+
end
|
|
300
|
+
|
|
301
|
+
# @param models [Array<Ask::ModelInfo>, nil] initial model list
|
|
302
|
+
def initialize(models = nil)
|
|
303
|
+
@models = models || []
|
|
304
|
+
end
|
|
305
|
+
|
|
306
|
+
# --- Querying ---
|
|
307
|
+
|
|
308
|
+
# @return [Array<Ask::ModelInfo>] all models in the catalog
|
|
309
|
+
def all
|
|
310
|
+
@models
|
|
311
|
+
end
|
|
312
|
+
|
|
313
|
+
# @yield [Ask::ModelInfo]
|
|
314
|
+
# @return [Enumerator]
|
|
315
|
+
def each(&block)
|
|
316
|
+
@models.each(&block)
|
|
317
|
+
end
|
|
318
|
+
|
|
319
|
+
# Find a model by ID, optionally scoped to a provider.
|
|
320
|
+
# @param model_id [String] model identifier
|
|
321
|
+
# @param provider [String, nil] provider slug
|
|
322
|
+
# @return [Ask::ModelInfo]
|
|
323
|
+
# @raise [ModelNotFound] if the model is not found
|
|
324
|
+
def find(model_id, provider = nil)
|
|
325
|
+
if provider
|
|
326
|
+
find_with_provider(model_id, provider.to_s)
|
|
327
|
+
else
|
|
328
|
+
find_without_provider(model_id)
|
|
329
|
+
end
|
|
330
|
+
end
|
|
331
|
+
|
|
332
|
+
# @return [Ask::ModelCatalog] new catalog containing only chat models
|
|
333
|
+
def chat_models
|
|
334
|
+
self.class.new(@models.select(&:chat?))
|
|
335
|
+
end
|
|
336
|
+
|
|
337
|
+
# @return [Ask::ModelCatalog] new catalog containing only embedding models
|
|
338
|
+
def embedding_models
|
|
339
|
+
self.class.new(@models.select(&:embedding?))
|
|
340
|
+
end
|
|
341
|
+
|
|
342
|
+
# @return [Ask::ModelCatalog] new catalog containing only audio models
|
|
343
|
+
def audio_models
|
|
344
|
+
self.class.new(@models.select(&:audio?))
|
|
345
|
+
end
|
|
346
|
+
|
|
347
|
+
# @return [Ask::ModelCatalog] new catalog containing only image models
|
|
348
|
+
def image_models
|
|
349
|
+
self.class.new(@models.select(&:image?))
|
|
350
|
+
end
|
|
351
|
+
|
|
352
|
+
# @param family [String] family name
|
|
353
|
+
# @return [Ask::ModelCatalog] new catalog filtered by family
|
|
354
|
+
def by_family(family)
|
|
355
|
+
self.class.new(@models.select { |m| m.family.to_s == family.to_s })
|
|
356
|
+
end
|
|
357
|
+
|
|
358
|
+
# @param provider [String] provider slug
|
|
359
|
+
# @return [Ask::ModelCatalog] new catalog filtered by provider
|
|
360
|
+
def by_provider(provider)
|
|
361
|
+
self.class.new(@models.select { |m| m.provider == provider.to_s })
|
|
362
|
+
end
|
|
363
|
+
|
|
364
|
+
# @return [Integer] number of models
|
|
365
|
+
def length = @models.length
|
|
366
|
+
alias size length
|
|
367
|
+
|
|
368
|
+
# --- Refresh from models.dev ---
|
|
369
|
+
|
|
370
|
+
# Fetch the latest model data from the models.dev API.
|
|
371
|
+
# Falls back to current models if the API is unreachable.
|
|
372
|
+
# @param timeout [Integer] HTTP timeout in seconds
|
|
373
|
+
# @return [self]
|
|
374
|
+
def refresh!(timeout: 10)
|
|
375
|
+
@models = fetch_from_models_dev(timeout: timeout)
|
|
376
|
+
self
|
|
377
|
+
end
|
|
378
|
+
|
|
379
|
+
# --- Registration ---
|
|
380
|
+
|
|
381
|
+
# Register a single model, skipping duplicates.
|
|
382
|
+
# @param model [Ask::ModelInfo]
|
|
383
|
+
# @return [self]
|
|
384
|
+
def register(model)
|
|
385
|
+
@models << model unless @models.any? { |m| m.id == model.id && m.provider == model.provider }
|
|
386
|
+
self
|
|
387
|
+
end
|
|
388
|
+
|
|
389
|
+
private
|
|
390
|
+
|
|
391
|
+
def find_with_provider(model_id, provider)
|
|
392
|
+
exact = @models.find { |m| m.id == model_id && m.provider == provider }
|
|
393
|
+
return exact if exact
|
|
394
|
+
|
|
395
|
+
@models.find { |m| m.id == model_id && m.provider == provider } ||
|
|
396
|
+
raise(ModelNotFound, "Model #{model_id.inspect} not found for provider #{provider.inspect}. " \
|
|
397
|
+
"Try ModelCatalog.refresh! to update the catalog.")
|
|
398
|
+
end
|
|
399
|
+
|
|
400
|
+
def find_without_provider(model_id)
|
|
401
|
+
matches = @models.select { |m| m.id == model_id }
|
|
402
|
+
return preferred_match(matches) if matches.any?
|
|
403
|
+
|
|
404
|
+
raise ModelNotFound, "Unknown model: #{model_id.inspect}. " \
|
|
405
|
+
"Try ModelCatalog.refresh! to update the catalog."
|
|
406
|
+
end
|
|
407
|
+
|
|
408
|
+
def preferred_match(candidates)
|
|
409
|
+
return candidates.first if candidates.size == 1
|
|
410
|
+
|
|
411
|
+
candidates.min_by do |model|
|
|
412
|
+
PROVIDER_PREFERENCE.index(model.provider) || PROVIDER_PREFERENCE.length
|
|
413
|
+
end
|
|
414
|
+
end
|
|
415
|
+
|
|
416
|
+
def fetch_from_models_dev(timeout: 10)
|
|
417
|
+
uri = URI(MODELS_DEV_URL)
|
|
418
|
+
http = Net::HTTP.new(uri.host, uri.port)
|
|
419
|
+
http.use_ssl = true
|
|
420
|
+
http.open_timeout = timeout
|
|
421
|
+
http.read_timeout = timeout
|
|
422
|
+
|
|
423
|
+
request = Net::HTTP::Get.new(uri)
|
|
424
|
+
response = http.request(request)
|
|
425
|
+
|
|
426
|
+
unless response.is_a?(Net::HTTPOK)
|
|
427
|
+
warn "Failed to fetch models.dev: HTTP #{response.code}. Keeping existing models."
|
|
428
|
+
return @models
|
|
429
|
+
end
|
|
430
|
+
|
|
431
|
+
providers_data = JSON.parse(response.body)
|
|
432
|
+
ModelsDevParser.parse(providers_data)
|
|
433
|
+
rescue StandardError => e
|
|
434
|
+
warn "Failed to fetch models.dev: #{e.class}: #{e.message}. Keeping existing models."
|
|
435
|
+
@models
|
|
436
|
+
end
|
|
437
|
+
end
|
|
438
|
+
end
|
data/lib/ask/provider.rb
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ask
|
|
4
|
+
# Abstract base class for all LLM providers. Defines the interface that
|
|
5
|
+
# provider gems (ask-openai, ask-anthropic, etc.) must implement.
|
|
6
|
+
#
|
|
7
|
+
# Provider gems subclass this and implement the abstract methods:
|
|
8
|
+
# - {#chat} — send a chat completion request
|
|
9
|
+
# - {#embed} — generate embeddings
|
|
10
|
+
# - {#list_models} — list available models
|
|
11
|
+
# - {#api_base} — provider API base URL
|
|
12
|
+
#
|
|
13
|
+
# Providers register themselves via {.register} so that
|
|
14
|
+
# {.resolve} returns the correct class by name.
|
|
15
|
+
#
|
|
16
|
+
# @example Defining a custom provider
|
|
17
|
+
# class MyProvider < Ask::Provider
|
|
18
|
+
# def api_base = "https://api.example.com/v1"
|
|
19
|
+
# def headers = { "Authorization" => "Bearer #{@config.api_key}" }
|
|
20
|
+
# def chat(messages, model:, **opts, &block) = # ...
|
|
21
|
+
# def embed(text, model:) = # ...
|
|
22
|
+
# def list_models = # ...
|
|
23
|
+
# end
|
|
24
|
+
# Ask::Provider.register(:my, MyProvider)
|
|
25
|
+
#
|
|
26
|
+
class Provider
|
|
27
|
+
# Global mutex protecting the provider registry.
|
|
28
|
+
REGISTRY_MUTEX = Mutex.new
|
|
29
|
+
private_constant :REGISTRY_MUTEX
|
|
30
|
+
|
|
31
|
+
# @return [Object] the configuration object passed to the constructor
|
|
32
|
+
attr_reader :config
|
|
33
|
+
|
|
34
|
+
# @param config [Object] provider configuration (must respond to configuration_requirements)
|
|
35
|
+
def initialize(config = {})
|
|
36
|
+
@config = config
|
|
37
|
+
ensure_configured!
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# --- Abstract interface (provider gems implement these) ---
|
|
41
|
+
|
|
42
|
+
# Send a chat completion request.
|
|
43
|
+
# @param messages [Array<Ask::Message>] conversation messages
|
|
44
|
+
# @param model [String] model ID to use
|
|
45
|
+
# @param tools [Array<Ask::ToolDef>, nil] tool definitions
|
|
46
|
+
# @param temperature [Float, nil] sampling temperature
|
|
47
|
+
# @param stream [Boolean, nil] if true, yield {Ask::Chunk}s to the block
|
|
48
|
+
# @param schema [Hash, nil] JSON schema for structured output
|
|
49
|
+
# @yield [Ask::Chunk] yields chunks when streaming
|
|
50
|
+
# @return [Ask::Message] the assistant's response
|
|
51
|
+
def chat(messages, model:, tools: nil, temperature: nil, stream: nil, schema: nil, **params, &block)
|
|
52
|
+
raise NotImplementedError, "#{self.class} must implement #chat"
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
# Generate embeddings for the given text.
|
|
56
|
+
# @param text [String] input text
|
|
57
|
+
# @param model [String] embedding model ID
|
|
58
|
+
# @return [Array<Float>] embedding vector
|
|
59
|
+
def embed(text, model:)
|
|
60
|
+
raise NotImplementedError, "#{self.class} must implement #embed"
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
# List available models from this provider.
|
|
64
|
+
# @return [Array<Ask::ModelInfo>] available models
|
|
65
|
+
def list_models
|
|
66
|
+
raise NotImplementedError, "#{self.class} must implement #list_models"
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
# @abstract The base URL for this provider's API.
|
|
70
|
+
# @return [String]
|
|
71
|
+
def api_base
|
|
72
|
+
raise NotImplementedError, "#{self.class} must implement #api_base"
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
# --- Optional overrides ---
|
|
76
|
+
|
|
77
|
+
# Additional HTTP headers for API requests.
|
|
78
|
+
# @return [Hash<String, String>]
|
|
79
|
+
def headers
|
|
80
|
+
{}
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
# Parse an error response body into a human-readable message.
|
|
84
|
+
# @param response [Object] the error response
|
|
85
|
+
# @return [String, nil]
|
|
86
|
+
def parse_error(response)
|
|
87
|
+
nil
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
# @return [Boolean] true if the provider runs locally (e.g., Ollama)
|
|
91
|
+
def local? = self.class.local?
|
|
92
|
+
|
|
93
|
+
# @return [Boolean] true if the provider requires a remote API
|
|
94
|
+
def remote? = !local?
|
|
95
|
+
|
|
96
|
+
# @return [Boolean] true if all models can be assumed to exist
|
|
97
|
+
def assume_models_exist? = self.class.assume_models_exist?
|
|
98
|
+
|
|
99
|
+
# --- Slug / name / capabilities ---
|
|
100
|
+
|
|
101
|
+
# @return [String] lowercased provider slug
|
|
102
|
+
def slug
|
|
103
|
+
self.class.slug
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
# @return [String] provider name (demodulized class name)
|
|
107
|
+
def name
|
|
108
|
+
self.class.name
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
# @return [Hash, nil] provider capabilities metadata
|
|
112
|
+
def capabilities
|
|
113
|
+
self.class.capabilities
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
# --- Registry (class-level) ---
|
|
117
|
+
|
|
118
|
+
class << self
|
|
119
|
+
# Register a provider class so it can be resolved by name.
|
|
120
|
+
# Thread-safe via {REGISTRY_MUTEX}.
|
|
121
|
+
# @param name [Symbol] short name for the provider
|
|
122
|
+
# @param provider_class [Class<Ask::Provider>] the provider class
|
|
123
|
+
def register(name, provider_class)
|
|
124
|
+
REGISTRY_MUTEX.synchronize do
|
|
125
|
+
registry[name.to_sym] = provider_class
|
|
126
|
+
end
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
# Resolve a registered provider by name.
|
|
130
|
+
# Thread-safe via {REGISTRY_MUTEX}.
|
|
131
|
+
# @param name [Symbol, String] provider name
|
|
132
|
+
# @return [Class<Ask::Provider>]
|
|
133
|
+
# @raise [UnknownProvider] if not registered
|
|
134
|
+
def resolve(name)
|
|
135
|
+
REGISTRY_MUTEX.synchronize do
|
|
136
|
+
registry[name.to_sym] || raise(UnknownProvider,
|
|
137
|
+
"Unknown provider: #{name.inspect}. " \
|
|
138
|
+
"Available: #{registry.keys.join(', ')}")
|
|
139
|
+
end
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
# Return a shallow copy of all registered providers.
|
|
143
|
+
# Thread-safe via {REGISTRY_MUTEX}.
|
|
144
|
+
# @return [Hash{Symbol => Class<Ask::Provider>}]
|
|
145
|
+
def providers
|
|
146
|
+
REGISTRY_MUTEX.synchronize do
|
|
147
|
+
registry.dup
|
|
148
|
+
end
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
# Clear all registered providers (used in testing).
|
|
152
|
+
def clear_providers!
|
|
153
|
+
REGISTRY_MUTEX.synchronize do
|
|
154
|
+
@registry = {}
|
|
155
|
+
end
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
# @return [String] lowercased, underscored slug from the class name
|
|
159
|
+
def slug
|
|
160
|
+
name.split("::").last.gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2')
|
|
161
|
+
.gsub(/([a-z\d])([A-Z])/, '\1_\2')
|
|
162
|
+
.downcase
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
# @return [String] class name without module prefix
|
|
166
|
+
def name
|
|
167
|
+
to_s.split("::").last
|
|
168
|
+
end
|
|
169
|
+
|
|
170
|
+
# @return [Hash, nil] capabilities
|
|
171
|
+
def capabilities
|
|
172
|
+
nil
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
# @return [Array<Symbol>] config keys this provider supports
|
|
176
|
+
def configuration_options
|
|
177
|
+
[]
|
|
178
|
+
end
|
|
179
|
+
|
|
180
|
+
# @return [Array<Symbol>] config keys this provider requires
|
|
181
|
+
def configuration_requirements
|
|
182
|
+
[]
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
# Check if this provider is fully configured.
|
|
186
|
+
# @param config [Object] configuration object
|
|
187
|
+
# @return [Boolean]
|
|
188
|
+
def configured?(config)
|
|
189
|
+
configuration_requirements.all? { |req| config.respond_to?(req) && config.public_send(req) }
|
|
190
|
+
end
|
|
191
|
+
|
|
192
|
+
# @return [Boolean] true if this provider runs locally
|
|
193
|
+
def local?
|
|
194
|
+
false
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
# @return [Boolean] true if this provider requires a remote API
|
|
198
|
+
def remote?
|
|
199
|
+
!local?
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
# @return [Boolean] whether all models can be assumed to exist
|
|
203
|
+
def assume_models_exist?
|
|
204
|
+
false
|
|
205
|
+
end
|
|
206
|
+
|
|
207
|
+
private
|
|
208
|
+
|
|
209
|
+
# The internal registry hash (access must be synchronized via REGISTRY_MUTEX).
|
|
210
|
+
def registry
|
|
211
|
+
@registry ||= {}
|
|
212
|
+
end
|
|
213
|
+
end
|
|
214
|
+
|
|
215
|
+
private
|
|
216
|
+
|
|
217
|
+
def normalize_config(config)
|
|
218
|
+
config
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
def ensure_configured!
|
|
222
|
+
missing = self.class.configuration_requirements.reject do |req|
|
|
223
|
+
@config.respond_to?(req) && @config.public_send(req)
|
|
224
|
+
end
|
|
225
|
+
return if missing.empty?
|
|
226
|
+
|
|
227
|
+
raise ConfigurationError,
|
|
228
|
+
"Missing configuration for #{self.class.name}: #{missing.join(', ')}. " \
|
|
229
|
+
"Set these keys on your provider config before using this provider."
|
|
230
|
+
end
|
|
231
|
+
end
|
|
232
|
+
end
|