ruby_llm 0.1.0.pre → 0.1.0.pre2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/.github/workflows/gem-push.yml +9 -3
 - data/.github/workflows/test.yml +32 -0
 - data/.gitignore +58 -0
 - data/.overcommit.yml +26 -0
 - data/.rspec +3 -0
 - data/.rubocop.yml +3 -0
 - data/Gemfile +5 -0
 - data/README.md +68 -13
 - data/Rakefile +4 -2
 - data/bin/console +6 -3
 - data/lib/ruby_llm/active_record/acts_as.rb +31 -18
 - data/lib/ruby_llm/client.rb +32 -16
 - data/lib/ruby_llm/configuration.rb +5 -3
 - data/lib/ruby_llm/conversation.rb +3 -0
 - data/lib/ruby_llm/message.rb +6 -3
 - data/lib/ruby_llm/model_capabilities/anthropic.rb +81 -0
 - data/lib/ruby_llm/model_capabilities/base.rb +35 -0
 - data/lib/ruby_llm/model_capabilities/openai.rb +121 -0
 - data/lib/ruby_llm/model_info.rb +42 -0
 - data/lib/ruby_llm/providers/anthropic.rb +226 -0
 - data/lib/ruby_llm/providers/base.rb +21 -2
 - data/lib/ruby_llm/providers/openai.rb +161 -0
 - data/lib/ruby_llm/railtie.rb +3 -0
 - data/lib/ruby_llm/tool.rb +75 -0
 - data/lib/ruby_llm/version.rb +3 -1
 - data/lib/ruby_llm.rb +35 -3
 - data/ruby_llm.gemspec +42 -31
 - metadata +142 -17
 
| 
         @@ -0,0 +1,121 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            module RubyLLM
         
     | 
| 
      
 4 
     | 
    
         
            +
              module ModelCapabilities
         
     | 
| 
      
 5 
     | 
    
         
            +
                class OpenAI < Base
         
     | 
| 
      
 6 
     | 
    
         
            +
                  def determine_context_window(model_id)
         
     | 
| 
      
 7 
     | 
    
         
            +
                    case model_id
         
     | 
| 
      
 8 
     | 
    
         
            +
                    when /gpt-4o/, /o1/, /gpt-4-turbo/
         
     | 
| 
      
 9 
     | 
    
         
            +
                      128_000
         
     | 
| 
      
 10 
     | 
    
         
            +
                    when /gpt-4-0[0-9]{3}/
         
     | 
| 
      
 11 
     | 
    
         
            +
                      8_192
         
     | 
| 
      
 12 
     | 
    
         
            +
                    when /gpt-3.5-turbo-instruct/
         
     | 
| 
      
 13 
     | 
    
         
            +
                      4_096
         
     | 
| 
      
 14 
     | 
    
         
            +
                    when /gpt-3.5/
         
     | 
| 
      
 15 
     | 
    
         
            +
                      16_385
         
     | 
| 
      
 16 
     | 
    
         
            +
                    else
         
     | 
| 
      
 17 
     | 
    
         
            +
                      4_096
         
     | 
| 
      
 18 
     | 
    
         
            +
                    end
         
     | 
| 
      
 19 
     | 
    
         
            +
                  end
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
                  def determine_max_tokens(model_id)
         
     | 
| 
      
 22 
     | 
    
         
            +
                    case model_id
         
     | 
| 
      
 23 
     | 
    
         
            +
                    when /o1-2024-12-17/
         
     | 
| 
      
 24 
     | 
    
         
            +
                      100_000
         
     | 
| 
      
 25 
     | 
    
         
            +
                    when /o1-mini-2024-09-12/
         
     | 
| 
      
 26 
     | 
    
         
            +
                      65_536
         
     | 
| 
      
 27 
     | 
    
         
            +
                    when /o1-preview-2024-09-12/
         
     | 
| 
      
 28 
     | 
    
         
            +
                      32_768
         
     | 
| 
      
 29 
     | 
    
         
            +
                    when /gpt-4o/, /gpt-4-turbo/
         
     | 
| 
      
 30 
     | 
    
         
            +
                      16_384
         
     | 
| 
      
 31 
     | 
    
         
            +
                    when /gpt-4-0[0-9]{3}/
         
     | 
| 
      
 32 
     | 
    
         
            +
                      8_192
         
     | 
| 
      
 33 
     | 
    
         
            +
                    when /gpt-3.5-turbo/
         
     | 
| 
      
 34 
     | 
    
         
            +
                      4_096
         
     | 
| 
      
 35 
     | 
    
         
            +
                    else
         
     | 
| 
      
 36 
     | 
    
         
            +
                      4_096
         
     | 
| 
      
 37 
     | 
    
         
            +
                    end
         
     | 
| 
      
 38 
     | 
    
         
            +
                  end
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
                  def get_input_price(model_id)
         
     | 
| 
      
 41 
     | 
    
         
            +
                    case model_id
         
     | 
| 
      
 42 
     | 
    
         
            +
                    when /o1-2024/
         
     | 
| 
      
 43 
     | 
    
         
            +
                      15.0    # $15.00 per million tokens
         
     | 
| 
      
 44 
     | 
    
         
            +
                    when /o1-mini/
         
     | 
| 
      
 45 
     | 
    
         
            +
                      3.0     # $3.00 per million tokens
         
     | 
| 
      
 46 
     | 
    
         
            +
                    when /gpt-4o-realtime-preview/
         
     | 
| 
      
 47 
     | 
    
         
            +
                      5.0     # $5.00 per million tokens
         
     | 
| 
      
 48 
     | 
    
         
            +
                    when /gpt-4o-mini-realtime-preview/
         
     | 
| 
      
 49 
     | 
    
         
            +
                      0.60    # $0.60 per million tokens
         
     | 
| 
      
 50 
     | 
    
         
            +
                    when /gpt-4o-mini/
         
     | 
| 
      
 51 
     | 
    
         
            +
                      0.15    # $0.15 per million tokens
         
     | 
| 
      
 52 
     | 
    
         
            +
                    when /gpt-4o/
         
     | 
| 
      
 53 
     | 
    
         
            +
                      2.50    # $2.50 per million tokens
         
     | 
| 
      
 54 
     | 
    
         
            +
                    when /gpt-4-turbo/
         
     | 
| 
      
 55 
     | 
    
         
            +
                      10.0    # $10.00 per million tokens
         
     | 
| 
      
 56 
     | 
    
         
            +
                    when /gpt-3.5/
         
     | 
| 
      
 57 
     | 
    
         
            +
                      0.50    # $0.50 per million tokens
         
     | 
| 
      
 58 
     | 
    
         
            +
                    else
         
     | 
| 
      
 59 
     | 
    
         
            +
                      0.50    # Default to GPT-3.5 pricing
         
     | 
| 
      
 60 
     | 
    
         
            +
                    end
         
     | 
| 
      
 61 
     | 
    
         
            +
                  end
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                  def get_output_price(model_id)
         
     | 
| 
      
 64 
     | 
    
         
            +
                    case model_id
         
     | 
