ruby_llm 0.1.0.pre35 → 0.1.0.pre37
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/.github/workflows/docs.yml +53 -0
 - data/.rspec_status +7 -35
 - data/.rubocop.yml +7 -2
 - data/.yardopts +12 -0
 - data/Gemfile +27 -0
 - data/bin/console +4 -4
 - data/docs/.gitignore +7 -0
 - data/docs/Gemfile +11 -0
 - data/docs/_config.yml +43 -0
 - data/docs/_data/navigation.yml +25 -0
 - data/docs/guides/chat.md +206 -0
 - data/docs/guides/embeddings.md +325 -0
 - data/docs/guides/error-handling.md +301 -0
 - data/docs/guides/getting-started.md +164 -0
 - data/docs/guides/image-generation.md +274 -0
 - data/docs/guides/index.md +45 -0
 - data/docs/guides/rails.md +401 -0
 - data/docs/guides/streaming.md +242 -0
 - data/docs/guides/tools.md +247 -0
 - data/docs/index.md +53 -0
 - data/docs/installation.md +98 -0
 - data/lib/ruby_llm/active_record/acts_as.rb +2 -2
 - data/lib/ruby_llm/chat.rb +7 -7
 - data/lib/ruby_llm/models.json +27 -27
 - data/lib/ruby_llm/providers/anthropic/capabilities.rb +56 -19
 - data/lib/ruby_llm/providers/anthropic/chat.rb +2 -3
 - data/lib/ruby_llm/providers/deepseek/capabilities.rb +39 -1
 - data/lib/ruby_llm/providers/gemini/capabilities.rb +70 -8
 - data/lib/ruby_llm/providers/openai/capabilities.rb +72 -24
 - data/lib/ruby_llm/providers/openai/embeddings.rb +1 -1
 - data/lib/ruby_llm/version.rb +1 -1
 - data/lib/tasks/models.rake +27 -5
 - data/ruby_llm.gemspec +10 -32
 - metadata +22 -296
 
| 
         @@ -7,45 +7,73 @@ module RubyLLM 
     | 
|
| 
       7 
7 
     | 
    
         
             
                  module Capabilities
         
     | 
| 
       8 
8 
     | 
    
         
             
                    module_function
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
       10 
     | 
    
         
            -
                     
     | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
       13 
     | 
    
         
            -
             
     | 
| 
       14 
     | 
    
         
            -
                       
     | 
| 
      
 10 
     | 
    
         
            +
                    # Determines the context window size for a given model
         
     | 
| 
      
 11 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 12 
     | 
    
         
            +
                    # @return [Integer] the context window size in tokens
         
     | 
| 
      
 13 
     | 
    
         
            +
                    def determine_context_window(_model_id)
         
     | 
| 
      
 14 
     | 
    
         
            +
                      # All Claude 3 and 3.5 and 3.7 models have 200K token context windows
         
     | 
| 
      
 15 
     | 
    
         
            +
                      200_000
         
     | 
| 
       15 
16 
     | 
    
         
             
                    end
         
     | 
| 
       16 
17 
     | 
    
         | 
| 
      
 18 
     | 
    
         
            +
                    # Determines the maximum output tokens for a given model
         
     | 
| 
      
 19 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 20 
     | 
    
         
            +
                    # @return [Integer] the maximum output tokens
         
     | 
| 
       17 
21 
     | 
    
         
             
                    def determine_max_tokens(model_id)
         
     | 
| 
       18 
22 
     | 
    
         
             
                      case model_id
         
     | 
| 
       19 
     | 
    
         
            -
                      when /claude-3-5/ then 8_192
         
     | 
| 
       20 
     | 
    
         
            -
                      else 4_096
         
     | 
| 
      
 23 
     | 
    
         
            +
                      when /claude-3-(7-sonnet|5)/ then 8_192 # Can be increased to 64K with extended thinking
         
     | 
| 
      
 24 
     | 
    
         
            +
                      else 4_096 # Claude 3 Opus and Haiku
         
     | 
| 
       21 
25 
     | 
    
         
             
                      end
         
     | 
| 
       22 
26 
     | 
    
         
             
                    end
         
     | 
| 
       23 
27 
     | 
    
         | 
| 
      
 28 
     | 
    
         
            +
                    # Gets the input price per million tokens for a given model
         
     | 
| 
      
 29 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 30 
     | 
    
         
            +
                    # @return [Float] the price per million tokens for input
         
     | 
| 
       24 
31 
     | 
    
         
             
                    def get_input_price(model_id)
         
     | 
| 
       25 
32 
     | 
    
         
             
                      PRICES.dig(model_family(model_id), :input) || default_input_price
         
     | 
| 
       26 
33 
     | 
    
         
             
                    end
         
     | 
| 
       27 
34 
     | 
    
         | 
| 
      
 35 
     | 
    
         
            +
                    # Gets the output price per million tokens for a given model
         
     | 
| 
      
 36 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 37 
     | 
    
         
            +
                    # @return [Float] the price per million tokens for output
         
     | 
| 
       28 
38 
     | 
    
         
             
                    def get_output_price(model_id)
         
     | 
| 
       29 
39 
     | 
    
         
             
                      PRICES.dig(model_family(model_id), :output) || default_output_price
         
     | 
| 
       30 
40 
     | 
    
         
             
                    end
         
     | 
| 
       31 
41 
     | 
    
         | 
| 
      
 42 
     | 
    
         
            +
                    # Determines if a model supports vision capabilities
         
     | 
| 
      
 43 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 44 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports vision
         
     | 
| 
       32 
