ruby_llm 0.1.0.pre → 0.1.0.pre3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,121 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module ModelCapabilities
5
+ class OpenAI < Base
6
+ def determine_context_window(model_id)
7
+ case model_id
8
+ when /gpt-4o/, /o1/, /gpt-4-turbo/
9
+ 128_000
10
+ when /gpt-4-0[0-9]{3}/
11
+ 8_192
12
+ when /gpt-3.5-turbo-instruct/
13
+ 4_096
14
+ when /gpt-3.5/
15
+ 16_385
16
+ else
17
+ 4_096
18
+ end
19
+ end
20
+
21
+ def determine_max_tokens(model_id)
22
+ case model_id
23
+ when /o1-2024-12-17/
24
+ 100_000
25
+ when /o1-mini-2024-09-12/
26
+ 65_536
27
+ when /o1-preview-2024-09-12/
28
+ 32_768
29
+ when /gpt-4o/, /gpt-4-turbo/
30
+ 16_384
31
+ when /gpt-4-0[0-9]{3}/
32
+ 8_192
33
+ when /gpt-3.5-turbo/
34
+ 4_096
35
+ else
36
+ 4_096
37
+ end
38
+ end
39
+
40
+ def get_input_price(model_id)
41
+ case model_id
42
+ when /o1-2024/
43
+ 15.0 # $15.00 per million tokens
44
+ when /o1-mini/
45
+ 3.0 # $3.00 per million tokens
46
+ when /gpt-4o-realtime-preview/
47
+ 5.0 # $5.00 per million tokens
48
+ when /gpt-4o-mini-realtime-preview/
49
+ 0.60 # $0.60 per million tokens
50
+ when /gpt-4o-mini/
51
+ 0.15 # $0.15 per million tokens
52
+ when /gpt-4o/
53
+ 2.50 # $2.50 per million tokens
54
+ when /gpt-4-turbo/
55
+ 10.0 # $10.00 per million tokens
56
+ when /gpt-3.5/
57
+ 0.50 # $0.50 per million tokens
58
+ else
59
+ 0.50 # Default to GPT-3.5 pricing
60
+ end
61
+ end
62
+
63
+ def get_output_price(model_id)
64
+ case model_id
65
+ when /o1-2024/
66
+ 60.0 # $60.00 per million tokens
67
+ when /o1-mini/
68
+ 12.0 # $12.00 per million tokens
69
+ when /gpt-4o-realtime-preview/
70
+ 20.0 # $20.00 per million tokens
71
+ when /gpt-4o-mini-realtime-preview/
72
+ 2.40 # $2.40 per million tokens
73
+ when /gpt-4o-mini/
74
+ 0.60 # $0.60 per million tokens
75
+ when /gpt-4o/
76
+ 10.0 # $10.00 per million tokens
77
+ when /gpt-4-turbo/
78
+ 30.0 # $30.00 per million tokens
79
+ when /gpt-3.5/
80
+ 1.50 # $1.50 per million tokens
81
+ else
82
+ 1.50 # Default to GPT-3.5 pricing
83
+ end
84
+ end
85
+
86
+ def supports_functions?(model_id)
87
+ !model_id.include?('instruct')
88
+ end
89
+
90
+ def supports_vision?(model_id)
91
+ model_id.include?('vision') || model_id.match?(/gpt-4-(?!0314|0613)/)
92
+ end
93
+
94
+ def supports_json_mode?(model_id)
95
+ model_id.match?(/gpt-4-\d{4}-preview/) ||
96
+ model_id.include?('turbo') ||
97
+ model_id.match?(/gpt-3.5-turbo-(?!0301|0613)/)
98
+ end
99
+
100
+ def format_display_name(model_id)
101
+ # First replace hyphens with spaces
102
+ name = model_id.tr('-', ' ')
103
+
104
+ # Capitalize each word
105
+ name = name.split(' ').map { |word| word.capitalize }.join(' ')
106
+
107
+ # Apply specific formatting rules
108
+ name.gsub(/(\d{4}) (\d{2}) (\d{2})/, '\1\2\3') # Convert dates to YYYYMMDD
109
+ .gsub(/^Gpt /, 'GPT-')
110
+ .gsub(/^O1 /, 'O1-')
111
+ .gsub(/^Chatgpt /, 'ChatGPT-')
112
+ .gsub(/^Tts /, 'TTS-')
113
+ .gsub(/^Dall E /, 'DALL-E-')
114
+ .gsub(/3\.5 /, '3.5-')
115
+ .gsub(/4 /, '4-')
116
+ .gsub(/4o (?=Mini|Preview|Turbo)/, '4o-')
117
+ .gsub(/\bHd\b/, 'HD')
118
+ end
119
+ end
120
+ end
121
+ end
@@ -0,0 +1,42 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ class ModelInfo
5
+ attr_reader :id, :created_at, :display_name, :provider, :metadata,
6
+ :context_window, :max_tokens, :supports_vision, :supports_functions,
7
+ :supports_json_mode, :input_price_per_million, :output_price_per_million
8
+
9
+ def initialize(id:, created_at:, display_name:, provider:, context_window:, max_tokens:, supports_vision:,
10
+ supports_functions:, supports_json_mode:, input_price_per_million:, output_price_per_million:, metadata: {})
11
+ @id = id
12
+ @created_at = created_at
13
+ @display_name = display_name
14
+ @provider = provider
15
+ @metadata = metadata
16
+ @context_window = context_window
17
+ @max_tokens = max_tokens
18
+ @supports_vision = supports_vision
19
+ @supports_functions = supports_functions
20
+ @supports_json_mode = supports_json_mode
21
+ @input_price_per_million = input_price_per_million
22
+ @output_price_per_million = output_price_per_million
23
+ end
24
+
25
+ def to_h
26
+ {
27
+ id: id,
28
+ created_at: created_at,
29
+ display_name: display_name,
30
+ provider: provider,
31
+ metadata: metadata,
32
+ context_window: context_window,
33
+ max_tokens: max_tokens,
34
+ supports_vision: supports_vision,
35
+ supports_functions: supports_functions,
36
+ supports_json_mode: supports_json_mode,
37
+ input_price_per_million: input_price_per_million,
38
+ output_price_per_million: output_price_per_million
39
+ }
40
+ end
41
+ end
42
+ end
@@ -0,0 +1,254 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'time'
4
+
5
+ module RubyLLM
6
+ module Providers
7
+ class Anthropic < Base
8
+ def chat(messages, model: nil, temperature: 0.7, stream: false, tools: nil, &block)
9
+ payload = {
10
+ model: model || 'claude-3-5-sonnet-20241022',
11
+ messages: format_messages(messages),
12
+ temperature: temperature,
13
+ stream: stream,
14
+ max_tokens: 4096
15
+ }
16
+
17
+ payload[:tools] = tools.map { |tool| tool_to_anthropic(tool) } if tools&.any?
18
+
19
+ puts 'Sending payload to Anthropic:' if ENV['RUBY_LLM_DEBUG']
20
+ puts JSON.pretty_generate(payload) if ENV['RUBY_LLM_DEBUG']
21
+
22
+ if stream && block_given?
23
+ stream_chat_completion(payload, tools, &block)
24
+ else
25
+ create_chat_completion(payload, tools)
26
+ end
27
+ end
28
+
29
+ def list_models
30
+ response = @connection.get('/v1/models') do |req|
31
+ req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
32
+ req.headers['anthropic-version'] = '2023-06-01'
33
+ end
34
+
35
+ raise RubyLLM::Error, "API error: #{parse_error_message(response)}" if response.status >= 400
36
+
37
+ capabilities = RubyLLM::ModelCapabilities::Anthropic.new
38
+ models_data = response.body['data'] || []
39
+
40
+ models_data.map do |model|
41
+ ModelInfo.new(
42
+ id: model['id'],
43
+ created_at: Time.parse(model['created_at']),
44
+ display_name: model['display_name'],
45
+ provider: 'anthropic',
46
+ metadata: {
47
+ type: model['type']
48
+ },
49
+ context_window: capabilities.determine_context_window(model['id']),
50
+ max_tokens: capabilities.determine_max_tokens(model['id']),
51
+ supports_vision: capabilities.supports_vision?(model['id']),
52
+ supports_functions: capabilities.supports_functions?(model['id']),
53
+ supports_json_mode: capabilities.supports_json_mode?(model['id']),
54
+ input_price_per_million: capabilities.get_input_price(model['id']),
55
+ output_price_per_million: capabilities.get_output_price(model['id'])
56
+ )
57
+ end
58
+ rescue Faraday::Error => e
59
+ handle_error(e)
60
+ end
61
+
62
+ private
63
+
64
+ def tool_to_anthropic(tool)
65
+ {
66
+ name: tool.name,
67
+ description: tool.description,
68
+ input_schema: {
69
+ type: 'object',
70
+ properties: tool.parameters,
71
+ required: tool.parameters.select { |_, v| v[:required] }.keys
72
+ }
73
+ }
74
+ end
75
+
76
+ def format_messages(messages)
77
+ messages.map do |msg|
78
+ message = { role: msg.role == :user ? 'user' : 'assistant' }
79
+
80
+ message[:content] = if msg.tool_results
81
+ [
82
+ {
83
+ type: 'tool_result',
84
+ tool_use_id: msg.tool_results[:tool_use_id],
85
+ content: msg.tool_results[:content],
86
+ is_error: msg.tool_results[:is_error]
87
+ }.compact
88
+ ]
89
+ else
90
+ msg.content
91
+ end
92
+
93
+ message
94
+ end
95
+ end
96
+
97
+ def create_chat_completion(payload, tools = nil)
98
+ response = @connection.post('/v1/messages') do |req|
99
+ req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
100
+ req.headers['anthropic-version'] = '2023-06-01'
101
+ req.headers['Content-Type'] = 'application/json'
102
+ req.body = payload
103
+ end
104
+
105
+ puts 'Response from Anthropic:' if ENV['RUBY_LLM_DEBUG']
106
+ puts JSON.pretty_generate(response.body) if ENV['RUBY_LLM_DEBUG']
107
+
108
+ handle_response(response, tools, payload)
109
+ rescue Faraday::Error => e
110
+ handle_error(e)
111
+ end
112
+
113
+ def stream_chat_completion(payload, tools = nil)
114
+ response = @connection.post('/v1/messages') do |req|
115
+ req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
116
+ req.headers['anthropic-version'] = '2023-06-01'
117
+ req.body = payload
118
+ end
119
+
120
+ response.body.each_line do |line|
121
+ next if line.strip.empty?
122
+ next if line == 'data: [DONE]'
123
+
124
+ begin
125
+ data = JSON.parse(line.sub(/^data: /, ''))
126
+
127
+ if data['type'] == 'content_block_delta'
128
+ content = data['delta']['text']
129
+ yield Message.new(role: :assistant, content: content) if content
130
+ elsif data['type'] == 'tool_call'
131
+ handle_tool_calls(data['tool_calls'], tools) do |result|
132
+ yield Message.new(role: :assistant, content: result)
133
+ end
134
+ end
135
+ rescue JSON::ParserError
136
+ next
137
+ end
138
+ end
139
+ rescue Faraday::Error => e
140
+ handle_error(e)
141
+ end
142
+
143
+ def handle_response(response, tools, payload)
144
+ data = response.body
145
+ return Message.new(role: :assistant, content: '') if data['type'] == 'error'
146
+
147
+ # Extract text content and tool use from response
148
+ content_parts = data['content'] || []
149
+ text_content = content_parts.find { |c| c['type'] == 'text' }&.fetch('text', '')
150
+ tool_use = content_parts.find { |c| c['type'] == 'tool_use' }
151
+
152
+ if tool_use && tools
153
+ tool = tools.find { |t| t.name == tool_use['name'] }
154
+ result = if tool
155
+ begin
156
+ tool_result = tool.call(tool_use['input'] || {})
157
+ {
158
+ tool_use_id: tool_use['id'],
159
+ content: tool_result.to_s
160
+ }
161
+ rescue StandardError => e
162
+ {
163
+ tool_use_id: tool_use['id'],
164
+ content: "Error executing tool #{tool.name}: #{e.message}",
165
+ is_error: true
166
+ }
167
+ end
168
+ end
169
+
170
+ # Create a new message with the tool result
171
+ new_messages = payload[:messages] + [
172
+ { role: 'assistant', content: data['content'] },
173
+ {
174
+ role: 'user',
175
+ content: [
176
+ {
177
+ type: 'tool_result',
178
+ tool_use_id: result[:tool_use_id],
179
+ content: result[:content],
180
+ is_error: result[:is_error]
181
+ }.compact
182
+ ]
183
+ }
184
+ ]
185
+
186
+ return create_chat_completion(payload.merge(messages: new_messages), tools)
187
+ end
188
+
189
+ # Extract token usage from response
190
+ token_usage = if data['usage']
191
+ {
192
+ input_tokens: data['usage']['input_tokens'],
193
+ output_tokens: data['usage']['output_tokens'],
194
+ total_tokens: data['usage']['input_tokens'] + data['usage']['output_tokens']
195
+ }
196
+ end
197
+
198
+ Message.new(
199
+ role: :assistant,
200
+ content: text_content,
201
+ token_usage: token_usage,
202
+ model_id: data['model']
203
+ )
204
+ end
205
+
206
+ def handle_tool_calls(tool_calls, tools)
207
+ return [] unless tool_calls && tools
208
+
209
+ tool_calls.map do |tool_call|
210
+ tool = tools.find { |t| t.name == tool_call['name'] }
211
+ next unless tool
212
+
213
+ begin
214
+ args = JSON.parse(tool_call['arguments'])
215
+ result = tool.call(args)
216
+ puts "Tool result: #{result}" if ENV['RUBY_LLM_DEBUG']
217
+ {
218
+ tool_use_id: tool_call['id'],
219
+ content: result.to_s
220
+ }
221
+ rescue JSON::ParserError, ArgumentError => e
222
+ puts "Error executing tool: #{e.message}" if ENV['RUBY_LLM_DEBUG']
223
+ {
224
+ tool_use_id: tool_call['id'],
225
+ content: "Error executing tool #{tool.name}: #{e.message}",
226
+ is_error: true
227
+ }
228
+ end
229
+ end.compact
230
+ end
231
+
232
+ def handle_api_error(error)
233
+ response_body = error.response[:body]
234
+ if response_body.is_a?(String)
235
+ begin
236
+ error_data = JSON.parse(response_body)
237
+ message = error_data.dig('error', 'message')
238
+ raise RubyLLM::Error, "API error: #{message}" if message
239
+ rescue JSON::ParserError
240
+ raise RubyLLM::Error, "API error: #{error.response[:status]}"
241
+ end
242
+ elsif response_body.dig('error', 'type') == 'invalid_request_error'
243
+ raise RubyLLM::Error, "API error: #{response_body['error']['message']}"
244
+ else
245
+ raise RubyLLM::Error, "API error: #{error.response[:status]}"
246
+ end
247
+ end
248
+
249
+ def api_base
250
+ 'https://api.anthropic.com'
251
+ end
252
+ end
253
+ end
254
+ end
@@ -1,6 +1,11 @@
1
+ # frozen_string_literal: true
2
+
1
3
  module RubyLLM