| 
      
 65 
     | 
    
         
            +
                    when /o1-2024/
         
     | 
| 
      
 66 
     | 
    
         
            +
                      60.0    # $60.00 per million tokens
         
     | 
| 
      
 67 
     | 
    
         
            +
                    when /o1-mini/
         
     | 
| 
      
 68 
     | 
    
         
            +
                      12.0    # $12.00 per million tokens
         
     | 
| 
      
 69 
     | 
    
         
            +
                    when /gpt-4o-realtime-preview/
         
     | 
| 
      
 70 
     | 
    
         
            +
                      20.0    # $20.00 per million tokens
         
     | 
| 
      
 71 
     | 
    
         
            +
                    when /gpt-4o-mini-realtime-preview/
         
     | 
| 
      
 72 
     | 
    
         
            +
                      2.40    # $2.40 per million tokens
         
     | 
| 
      
 73 
     | 
    
         
            +
                    when /gpt-4o-mini/
         
     | 
| 
      
 74 
     | 
    
         
            +
                      0.60    # $0.60 per million tokens
         
     | 
| 
      
 75 
     | 
    
         
            +
                    when /gpt-4o/
         
     | 
| 
      
 76 
     | 
    
         
            +
                      10.0    # $10.00 per million tokens
         
     | 
| 
      
 77 
     | 
    
         
            +
                    when /gpt-4-turbo/
         
     | 
| 
      
 78 
     | 
    
         
            +
                      30.0    # $30.00 per million tokens
         
     | 
| 
      
 79 
     | 
    
         
            +
                    when /gpt-3.5/
         
     | 
| 
      
 80 
     | 
    
         
            +
                      1.50    # $1.50 per million tokens
         
     | 
| 
      
 81 
     | 
    
         
            +
                    else
         
     | 
| 
      
 82 
     | 
    
         
            +
                      1.50    # Default to GPT-3.5 pricing
         
     | 
| 
      
 83 
     | 
    
         
            +
                    end
         
     | 
| 
      
 84 
     | 
    
         
            +
                  end
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
                  def supports_functions?(model_id)
         
     | 
| 
      
 87 
     | 
    
         
            +
                    !model_id.include?('instruct')
         
     | 
| 
      
 88 
     | 
    
         
            +
                  end
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                  def supports_vision?(model_id)
         
     | 
| 
      
 91 
     | 
    
         
            +
                    model_id.include?('vision') || model_id.match?(/gpt-4-(?!0314|0613)/)
         
     | 
| 
      
 92 
     | 
    
         
            +
                  end
         
     | 
| 
      
 93 
     | 
    
         
            +
             
     | 
| 
      
 94 
     | 
    
         
            +
                  def supports_json_mode?(model_id)
         
     | 
| 
      
 95 
     | 
    
         
            +
                    model_id.match?(/gpt-4-\d{4}-preview/) ||
         
     | 
| 
      
 96 
     | 
    
         
            +
                      model_id.include?('turbo') ||
         
     | 
| 
      
 97 
     | 
    
         
            +
                      model_id.match?(/gpt-3.5-turbo-(?!0301|0613)/)
         
     | 
| 
      
 98 
     | 
    
         
            +
                  end
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
                  def format_display_name(model_id)
         
     | 
| 
      
 101 
     | 
    
         
            +
                    # First replace hyphens with spaces
         
     | 
| 
      
 102 
     | 
    
         
            +
                    name = model_id.tr('-', ' ')
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
                    # Capitalize each word
         
     | 
| 
      
 105 
     | 
    
         
            +
                    name = name.split(' ').map { |word| word.capitalize }.join(' ')
         
     | 
| 
      
 106 
     | 
    
         
            +
             
     | 
| 
      
 107 
     | 
    
         
            +
                    # Apply specific formatting rules
         
     | 
| 
      
 108 
     | 
    
         
            +
                    name.gsub(/(\d{4}) (\d{2}) (\d{2})/, '\1\2\3') # Convert dates to YYYYMMDD
         
     | 
| 
      
 109 
     | 
    
         
            +
                        .gsub(/^Gpt /, 'GPT-')
         
     | 
| 
      
 110 
     | 
    
         
            +
                        .gsub(/^O1 /, 'O1-')
         
     | 
| 
      
 111 
     | 
    
         
            +
                        .gsub(/^Chatgpt /, 'ChatGPT-')
         
     | 
| 
      
 112 
     | 
    
         
            +
                        .gsub(/^Tts /, 'TTS-')
         
     | 
| 
      
 113 
     | 
    
         
            +
                        .gsub(/^Dall E /, 'DALL-E-')
         
     | 
| 
      
 114 
     | 
    
         
            +
                        .gsub(/3\.5 /, '3.5-')
         
     | 
| 
      
 115 
     | 
    
         
            +
                        .gsub(/4 /, '4-')
         
     | 
| 
      
 116 
     | 
    
         
            +
                        .gsub(/4o (?=Mini|Preview|Turbo)/, '4o-')
         
     | 
| 
      
 117 
     | 
    
         
            +
                        .gsub(/\bHd\b/, 'HD')
         
     | 
| 
      
 118 
     | 
    
         
            +
                  end
         
     | 
| 
      
 119 
     | 
    
         
            +
                end
         
     | 
| 
      
 120 
     | 
    
         
            +
              end
         
     | 
| 
      
 121 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,42 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            module RubyLLM
         
     | 
| 
      
 4 
     | 
    
         
            +
              class ModelInfo
         
     | 
| 
      
 5 
     | 
    
         
            +
                attr_reader :id, :created_at, :display_name, :provider, :metadata,
         
     | 
| 
      
 6 
     | 
    
         
            +
                            :context_window, :max_tokens, :supports_vision, :supports_functions,
         
     | 
| 
      
 7 
     | 
    
         
            +
                            :supports_json_mode, :input_price_per_million, :output_price_per_million
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
                def initialize(id:, created_at:, display_name:, provider:, context_window:, max_tokens:, supports_vision:,
         
     | 
| 
      
 10 
     | 
    
         
            +
                               supports_functions:, supports_json_mode:, input_price_per_million:, output_price_per_million:, metadata: {})
         
     | 
| 
      
 11 
     | 
    
         
            +
                  @id = id
         
     | 
| 
      
 12 
     | 
    
         
            +
                  @created_at = created_at
         
     | 
| 
      
 13 
     | 
    
         
            +
                  @display_name = display_name
         
     | 
| 
      
 14 
     | 
    
         
            +
                  @provider = provider
         
     | 
| 
      
 15 
     | 
    
         
            +
                  @metadata = metadata
         
     | 
| 
      
 16 
     | 
    
         
            +
                  @context_window = context_window
         
     | 
| 
      
 17 
     | 
    
         
            +
                  @max_tokens = max_tokens
         
     | 
| 
      
 18 
     | 
    
         
            +
                  @supports_vision = supports_vision
         
     | 
| 
      
 19 
     | 
    
         
            +
                  @supports_functions = supports_functions
         
     | 