45 
     | 
    
         
             
                    def supports_vision?(model_id)
         
     | 
| 
       33 
     | 
    
         
            -
                       
     | 
| 
       34 
     | 
    
         
            -
                       
     | 
| 
       35 
     | 
    
         
            -
             
     | 
| 
       36 
     | 
    
         
            -
                      true
         
     | 
| 
      
 46 
     | 
    
         
            +
                      # All Claude 3, 3.5, and 3.7 models support vision
         
     | 
| 
      
 47 
     | 
    
         
            +
                      !model_id.match?(/claude-[12]/)
         
     | 
| 
       37 
48 
     | 
    
         
             
                    end
         
     | 
| 
       38 
49 
     | 
    
         | 
| 
      
 50 
     | 
    
         
            +
                    # Determines if a model supports function calling
         
     | 
| 
      
 51 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 52 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports functions
         
     | 
| 
       39 
53 
     | 
    
         
             
                    def supports_functions?(model_id)
         
     | 
| 
       40 
     | 
    
         
            -
                      model_id. 
     | 
| 
      
 54 
     | 
    
         
            +
                      model_id.match?(/claude-3/)
         
     | 
| 
       41 
55 
     | 
    
         
             
                    end
         
     | 
| 
       42 
56 
     | 
    
         | 
| 
      
 57 
     | 
    
         
            +
                    # Determines if a model supports JSON mode
         
     | 
| 
      
 58 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 59 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports JSON mode
         
     | 
| 
       43 
60 
     | 
    
         
             
                    def supports_json_mode?(model_id)
         
     | 
| 
       44 
     | 
    
         
            -
                      model_id. 
     | 
| 
      
 61 
     | 
    
         
            +
                      model_id.match?(/claude-3/)
         
     | 
| 
      
 62 
     | 
    
         
            +
                    end
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                    # Determines if a model supports extended thinking
         
     | 
| 
      
 65 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 66 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports extended thinking
         
     | 
| 
      
 67 
     | 
    
         
            +
                    def supports_extended_thinking?(model_id)
         
     | 
| 
      
 68 
     | 
    
         
            +
                      model_id.match?(/claude-3-7-sonnet/)
         
     | 
| 
       45 
69 
     | 
    
         
             
                    end
         
     | 
| 
       46 
70 
     | 
    
         | 
| 
      
 71 
     | 
    
         
            +
                    # Determines the model family for a given model ID
         
     | 
| 
      
 72 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 73 
     | 
    
         
            +
                    # @return [Symbol] the model family identifier
         
     | 
| 
       47 
74 
     | 
    
         
             
                    def model_family(model_id)
         
     | 
| 
       48 
75 
     | 
    
         
             
                      case model_id
         
     | 
| 
      
 76 
     | 
    
         
            +
                      when /claude-3-7-sonnet/  then :claude37_sonnet
         
     | 
| 
       49 
77 
     | 
    
         
             
                      when /claude-3-5-sonnet/  then :claude35_sonnet
         
     | 
| 
       50 
78 
     | 
    
         
             
                      when /claude-3-5-haiku/   then :claude35_haiku
         
     | 
| 
       51 
79 
     | 
    
         
             
                      when /claude-3-opus/      then :claude3_opus
         
     | 
| 
         @@ -55,23 +83,32 @@ module RubyLLM 
     | 
|
| 
       55 
83 
     | 
    
         
             
                      end
         
     | 
| 
       56 
84 
     | 
    
         
             
                    end
         
     | 
| 
       57 
85 
     | 
    
         | 
| 
      
 86 
     | 
    
         
            +
                    # Returns the model type
         
     | 
| 
      
 87 
     | 
    
         
            +
                    # @param model_id [String] the model identifier (unused but kept for API consistency)
         
     | 
| 
      
 88 
     | 
    
         
            +
                    # @return [String] the model type, always 'chat' for Anthropic models
         
     | 
| 
       58 
89 
     | 
    
         
             
                    def model_type(_)
         
     | 
| 
       59 
90 
     | 
    
         
             
                      'chat'
         
     | 
| 
       60 
91 
     | 
    
         
             
                    end
         
     | 
| 
       61 
92 
     | 
    
         | 
| 
      
 93 
     | 
    
         
            +
                    # Pricing information for Anthropic models (per million tokens)
         
     | 
| 
       62 
94 
     | 
    
         
             
                    PRICES = {
         
     | 
| 
       63 
     | 
    
         
            -
                       
     | 
| 
       64 
     | 
    
         
            -
                       
     | 
| 
       65 
     | 
    
         
            -
                       
     | 
| 
       66 
     | 
    
         
            -
                       
     | 
| 
       67 
     | 
    
         
            -
                       
     | 
| 
       68 
     | 
    
         
            -
                       
     | 
| 
      
 95 
     | 
    
         
            +
                      claude37_sonnet: { input: 3.0, output: 15.0 },   # $3.00/$15.00 per million tokens
         
     | 
| 
      
 96 
     | 
    
         
            +
                      claude35_sonnet: { input: 3.0, output: 15.0 },   # $3.00/$15.00 per million tokens
         
     | 
| 
      
 97 
     | 
    
         
            +
                      claude35_haiku: { input: 0.80, output: 4.0 },    # $0.80/$4.00 per million tokens
         
     | 
| 
      
 98 
     | 
    
         
            +
                      claude3_opus: { input: 15.0, output: 75.0 },     # $15.00/$75.00 per million tokens
         
     | 
| 
      
 99 
     | 
    
         
            +
                      claude3_sonnet: { input: 3.0, output: 15.0 },    # $3.00/$15.00 per million tokens
         
     | 
| 
      
 100 
     | 
    
         
            +
                      claude3_haiku: { input: 0.25, output: 1.25 },    # $0.25/$1.25 per million tokens
         
     | 
| 
      
 101 
     | 
    
         
            +
                      claude2: { input: 3.0, output: 15.0 }            # Default pricing for Claude 2.x models
         
     | 
| 
       69 
102 
     | 
    
         
             
                    }.freeze
         
     | 