2
4
  module Providers
5
+ # Base provider class for LLM interactions
3
6
  class Base
7
+ attr_reader :connection
8
+
4
9
  def initialize
5
10
  @connection = build_connection
6
11
  end
@@ -23,9 +28,9 @@ module RubyLLM
23
28
  def handle_error(error)
24
29
  case error
25
30
  when Faraday::TimeoutError
26
- raise RubyLLM::Error, "Request timed out"
31
+ raise RubyLLM::Error, 'Request timed out'
27
32
  when Faraday::ConnectionFailed
28
- raise RubyLLM::Error, "Connection failed"
33
+ raise RubyLLM::Error, 'Connection failed'
29
34
  when Faraday::ClientError
30
35
  handle_api_error(error)
31
36
  else
@@ -36,6 +41,20 @@ module RubyLLM
36
41
  def handle_api_error(error)
37
42
  raise RubyLLM::Error, "API error: #{error.response[:status]}"
38
43
  end
44
+
45
+ def parse_error_message(response)
46
+ return "HTTP #{response.status}" unless response.body
47
+
48
+ if response.body.is_a?(String)
49
+ begin
50
+ JSON.parse(response.body).dig('error', 'message')
51
+ rescue StandardError
52
+ "HTTP #{response.status}"
53
+ end
54
+ else
55
+ response.body.dig('error', 'message') || "HTTP #{response.status}"
56
+ end
57
+ end
39
58
  end