| 
      
 20 
     | 
    
         
            +
                  @supports_json_mode = supports_json_mode
         
     | 
| 
      
 21 
     | 
    
         
            +
                  @input_price_per_million = input_price_per_million
         
     | 
| 
      
 22 
     | 
    
         
            +
                  @output_price_per_million = output_price_per_million
         
     | 
| 
      
 23 
     | 
    
         
            +
                end
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
                def to_h
         
     | 
| 
      
 26 
     | 
    
         
            +
                  {
         
     | 
| 
      
 27 
     | 
    
         
            +
                    id: id,
         
     | 
| 
      
 28 
     | 
    
         
            +
                    created_at: created_at,
         
     | 
| 
      
 29 
     | 
    
         
            +
                    display_name: display_name,
         
     | 
| 
      
 30 
     | 
    
         
            +
                    provider: provider,
         
     | 
| 
      
 31 
     | 
    
         
            +
                    metadata: metadata,
         
     | 
| 
      
 32 
     | 
    
         
            +
                    context_window: context_window,
         
     | 
| 
      
 33 
     | 
    
         
            +
                    max_tokens: max_tokens,
         
     | 
| 
      
 34 
     | 
    
         
            +
                    supports_vision: supports_vision,
         
     | 
| 
      
 35 
     | 
    
         
            +
                    supports_functions: supports_functions,
         
     | 
| 
      
 36 
     | 
    
         
            +
                    supports_json_mode: supports_json_mode,
         
     | 
| 
      
 37 
     | 
    
         
            +
                    input_price_per_million: input_price_per_million,
         
     | 
| 
      
 38 
     | 
    
         
            +
                    output_price_per_million: output_price_per_million
         
     | 
| 
      
 39 
     | 
    
         
            +
                  }
         
     | 
| 
      
 40 
     | 
    
         
            +
                end
         
     | 
| 
      
 41 
     | 
    
         
            +
              end
         
     | 
| 
      
 42 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,226 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require 'time'
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            module RubyLLM
         
     | 
| 
      
 6 
     | 
    
         
            +
              module Providers
         
     | 
| 
      
 7 
     | 
    
         
            +
                class Anthropic < Base
         
     | 
| 
      
 8 
     | 
    
         
            +
                  def chat(messages, model: nil, temperature: 0.7, stream: false, tools: nil, &block)
         
     | 
| 
      
 9 
     | 
    
         
            +
                    payload = {
         
     | 
| 
      
 10 
     | 
    
         
            +
                      model: model || 'claude-3-5-sonnet-20241022',
         
     | 
| 
      
 11 
     | 
    
         
            +
                      messages: format_messages(messages),
         
     | 
| 
      
 12 
     | 
    
         
            +
                      temperature: temperature,
         
     | 
| 
      
 13 
     | 
    
         
            +
                      stream: stream,
         
     | 
| 
      
 14 
     | 
    
         
            +
                      max_tokens: 4096
         
     | 
| 
      
 15 
     | 
    
         
            +
                    }
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
                    payload[:tools] = tools.map { |tool| tool_to_anthropic(tool) } if tools&.any?
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                    puts 'Sending payload to Anthropic:' if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 20 
     | 
    
         
            +
                    puts JSON.pretty_generate(payload) if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
                    if stream && block_given?
         
     | 
| 
      
 23 
     | 
    
         
            +
                      stream_chat_completion(payload, tools, &block)
         
     | 
| 
      
 24 
     | 
    
         
            +
                    else
         
     | 
| 
      
 25 
     | 
    
         
            +
                      create_chat_completion(payload, tools)
         
     | 
| 
      
 26 
     | 
    
         
            +
                    end
         
     | 
| 
      
 27 
     | 
    
         
            +
                  end
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                  def list_models
         
     | 
| 
      
 30 
     | 
    
         
            +
                    response = @connection.get('/v1/models') do |req|
         
     | 
| 
      
 31 
     | 
    
         
            +
                      req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
         
     | 
| 
      
 32 
     | 
    
         
            +
                      req.headers['anthropic-version'] = '2023-06-01'
         
     | 
| 
      
 33 
     | 
    
         
            +
                    end
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
                    raise RubyLLM::Error, "API error: #{parse_error_message(response)}" if response.status >= 400
         
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
                    capabilities = RubyLLM::ModelCapabilities::Anthropic.new
         
     | 
| 
      
 38 
     | 
    
         
            +
                    models_data = response.body['data'] || []
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
                    models_data.map do |model|
         
     | 
| 
      
 41 
     | 
    
         
            +
                      ModelInfo.new(
         
     | 
| 
      
 42 
     | 
    
         
            +
                        id: model['id'],
         
     | 
| 
      
 43 
     | 
    
         
            +
                        created_at: Time.parse(model['created_at']),
         
     | 
| 
      
 44 
     | 
    
         
            +
                        display_name: model['display_name'],
         
     | 
| 
      
 45 
     | 
    
         
            +
                        provider: 'anthropic',
         
     | 
| 
      
 46 
     | 
    
         
            +
                        metadata: {
         
     | 
| 
      
 47 
     | 
    
         
            +
                          type: model['type']
         
     | 
| 
      
 48 
     | 
    
         
            +
                        },
         
     | 
| 
      
 49 
     | 
    
         
            +
                        context_window: capabilities.determine_context_window(model['id']),
         
     | 
| 
      
 50 
     | 
    
         
            +
                        max_tokens: capabilities.determine_max_tokens(model['id']),
         
     | 
| 
      
 51 
     | 
    
         
            +
                        supports_vision: capabilities.supports_vision?(model['id']),
         
     | 
| 
      
 52 
     | 
    
         
            +
                        supports_functions: capabilities.supports_functions?(model['id']),
         
     | 
| 
      
 53 
     | 
    
         
            +
                        supports_json_mode: capabilities.supports_json_mode?(model['id']),
         
     | 
| 
      
 54 
     | 
    
         
            +
                        input_price_per_million: capabilities.get_input_price(model['id']),
         
     | 
| 
      
 55 
     | 
    
         
            +
                        output_price_per_million: capabilities.get_output_price(model['id'])
         
     | 
| 
      
 56 
     | 
    
         
            +
                      )
         
     | 
| 
      
 57 
     | 
    
         
            +
                    end
         
     | 
| 
      
 58 
     | 
    
         
            +
                  rescue Faraday::Error => e
         
     | 
| 
      
 59 
     | 
    
         
            +
                    handle_error(e)
         
     | 
| 
      
 60 
     | 
    
         
            +
                  end
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
                  private
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                  def tool_to_anthropic(tool)
         
     | 
| 
      
 65 
     | 
    
         
            +
                    {
         
     | 
| 
      
 66 
     | 
    
         
            +
                      name: tool.name,
         
     | 
| 
      
 67 
     | 
    
         
            +
                      description: tool.description,
         
     | 
| 
      
 68 
     | 
    
         
            +
                      input_schema: {
         
     | 
| 
      
 69 
     | 
    
         
            +
                        type: 'object',
         
     | 
| 
      
 70 
     | 
    
         
            +
                        properties: tool.parameters,
         
     | 
| 
      
 71 
     | 
    
         
            +
                        required: tool.parameters.select { |_, v| v[:required] }.keys
         
     | 
| 
      
 72 
     | 
    
         
            +
                      }
         
     | 
| 
      
 73 
     | 
    
         
            +
                    }
         
     | 