| 
       70 
103 
     | 
    
         | 
| 
      
 104 
     | 
    
         
            +
                    # Default input price if model not found in PRICES
         
     | 
| 
      
 105 
     | 
    
         
            +
                    # @return [Float] default price per million tokens for input
         
     | 
| 
       71 
106 
     | 
    
         
             
                    def default_input_price
         
     | 
| 
       72 
107 
     | 
    
         
             
                      3.0
         
     | 
| 
       73 
108 
     | 
    
         
             
                    end
         
     | 
| 
       74 
109 
     | 
    
         | 
| 
      
 110 
     | 
    
         
            +
                    # Default output price if model not found in PRICES
         
     | 
| 
      
 111 
     | 
    
         
            +
                    # @return [Float] default price per million tokens for output
         
     | 
| 
       75 
112 
     | 
    
         
             
                    def default_output_price
         
     | 
| 
       76 
113 
     | 
    
         
             
                      15.0
         
     | 
| 
       77 
114 
     | 
    
         
             
                    end
         
     | 
| 
         @@ -35,7 +35,7 @@ module RubyLLM 
     | 
|
| 
       35 
35 
     | 
    
         | 
| 
       36 
36 
     | 
    
         
             
                    def extract_text_content(blocks)
         
     | 
| 
       37 
37 
     | 
    
         
             
                      text_blocks = blocks.select { |c| c['type'] == 'text' }
         
     | 
| 
       38 
     | 
    
         
            -
                      text_blocks.map { |c| c['text'] }.join 
     | 
| 
      
 38 
     | 
    
         
            +
                      text_blocks.map { |c| c['text'] }.join
         
     | 
| 
       39 
39 
     | 
    
         
             
                    end
         
     | 
| 
       40 
40 
     | 
    
         | 
| 
       41 
41 
     | 
    
         
             
                    def build_message(data, content, tool_use)
         
     | 
| 
         @@ -68,8 +68,7 @@ module RubyLLM 
     | 
|
| 
       68 
68 
     | 
    
         | 
| 
       69 
69 
     | 
    
         
             
                    def convert_role(role)
         
     | 
| 
       70 
70 
     | 
    
         
             
                      case role
         
     | 
| 
       71 
     | 
    
         
            -
                      when :tool then 'user'
         
     | 
| 
       72 
     | 
    
         
            -
                      when :user then 'user'
         
     | 
| 
      
 71 
     | 
    
         
            +
                      when :tool, :user then 'user'
         
     | 
| 
       73 
72 
     | 
    
         
             
                      else 'assistant'
         
     | 
| 
       74 
73 
     | 
    
         
             
                      end
         
     | 
| 
       75 
74 
     | 
    
         
             
                    end
         
     | 
| 
         @@ -7,6 +7,9 @@ module RubyLLM 
     | 
|
| 
       7 
7 
     | 
    
         
             
                  module Capabilities
         
     | 
| 
       8 
8 
     | 
    
         
             
                    module_function
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
      
 10 
     | 
    
         
            +
                    # Returns the context window size for the given model
         
     | 
| 
      
 11 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 12 
     | 
    
         
            +
                    # @return [Integer] the context window size in tokens
         
     | 
| 
       10 
13 
     | 
    
         
             
                    def context_window_for(model_id)
         
     | 
| 
       11 
14 
     | 
    
         
             
                      case model_id
         
     | 
| 
       12 
15 
     | 
    
         
             
                      when /deepseek-(?:chat|reasoner)/ then 64_000
         
     | 
| 
         @@ -14,6 +17,9 @@ module RubyLLM 
     | 
|
| 
       14 
17 
     | 
    
         
             
                      end
         
     | 
| 
       15 
18 
     | 
    
         
             
                    end
         
     | 
| 
       16 
19 
     | 
    
         | 
| 
      
 20 
     | 
    
         
            +
                    # Returns the maximum number of tokens that can be generated
         
     | 
| 
      
 21 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 22 
     | 
    
         
            +
                    # @return [Integer] the maximum number of tokens
         
     | 
| 
       17 
23 
     | 
    
         
             
                    def max_tokens_for(model_id)
         
     | 
| 
       18 
24 
     | 
    
         
             
                      case model_id
         
     | 
| 
       19 
25 
     | 
    
         
             
                      when /deepseek-(?:chat|reasoner)/ then 8_192
         
     | 
| 
         @@ -21,30 +27,51 @@ module RubyLLM 
     | 
|
| 
       21 
27 
     | 
    
         
             
                      end
         
     | 
| 
       22 
28 
     | 
    
         
             
                    end
         
     | 
| 
       23 
29 
     | 
    
         | 
| 
      
 30 
     | 
    
         
            +
                    # Returns the price per million tokens for input (cache miss)
         
     | 
| 
      
 31 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 32 
     | 
    
         
            +
                    # @return [Float] the price per million tokens in USD
         
     | 
| 
       24 
33 
     | 
    
         
             
                    def input_price_for(model_id)
         
     | 
| 
       25 
34 
     | 
    
         
             
                      PRICES.dig(model_family(model_id), :input_miss) || default_input_price
         
     | 