40
59
  end
41
60
  end
@@ -0,0 +1,189 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module Providers
5
+ class OpenAI < Base
6
+ def chat(messages, model: nil, temperature: 0.7, stream: false, tools: nil, &block)
7
+ payload = {
8
+ model: model || RubyLLM.configuration.default_model,
9
+ messages: messages.map(&:to_h),
10
+ temperature: temperature,
11
+ stream: stream
12
+ }
13
+
14
+ if tools&.any?
15
+ payload[:functions] = tools.map { |tool| tool_to_function(tool) }
16
+ payload[:function_call] = 'auto'
17
+ end
18
+
19
+ puts 'Sending payload to OpenAI:' if ENV['RUBY_LLM_DEBUG']
20
+ puts JSON.pretty_generate(payload) if ENV['RUBY_LLM_DEBUG']
21
+
22
+ if stream && block_given?
23
+ stream_chat_completion(payload, tools, &block)
24
+ else
25
+ create_chat_completion(payload, tools)
26
+ end
27
+ rescue Faraday::TimeoutError
28
+ raise RubyLLM::Error, 'Request timed out'
29
+ rescue Faraday::ConnectionFailed
30
+ raise RubyLLM::Error, 'Connection failed'
31
+ rescue Faraday::ClientError => e
32
+ raise RubyLLM::Error, 'Client error' unless e.response
33
+
34
+ error_msg = e.response[:body]['error']&.fetch('message', nil) || "HTTP #{e.response[:status]}"
35
+ raise RubyLLM::Error, "API error: #{error_msg}"
36
+ end
37
+
38
+ def list_models
39
+ response = @connection.get('/v1/models') do |req|
40
+ req.headers['Authorization'] = "Bearer #{RubyLLM.configuration.openai_api_key}"
41
+ end
42
+
43
+ raise RubyLLM::Error, "API error: #{parse_error_message(response)}" if response.status >= 400
44
+
45
+ capabilities = RubyLLM::ModelCapabilities::OpenAI.new
46
+ (response.body['data'] || []).map do |model|
47
+ ModelInfo.new(
48
+ id: model['id'],
49
+ created_at: Time.at(model['created']),
50
+ display_name: capabilities.format_display_name(model['id']),
51
+ provider: 'openai',
52
+ metadata: {
53
+ object: model['object'],
54
+ owned_by: model['owned_by']
55
+ },
56
+ context_window: capabilities.determine_context_window(model['id']),
57
+ max_tokens: capabilities.determine_max_tokens(model['id']),
58
+ supports_vision: capabilities.supports_vision?(model['id']),
59
+ supports_functions: capabilities.supports_functions?(model['id']),
60
+ supports_json_mode: capabilities.supports_json_mode?(model['id']),
61
+ input_price_per_million: capabilities.get_input_price(model['id']),
62
+ output_price_per_million: capabilities.get_output_price(model['id'])
63
+ )
64
+ end
65
+ rescue Faraday::Error => e
66
+ handle_error(e)
67
+ end
68
+
69
+ private
70
+
71
+ def tool_to_function(tool)
72
+ {
73
+ name: tool.name,
74
+ description: tool.description,
75
+ parameters: {
76
+ type: 'object',
77
+ properties: tool.parameters.transform_values { |v| v.reject { |k, _| k == :required } },
78
+ required: tool.parameters.select { |_, v| v[:required] }.keys
79
+ }
80
+ }
81
+ end
82
+
83
+ def create_chat_completion(payload, tools = nil)
84
+ response = connection.post('/v1/chat/completions') do |req|
85
+ req.headers['Authorization'] = "Bearer #{RubyLLM.configuration.openai_api_key}"
86
+ req.headers['Content-Type'] = 'application/json'
87
+ req.body = payload
88
+ end
89
+
90
+ puts 'Response from OpenAI:' if ENV['RUBY_LLM_DEBUG']
91
+ puts JSON.pretty_generate(response.body) if ENV['RUBY_LLM_DEBUG']
92
+
93
+ if response.status >= 400
94
+ error_msg = response.body['error']&.fetch('message', nil) || "HTTP #{response.status}"
95
+ raise RubyLLM::Error, "API error: #{error_msg}"
96
+ end
97
+
98
+ handle_response(response, tools, payload)
99
+ end
100
+
101
+ def handle_response(response, tools, payload)
102
+ data = response.body
103
+ message_data = data.dig('choices', 0, 'message')
104
+ return Message.new(role: :assistant, content: '') unless message_data
105
+
106
+ if message_data['function_call'] && tools
107
+ result = handle_function_call(message_data['function_call'], tools)
108
+ puts "Function result: #{result}" if ENV['RUBY_LLM_DEBUG']
109
+
110
+ # Create a new chat completion with the function results
111
+ new_messages = payload[:messages] + [
112
+ { role: 'assistant', content: message_data['content'], function_call: message_data['function_call'] },
113
+ { role: 'function', name: message_data['function_call']['name'], content: result }
114
+ ]
115
+
116
+ return create_chat_completion(payload.merge(messages: new_messages), tools)
117
+ end
118
+
119
+ # Extract token usage from response
120
+ token_usage = if data['usage']
121
+ {
122
+ input_tokens: data['usage']['prompt_tokens'],
123
+ output_tokens: data['usage']['completion_tokens'],
124
+ total_tokens: data['usage']['total_tokens']
125
+ }
126
+ end
127
+
128
+ Message.new(
129
+ role: :assistant,
130
+ content: message_data['content'],
131
+ token_usage: token_usage,
132
+ model_id: data['model']
133
+ )
134
+ end
135
+
136
+ def handle_function_call(function_call, tools)
137
+ return unless function_call && tools
138
+
139
+ tool = tools.find { |t| t.name == function_call['name'] }
140
+ return unless tool
141
+
142
+ begin
143
+ args = JSON.parse(function_call['arguments'])
144
+ tool.call(args)
145
+ rescue JSON::ParserError, ArgumentError => e
146
+ "Error executing function #{tool.name}: #{e.message}"
147
+ end
148
+ end
149
+
150
+ def handle_error(error)
151
+ case error
152
+ when Faraday::TimeoutError
153
+ raise RubyLLM::Error, 'Request timed out'
154
+ when Faraday::ConnectionFailed
155
+ raise RubyLLM::Error, 'Connection failed'
156
+ when Faraday::ClientError
157
+ raise RubyLLM::Error, 'Client error' unless error.response
158
+
159
+ error_msg = error.response[:body]['error']&.fetch('message', nil) || "HTTP #{error.response[:status]}"
160
+ raise RubyLLM::Error, "API error: #{error_msg}"
161
+
162
+ else
163
+ raise error
164
+ end
165
+ end
166
+
167
+ def handle_api_error(error)
168
+ response_body = error.response[:body]
169
+ if response_body.is_a?(String)
170
+ begin
171
+ error_data = JSON.parse(response_body)
172
+ message = error_data.dig('error', 'message')
173
+ raise RubyLLM::Error, "API error: #{message}" if message
174
+ rescue JSON::ParserError
175
+ raise RubyLLM::Error, "API error: #{error.response[:status]}"
176
+ end
177
+ elsif response_body['error']
178
+ raise RubyLLM::Error, "API error: #{response_body['error']['message']}"
179
+ else
180
+ raise RubyLLM::Error, "API error: #{error.response[:status]}"
181
+ end
182
+ end
183
+
184
+ def api_base
185
+ 'https://api.openai.com'
186
+ end
187
+ end
188
+ end
189
+ end
@@ -1,4 +1,7 @@
1
+ # frozen_string_literal: true
2
+
1
3
  module RubyLLM
4
+ # Rails integration for RubyLLM
2
5
  class Railtie < Rails::Railtie
3
6
  initializer 'ruby_llm.initialize' do
4
7
  ActiveSupport.on_load(:active_record) do