| 
      
 74 
     | 
    
         
            +
                  end
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
                  def format_messages(messages)
         
     | 
| 
      
 77 
     | 
    
         
            +
                    messages.map do |msg|
         
     | 
| 
      
 78 
     | 
    
         
            +
                      message = { role: msg.role == :user ? 'user' : 'assistant' }
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                      message[:content] = if msg.tool_results
         
     | 
| 
      
 81 
     | 
    
         
            +
                                            [
         
     | 
| 
      
 82 
     | 
    
         
            +
                                              {
         
     | 
| 
      
 83 
     | 
    
         
            +
                                                type: 'tool_result',
         
     | 
| 
      
 84 
     | 
    
         
            +
                                                tool_use_id: msg.tool_results[:tool_use_id],
         
     | 
| 
      
 85 
     | 
    
         
            +
                                                content: msg.tool_results[:content],
         
     | 
| 
      
 86 
     | 
    
         
            +
                                                is_error: msg.tool_results[:is_error]
         
     | 
| 
      
 87 
     | 
    
         
            +
                                              }.compact
         
     | 
| 
      
 88 
     | 
    
         
            +
                                            ]
         
     | 
| 
      
 89 
     | 
    
         
            +
                                          else
         
     | 
| 
      
 90 
     | 
    
         
            +
                                            msg.content
         
     | 
| 
      
 91 
     | 
    
         
            +
                                          end
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
                      message
         
     | 
| 
      
 94 
     | 
    
         
            +
                    end
         
     | 
| 
      
 95 
     | 
    
         
            +
                  end
         
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
      
 97 
     | 
    
         
            +
                  def create_chat_completion(payload, tools = nil)
         
     | 
| 
      
 98 
     | 
    
         
            +
                    response = @connection.post('/v1/messages') do |req|
         
     | 
| 
      
 99 
     | 
    
         
            +
                      req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
         
     | 
| 
      
 100 
     | 
    
         
            +
                      req.headers['anthropic-version'] = '2023-06-01'
         
     | 
| 
      
 101 
     | 
    
         
            +
                      req.headers['Content-Type'] = 'application/json'
         
     | 
| 
      
 102 
     | 
    
         
            +
                      req.body = payload
         
     | 
| 
      
 103 
     | 
    
         
            +
                    end
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                    puts 'Response from Anthropic:' if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 106 
     | 
    
         
            +
                    puts JSON.pretty_generate(response.body) if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 107 
     | 
    
         
            +
             
     | 
| 
      
 108 
     | 
    
         
            +
                    handle_response(response, tools, payload)
         
     | 
| 
      
 109 
     | 
    
         
            +
                  rescue Faraday::Error => e
         
     | 
| 
      
 110 
     | 
    
         
            +
                    handle_error(e)
         
     | 
| 
      
 111 
     | 
    
         
            +
                  end
         
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
                  def stream_chat_completion(payload, tools = nil)
         
     | 
| 
      
 114 
     | 
    
         
            +
                    response = @connection.post('/v1/messages') do |req|
         
     | 
| 
      
 115 
     | 
    
         
            +
                      req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
         
     | 
| 
      
 116 
     | 
    
         
            +
                      req.headers['anthropic-version'] = '2023-06-01'
         
     | 
| 
      
 117 
     | 
    
         
            +
                      req.body = payload
         
     | 
| 
      
 118 
     | 
    
         
            +
                    end
         
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
                    response.body.each_line do |line|
         
     | 
| 
      
 121 
     | 
    
         
            +
                      next if line.strip.empty?
         
     | 
| 
      
 122 
     | 
    
         
            +
                      next if line == 'data: [DONE]'
         
     | 
| 
      
 123 
     | 
    
         
            +
             
     | 
| 
      
 124 
     | 
    
         
            +
                      begin
         
     | 
| 
      
 125 
     | 
    
         
            +
                        data = JSON.parse(line.sub(/^data: /, ''))
         
     | 
| 
      
 126 
     | 
    
         
            +
             
     | 
| 
      
 127 
     | 
    
         
            +
                        if data['type'] == 'content_block_delta'
         
     | 
| 
      
 128 
     | 
    
         
            +
                          content = data['delta']['text']
         
     | 
| 
      
 129 
     | 
    
         
            +
                          yield Message.new(role: :assistant, content: content) if content
         
     | 
| 
      
 130 
     | 
    
         
            +
                        elsif data['type'] == 'tool_call'
         
     | 
| 
      
 131 
     | 
    
         
            +
                          handle_tool_calls(data['tool_calls'], tools) do |result|
         
     | 
| 
      
 132 
     | 
    
         
            +
                            yield Message.new(role: :assistant, content: result)
         
     | 
| 
      
 133 
     | 
    
         
            +
                          end
         
     | 
| 
      
 134 
     | 
    
         
            +
                        end
         
     | 
| 
      
 135 
     | 
    
         
            +
                      rescue JSON::ParserError
         
     | 
| 
      
 136 
     | 
    
         
            +
                        next
         
     | 
| 
      
 137 
     | 
    
         
            +
                      end
         
     | 
| 
      
 138 
     | 
    
         
            +
                    end
         
     | 
| 
      
 139 
     | 
    
         
            +
                  rescue Faraday::Error => e
         
     | 
| 
      
 140 
     | 
    
         
            +
                    handle_error(e)
         
     | 
| 
      
 141 
     | 
    
         
            +
                  end
         
     | 
| 
      
 142 
     | 
    
         
            +
             
     | 
| 
      
 143 
     | 
    
         
            +
                  def handle_response(response, tools, payload)
         
     | 
| 
      
 144 
     | 
    
         
            +
                    data = response.body
         
     | 
| 
      
 145 
     | 
    
         
            +
                    return Message.new(role: :assistant, content: '') if data['type'] == 'error'
         
     | 
| 
      
 146 
     | 
    
         
            +
             
     | 
| 
      
 147 
     | 
    
         
            +
                    # Extract text content and tool use from response
         
     | 
| 
      
 148 
     | 
    
         
            +
                    content_parts = data['content'] || []
         
     | 
| 
      
 149 
     | 
    
         
            +
                    text_content = content_parts.find { |c| c['type'] == 'text' }&.fetch('text', '')
         
     | 
| 
      
 150 
     | 
    
         
            +
                    tool_use = content_parts.find { |c| c['type'] == 'tool_use' }
         
     | 
| 
      
 151 
     | 
    
         
            +
             
     | 
| 
      
 152 
     | 
    
         
            +
                    if tool_use && tools
         
     | 
| 
      
 153 
     | 
    
         
            +
                      tool = tools.find { |t| t.name == tool_use['name'] }
         
     | 