| 
       26 
35 
     | 
    
         
             
                    end
         
     | 
| 
       27 
36 
     | 
    
         | 
| 
      
 37 
     | 
    
         
            +
                    # Returns the price per million tokens for output
         
     | 
| 
      
 38 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 39 
     | 
    
         
            +
                    # @return [Float] the price per million tokens in USD
         
     | 
| 
       28 
40 
     | 
    
         
             
                    def output_price_for(model_id)
         
     | 
| 
       29 
41 
     | 
    
         
             
                      PRICES.dig(model_family(model_id), :output) || default_output_price
         
     | 
| 
       30 
42 
     | 
    
         
             
                    end
         
     | 
| 
       31 
43 
     | 
    
         | 
| 
      
 44 
     | 
    
         
            +
                    # Returns the price per million tokens for input with cache hit
         
     | 
| 
      
 45 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 46 
     | 
    
         
            +
                    # @return [Float] the price per million tokens in USD
         
     | 
| 
       32 
47 
     | 
    
         
             
                    def cache_hit_price_for(model_id)
         
     | 
| 
       33 
48 
     | 
    
         
             
                      PRICES.dig(model_family(model_id), :input_hit) || default_cache_hit_price
         
     | 
| 
       34 
49 
     | 
    
         
             
                    end
         
     | 
| 
       35 
50 
     | 
    
         | 
| 
      
 51 
     | 
    
         
            +
                    # Determines if the model supports vision capabilities
         
     | 
| 
      
 52 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 53 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports vision
         
     | 
| 
       36 
54 
     | 
    
         
             
                    def supports_vision?(_model_id)
         
     | 
| 
       37 
55 
     | 
    
         
             
                      false # DeepSeek models don't currently support vision
         
     | 
| 
       38 
56 
     | 
    
         
             
                    end
         
     | 
| 
       39 
57 
     | 
    
         | 
| 
      
 58 
     | 
    
         
            +
                    # Determines if the model supports function calling
         
     | 
| 
      
 59 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 60 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports function calling
         
     | 
| 
       40 
61 
     | 
    
         
             
                    def supports_functions?(model_id)
         
     | 
| 
       41 
62 
     | 
    
         
             
                      model_id.match?(/deepseek-chat/) # Only deepseek-chat supports function calling
         
     | 
| 
       42 
63 
     | 
    
         
             
                    end
         
     | 
| 
       43 
64 
     | 
    
         | 
| 
      
 65 
     | 
    
         
            +
                    # Determines if the model supports JSON mode
         
     | 
| 
      
 66 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 67 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports JSON mode
         
     | 
| 
       44 
68 
     | 
    
         
             
                    def supports_json_mode?(model_id)
         
     | 
| 
       45 
69 
     | 
    
         
             
                      model_id.match?(/deepseek-chat/) # Only deepseek-chat supports JSON mode
         
     | 
| 
       46 
70 
     | 
    
         
             
                    end
         
     | 
| 
       47 
71 
     | 
    
         | 
| 
      
 72 
     | 
    
         
            +
                    # Returns a formatted display name for the model
         
     | 
| 
      
 73 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 74 
     | 
    
         
            +
                    # @return [String] the formatted display name
         
     | 
| 
       48 
75 
     | 
    
         
             
                    def format_display_name(model_id)
         
     | 
| 
       49 
76 
     | 
    
         
             
                      case model_id
         
     | 
| 
       50 
77 
     | 
    
         
             
                      when 'deepseek-chat' then 'DeepSeek V3'
         
     | 
| 
         @@ -56,13 +83,18 @@ module RubyLLM 
     | 
|
| 
       56 
83 
     | 
    
         
             
                      end
         
     | 
| 
       57 
84 
     | 
    
         
             
                    end
         
     | 
| 
       58 
85 
     | 
    
         | 
| 
      
 86 
     | 
    
         
            +
                    # Returns the model type
         
     | 
| 
      
 87 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 88 
     | 
    
         
            +
                    # @return [String] the model type (e.g., 'chat')
         
     | 
| 
       59 
89 
     | 
    
         
             
                    def model_type(_model_id)
         
     | 
| 
       60 
90 
     | 
    
         
             
                      'chat' # All DeepSeek models are chat models
         
     | 
| 
       61 
91 
     | 
    
         
             
                    end
         
     | 
| 
       62 
92 
     | 
    
         | 
| 
      
 93 
     | 
    
         
            +
                    # Returns the model family
         
     | 
| 
      
 94 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 95 
     | 
    
         
            +
                    # @return [Symbol] the model family
         
     | 
| 
       63 
96 
     | 
    
         
             
                    def model_family(model_id)
         
     | 
| 
       64 
97 
     | 
    
         
             
                      case model_id
         
     | 
| 
       65 
     | 
    
         
            -
                      when /deepseek-chat/ then :chat
         
     | 
| 
       66 
98 
     | 
    
         
             
                      when /deepseek-reasoner/ then :reasoner
         
     | 
| 
       67 
99 
     | 
    
         
             
                      else :chat # Default to chat family
         
     | 
| 
       68 
100 
     | 
    
         
             
                      end
         
     | 
| 
         @@ -84,14 +116,20 @@ module RubyLLM 
     | 
|
| 
       84 
116 
     | 
    
         | 
| 
       85 
117 
     | 
    
         
             
                    private
         
     | 
| 
       86 
118 
     | 
    
         | 
| 
      
 119 
     | 
    
         
            +
                    # Default input price when model family can't be determined
         
     | 
