ask-core 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/lib/ask/models.rb ADDED
@@ -0,0 +1,438 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+ require "net/http"
5
+ require "time"
6
+
7
+ module Ask
8
+ # Model metadata: capabilities, pricing, context window, modalities.
9
+ # Immutable value object representing a single model entry.
10
+ class ModelInfo
11
+ # @return [String] model identifier (e.g. "gpt-4o")
12
+ attr_reader :id
13
+
14
+ # @return [String] human-readable model name
15
+ attr_reader :name
16
+
17
+ # @return [String] provider slug (e.g. "openai")
18
+ attr_reader :provider
19
+
20
+ # @return [String, nil] model family (e.g. "gpt", "claude")
21
+ attr_reader :family
22
+
23
+ # @return [Array<String>] capability strings
24
+ attr_reader :capabilities
25
+
26
+ # @return [Integer, nil] maximum context window in tokens
27
+ attr_reader :context_window
28
+
29
+ # @return [Integer, nil] maximum output tokens
30
+ attr_reader :max_output_tokens
31
+
32
+ # @return [Hash] input/output modalities
33
+ attr_reader :modalities
34
+
35
+ # @return [Hash] pricing information
36
+ attr_reader :pricing
37
+
38
+ # @return [Date, nil] knowledge cutoff date
39
+ attr_reader :knowledge_cutoff
40
+
41
+ # @return [String, nil] creation/publication date
42
+ attr_reader :created_at
43
+
44
+ # @return [Hash] additional metadata
45
+ attr_reader :metadata
46
+
47
+ def initialize(id:, name: nil, provider:, family: nil, capabilities: [],
48
+ context_window: nil, max_output_tokens: nil,
49
+ modalities: {}, pricing: {}, knowledge_cutoff: nil,
50
+ created_at: nil, metadata: {})
51
+ @id = id
52
+ @name = name || id
53
+ @provider = provider.to_s
54
+ @family = family
55
+ @capabilities = Array(capabilities).map(&:to_s)
56
+ @context_window = context_window
57
+ @max_output_tokens = max_output_tokens
58
+ @modalities = modalities
59
+ @pricing = pricing
60
+ @knowledge_cutoff = knowledge_cutoff
61
+ @created_at = created_at
62
+ @metadata = metadata
63
+ freeze
64
+ end
65
+
66
+ # @return [Boolean] true if this is a chat model
67
+ def chat? = type == "chat"
68
+
69
+ # @return [Boolean] true if this is an embedding model
70
+ def embedding? = type == "embedding" || modalities.dig(:output)&.include?("embeddings")
71
+
72
+ # @return [Boolean] true if this model supports audio output
73
+ def audio? = modalities.dig(:output)&.include?("audio")
74
+
75
+ # @return [Boolean] true if this model supports image output
76
+ def image? = modalities.dig(:output)&.include?("image")
77
+
78
+ # Check if this model supports a given capability.
79
+ # @param capability [String, Symbol] capability name
80
+ # @return [Boolean]
81
+ def supports?(capability)
82
+ capabilities.include?(capability.to_s)
83
+ end
84
+
85
+ # @return [String] model type ("chat", "embedding", "audio", "image")
86
+ def type
87
+ @metadata[:type] || infer_type
88
+ end
89
+
90
+ # @return [Hash] serialized model info
91
+ def to_h
92
+ {
93
+ id: @id,
94
+ name: @name,
95
+ provider: @provider,
96
+ family: @family,
97
+ capabilities: @capabilities,
98
+ context_window: @context_window,
99
+ max_output_tokens: @max_output_tokens,
100
+ modalities: @modalities,
101
+ pricing: @pricing,
102
+ knowledge_cutoff: @knowledge_cutoff,
103
+ created_at: @created_at,
104
+ metadata: @metadata
105
+ }.compact
106
+ end
107
+
108
+ # @return [String]
109
+ def inspect
110
+ "#<Ask::ModelInfo id=#{@id.inspect} provider=#{@provider.inspect}>"
111
+ end
112
+
113
+ private
114
+
115
+ def infer_type
116
+ return "chat" if supports?(:function_calling) || supports?(:structured_output)
117
+ return "embedding" if @id.to_s.include?("embedding")
118
+ return "audio" if audio?
119
+ return "image" if image?
120
+ "chat"
121
+ end
122
+ end
123
+
124
+ # Parses raw models.dev API response JSON into {ModelInfo} objects.
125
+ # Extracted into a module for independent unit testing.
126
+ module ModelsDevParser
127
+ # Maps models.dev provider keys to ask-rb provider slugs.
128
+ PROVIDER_MAP = {
129
+ "openai" => "openai",
130
+ "anthropic" => "anthropic",
131
+ "google" => "gemini",
132
+ "google-vertex" => "vertexai",
133
+ "amazon-bedrock" => "bedrock",
134
+ "deepseek" => "deepseek",
135
+ "mistral" => "mistral",
136
+ "openrouter" => "openrouter",
137
+ "perplexity" => "perplexity",
138
+ "xai" => "xai",
139
+ "github" => "github"
140
+ }.freeze
141
+
142
+ INPUT_MODALITIES = %w[text image audio pdf video file].freeze
143
+ OUTPUT_MODALITIES = %w[text image audio video embeddings moderation].freeze
144
+
145
+ module_function
146
+
147
+ # Parse a raw models.dev API response into {ModelInfo} objects.
148
+ # @param api_response [Hash] the parsed JSON from models.dev/api.json
149
+ # @return [Array<Ask::ModelInfo>]
150
+ def parse(api_response)
151
+ api_response.flat_map do |provider_key, provider_data|
152
+ provider_slug = PROVIDER_MAP[provider_key.to_s]
153
+ next [] unless provider_slug
154
+
155
+ models_data = provider_data.dig("models") || {}
156
+ models_data.values.map do |model_data|
157
+ build_model(model_data, provider_slug, provider_key.to_s)
158
+ end
159
+ end.compact
160
+ end
161
+
162
+ # Build a {ModelInfo} from a single model entry in the models.dev response.
163
+ # @param model_data [Hash] the model data from the API
164
+ # @param provider_slug [String] normalized provider slug
165
+ # @param provider_key [String] original provider key from the API
166
+ # @return [Ask::ModelInfo]
167
+ def build_model(model_data, provider_slug, provider_key)
168
+ modalities = normalize_modalities(model_data["modalities"])
169
+ capabilities = extract_capabilities(model_data, modalities)
170
+ pricing = build_pricing(model_data["cost"])
171
+ created_date = [model_data["release_date"], model_data["last_updated"]]
172
+ .find { |v| v && !v.to_s.strip.empty? }
173
+
174
+ ModelInfo.new(
175
+ id: model_data["id"],
176
+ name: model_data["name"] || model_data["id"],
177
+ provider: provider_slug,
178
+ family: model_data["family"],
179
+ capabilities: capabilities,
180
+ context_window: model_data.dig("limit", "context"),
181
+ max_output_tokens: model_data.dig("limit", "output"),
182
+ modalities: modalities,
183
+ pricing: pricing,
184
+ knowledge_cutoff: parse_date(model_data["knowledge"]),
185
+ created_at: created_date,
186
+ metadata: {
187
+ source: "models.dev",
188
+ provider_id: provider_key,
189
+ open_weights: model_data["open_weights"],
190
+ status: model_data["status"],
191
+ reasoning_options: model_data["reasoning_options"]
192
+ }.compact
193
+ )
194
+ end
195
+
196
+ # @param modalities [Hash, nil] raw modalities hash
197
+ # @return [Hash{Symbol => Array<String>}] normalized with known modality filters
198
+ def normalize_modalities(modalities)
199
+ return { input: [], output: [] } unless modalities
200
+
201
+ {
202
+ input: Array(modalities["input"]).compact & INPUT_MODALITIES,
203
+ output: Array(modalities["output"]).compact & OUTPUT_MODALITIES
204
+ }
205
+ end
206
+
207
+ # Extract capability strings from model data and modalities.
208
+ # @param model_data [Hash] raw model data
209
+ # @param modalities [Hash] normalized modalities
210
+ # @return [Array<String>]
211
+ def extract_capabilities(model_data, modalities)
212
+ caps = []
213
+ caps << "function_calling" if model_data["tool_call"]
214
+ caps << "structured_output" if model_data["structured_output"]
215
+ caps << "reasoning" if model_data["reasoning"] || model_data["reasoning_options"]
216
+ caps << "vision" if modalities[:input].intersect?(%w[image video pdf])
217
+ caps.uniq
218
+ end
219
+
220
+ # Build a pricing hash from raw cost data.
221
+ # @param cost [Hash, nil] cost object from the API
222
+ # @return [Hash]
223
+ def build_pricing(cost)
224
+ return {} unless cost
225
+
226
+ text_standard = {
227
+ input_per_million: cost["input"],
228
+ output_per_million: cost["output"],
229
+ cache_read_input_per_million: cost["cache_read"],
230
+ cache_write_input_per_million: cost["cache_write"],
231
+ reasoning_output_per_million: cost["reasoning"]
232
+ }.compact
233
+
234
+ audio_standard = {
235
+ input_per_million: cost["input_audio"],
236
+ output_per_million: cost["output_audio"]
237
+ }.compact
238
+
239
+ pricing = {}
240
+ pricing[:text_tokens] = { standard: text_standard } if text_standard.any?
241
+ pricing[:audio_tokens] = { standard: audio_standard } if audio_standard.any?
242
+ pricing
243
+ end
244
+
245
+ # Parse a date from a string, returning nil on failure.
246
+ # @param value [String, Date, nil]
247
+ # @return [Date, nil]
248
+ def parse_date(value)
249
+ return nil if value.nil?
250
+ return value if value.is_a?(Date)
251
+
252
+ Date.parse(value.to_s)
253
+ rescue ArgumentError
254
+ nil
255
+ end
256
+ end
257
+
258
+ # Catalog of available AI models. Provides model resolution by name/ID,
259
+ # filtering by capability, and refresh from the models.dev API.
260
+ #
261
+ # Ask::ModelCatalog.find("gpt-4o")
262
+ # Ask::ModelCatalog.chat_models
263
+ # Ask::ModelCatalog.refresh!
264
+ #
265
+ class ModelCatalog
266
+ include Enumerable
267
+
268
+ # @return [String] URL for the models.dev API
269
+ MODELS_DEV_URL = "https://models.dev/api.json".freeze
270
+
271
+ # Ordered provider preference for disambiguation.
272
+ PROVIDER_PREFERENCE = %w[
273
+ openai anthropic gemini vertexai bedrock
274
+ openrouter deepseek mistral perplexity xai
275
+ azure ollama gpustack github
276
+ ].freeze
277
+
278
+ # Methods delegated to the singleton instance.
279
+ DELEGATES = %i[all each find chat_models embedding_models
280
+ audio_models image_models by_family by_provider
281
+ refresh!].freeze
282
+
283
+ class << self
284
+ DELEGATES.each do |method|
285
+ define_method(method) do |*args, **kwargs, &block|
286
+ instance.public_send(method, *args, **kwargs, &block)
287
+ end
288
+ end
289
+
290
+ # @return [Ask::ModelCatalog] the process-wide singleton instance
291
+ def instance
292
+ @instance ||= new
293
+ end
294
+
295
+ # Reset the singleton instance (useful for testing).
296
+ def reset_instance!
297
+ @instance = nil
298
+ end
299
+ end
300
+
301
+ # @param models [Array<Ask::ModelInfo>, nil] initial model list
302
+ def initialize(models = nil)
303
+ @models = models || []
304
+ end
305
+
306
+ # --- Querying ---
307
+
308
+ # @return [Array<Ask::ModelInfo>] all models in the catalog
309
+ def all
310
+ @models
311
+ end
312
+
313
+ # @yield [Ask::ModelInfo]
314
+ # @return [Enumerator]
315
+ def each(&block)
316
+ @models.each(&block)
317
+ end
318
+
319
+ # Find a model by ID, optionally scoped to a provider.
320
+ # @param model_id [String] model identifier
321
+ # @param provider [String, nil] provider slug
322
+ # @return [Ask::ModelInfo]
323
+ # @raise [ModelNotFound] if the model is not found
324
+ def find(model_id, provider = nil)
325
+ if provider
326
+ find_with_provider(model_id, provider.to_s)
327
+ else
328
+ find_without_provider(model_id)
329
+ end
330
+ end
331
+
332
+ # @return [Ask::ModelCatalog] new catalog containing only chat models
333
+ def chat_models
334
+ self.class.new(@models.select(&:chat?))
335
+ end
336
+
337
+ # @return [Ask::ModelCatalog] new catalog containing only embedding models
338
+ def embedding_models
339
+ self.class.new(@models.select(&:embedding?))
340
+ end
341
+
342
+ # @return [Ask::ModelCatalog] new catalog containing only audio models
343
+ def audio_models
344
+ self.class.new(@models.select(&:audio?))
345
+ end
346
+
347
+ # @return [Ask::ModelCatalog] new catalog containing only image models
348
+ def image_models
349
+ self.class.new(@models.select(&:image?))
350
+ end
351
+
352
+ # @param family [String] family name
353
+ # @return [Ask::ModelCatalog] new catalog filtered by family
354
+ def by_family(family)
355
+ self.class.new(@models.select { |m| m.family.to_s == family.to_s })
356
+ end
357
+
358
+ # @param provider [String] provider slug
359
+ # @return [Ask::ModelCatalog] new catalog filtered by provider
360
+ def by_provider(provider)
361
+ self.class.new(@models.select { |m| m.provider == provider.to_s })
362
+ end
363
+
364
+ # @return [Integer] number of models
365
+ def length = @models.length
366
+ alias size length
367
+
368
+ # --- Refresh from models.dev ---
369
+
370
+ # Fetch the latest model data from the models.dev API.
371
+ # Falls back to current models if the API is unreachable.
372
+ # @param timeout [Integer] HTTP timeout in seconds
373
+ # @return [self]
374
+ def refresh!(timeout: 10)
375
+ @models = fetch_from_models_dev(timeout: timeout)
376
+ self
377
+ end
378
+
379
+ # --- Registration ---
380
+
381
+ # Register a single model, skipping duplicates.
382
+ # @param model [Ask::ModelInfo]
383
+ # @return [self]
384
+ def register(model)
385
+ @models << model unless @models.any? { |m| m.id == model.id && m.provider == model.provider }
386
+ self
387
+ end
388
+
389
+ private
390
+
391
+ def find_with_provider(model_id, provider)
392
+ exact = @models.find { |m| m.id == model_id && m.provider == provider }
393
+ return exact if exact
394
+
395
+ @models.find { |m| m.id == model_id && m.provider == provider } ||
396
+ raise(ModelNotFound, "Model #{model_id.inspect} not found for provider #{provider.inspect}. " \
397
+ "Try ModelCatalog.refresh! to update the catalog.")
398
+ end
399
+
400
+ def find_without_provider(model_id)
401
+ matches = @models.select { |m| m.id == model_id }
402
+ return preferred_match(matches) if matches.any?
403
+
404
+ raise ModelNotFound, "Unknown model: #{model_id.inspect}. " \
405
+ "Try ModelCatalog.refresh! to update the catalog."
406
+ end
407
+
408
+ def preferred_match(candidates)
409
+ return candidates.first if candidates.size == 1
410
+
411
+ candidates.min_by do |model|
412
+ PROVIDER_PREFERENCE.index(model.provider) || PROVIDER_PREFERENCE.length
413
+ end
414
+ end
415
+
416
+ def fetch_from_models_dev(timeout: 10)
417
+ uri = URI(MODELS_DEV_URL)
418
+ http = Net::HTTP.new(uri.host, uri.port)
419
+ http.use_ssl = true
420
+ http.open_timeout = timeout
421
+ http.read_timeout = timeout
422
+
423
+ request = Net::HTTP::Get.new(uri)
424
+ response = http.request(request)
425
+
426
+ unless response.is_a?(Net::HTTPOK)
427
+ warn "Failed to fetch models.dev: HTTP #{response.code}. Keeping existing models."
428
+ return @models
429
+ end
430
+
431
+ providers_data = JSON.parse(response.body)
432
+ ModelsDevParser.parse(providers_data)
433
+ rescue StandardError => e
434
+ warn "Failed to fetch models.dev: #{e.class}: #{e.message}. Keeping existing models."
435
+ @models
436
+ end
437
+ end
438
+ end
@@ -0,0 +1,232 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ask
4
+ # Abstract base class for all LLM providers. Defines the interface that
5
+ # provider gems (ask-openai, ask-anthropic, etc.) must implement.
6
+ #
7
+ # Provider gems subclass this and implement the abstract methods:
8
+ # - {#chat} — send a chat completion request
9
+ # - {#embed} — generate embeddings
10
+ # - {#list_models} — list available models
11
+ # - {#api_base} — provider API base URL
12
+ #
13
+ # Providers register themselves via {.register} so that
14
+ # {.resolve} returns the correct class by name.
15
+ #
16
+ # @example Defining a custom provider
17
+ # class MyProvider < Ask::Provider
18
+ # def api_base = "https://api.example.com/v1"
19
+ # def headers = { "Authorization" => "Bearer #{@config.api_key}" }
20
+ # def chat(messages, model:, **opts, &block) = # ...
21
+ # def embed(text, model:) = # ...
22
+ # def list_models = # ...
23
+ # end
24
+ # Ask::Provider.register(:my, MyProvider)
25
+ #
26
+ class Provider
27
+ # Global mutex protecting the provider registry.
28
+ REGISTRY_MUTEX = Mutex.new
29
+ private_constant :REGISTRY_MUTEX
30
+
31
+ # @return [Object] the configuration object passed to the constructor
32
+ attr_reader :config
33
+
34
+ # @param config [Object] provider configuration (must respond to configuration_requirements)
35
+ def initialize(config = {})
36
+ @config = config
37
+ ensure_configured!
38
+ end
39
+
40
+ # --- Abstract interface (provider gems implement these) ---
41
+
42
+ # Send a chat completion request.
43
+ # @param messages [Array<Ask::Message>] conversation messages
44
+ # @param model [String] model ID to use
45
+ # @param tools [Array<Ask::ToolDef>, nil] tool definitions
46
+ # @param temperature [Float, nil] sampling temperature
47
+ # @param stream [Boolean, nil] if true, yield {Ask::Chunk}s to the block
48
+ # @param schema [Hash, nil] JSON schema for structured output
49
+ # @yield [Ask::Chunk] yields chunks when streaming
50
+ # @return [Ask::Message] the assistant's response
51
+ def chat(messages, model:, tools: nil, temperature: nil, stream: nil, schema: nil, **params, &block)
52
+ raise NotImplementedError, "#{self.class} must implement #chat"
53
+ end
54
+
55
+ # Generate embeddings for the given text.
56
+ # @param text [String] input text
57
+ # @param model [String] embedding model ID
58
+ # @return [Array<Float>] embedding vector
59
+ def embed(text, model:)
60
+ raise NotImplementedError, "#{self.class} must implement #embed"
61
+ end
62
+
63
+ # List available models from this provider.
64
+ # @return [Array<Ask::ModelInfo>] available models
65
+ def list_models
66
+ raise NotImplementedError, "#{self.class} must implement #list_models"
67
+ end
68
+
69
+ # @abstract The base URL for this provider's API.
70
+ # @return [String]
71
+ def api_base
72
+ raise NotImplementedError, "#{self.class} must implement #api_base"
73
+ end
74
+
75
+ # --- Optional overrides ---
76
+
77
+ # Additional HTTP headers for API requests.
78
+ # @return [Hash<String, String>]
79
+ def headers
80
+ {}
81
+ end
82
+
83
+ # Parse an error response body into a human-readable message.
84
+ # @param response [Object] the error response
85
+ # @return [String, nil]
86
+ def parse_error(response)
87
+ nil
88
+ end
89
+
90
+ # @return [Boolean] true if the provider runs locally (e.g., Ollama)
91
+ def local? = self.class.local?
92
+
93
+ # @return [Boolean] true if the provider requires a remote API
94
+ def remote? = !local?
95
+
96
+ # @return [Boolean] true if all models can be assumed to exist
97
+ def assume_models_exist? = self.class.assume_models_exist?
98
+
99
+ # --- Slug / name / capabilities ---
100
+
101
+ # @return [String] lowercased provider slug
102
+ def slug
103
+ self.class.slug
104
+ end
105
+
106
+ # @return [String] provider name (demodulized class name)
107
+ def name
108
+ self.class.name
109
+ end
110
+
111
+ # @return [Hash, nil] provider capabilities metadata
112
+ def capabilities
113
+ self.class.capabilities
114
+ end
115
+
116
+ # --- Registry (class-level) ---
117
+
118
+ class << self
119
+ # Register a provider class so it can be resolved by name.
120
+ # Thread-safe via {REGISTRY_MUTEX}.
121
+ # @param name [Symbol] short name for the provider
122
+ # @param provider_class [Class<Ask::Provider>] the provider class
123
+ def register(name, provider_class)
124
+ REGISTRY_MUTEX.synchronize do
125
+ registry[name.to_sym] = provider_class
126
+ end
127
+ end
128
+
129
+ # Resolve a registered provider by name.
130
+ # Thread-safe via {REGISTRY_MUTEX}.
131
+ # @param name [Symbol, String] provider name
132
+ # @return [Class<Ask::Provider>]
133
+ # @raise [UnknownProvider] if not registered
134
+ def resolve(name)
135
+ REGISTRY_MUTEX.synchronize do
136
+ registry[name.to_sym] || raise(UnknownProvider,
137
+ "Unknown provider: #{name.inspect}. " \
138
+ "Available: #{registry.keys.join(', ')}")
139
+ end
140
+ end
141
+
142
+ # Return a shallow copy of all registered providers.
143
+ # Thread-safe via {REGISTRY_MUTEX}.
144
+ # @return [Hash{Symbol => Class<Ask::Provider>}]
145
+ def providers
146
+ REGISTRY_MUTEX.synchronize do
147
+ registry.dup
148
+ end
149
+ end
150
+
151
+ # Clear all registered providers (used in testing).
152
+ def clear_providers!
153
+ REGISTRY_MUTEX.synchronize do
154
+ @registry = {}
155
+ end
156
+ end
157
+
158
+ # @return [String] lowercased, underscored slug from the class name
159
+ def slug
160
+ name.split("::").last.gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2')
161
+ .gsub(/([a-z\d])([A-Z])/, '\1_\2')
162
+ .downcase
163
+ end
164
+
165
+ # @return [String] class name without module prefix
166
+ def name
167
+ to_s.split("::").last
168
+ end
169
+
170
+ # @return [Hash, nil] capabilities
171
+ def capabilities
172
+ nil
173
+ end
174
+
175
+ # @return [Array<Symbol>] config keys this provider supports
176
+ def configuration_options
177
+ []
178
+ end
179
+
180
+ # @return [Array<Symbol>] config keys this provider requires
181
+ def configuration_requirements
182
+ []
183
+ end
184
+
185
+ # Check if this provider is fully configured.
186
+ # @param config [Object] configuration object
187
+ # @return [Boolean]
188
+ def configured?(config)
189
+ configuration_requirements.all? { |req| config.respond_to?(req) && config.public_send(req) }
190
+ end
191
+
192
+ # @return [Boolean] true if this provider runs locally
193
+ def local?
194
+ false
195
+ end
196
+
197
+ # @return [Boolean] true if this provider requires a remote API
198
+ def remote?
199
+ !local?
200
+ end
201
+
202
+ # @return [Boolean] whether all models can be assumed to exist
203
+ def assume_models_exist?
204
+ false
205
+ end
206
+
207
+ private
208
+
209
+ # The internal registry hash (access must be synchronized via REGISTRY_MUTEX).
210
+ def registry
211
+ @registry ||= {}
212
+ end
213
+ end
214
+
215
+ private
216
+
217
+ def normalize_config(config)
218
+ config
219
+ end
220
+
221
+ def ensure_configured!
222
+ missing = self.class.configuration_requirements.reject do |req|
223
+ @config.respond_to?(req) && @config.public_send(req)
224
+ end
225
+ return if missing.empty?
226
+
227
+ raise ConfigurationError,
228
+ "Missing configuration for #{self.class.name}: #{missing.join(', ')}. " \
229
+ "Set these keys on your provider config before using this provider."
230
+ end
231
+ end
232
+ end