| 
      
 154 
     | 
    
         
            +
                      result = if tool
         
     | 
| 
      
 155 
     | 
    
         
            +
                                 begin
         
     | 
| 
      
 156 
     | 
    
         
            +
                                   tool_result = tool.call(tool_use['input'] || {})
         
     | 
| 
      
 157 
     | 
    
         
            +
                                   {
         
     | 
| 
      
 158 
     | 
    
         
            +
                                     tool_use_id: tool_use['id'],
         
     | 
| 
      
 159 
     | 
    
         
            +
                                     content: tool_result.to_s
         
     | 
| 
      
 160 
     | 
    
         
            +
                                   }
         
     | 
| 
      
 161 
     | 
    
         
            +
                                 rescue StandardError => e
         
     | 
| 
      
 162 
     | 
    
         
            +
                                   {
         
     | 
| 
      
 163 
     | 
    
         
            +
                                     tool_use_id: tool_use['id'],
         
     | 
| 
      
 164 
     | 
    
         
            +
                                     content: "Error executing tool #{tool.name}: #{e.message}",
         
     | 
| 
      
 165 
     | 
    
         
            +
                                     is_error: true
         
     | 
| 
      
 166 
     | 
    
         
            +
                                   }
         
     | 
| 
      
 167 
     | 
    
         
            +
                                 end
         
     | 
| 
      
 168 
     | 
    
         
            +
                               end
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
                      # Create a new message with the tool result
         
     | 
| 
      
 171 
     | 
    
         
            +
                      new_messages = payload[:messages] + [
         
     | 
| 
      
 172 
     | 
    
         
            +
                        { role: 'assistant', content: data['content'] },
         
     | 
| 
      
 173 
     | 
    
         
            +
                        {
         
     | 
| 
      
 174 
     | 
    
         
            +
                          role: 'user',
         
     | 
| 
      
 175 
     | 
    
         
            +
                          content: [
         
     | 
| 
      
 176 
     | 
    
         
            +
                            {
         
     | 
| 
      
 177 
     | 
    
         
            +
                              type: 'tool_result',
         
     | 
| 
      
 178 
     | 
    
         
            +
                              tool_use_id: result[:tool_use_id],
         
     | 
| 
      
 179 
     | 
    
         
            +
                              content: result[:content],
         
     | 
| 
      
 180 
     | 
    
         
            +
                              is_error: result[:is_error]
         
     | 
| 
      
 181 
     | 
    
         
            +
                            }.compact
         
     | 
| 
      
 182 
     | 
    
         
            +
                          ]
         
     | 
| 
      
 183 
     | 
    
         
            +
                        }
         
     | 
| 
      
 184 
     | 
    
         
            +
                      ]
         
     | 
| 
      
 185 
     | 
    
         
            +
             
     | 
| 
      
 186 
     | 
    
         
            +
                      return create_chat_completion(payload.merge(messages: new_messages), tools)
         
     | 
| 
      
 187 
     | 
    
         
            +
                    end
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
                    Message.new(
         
     | 
| 
      
 190 
     | 
    
         
            +
                      role: :assistant,
         
     | 
| 
      
 191 
     | 
    
         
            +
                      content: text_content
         
     | 
| 
      
 192 
     | 
    
         
            +
                    )
         
     | 
| 
      
 193 
     | 
    
         
            +
                  end
         
     | 
| 
      
 194 
     | 
    
         
            +
             
     | 
| 
      
 195 
     | 
    
         
            +
                  def handle_tool_calls(tool_calls, tools)
         
     | 
| 
      
 196 
     | 
    
         
            +
                    return [] unless tool_calls && tools
         
     | 
| 
      
 197 
     | 
    
         
            +
             
     | 
| 
      
 198 
     | 
    
         
            +
                    tool_calls.map do |tool_call|
         
     | 
| 
      
 199 
     | 
    
         
            +
                      tool = tools.find { |t| t.name == tool_call['name'] }
         
     | 
| 
      
 200 
     | 
    
         
            +
                      next unless tool
         
     | 
| 
      
 201 
     | 
    
         
            +
             
     | 
| 
      
 202 
     | 
    
         
            +
                      begin
         
     | 
| 
      
 203 
     | 
    
         
            +
                        args = JSON.parse(tool_call['arguments'])
         
     | 
| 
      
 204 
     | 
    
         
            +
                        result = tool.call(args)
         
     | 
| 
      
 205 
     | 
    
         
            +
                        puts "Tool result: #{result}" if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 206 
     | 
    
         
            +
                        {
         
     | 
| 
      
 207 
     | 
    
         
            +
                          tool_use_id: tool_call['id'],
         
     | 
| 
      
 208 
     | 
    
         
            +
                          content: result.to_s
         
     | 
| 
      
 209 
     | 
    
         
            +
                        }
         
     | 
| 
      
 210 
     | 
    
         
            +
                      rescue JSON::ParserError, ArgumentError => e
         
     | 
| 
      
 211 
     | 
    
         
            +
                        puts "Error executing tool: #{e.message}" if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 212 
     | 
    
         
            +
                        {
         
     | 
| 
      
 213 
     | 
    
         
            +
                          tool_use_id: tool_call['id'],
         
     | 
| 
      
 214 
     | 
    
         
            +
                          content: "Error executing tool #{tool.name}: #{e.message}",
         
     | 
| 
      
 215 
     | 
    
         
            +
                          is_error: true
         
     | 
| 
      
 216 
     | 
    
         
            +
                        }
         
     | 
| 
      
 217 
     | 
    
         
            +
                      end
         
     | 
| 
      
 218 
     | 
    
         
            +
                    end.compact
         
     | 
| 
      
 219 
     | 
    
         
            +
                  end
         
     | 
| 
      
 220 
     | 
    
         
            +
             
     | 
| 
      
 221 
     | 
    
         
            +
                  def api_base
         
     | 
| 
      
 222 
     | 
    
         
            +
                    'https://api.anthropic.com'
         
     | 
| 
      
 223 
     | 
    
         
            +
                  end
         
     | 
| 
      
 224 
     | 
    
         
            +
                end
         
     | 
| 
      
 225 
     | 
    
         
            +
              end
         
     | 
| 
      
 226 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -1,6 +1,11 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
       1 
3 
     | 
    
         
             
            module RubyLLM
         
     | 
| 
       2 
4 
     | 
    
         
             
              module Providers
         
     | 
| 
      
 5 
     | 
    
         
            +
                # Base provider class for LLM interactions
         
     | 
| 
       3 
6 
     | 
    
         
             
                class Base
         
     | 
| 
      
 7 
     | 
    
         
            +
                  attr_reader :connection
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
       4 
9 
     | 
    
         
             
                  def initialize
         
     | 
| 
       5 
10 
     | 
    
         
             
                    @connection = build_connection
         
     | 
| 
       6 
11 
     | 
    
         
             
                  end
         
     | 
| 
         @@ -23,9 +28,9 @@ module RubyLLM 
     | 
|
| 
       23 
28 
     | 
    
         
             
                  def handle_error(error)
         
     | 