| 
      
 120 
     | 
    
         
            +
                    # @return [Float] the default input price
         
     | 
| 
       87 
121 
     | 
    
         
             
                    def default_input_price
         
     | 
| 
       88 
122 
     | 
    
         
             
                      0.27 # Default to chat cache miss price
         
     | 
| 
       89 
123 
     | 
    
         
             
                    end
         
     | 
| 
       90 
124 
     | 
    
         | 
| 
      
 125 
     | 
    
         
            +
                    # Default output price when model family can't be determined
         
     | 
| 
      
 126 
     | 
    
         
            +
                    # @return [Float] the default output price
         
     | 
| 
       91 
127 
     | 
    
         
             
                    def default_output_price
         
     | 
| 
       92 
128 
     | 
    
         
             
                      1.10 # Default to chat output price
         
     | 
| 
       93 
129 
     | 
    
         
             
                    end
         
     | 
| 
       94 
130 
     | 
    
         | 
| 
      
 131 
     | 
    
         
            +
                    # Default cache hit price when model family can't be determined
         
     | 
| 
      
 132 
     | 
    
         
            +
                    # @return [Float] the default cache hit price
         
     | 
| 
       95 
133 
     | 
    
         
             
                    def default_cache_hit_price
         
     | 
| 
       96 
134 
     | 
    
         
             
                      0.07 # Default to chat cache hit price
         
     | 
| 
       97 
135 
     | 
    
         
             
                    end
         
     | 
| 
         @@ -7,25 +7,34 @@ module RubyLLM 
     | 
|
| 
       7 
7 
     | 
    
         
             
                  module Capabilities # rubocop:disable Metrics/ModuleLength
         
     | 
| 
       8 
8 
     | 
    
         
             
                    module_function
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
      
 10 
     | 
    
         
            +
                    # Returns the context window size (input token limit) for the given model
         
     | 
| 
      
 11 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 12 
     | 
    
         
            +
                    # @return [Integer] the context window size in tokens
         
     | 
| 
       10 
13 
     | 
    
         
             
                    def context_window_for(model_id)
         
     | 
| 
       11 
14 
     | 
    
         
             
                      case model_id
         
     | 
| 
       12 
15 
     | 
    
         
             
                      when /gemini-2\.0-flash/, /gemini-1\.5-flash/ then 1_048_576
         
     | 
| 
       13 
16 
     | 
    
         
             
                      when /gemini-1\.5-pro/ then 2_097_152
         
     | 
| 
       14 
     | 
    
         
            -
                      when /text-embedding/, /embedding-001/ then 2_048
         
     | 
| 
      
 17 
     | 
    
         
            +
                      when /text-embedding-004/, /embedding-001/ then 2_048
         
     | 
| 
       15 
18 
     | 
    
         
             
                      when /aqa/ then 7_168
         
     | 
| 
       16 
19 
     | 
    
         
             
                      else 32_768 # Sensible default for unknown models
         
     | 
| 
       17 
20 
     | 
    
         
             
                      end
         
     | 
| 
       18 
21 
     | 
    
         
             
                    end
         
     | 
| 
       19 
22 
     | 
    
         | 
| 
      
 23 
     | 
    
         
            +
                    # Returns the maximum output tokens for the given model
         
     | 
| 
      
 24 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 25 
     | 
    
         
            +
                    # @return [Integer] the maximum output tokens
         
     | 
| 
       20 
26 
     | 
    
         
             
                    def max_tokens_for(model_id)
         
     | 
| 
       21 
27 
     | 
    
         
             
                      case model_id
         
     | 
| 
       22 
28 
     | 
    
         
             
                      when /gemini-2\.0-flash/, /gemini-1\.5/ then 8_192
         
     | 
| 
       23 
     | 
    
         
            -
                      when /text-embedding/, /embedding-001/ then 768 # Output dimension size for embeddings
         
     | 
| 
      
 29 
     | 
    
         
            +
                      when /text-embedding-004/, /embedding-001/ then 768 # Output dimension size for embeddings
         
     | 
| 
       24 
30 
     | 
    
         
             
                      when /aqa/ then 1_024
         
     | 
| 
       25 
31 
     | 
    
         
             
                      else 4_096 # Sensible default
         
     | 
| 
       26 
32 
     | 
    
         
             
                      end
         
     | 
| 
       27 
33 
     | 
    
         
             
                    end
         
     | 
| 
       28 
34 
     | 
    
         | 
| 
      
 35 
     | 
    
         
            +
                    # Returns the input price per million tokens for the given model
         
     | 
| 
      
 36 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 37 
     | 
    
         
            +
                    # @return [Float] the price per million tokens in USD
         
     | 
| 
       29 
38 
     | 
    
         
             
                    def input_price_for(model_id)
         
     | 
| 
       30 
39 
     | 
    
         
             
                      base_price = PRICES.dig(pricing_family(model_id), :input) || default_input_price
         
     | 
| 
       31 
40 
     | 
    
         
             
                      return base_price unless long_context_model?(model_id)
         
     | 
| 
         @@ -34,6 +43,9 @@ module RubyLLM 
     | 
|
| 
       34 
43 
     | 
    
         
             
                      context_length(model_id) > 128_000 ? base_price * 2 : base_price
         
     | 
| 
       35 
44 
     | 
    
         
             
                    end
         
     | 
| 
       36 
45 
     | 
    
         | 
| 
      
 46 
     | 
    
         
            +
                    # Returns the output price per million tokens for the given model
         
     | 
| 
      
 47 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 48 
     | 
    
         
            +
                    # @return [Float] the price per million tokens in USD
         
     | 