| 
       24 
29 
     | 
    
         
             
                    case error
         
     | 
| 
       25 
30 
     | 
    
         
             
                    when Faraday::TimeoutError
         
     | 
| 
       26 
     | 
    
         
            -
                      raise RubyLLM::Error,  
     | 
| 
      
 31 
     | 
    
         
            +
                      raise RubyLLM::Error, 'Request timed out'
         
     | 
| 
       27 
32 
     | 
    
         
             
                    when Faraday::ConnectionFailed
         
     | 
| 
       28 
     | 
    
         
            -
                      raise RubyLLM::Error,  
     | 
| 
      
 33 
     | 
    
         
            +
                      raise RubyLLM::Error, 'Connection failed'
         
     | 
| 
       29 
34 
     | 
    
         
             
                    when Faraday::ClientError
         
     | 
| 
       30 
35 
     | 
    
         
             
                      handle_api_error(error)
         
     | 
| 
       31 
36 
     | 
    
         
             
                    else
         
     | 
| 
         @@ -36,6 +41,20 @@ module RubyLLM 
     | 
|
| 
       36 
41 
     | 
    
         
             
                  def handle_api_error(error)
         
     | 
| 
       37 
42 
     | 
    
         
             
                    raise RubyLLM::Error, "API error: #{error.response[:status]}"
         
     | 
| 
       38 
43 
     | 
    
         
             
                  end
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
                  def parse_error_message(response)
         
     | 
| 
      
 46 
     | 
    
         
            +
                    return "HTTP #{response.status}" unless response.body
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
                    if response.body.is_a?(String)
         
     | 
| 
      
 49 
     | 
    
         
            +
                      begin
         
     | 
| 
      
 50 
     | 
    
         
            +
                        JSON.parse(response.body).dig('error', 'message')
         
     | 
| 
      
 51 
     | 
    
         
            +
                      rescue StandardError
         
     | 
| 
      
 52 
     | 
    
         
            +
                        "HTTP #{response.status}"
         
     | 
| 
      
 53 
     | 
    
         
            +
                      end
         
     | 
| 
      
 54 
     | 
    
         
            +
                    else
         
     | 
| 
      
 55 
     | 
    
         
            +
                      response.body.dig('error', 'message') || "HTTP #{response.status}"
         
     | 
| 
      
 56 
     | 
    
         
            +
                    end
         
     | 
| 
      
 57 
     | 
    
         
            +
                  end
         
     | 
| 
       39 
58 
     | 
    
         
             
                end
         
     | 
| 
       40 
59 
     | 
    
         
             
              end
         
     | 
| 
       41 
60 
     | 
    
         
             
            end
         
     | 
| 
         @@ -0,0 +1,161 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            module RubyLLM
         
     | 
| 
      
 4 
     | 
    
         
            +
              module Providers
         
     | 
| 
      
 5 
     | 
    
         
            +
                class OpenAI < Base
         
     | 
| 
      
 6 
     | 
    
         
            +
                  def chat(messages, model: nil, temperature: 0.7, stream: false, tools: nil, &block)
         
     | 
| 
      
 7 
     | 
    
         
            +
                    payload = {
         
     | 
| 
      
 8 
     | 
    
         
            +
                      model: model || RubyLLM.configuration.default_model,
         
     | 
| 
      
 9 
     | 
    
         
            +
                      messages: messages.map(&:to_h),
         
     | 
| 
      
 10 
     | 
    
         
            +
                      temperature: temperature,
         
     | 
| 
      
 11 
     | 
    
         
            +
                      stream: stream
         
     | 
| 
      
 12 
     | 
    
         
            +
                    }
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
                    if tools&.any?
         
     | 
| 
      
 15 
     | 
    
         
            +
                      payload[:functions] = tools.map { |tool| tool_to_function(tool) }
         
     | 
| 
      
 16 
     | 
    
         
            +
                      payload[:function_call] = 'auto'
         
     | 
| 
      
 17 
     | 
    
         
            +
                    end
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                    puts 'Sending payload to OpenAI:' if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 20 
     | 
    
         
            +
                    puts JSON.pretty_generate(payload) if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
                    if stream && block_given?
         
     | 
| 
      
 23 
     | 
    
         
            +
                      stream_chat_completion(payload, tools, &block)
         
     | 
| 
      
 24 
     | 
    
         
            +
                    else
         
     | 
| 
      
 25 
     | 
    
         
            +
                      create_chat_completion(payload, tools)
         
     | 
| 
      
 26 
     | 
    
         
            +
                    end
         
     | 
| 
      
 27 
     | 
    
         
            +
                  rescue Faraday::TimeoutError
         
     | 
| 
      
 28 
     | 
    
         
            +
                    raise RubyLLM::Error, 'Request timed out'
         
     | 
| 
      
 29 
     | 
    
         
            +
                  rescue Faraday::ConnectionFailed
         
     | 
| 
      
 30 
     | 
    
         
            +
                    raise RubyLLM::Error, 'Connection failed'
         
     | 
| 
      
 31 
     | 
    
         
            +
                  rescue Faraday::ClientError => e
         
     | 
| 
      
 32 
     | 
    
         
            +
                    raise RubyLLM::Error, 'Client error' unless e.response
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
                    error_msg = e.response[:body]['error']&.fetch('message', nil) || "HTTP #{e.response[:status]}"
         
     | 
| 
      
 35 
     | 
    
         
            +
                    raise RubyLLM::Error, "API error: #{error_msg}"
         
     | 
| 
      
 36 
     | 
    
         
            +
                  end
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
                  def list_models
         
     | 
| 
      
 39 
     | 
    
         
            +
                    response = @connection.get('/v1/models') do |req|
         
     | 
| 
      
 40 
     | 
    
         
            +
                      req.headers['Authorization'] = "Bearer #{RubyLLM.configuration.openai_api_key}"
         
     | 
| 
      
 41 
     | 
    
         
            +
                    end
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                    raise RubyLLM::Error, "API error: #{parse_error_message(response)}" if response.status >= 400
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
                    capabilities = RubyLLM::ModelCapabilities::OpenAI.new
         
     | 
| 
      
 46 
     | 
    
         
            +
                    (response.body['data'] || []).map do |model|
         
     | 
| 
      
 47 
     | 
    
         
            +
                      ModelInfo.new(
         
     | 
| 
      
 48 
     | 
    
         
            +
                        id: model['id'],
         
     | 
| 
      
 49 
     | 
    
         
            +
                        created_at: Time.at(model['created']),
         
     | 
| 
      
 50 
     | 
    
         
            +
                        display_name: capabilities.format_display_name(model['id']),
         
     | 
| 
      
 51 
     | 
    
         
            +
                        provider: 'openai',
         
     | 
| 
      
 52 
     | 
    
         
            +
                        metadata: {
         
     | 
| 
      
 53 
     | 
    
         
            +
                          object: model['object'],
         
     | 
| 
      
 54 
     | 
    
         
            +
                          owned_by: model['owned_by']
         
     | 
| 
      
 55 
     | 
    
         
            +
                        },
         
     | 
| 
      
 56 
     | 
    
         
            +
                        context_window: capabilities.determine_context_window(model['id']),
         
     | 
| 
      
 57 
     | 
    
         
            +
                        max_tokens: capabilities.determine_max_tokens(model['id']),
         
     | 
| 
      
 58 
     | 
    
         
            +
                        supports_vision: capabilities.supports_vision?(model['id']),
         
     | 
| 
      
 59 
     | 
    
         
            +
                        supports_functions: capabilities.supports_functions?(model['id']),
         
     | 
| 
      
 60 
     | 
    
         
            +
                        supports_json_mode: capabilities.supports_json_mode?(model['id']),
         
     | 
| 
      
 61 
     | 
    
         
            +
                        input_price_per_million: capabilities.get_input_price(model['id']),
         
     | 
| 
      
 62 
     | 
    
         
            +
                        output_price_per_million: capabilities.get_output_price(model['id'])
         
     | 
| 
      
 63 
     | 
    
         
            +
                      )
         
     | 
| 
      
 64 
     | 
    
         
            +
                    end
         
     | 
| 
      
 65 
     | 
    
         
            +
                  rescue Faraday::Error => e
         
     | 
| 
      
 66 
     | 
    
         
            +
                    handle_error(e)
         
     | 
| 
      
 67 
     | 
    
         
            +
                  end
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
                  private
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                  def tool_to_function(tool)
         
     | 
| 
      
 72 
     | 
    
         
            +
                    {
         
     | 
| 
      
 73 
     | 
    
         
            +
                      name: tool.name,
         
     | 
| 
      
 74 
     | 
    
         
            +
                      description: tool.description,
         
     | 
| 
      
 75 
     | 
    
         
            +
                      parameters: {
         
     | 
| 
      
 76 
     | 
    
         
            +
                        type: 'object',
         
     | 
| 
      
 77 
     | 
    
         
            +
                        properties: tool.parameters.transform_values { |v| v.reject { |k, _| k == :required } },
         
     | 
| 
      
 78 
     | 
    
         
            +
                        required: tool.parameters.select { |_, v| v[:required] }.keys
         
     | 
| 
      
 79 
     | 
    
         
            +
                      }
         
     | 
| 
      
 80 
     | 
    
         
            +
                    }
         
     | 
| 
      
 81 
     | 
    
         
            +
                  end
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                  def create_chat_completion(payload, tools = nil)
         
     | 
| 
      
 84 
     | 
    
         
            +
                    response = connection.post('/v1/chat/completions') do |req|
         
     | 
| 
      
 85 
     | 
    
         
            +
                      req.headers['Authorization'] = "Bearer #{RubyLLM.configuration.openai_api_key}"
         
     | 
| 
      
 86 
     | 
    
         
            +
                      req.headers['Content-Type'] = 'application/json'
         
     | 
| 
      
 87 
     | 
    
         
            +
                      req.body = payload
         
     | 
| 
      
 88 
     | 
    
         
            +
                    end
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                    puts 'Response from OpenAI:' if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 91 
     | 
    
         
            +
                    puts JSON.pretty_generate(response.body) if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
                    if response.status >= 400
         
     | 
| 
      
 94 
     | 
    
         
            +
                      error_msg = response.body['error']&.fetch('message', nil) || "HTTP #{response.status}"
         
     | 
| 
      
 95 
     | 
    
         
            +
                      raise RubyLLM::Error, "API error: #{error_msg}"
         
     | 
| 
      
 96 
     | 
    
         
            +
                    end
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
                    handle_response(response, tools, payload)
         
     | 
| 
      
 99 
     | 
    
         
            +
                  end
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
                  def handle_response(response, tools, payload)
         
     | 
| 
      
 102 
     | 
    
         
            +
                    data = response.body
         
     | 
| 
      
 103 
     | 
    
         
            +
                    message_data = data.dig('choices', 0, 'message')
         
     | 
| 
      
 104 
     | 
    
         
            +
                    return Message.new(role: :assistant, content: '') unless message_data
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
      
 106 
     | 
    
         
            +
                    if message_data['function_call'] && tools
         
     | 
| 
      
 107 
     | 
    
         
            +
                      result = handle_function_call(message_data['function_call'], tools)
         
     | 
| 
      
 108 
     | 
    
         
            +
                      puts "Function result: #{result}" if ENV['RUBY_LLM_DEBUG']
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
                      # Create a new chat completion with the function results
         
     | 
| 
      
 111 
     | 
    
         
            +
                      new_messages = payload[:messages] + [
         
     | 
| 
      
 112 
     | 
    
         
            +
                        { role: 'assistant', content: message_data['content'], function_call: message_data['function_call'] },
         
     | 
| 
      
 113 
     | 
    
         
            +
                        { role: 'function', name: message_data['function_call']['name'], content: result }
         
     | 
| 
      
 114 
     | 
    
         
            +
                      ]
         
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
                      return create_chat_completion(payload.merge(messages: new_messages), tools)
         
     | 
| 
      
 117 
     | 
    
         
            +
                    end
         
     | 
| 
      
 118 
     | 
    
         
            +
             
     | 
| 
      
 119 
     | 
    
         
            +
                    Message.new(
         
     | 
| 
      
 120 
     | 
    
         
            +
                      role: :assistant,
         
     | 
| 
      
 121 
     | 
    
         
            +
                      content: message_data['content']
         
     | 
| 
      
 122 
     | 
    
         
            +
                    )
         
     | 
| 
      
 123 
     | 
    
         
            +
                  end
         
     | 
| 
      
 124 
     | 
    
         
            +
             
     | 
| 
      
 125 
     | 
    
         
            +
                  def handle_function_call(function_call, tools)
         
     | 
| 
      
 126 
     | 
    
         
            +
                    return unless function_call && tools
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
                    tool = tools.find { |t| t.name == function_call['name'] }
         
     | 
| 
      
 129 
     | 
    
         
            +
                    return unless tool
         
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
                    begin
         
     | 
| 
      
 132 
     | 
    
         
            +
                      args = JSON.parse(function_call['arguments'])
         
     | 
| 
      
 133 
     | 
    
         
            +
                      tool.call(args)
         
     | 
| 
      
 134 
     | 
    
         
            +
                    rescue JSON::ParserError, ArgumentError => e
         
     | 
| 
      
 135 
     | 
    
         
            +
                      "Error executing function #{tool.name}: #{e.message}"
         
     | 
| 
      
 136 
     | 
    
         
            +
                    end
         
     | 
| 
      
 137 
     | 
    
         
            +
                  end
         
     | 
| 
      
 138 
     | 
    
         
            +
             
     | 
| 
      
 139 
     | 
    
         
            +
                  def handle_error(error)
         
     | 
| 
      
 140 
     | 
    
         
            +
                    case error
         
     | 
| 
      
 141 
     | 
    
         
            +
                    when Faraday::TimeoutError
         
     | 