| 
       37 
49 
     | 
    
         
             
                    def output_price_for(model_id)
         
     | 
| 
       38 
50 
     | 
    
         
             
                      base_price = PRICES.dig(pricing_family(model_id), :output) || default_output_price
         
     | 
| 
       39 
51 
     | 
    
         
             
                      return base_price unless long_context_model?(model_id)
         
     | 
| 
         @@ -42,6 +54,9 @@ module RubyLLM 
     | 
|
| 
       42 
54 
     | 
    
         
             
                      context_length(model_id) > 128_000 ? base_price * 2 : base_price
         
     | 
| 
       43 
55 
     | 
    
         
             
                    end
         
     | 
| 
       44 
56 
     | 
    
         | 
| 
      
 57 
     | 
    
         
            +
                    # Determines if the model supports vision (image/video) inputs
         
     | 
| 
      
 58 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 59 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports vision inputs
         
     | 
| 
       45 
60 
     | 
    
         
             
                    def supports_vision?(model_id)
         
     | 
| 
       46 
61 
     | 
    
         
             
                      return false if model_id.match?(/text-embedding|embedding-001|aqa/)
         
     | 
| 
       47 
62 
     | 
    
         
             
                      return false if model_id.match?(/gemini-1\.0/)
         
     | 
| 
         @@ -49,6 +64,9 @@ module RubyLLM 
     | 
|
| 
       49 
64 
     | 
    
         
             
                      model_id.match?(/gemini-[12]\.[05]/)
         
     | 
| 
       50 
65 
     | 
    
         
             
                    end
         
     | 
| 
       51 
66 
     | 
    
         | 
| 
      
 67 
     | 
    
         
            +
                    # Determines if the model supports function calling
         
     | 
| 
      
 68 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 69 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports function calling
         
     | 
| 
       52 
70 
     | 
    
         
             
                    def supports_functions?(model_id)
         
     | 
| 
       53 
71 
     | 
    
         
             
                      return false if model_id.match?(/text-embedding|embedding-001|aqa/)
         
     | 
| 
       54 
72 
     | 
    
         
             
                      return false if model_id.match?(/flash-lite/)
         
     | 
| 
         @@ -57,13 +75,20 @@ module RubyLLM 
     | 
|
| 
       57 
75 
     | 
    
         
             
                      model_id.match?(/gemini-[12]\.[05]-(?:pro|flash)(?!-lite)/)
         
     | 
| 
       58 
76 
     | 
    
         
             
                    end
         
     | 
| 
       59 
77 
     | 
    
         | 
| 
      
 78 
     | 
    
         
            +
                    # Determines if the model supports JSON mode
         
     | 
| 
      
 79 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 80 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports JSON mode
         
     | 
| 
       60 
81 
     | 
    
         
             
                    def supports_json_mode?(model_id)
         
     | 
| 
       61 
82 
     | 
    
         
             
                      return false if model_id.match?(/text-embedding|embedding-001|aqa/)
         
     | 
| 
       62 
83 
     | 
    
         
             
                      return false if model_id.match?(/gemini-1\.0/)
         
     | 
| 
      
 84 
     | 
    
         
            +
                      return false if model_id.match?(/gemini-2\.0-flash-lite/)
         
     | 
| 
       63 
85 
     | 
    
         | 
| 
       64 
86 
     | 
    
         
             
                      model_id.match?(/gemini-\d/)
         
     | 
| 
       65 
87 
     | 
    
         
             
                    end
         
     | 
| 
       66 
88 
     | 
    
         | 
| 
      
 89 
     | 
    
         
            +
                    # Formats the model ID into a human-readable display name
         
     | 
| 
      
 90 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 91 
     | 
    
         
            +
                    # @return [String] the formatted display name
         
     | 
| 
       67 
92 
     | 
    
         
             
                    def format_display_name(model_id)
         
     | 
| 
       68 
93 
     | 
    
         
             
                      model_id
         
     | 
| 
       69 
94 
     | 
    
         
             
                        .delete_prefix('models/')
         
     | 
| 
         @@ -72,24 +97,36 @@ module RubyLLM 
     | 
|
| 
       72 
97 
     | 
    
         
             
                        .join(' ')
         
     | 
| 
       73 
98 
     | 
    
         
             
                        .gsub(/(\d+\.\d+)/, ' \1') # Add space before version numbers
         
     | 
| 
       74 
99 
     | 
    
         
             
                        .gsub(/\s+/, ' ')          # Clean up multiple spaces
         
     | 
| 
       75 
     | 
    
         
            -
                        .gsub( 
     | 
| 
      
 100 
     | 
    
         
            +
                        .gsub('Aqa', 'AQA')        # Special case for AQA
         
     | 
| 
       76 
101 
     | 
    
         
             
                        .strip
         
     | 
| 
       77 
102 
     | 
    
         
             
                    end
         
     | 
| 
       78 
103 
     | 
    
         | 
| 
      
 104 
     | 
    
         
            +
                    # Determines if the model supports context caching
         
     | 
| 
      
 105 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 106 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports caching
         
     | 
| 
       79 
107 
     | 
    
         
             
                    def supports_caching?(model_id)
         
     | 
| 
       80 
108 
     | 
    
         
             
                      return false if model_id.match?(/flash-lite|gemini-1\.0/)
         
     | 
| 
       81 
109 
     | 
    
         | 
| 
       82 
110 
     | 
    
         
             
                      model_id.match?(/gemini-[12]\.[05]/)
         
     | 
| 
       83 
111 
     | 
    
         
             
                    end
         
     | 
| 
       84 
112 
     | 
    
         | 
| 
      
 113 
     | 
    
         
            +
                    # Determines if the model supports tuning
         
     | 
| 
      
 114 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 115 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports tuning
         
     | 
| 
       85 
116 
     | 
    
         
             
                    def supports_tuning?(model_id)
         
     | 
| 
       86 
117 
     | 
    
         
             
                      model_id.match?(/gemini-1\.5-flash/)
         
     | 
| 
       87 
118 
     | 
    
         
             
                    end
         
     | 
| 
       88 
119 
     | 
    
         | 
| 
      
 120 
     | 
    
         
            +
                    # Determines if the model supports audio inputs
         
     | 
| 
      
 121 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 122 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports audio inputs
         
     | 
| 
       89 
123 
     | 
    
         
             
                    def supports_audio?(model_id)
         
     | 
| 
       90 
124 
     | 
    
         
             
                      model_id.match?(/gemini-[12]\.[05]/)
         
     | 
| 
       91 
125 
     | 
    
         
             
                    end
         
     | 
| 
       92 
126 
     | 
    
         | 
| 
      
 127 
     | 
    
         
            +
                    # Returns the type of model (chat, embedding, image)
         
     | 
| 
      
 128 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 129 
     | 
    
         
            +
                    # @return [String] the model type
         
     | 
| 
       93 
130 
     | 
    
         
             
                    def model_type(model_id)
         
     | 
| 
       94 
131 
     | 
    
         
             
                      case model_id
         
     | 
| 
       95 
132 
     | 
    
         
             
                      when /text-embedding|embedding/ then 'embedding'
         
     | 
| 
         @@ -98,6 +135,9 @@ module RubyLLM 
     | 
|
| 
       98 
135 
     | 
    
         
             
                      end
         
     | 
| 
       99 
136 
     | 
    
         
             
                    end
         
     | 
| 
       100 
137 
     | 
    
         | 
| 
      
 138 
     | 
    
         
            +
                    # Returns the model family identifier
         
     | 
| 
      
 139 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 140 
     | 
    
         
            +
                    # @return [String] the model family identifier
         
     | 
| 
       101 
141 
     | 
    
         
             
                    def model_family(model_id) # rubocop:disable Metrics/CyclomaticComplexity,Metrics/MethodLength
         
     | 
| 
       102 
142 
     | 
    
         
             
                      case model_id
         
     | 
| 
       103 
143 
     | 
    
         
             
                      when /gemini-2\.0-flash-lite/ then 'gemini20_flash_lite'
         
     | 
| 
         @@ -113,7 +153,10 @@ module RubyLLM 
     | 
|
| 
       113 
153 
     | 
    
         
             
                      end
         
     | 
| 
       114 
154 
     | 
    
         
             
                    end
         
     | 
| 
       115 
155 
     | 
    
         | 
| 
       116 
     | 
    
         
            -
                     
     | 
| 
      
 156 
     | 
    
         
            +
                    # Returns the pricing family identifier for the model
         
     | 
| 
      
 157 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 158 
     | 
    
         
            +
                    # @return [Symbol] the pricing family identifier
         
     | 
| 
      
 159 
     | 
    
         
            +
                    def pricing_family(model_id) # rubocop:disable Metrics/CyclomaticComplexity,Metrics/MethodLength
         
     | 
| 
       117 
160 
     | 
    
         
             
                      case model_id
         
     | 
| 
       118 
161 
     | 
    
         
             
                      when /gemini-2\.0-flash-lite/ then :flash_lite_2 # rubocop:disable Naming/VariableNumber
         
     | 
| 
       119 
162 
     | 
    
         
             
                      when /gemini-2\.0-flash/ then :flash_2 # rubocop:disable Naming/VariableNumber
         
     | 
| 
         @@ -122,18 +165,26 @@ module RubyLLM 
     | 
|
| 
       122 
165 
     | 
    
         
             
                      when /gemini-1\.5-pro/ then :pro
         
     | 
| 
       123 
166 
     | 
    
         
             
                      when /gemini-1\.0-pro/ then :pro_1_0 # rubocop:disable Naming/VariableNumber
         
     | 
| 
       124 
167 
     | 
    
         
             
                      when /text-embedding|embedding/ then :embedding
         
     | 
| 
      
 168 
     | 
    
         
            +
                      when /aqa/ then :aqa
         
     | 
| 
       125 
169 
     | 
    
         
             
                      else :base
         
     | 
| 
       126 
170 
     | 
    
         
             
                      end
         
     | 
| 
       127 
171 
     | 
    
         
             
                    end
         
     | 
| 
       128 
172 
     | 
    
         | 
| 
      
 173 
     | 
    
         
            +
                    # Determines if the model supports long context
         
     | 
| 
      
 174 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 175 
     | 
    
         
            +
                    # @return [Boolean] true if the model supports long context
         
     | 
| 
       129 
176 
     | 
    
         
             
                    def long_context_model?(model_id)
         
     | 
| 
       130 
177 
     | 
    
         
             
                      model_id.match?(/gemini-1\.5-(?:pro|flash)/)
         
     | 
| 
       131 
178 
     | 
    
         
             
                    end
         
     | 
| 
       132 
179 
     | 
    
         | 
| 
      
 180 
     | 
    
         
            +
                    # Returns the context length for the model
         
     | 
| 
      
 181 
     | 
    
         
            +
                    # @param model_id [String] the model identifier
         
     | 
| 
      
 182 
     | 
    
         
            +
                    # @return [Integer] the context length in tokens
         
     | 
| 
       133 
183 
     | 
    
         
             
                    def context_length(model_id)
         
     | 
| 
       134 
184 
     | 
    
         
             
                      context_window_for(model_id)
         
     | 
| 
       135 
185 
     | 
    
         
             
                    end
         
     | 
| 
       136 
186 
     | 
    
         | 
| 
      
 187 
     | 
    
         
            +
                    # Pricing information for Gemini models (per 1M tokens in USD)
         
     | 
| 
       137 
188 
     | 
    
         
             
                    PRICES = {
         
     | 
| 
       138 
189 
     | 
    
         
             
                      flash_2: { # Gemini 2.0 Flash # rubocop:disable Naming/VariableNumber
         
     | 
| 
       139 
190 
     | 
    
         
             
                        input: 0.10,
         
     | 
| 
         @@ -152,19 +203,22 @@ module RubyLLM 
     | 
|
| 
       152 
203 
     | 
    
         
             
                        input: 0.075,
         
     | 
| 
       153 
204 
     | 
    
         
             
                        output: 0.30,
         
     | 
| 
       154 
205 
     | 
    
         
             
                        cache: 0.01875,
         
     | 
| 
       155 
     | 
    
         
            -
                        cache_storage: 1.00
         
     | 
| 
      
 206 
     | 
    
         
            +
                        cache_storage: 1.00,
         
     | 
| 
      
 207 
     | 
    
         
            +
                        grounding_search: 35.00 # per 1K requests
         
     | 
| 
       156 
208 
     | 
    
         
             
                      },
         
     | 
| 
       157 
209 
     | 
    
         
             
                      flash_8b: { # Gemini 1.5 Flash 8B
         
     | 
| 
       158 
210 
     | 
    
         
             
                        input: 0.0375,
         
     | 
| 
       159 
211 
     | 
    
         
             
                        output: 0.15,
         
     | 
| 
       160 
212 
     | 
    
         
             
                        cache: 0.01,
         
     | 
| 
       161 
     | 
    
         
            -
                        cache_storage: 0.25
         
     | 
| 
      
 213 
     | 
    
         
            +
                        cache_storage: 0.25,
         
     | 
| 
      
 214 
     | 
    
         
            +
                        grounding_search: 35.00 # per 1K requests
         
     | 
| 
       162 
215 
     | 
    
         
             
                      },
         
     | 
| 
       163 
216 
     | 
    
         
             
                      pro: { # Gemini 1.5 Pro
         
     | 
| 
       164 
217 
     | 
    
         
             
                        input: 1.25,
         
     | 
| 
       165 
218 
     | 
    
         
             
                        output: 5.0,
         
     | 
| 
       166 
219 
     | 
    
         
             
                        cache: 0.3125,
         
     | 
| 
       167 
     | 
    
         
            -
                        cache_storage: 4.50
         
     | 
| 
      
 220 
     | 
    
         
            +
                        cache_storage: 4.50,
         
     | 
| 
      
 221 
     | 
    
         
            +
                        grounding_search: 35.00 # per 1K requests
         
     | 
| 
       168 
222 
     | 
    
         
             
                      },
         
     | 
| 
       169 
223 
     | 
    
         
             
                      pro_1_0: { # Gemini 1.0 Pro # rubocop:disable Naming/VariableNumber
         
     | 
| 
       170 
224 
     | 
    
         
             
                        input: 0.50,
         
     | 
| 
         @@ -173,15 +227,23 @@ module RubyLLM 
     | 
|
| 
       173 
227 
     | 
    
         
             
                      embedding: { # Text Embedding models
         
     | 
| 
       174 
228 
     | 
    
         
             
                        input: 0.00,
         
     | 
| 
       175 
229 
     | 
    
         
             
                        output: 0.00
         
     | 
| 
      
 230 
     | 
    
         
            +
                      },
         
     | 
| 
      
 231 
     | 
    
         
            +
                      aqa: { # AQA model
         
     | 
| 
      
 232 
     | 
    
         
            +
                        input: 0.00,
         
     | 
| 
      
 233 
     | 
    
         
            +
                        output: 0.00
         
     | 
| 
       176 
234 
     | 
    
         
             
                      }
         
     | 
| 
       177 
235 
     | 
    
         
             
                    }.freeze
         
     | 
| 
       178 
236 
     | 
    
         | 
| 
      
 237 
     | 
    
         
            +
                    # Default input price for unknown models
         
     | 
| 
      
 238 
     | 
    
         
            +
                    # @return [Float] the default input price per million tokens
         
     | 
| 
       179 
239 
     | 
    
         
             
                    def default_input_price
         
     | 
| 
       180 
240 
     | 
    
         
             
                      0.075 # Default to Flash pricing
         
     | 
| 
       181 
241 
     | 
    
         
             
                    end
         
     | 
| 
       182 
242 
     | 
    
         | 
| 
      
 243 
     | 
    
         
            +
                    # Default output price for unknown models
         
     | 
| 
      
 244 
     | 
    
         
            +
                    # @return [Float] the default output price per million tokens
         
     | 
| 
       183 
245 
     | 
    
         
             
                    def default_output_price
         
     | 
| 
       184 
     | 
    
         
            -
                      0.30 
     | 
| 
      
 246 
     | 
    
         
            +
                      0.30 # Default to Flash pricing
         
     | 
| 
       185 
247 
     | 
    
         
             
                    end
         
     | 
| 
       186 
248 
     | 
    
         
             
                  end
         
     | 
| 
       187 
249 
     | 
    
         
             
                end
         
     |