| 
      
 142 
     | 
    
         
            +
                      raise RubyLLM::Error, 'Request timed out'
         
     | 
| 
      
 143 
     | 
    
         
            +
                    when Faraday::ConnectionFailed
         
     | 
| 
      
 144 
     | 
    
         
            +
                      raise RubyLLM::Error, 'Connection failed'
         
     | 
| 
      
 145 
     | 
    
         
            +
                    when Faraday::ClientError
         
     | 
| 
      
 146 
     | 
    
         
            +
                      raise RubyLLM::Error, 'Client error' unless error.response
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
                      error_msg = error.response[:body]['error']&.fetch('message', nil) || "HTTP #{error.response[:status]}"
         
     | 
| 
      
 149 
     | 
    
         
            +
                      raise RubyLLM::Error, "API error: #{error_msg}"
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
      
 151 
     | 
    
         
            +
                    else
         
     | 
| 
      
 152 
     | 
    
         
            +
                      raise error
         
     | 
| 
      
 153 
     | 
    
         
            +
                    end
         
     | 
| 
      
 154 
     | 
    
         
            +
                  end
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
                  def api_base
         
     | 
| 
      
 157 
     | 
    
         
            +
                    'https://api.openai.com'
         
     | 
| 
      
 158 
     | 
    
         
            +
                  end
         
     | 
| 
      
 159 
     | 
    
         
            +
                end
         
     | 
| 
      
 160 
     | 
    
         
            +
              end
         
     | 
| 
      
 161 
     | 
    
         
            +
            end
         
     | 
    
        data/lib/ruby_llm/railtie.rb
    CHANGED
    
    
| 
         @@ -0,0 +1,75 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            module RubyLLM
         
     | 
| 
      
 4 
     | 
    
         
            +
              # Represents a tool/function that can be called by an LLM
         
     | 
| 
      
 5 
     | 
    
         
            +
              class Tool
         
     | 
| 
      
 6 
     | 
    
         
            +
                attr_reader :name, :description, :parameters, :handler
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
                def self.from_method(method_object, description: nil, parameter_descriptions: {})
         
     | 
| 
      
 9 
     | 
    
         
            +
                  method_params = {}
         
     | 
| 
      
 10 
     | 
    
         
            +
                  method_object.parameters.each do |param_type, param_name|
         
     | 
| 
      
 11 
     | 
    
         
            +
                    next unless %i[req opt key keyreq].include?(param_type)
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
                    method_params[param_name] = {
         
     | 
| 
      
 14 
     | 
    
         
            +
                      type: 'string',
         
     | 
| 
      
 15 
     | 
    
         
            +
                      description: parameter_descriptions[param_name] || param_name.to_s.tr('_', ' '),
         
     | 
| 
      
 16 
     | 
    
         
            +
                      required: %i[req keyreq].include?(param_type)
         
     | 
| 
      
 17 
     | 
    
         
            +
                    }
         
     | 
| 
      
 18 
     | 
    
         
            +
                  end
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
                  new(
         
     | 
| 
      
 21 
     | 
    
         
            +
                    name: method_object.name.to_s,
         
     | 
| 
      
 22 
     | 
    
         
            +
                    description: description || "Executes the #{method_object.name} operation",
         
     | 
| 
      
 23 
     | 
    
         
            +
                    parameters: method_params
         
     | 
| 
      
 24 
     | 
    
         
            +
                  ) do |args|
         
     | 
| 
      
 25 
     | 
    
         
            +
                    # Create an instance if it's an instance method
         
     | 
| 
      
 26 
     | 
    
         
            +
                    instance = if method_object.owner.instance_methods.include?(method_object.name)
         
     | 
| 
      
 27 
     | 
    
         
            +
                                 method_object.owner.new
         
     | 
| 
      
 28 
     | 
    
         
            +
                               else
         
     | 
| 
      
 29 
     | 
    
         
            +
                                 method_object.owner
         
     | 
| 
      
 30 
     | 
    
         
            +
                               end
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                    # Call the method with the arguments
         
     | 
| 
      
 33 
     | 
    
         
            +
                    if args.is_a?(Hash)
         
     | 
| 
      
 34 
     | 
    
         
            +
                      instance.method(method_object.name).call(**args)
         
     | 
| 
      
 35 
     | 
    
         
            +
                    else
         
     | 
| 
      
 36 
     | 
    
         
            +
                      instance.method(method_object.name).call(args)
         
     | 
| 
      
 37 
     | 
    
         
            +
                    end
         
     | 
| 
      
 38 
     | 
    
         
            +
                  end
         
     | 
| 
      
 39 
     | 
    
         
            +
                end
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                def initialize(name:, description:, parameters: {}, &block)
         
     | 
| 
      
 42 
     | 
    
         
            +
                  @name = name
         
     | 
| 
      
 43 
     | 
    
         
            +
                  @description = description
         
     | 
| 
      
 44 
     | 
    
         
            +
                  @parameters = parameters
         
     | 
| 
      
 45 
     | 
    
         
            +
                  @handler = block
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                  validate!
         
     | 
| 
      
 48 
     | 
    
         
            +
                end
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
                def call(args)
         
     | 
| 
      
 51 
     | 
    
         
            +
                  validated_args = validate_args!(args)
         
     | 
| 
      
 52 
     | 
    
         
            +
                  handler.call(validated_args)
         
     | 
| 
      
 53 
     | 
    
         
            +
                end
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                private
         
     | 
| 
      
 56 
     | 
    
         
            +
             
     | 
| 
      
 57 
     | 
    
         
            +
                def validate!
         
     | 
| 
      
 58 
     | 
    
         
            +
                  raise ArgumentError, 'Name must be a string' unless name.is_a?(String)
         
     | 
| 
      
 59 
     | 
    
         
            +
                  raise ArgumentError, 'Description must be a string' unless description.is_a?(String)
         
     | 
| 
      
 60 
     | 
    
         
            +
                  raise ArgumentError, 'Parameters must be a hash' unless parameters.is_a?(Hash)
         
     | 
| 
      
 61 
     | 
    
         
            +
                  raise ArgumentError, 'Block must be provided' unless handler.respond_to?(:call)
         
     | 
| 
      
 62 
     | 
    
         
            +
                end
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                def validate_args!(args)
         
     | 
| 
      
 65 
     | 
    
         
            +
                  args = args.transform_keys(&:to_sym)
         
     | 
| 
      
 66 
     | 
    
         
            +
                  required_params = parameters.select { |_, v| v[:required] }.keys
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
                  required_params.each do |param|
         
     | 
| 
      
 69 
     | 
    
         
            +
                    raise ArgumentError, "Missing required parameter: #{param}" unless args.key?(param.to_sym)
         
     | 
| 
      
 70 
     | 
    
         
            +
                  end
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
                  args
         
     | 
| 
      
 73 
     | 
    
         
            +
                end
         
     | 
| 
      
 74 
     | 
    
         
            +
              end
         
     | 
| 
      
 75 
     | 
    
         
            +
            end
         
     | 
    
        data/lib/ruby_llm/version.rb
    CHANGED