ruby_llm 0.1.0.pre → 0.1.0.pre2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.github/workflows/gem-push.yml +9 -3
- data/.github/workflows/test.yml +32 -0
- data/.gitignore +58 -0
- data/.overcommit.yml +26 -0
- data/.rspec +3 -0
- data/.rubocop.yml +3 -0
- data/Gemfile +5 -0
- data/README.md +68 -13
- data/Rakefile +4 -2
- data/bin/console +6 -3
- data/lib/ruby_llm/active_record/acts_as.rb +31 -18
- data/lib/ruby_llm/client.rb +32 -16
- data/lib/ruby_llm/configuration.rb +5 -3
- data/lib/ruby_llm/conversation.rb +3 -0
- data/lib/ruby_llm/message.rb +6 -3
- data/lib/ruby_llm/model_capabilities/anthropic.rb +81 -0
- data/lib/ruby_llm/model_capabilities/base.rb +35 -0
- data/lib/ruby_llm/model_capabilities/openai.rb +121 -0
- data/lib/ruby_llm/model_info.rb +42 -0
- data/lib/ruby_llm/providers/anthropic.rb +226 -0
- data/lib/ruby_llm/providers/base.rb +21 -2
- data/lib/ruby_llm/providers/openai.rb +161 -0
- data/lib/ruby_llm/railtie.rb +3 -0
- data/lib/ruby_llm/tool.rb +75 -0
- data/lib/ruby_llm/version.rb +3 -1
- data/lib/ruby_llm.rb +35 -3
- data/ruby_llm.gemspec +42 -31
- metadata +142 -17
@@ -0,0 +1,121 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module RubyLLM
|
4
|
+
module ModelCapabilities
|
5
|
+
class OpenAI < Base
|
6
|
+
def determine_context_window(model_id)
|
7
|
+
case model_id
|
8
|
+
when /gpt-4o/, /o1/, /gpt-4-turbo/
|
9
|
+
128_000
|
10
|
+
when /gpt-4-0[0-9]{3}/
|
11
|
+
8_192
|
12
|
+
when /gpt-3.5-turbo-instruct/
|
13
|
+
4_096
|
14
|
+
when /gpt-3.5/
|
15
|
+
16_385
|
16
|
+
else
|
17
|
+
4_096
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
def determine_max_tokens(model_id)
|
22
|
+
case model_id
|
23
|
+
when /o1-2024-12-17/
|
24
|
+
100_000
|
25
|
+
when /o1-mini-2024-09-12/
|
26
|
+
65_536
|
27
|
+
when /o1-preview-2024-09-12/
|
28
|
+
32_768
|
29
|
+
when /gpt-4o/, /gpt-4-turbo/
|
30
|
+
16_384
|
31
|
+
when /gpt-4-0[0-9]{3}/
|
32
|
+
8_192
|
33
|
+
when /gpt-3.5-turbo/
|
34
|
+
4_096
|
35
|
+
else
|
36
|
+
4_096
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
def get_input_price(model_id)
|
41
|
+
case model_id
|
42
|
+
when /o1-2024/
|
43
|
+
15.0 # $15.00 per million tokens
|
44
|
+
when /o1-mini/
|
45
|
+
3.0 # $3.00 per million tokens
|
46
|
+
when /gpt-4o-realtime-preview/
|
47
|
+
5.0 # $5.00 per million tokens
|
48
|
+
when /gpt-4o-mini-realtime-preview/
|
49
|
+
0.60 # $0.60 per million tokens
|
50
|
+
when /gpt-4o-mini/
|
51
|
+
0.15 # $0.15 per million tokens
|
52
|
+
when /gpt-4o/
|
53
|
+
2.50 # $2.50 per million tokens
|
54
|
+
when /gpt-4-turbo/
|
55
|
+
10.0 # $10.00 per million tokens
|
56
|
+
when /gpt-3.5/
|
57
|
+
0.50 # $0.50 per million tokens
|
58
|
+
else
|
59
|
+
0.50 # Default to GPT-3.5 pricing
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
def get_output_price(model_id)
|
64
|
+
case model_id
|
65
|
+
when /o1-2024/
|
66
|
+
60.0 # $60.00 per million tokens
|
67
|
+
when /o1-mini/
|
68
|
+
12.0 # $12.00 per million tokens
|
69
|
+
when /gpt-4o-realtime-preview/
|
70
|
+
20.0 # $20.00 per million tokens
|
71
|
+
when /gpt-4o-mini-realtime-preview/
|
72
|
+
2.40 # $2.40 per million tokens
|
73
|
+
when /gpt-4o-mini/
|
74
|
+
0.60 # $0.60 per million tokens
|
75
|
+
when /gpt-4o/
|
76
|
+
10.0 # $10.00 per million tokens
|
77
|
+
when /gpt-4-turbo/
|
78
|
+
30.0 # $30.00 per million tokens
|
79
|
+
when /gpt-3.5/
|
80
|
+
1.50 # $1.50 per million tokens
|
81
|
+
else
|
82
|
+
1.50 # Default to GPT-3.5 pricing
|
83
|
+
end
|
84
|
+
end
|
85
|
+
|
86
|
+
def supports_functions?(model_id)
|
87
|
+
!model_id.include?('instruct')
|
88
|
+
end
|
89
|
+
|
90
|
+
def supports_vision?(model_id)
|
91
|
+
model_id.include?('vision') || model_id.match?(/gpt-4-(?!0314|0613)/)
|
92
|
+
end
|
93
|
+
|
94
|
+
def supports_json_mode?(model_id)
|
95
|
+
model_id.match?(/gpt-4-\d{4}-preview/) ||
|
96
|
+
model_id.include?('turbo') ||
|
97
|
+
model_id.match?(/gpt-3.5-turbo-(?!0301|0613)/)
|
98
|
+
end
|
99
|
+
|
100
|
+
def format_display_name(model_id)
|
101
|
+
# First replace hyphens with spaces
|
102
|
+
name = model_id.tr('-', ' ')
|
103
|
+
|
104
|
+
# Capitalize each word
|
105
|
+
name = name.split(' ').map { |word| word.capitalize }.join(' ')
|
106
|
+
|
107
|
+
# Apply specific formatting rules
|
108
|
+
name.gsub(/(\d{4}) (\d{2}) (\d{2})/, '\1\2\3') # Convert dates to YYYYMMDD
|
109
|
+
.gsub(/^Gpt /, 'GPT-')
|
110
|
+
.gsub(/^O1 /, 'O1-')
|
111
|
+
.gsub(/^Chatgpt /, 'ChatGPT-')
|
112
|
+
.gsub(/^Tts /, 'TTS-')
|
113
|
+
.gsub(/^Dall E /, 'DALL-E-')
|
114
|
+
.gsub(/3\.5 /, '3.5-')
|
115
|
+
.gsub(/4 /, '4-')
|
116
|
+
.gsub(/4o (?=Mini|Preview|Turbo)/, '4o-')
|
117
|
+
.gsub(/\bHd\b/, 'HD')
|
118
|
+
end
|
119
|
+
end
|
120
|
+
end
|
121
|
+
end
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module RubyLLM
|
4
|
+
class ModelInfo
|
5
|
+
attr_reader :id, :created_at, :display_name, :provider, :metadata,
|
6
|
+
:context_window, :max_tokens, :supports_vision, :supports_functions,
|
7
|
+
:supports_json_mode, :input_price_per_million, :output_price_per_million
|
8
|
+
|
9
|
+
def initialize(id:, created_at:, display_name:, provider:, context_window:, max_tokens:, supports_vision:,
|
10
|
+
supports_functions:, supports_json_mode:, input_price_per_million:, output_price_per_million:, metadata: {})
|
11
|
+
@id = id
|
12
|
+
@created_at = created_at
|
13
|
+
@display_name = display_name
|
14
|
+
@provider = provider
|
15
|
+
@metadata = metadata
|
16
|
+
@context_window = context_window
|
17
|
+
@max_tokens = max_tokens
|
18
|
+
@supports_vision = supports_vision
|
19
|
+
@supports_functions = supports_functions
|
20
|
+
@supports_json_mode = supports_json_mode
|
21
|
+
@input_price_per_million = input_price_per_million
|
22
|
+
@output_price_per_million = output_price_per_million
|
23
|
+
end
|
24
|
+
|
25
|
+
def to_h
|
26
|
+
{
|
27
|
+
id: id,
|
28
|
+
created_at: created_at,
|
29
|
+
display_name: display_name,
|
30
|
+
provider: provider,
|
31
|
+
metadata: metadata,
|
32
|
+
context_window: context_window,
|
33
|
+
max_tokens: max_tokens,
|
34
|
+
supports_vision: supports_vision,
|
35
|
+
supports_functions: supports_functions,
|
36
|
+
supports_json_mode: supports_json_mode,
|
37
|
+
input_price_per_million: input_price_per_million,
|
38
|
+
output_price_per_million: output_price_per_million
|
39
|
+
}
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
@@ -0,0 +1,226 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'time'
|
4
|
+
|
5
|
+
module RubyLLM
|
6
|
+
module Providers
|
7
|
+
class Anthropic < Base
|
8
|
+
def chat(messages, model: nil, temperature: 0.7, stream: false, tools: nil, &block)
|
9
|
+
payload = {
|
10
|
+
model: model || 'claude-3-5-sonnet-20241022',
|
11
|
+
messages: format_messages(messages),
|
12
|
+
temperature: temperature,
|
13
|
+
stream: stream,
|
14
|
+
max_tokens: 4096
|
15
|
+
}
|
16
|
+
|
17
|
+
payload[:tools] = tools.map { |tool| tool_to_anthropic(tool) } if tools&.any?
|
18
|
+
|
19
|
+
puts 'Sending payload to Anthropic:' if ENV['RUBY_LLM_DEBUG']
|
20
|
+
puts JSON.pretty_generate(payload) if ENV['RUBY_LLM_DEBUG']
|
21
|
+
|
22
|
+
if stream && block_given?
|
23
|
+
stream_chat_completion(payload, tools, &block)
|
24
|
+
else
|
25
|
+
create_chat_completion(payload, tools)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def list_models
|
30
|
+
response = @connection.get('/v1/models') do |req|
|
31
|
+
req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
|
32
|
+
req.headers['anthropic-version'] = '2023-06-01'
|
33
|
+
end
|
34
|
+
|
35
|
+
raise RubyLLM::Error, "API error: #{parse_error_message(response)}" if response.status >= 400
|
36
|
+
|
37
|
+
capabilities = RubyLLM::ModelCapabilities::Anthropic.new
|
38
|
+
models_data = response.body['data'] || []
|
39
|
+
|
40
|
+
models_data.map do |model|
|
41
|
+
ModelInfo.new(
|
42
|
+
id: model['id'],
|
43
|
+
created_at: Time.parse(model['created_at']),
|
44
|
+
display_name: model['display_name'],
|
45
|
+
provider: 'anthropic',
|
46
|
+
metadata: {
|
47
|
+
type: model['type']
|
48
|
+
},
|
49
|
+
context_window: capabilities.determine_context_window(model['id']),
|
50
|
+
max_tokens: capabilities.determine_max_tokens(model['id']),
|
51
|
+
supports_vision: capabilities.supports_vision?(model['id']),
|
52
|
+
supports_functions: capabilities.supports_functions?(model['id']),
|
53
|
+
supports_json_mode: capabilities.supports_json_mode?(model['id']),
|
54
|
+
input_price_per_million: capabilities.get_input_price(model['id']),
|
55
|
+
output_price_per_million: capabilities.get_output_price(model['id'])
|
56
|
+
)
|
57
|
+
end
|
58
|
+
rescue Faraday::Error => e
|
59
|
+
handle_error(e)
|
60
|
+
end
|
61
|
+
|
62
|
+
private
|
63
|
+
|
64
|
+
def tool_to_anthropic(tool)
|
65
|
+
{
|
66
|
+
name: tool.name,
|
67
|
+
description: tool.description,
|
68
|
+
input_schema: {
|
69
|
+
type: 'object',
|
70
|
+
properties: tool.parameters,
|
71
|
+
required: tool.parameters.select { |_, v| v[:required] }.keys
|
72
|
+
}
|
73
|
+
}
|
74
|
+
end
|
75
|
+
|
76
|
+
def format_messages(messages)
|
77
|
+
messages.map do |msg|
|
78
|
+
message = { role: msg.role == :user ? 'user' : 'assistant' }
|
79
|
+
|
80
|
+
message[:content] = if msg.tool_results
|
81
|
+
[
|
82
|
+
{
|
83
|
+
type: 'tool_result',
|
84
|
+
tool_use_id: msg.tool_results[:tool_use_id],
|
85
|
+
content: msg.tool_results[:content],
|
86
|
+
is_error: msg.tool_results[:is_error]
|
87
|
+
}.compact
|
88
|
+
]
|
89
|
+
else
|
90
|
+
msg.content
|
91
|
+
end
|
92
|
+
|
93
|
+
message
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
def create_chat_completion(payload, tools = nil)
|
98
|
+
response = @connection.post('/v1/messages') do |req|
|
99
|
+
req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
|
100
|
+
req.headers['anthropic-version'] = '2023-06-01'
|
101
|
+
req.headers['Content-Type'] = 'application/json'
|
102
|
+
req.body = payload
|
103
|
+
end
|
104
|
+
|
105
|
+
puts 'Response from Anthropic:' if ENV['RUBY_LLM_DEBUG']
|
106
|
+
puts JSON.pretty_generate(response.body) if ENV['RUBY_LLM_DEBUG']
|
107
|
+
|
108
|
+
handle_response(response, tools, payload)
|
109
|
+
rescue Faraday::Error => e
|
110
|
+
handle_error(e)
|
111
|
+
end
|
112
|
+
|
113
|
+
def stream_chat_completion(payload, tools = nil)
|
114
|
+
response = @connection.post('/v1/messages') do |req|
|
115
|
+
req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
|
116
|
+
req.headers['anthropic-version'] = '2023-06-01'
|
117
|
+
req.body = payload
|
118
|
+
end
|
119
|
+
|
120
|
+
response.body.each_line do |line|
|
121
|
+
next if line.strip.empty?
|
122
|
+
next if line == 'data: [DONE]'
|
123
|
+
|
124
|
+
begin
|
125
|
+
data = JSON.parse(line.sub(/^data: /, ''))
|
126
|
+
|
127
|
+
if data['type'] == 'content_block_delta'
|
128
|
+
content = data['delta']['text']
|
129
|
+
yield Message.new(role: :assistant, content: content) if content
|
130
|
+
elsif data['type'] == 'tool_call'
|
131
|
+
handle_tool_calls(data['tool_calls'], tools) do |result|
|
132
|
+
yield Message.new(role: :assistant, content: result)
|
133
|
+
end
|
134
|
+
end
|
135
|
+
rescue JSON::ParserError
|
136
|
+
next
|
137
|
+
end
|
138
|
+
end
|
139
|
+
rescue Faraday::Error => e
|
140
|
+
handle_error(e)
|
141
|
+
end
|
142
|
+
|
143
|
+
def handle_response(response, tools, payload)
|
144
|
+
data = response.body
|
145
|
+
return Message.new(role: :assistant, content: '') if data['type'] == 'error'
|
146
|
+
|
147
|
+
# Extract text content and tool use from response
|
148
|
+
content_parts = data['content'] || []
|
149
|
+
text_content = content_parts.find { |c| c['type'] == 'text' }&.fetch('text', '')
|
150
|
+
tool_use = content_parts.find { |c| c['type'] == 'tool_use' }
|
151
|
+
|
152
|
+
if tool_use && tools
|
153
|
+
tool = tools.find { |t| t.name == tool_use['name'] }
|
154
|
+
result = if tool
|
155
|
+
begin
|
156
|
+
tool_result = tool.call(tool_use['input'] || {})
|
157
|
+
{
|
158
|
+
tool_use_id: tool_use['id'],
|
159
|
+
content: tool_result.to_s
|
160
|
+
}
|
161
|
+
rescue StandardError => e
|
162
|
+
{
|
163
|
+
tool_use_id: tool_use['id'],
|
164
|
+
content: "Error executing tool #{tool.name}: #{e.message}",
|
165
|
+
is_error: true
|
166
|
+
}
|
167
|
+
end
|
168
|
+
end
|
169
|
+
|
170
|
+
# Create a new message with the tool result
|
171
|
+
new_messages = payload[:messages] + [
|
172
|
+
{ role: 'assistant', content: data['content'] },
|
173
|
+
{
|
174
|
+
role: 'user',
|
175
|
+
content: [
|
176
|
+
{
|
177
|
+
type: 'tool_result',
|
178
|
+
tool_use_id: result[:tool_use_id],
|
179
|
+
content: result[:content],
|
180
|
+
is_error: result[:is_error]
|
181
|
+
}.compact
|
182
|
+
]
|
183
|
+
}
|
184
|
+
]
|
185
|
+
|
186
|
+
return create_chat_completion(payload.merge(messages: new_messages), tools)
|
187
|
+
end
|
188
|
+
|
189
|
+
Message.new(
|
190
|
+
role: :assistant,
|
191
|
+
content: text_content
|
192
|
+
)
|
193
|
+
end
|
194
|
+
|
195
|
+
def handle_tool_calls(tool_calls, tools)
|
196
|
+
return [] unless tool_calls && tools
|
197
|
+
|
198
|
+
tool_calls.map do |tool_call|
|
199
|
+
tool = tools.find { |t| t.name == tool_call['name'] }
|
200
|
+
next unless tool
|
201
|
+
|
202
|
+
begin
|
203
|
+
args = JSON.parse(tool_call['arguments'])
|
204
|
+
result = tool.call(args)
|
205
|
+
puts "Tool result: #{result}" if ENV['RUBY_LLM_DEBUG']
|
206
|
+
{
|
207
|
+
tool_use_id: tool_call['id'],
|
208
|
+
content: result.to_s
|
209
|
+
}
|
210
|
+
rescue JSON::ParserError, ArgumentError => e
|
211
|
+
puts "Error executing tool: #{e.message}" if ENV['RUBY_LLM_DEBUG']
|
212
|
+
{
|
213
|
+
tool_use_id: tool_call['id'],
|
214
|
+
content: "Error executing tool #{tool.name}: #{e.message}",
|
215
|
+
is_error: true
|
216
|
+
}
|
217
|
+
end
|
218
|
+
end.compact
|
219
|
+
end
|
220
|
+
|
221
|
+
def api_base
|
222
|
+
'https://api.anthropic.com'
|
223
|
+
end
|
224
|
+
end
|
225
|
+
end
|
226
|
+
end
|
@@ -1,6 +1,11 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
module RubyLLM
|
2
4
|
module Providers
|
5
|
+
# Base provider class for LLM interactions
|
3
6
|
class Base
|
7
|
+
attr_reader :connection
|
8
|
+
|
4
9
|
def initialize
|
5
10
|
@connection = build_connection
|
6
11
|
end
|
@@ -23,9 +28,9 @@ module RubyLLM
|
|
23
28
|
def handle_error(error)
|
24
29
|
case error
|
25
30
|
when Faraday::TimeoutError
|
26
|
-
raise RubyLLM::Error,
|
31
|
+
raise RubyLLM::Error, 'Request timed out'
|
27
32
|
when Faraday::ConnectionFailed
|
28
|
-
raise RubyLLM::Error,
|
33
|
+
raise RubyLLM::Error, 'Connection failed'
|
29
34
|
when Faraday::ClientError
|
30
35
|
handle_api_error(error)
|
31
36
|
else
|
@@ -36,6 +41,20 @@ module RubyLLM
|
|
36
41
|
def handle_api_error(error)
|
37
42
|
raise RubyLLM::Error, "API error: #{error.response[:status]}"
|
38
43
|
end
|
44
|
+
|
45
|
+
def parse_error_message(response)
|
46
|
+
return "HTTP #{response.status}" unless response.body
|
47
|
+
|
48
|
+
if response.body.is_a?(String)
|
49
|
+
begin
|
50
|
+
JSON.parse(response.body).dig('error', 'message')
|
51
|
+
rescue StandardError
|
52
|
+
"HTTP #{response.status}"
|
53
|
+
end
|
54
|
+
else
|
55
|
+
response.body.dig('error', 'message') || "HTTP #{response.status}"
|
56
|
+
end
|
57
|
+
end
|
39
58
|
end
|
40
59
|
end
|
41
60
|
end
|
@@ -0,0 +1,161 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module RubyLLM
|
4
|
+
module Providers
|
5
|
+
class OpenAI < Base
|
6
|
+
def chat(messages, model: nil, temperature: 0.7, stream: false, tools: nil, &block)
|
7
|
+
payload = {
|
8
|
+
model: model || RubyLLM.configuration.default_model,
|
9
|
+
messages: messages.map(&:to_h),
|
10
|
+
temperature: temperature,
|
11
|
+
stream: stream
|
12
|
+
}
|
13
|
+
|
14
|
+
if tools&.any?
|
15
|
+
payload[:functions] = tools.map { |tool| tool_to_function(tool) }
|
16
|
+
payload[:function_call] = 'auto'
|
17
|
+
end
|
18
|
+
|
19
|
+
puts 'Sending payload to OpenAI:' if ENV['RUBY_LLM_DEBUG']
|
20
|
+
puts JSON.pretty_generate(payload) if ENV['RUBY_LLM_DEBUG']
|
21
|
+
|
22
|
+
if stream && block_given?
|
23
|
+
stream_chat_completion(payload, tools, &block)
|
24
|
+
else
|
25
|
+
create_chat_completion(payload, tools)
|
26
|
+
end
|
27
|
+
rescue Faraday::TimeoutError
|
28
|
+
raise RubyLLM::Error, 'Request timed out'
|
29
|
+
rescue Faraday::ConnectionFailed
|
30
|
+
raise RubyLLM::Error, 'Connection failed'
|
31
|
+
rescue Faraday::ClientError => e
|
32
|
+
raise RubyLLM::Error, 'Client error' unless e.response
|
33
|
+
|
34
|
+
error_msg = e.response[:body]['error']&.fetch('message', nil) || "HTTP #{e.response[:status]}"
|
35
|
+
raise RubyLLM::Error, "API error: #{error_msg}"
|
36
|
+
end
|
37
|
+
|
38
|
+
def list_models
|
39
|
+
response = @connection.get('/v1/models') do |req|
|
40
|
+
req.headers['Authorization'] = "Bearer #{RubyLLM.configuration.openai_api_key}"
|
41
|
+
end
|
42
|
+
|
43
|
+
raise RubyLLM::Error, "API error: #{parse_error_message(response)}" if response.status >= 400
|
44
|
+
|
45
|
+
capabilities = RubyLLM::ModelCapabilities::OpenAI.new
|
46
|
+
(response.body['data'] || []).map do |model|
|
47
|
+
ModelInfo.new(
|
48
|
+
id: model['id'],
|
49
|
+
created_at: Time.at(model['created']),
|
50
|
+
display_name: capabilities.format_display_name(model['id']),
|
51
|
+
provider: 'openai',
|
52
|
+
metadata: {
|
53
|
+
object: model['object'],
|
54
|
+
owned_by: model['owned_by']
|
55
|
+
},
|
56
|
+
context_window: capabilities.determine_context_window(model['id']),
|
57
|
+
max_tokens: capabilities.determine_max_tokens(model['id']),
|
58
|
+
supports_vision: capabilities.supports_vision?(model['id']),
|
59
|
+
supports_functions: capabilities.supports_functions?(model['id']),
|
60
|
+
supports_json_mode: capabilities.supports_json_mode?(model['id']),
|
61
|
+
input_price_per_million: capabilities.get_input_price(model['id']),
|
62
|
+
output_price_per_million: capabilities.get_output_price(model['id'])
|
63
|
+
)
|
64
|
+
end
|
65
|
+
rescue Faraday::Error => e
|
66
|
+
handle_error(e)
|
67
|
+
end
|
68
|
+
|
69
|
+
private
|
70
|
+
|
71
|
+
def tool_to_function(tool)
|
72
|
+
{
|
73
|
+
name: tool.name,
|
74
|
+
description: tool.description,
|
75
|
+
parameters: {
|
76
|
+
type: 'object',
|
77
|
+
properties: tool.parameters.transform_values { |v| v.reject { |k, _| k == :required } },
|
78
|
+
required: tool.parameters.select { |_, v| v[:required] }.keys
|
79
|
+
}
|
80
|
+
}
|
81
|
+
end
|
82
|
+
|
83
|
+
def create_chat_completion(payload, tools = nil)
|
84
|
+
response = connection.post('/v1/chat/completions') do |req|
|
85
|
+
req.headers['Authorization'] = "Bearer #{RubyLLM.configuration.openai_api_key}"
|
86
|
+
req.headers['Content-Type'] = 'application/json'
|
87
|
+
req.body = payload
|
88
|
+
end
|
89
|
+
|
90
|
+
puts 'Response from OpenAI:' if ENV['RUBY_LLM_DEBUG']
|
91
|
+
puts JSON.pretty_generate(response.body) if ENV['RUBY_LLM_DEBUG']
|
92
|
+
|
93
|
+
if response.status >= 400
|
94
|
+
error_msg = response.body['error']&.fetch('message', nil) || "HTTP #{response.status}"
|
95
|
+
raise RubyLLM::Error, "API error: #{error_msg}"
|
96
|
+
end
|
97
|
+
|
98
|
+
handle_response(response, tools, payload)
|
99
|
+
end
|
100
|
+
|
101
|
+
def handle_response(response, tools, payload)
|
102
|
+
data = response.body
|
103
|
+
message_data = data.dig('choices', 0, 'message')
|
104
|
+
return Message.new(role: :assistant, content: '') unless message_data
|
105
|
+
|
106
|
+
if message_data['function_call'] && tools
|
107
|
+
result = handle_function_call(message_data['function_call'], tools)
|
108
|
+
puts "Function result: #{result}" if ENV['RUBY_LLM_DEBUG']
|
109
|
+
|
110
|
+
# Create a new chat completion with the function results
|
111
|
+
new_messages = payload[:messages] + [
|
112
|
+
{ role: 'assistant', content: message_data['content'], function_call: message_data['function_call'] },
|
113
|
+
{ role: 'function', name: message_data['function_call']['name'], content: result }
|
114
|
+
]
|
115
|
+
|
116
|
+
return create_chat_completion(payload.merge(messages: new_messages), tools)
|
117
|
+
end
|
118
|
+
|
119
|
+
Message.new(
|
120
|
+
role: :assistant,
|
121
|
+
content: message_data['content']
|
122
|
+
)
|
123
|
+
end
|
124
|
+
|
125
|
+
def handle_function_call(function_call, tools)
|
126
|
+
return unless function_call && tools
|
127
|
+
|
128
|
+
tool = tools.find { |t| t.name == function_call['name'] }
|
129
|
+
return unless tool
|
130
|
+
|
131
|
+
begin
|
132
|
+
args = JSON.parse(function_call['arguments'])
|
133
|
+
tool.call(args)
|
134
|
+
rescue JSON::ParserError, ArgumentError => e
|
135
|
+
"Error executing function #{tool.name}: #{e.message}"
|
136
|
+
end
|
137
|
+
end
|
138
|
+
|
139
|
+
def handle_error(error)
|
140
|
+
case error
|
141
|
+
when Faraday::TimeoutError
|
142
|
+
raise RubyLLM::Error, 'Request timed out'
|
143
|
+
when Faraday::ConnectionFailed
|
144
|
+
raise RubyLLM::Error, 'Connection failed'
|
145
|
+
when Faraday::ClientError
|
146
|
+
raise RubyLLM::Error, 'Client error' unless error.response
|
147
|
+
|
148
|
+
error_msg = error.response[:body]['error']&.fetch('message', nil) || "HTTP #{error.response[:status]}"
|
149
|
+
raise RubyLLM::Error, "API error: #{error_msg}"
|
150
|
+
|
151
|
+
else
|
152
|
+
raise error
|
153
|
+
end
|
154
|
+
end
|
155
|
+
|
156
|
+
def api_base
|
157
|
+
'https://api.openai.com'
|
158
|
+
end
|
159
|
+
end
|
160
|
+
end
|
161
|
+
end
|
data/lib/ruby_llm/railtie.rb
CHANGED
@@ -0,0 +1,75 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module RubyLLM
|
4
|
+
# Represents a tool/function that can be called by an LLM
|
5
|
+
class Tool
|
6
|
+
attr_reader :name, :description, :parameters, :handler
|
7
|
+
|
8
|
+
def self.from_method(method_object, description: nil, parameter_descriptions: {})
|
9
|
+
method_params = {}
|
10
|
+
method_object.parameters.each do |param_type, param_name|
|
11
|
+
next unless %i[req opt key keyreq].include?(param_type)
|
12
|
+
|
13
|
+
method_params[param_name] = {
|
14
|
+
type: 'string',
|
15
|
+
description: parameter_descriptions[param_name] || param_name.to_s.tr('_', ' '),
|
16
|
+
required: %i[req keyreq].include?(param_type)
|
17
|
+
}
|
18
|
+
end
|
19
|
+
|
20
|
+
new(
|
21
|
+
name: method_object.name.to_s,
|
22
|
+
description: description || "Executes the #{method_object.name} operation",
|
23
|
+
parameters: method_params
|
24
|
+
) do |args|
|
25
|
+
# Create an instance if it's an instance method
|
26
|
+
instance = if method_object.owner.instance_methods.include?(method_object.name)
|
27
|
+
method_object.owner.new
|
28
|
+
else
|
29
|
+
method_object.owner
|
30
|
+
end
|
31
|
+
|
32
|
+
# Call the method with the arguments
|
33
|
+
if args.is_a?(Hash)
|
34
|
+
instance.method(method_object.name).call(**args)
|
35
|
+
else
|
36
|
+
instance.method(method_object.name).call(args)
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def initialize(name:, description:, parameters: {}, &block)
|
42
|
+
@name = name
|
43
|
+
@description = description
|
44
|
+
@parameters = parameters
|
45
|
+
@handler = block
|
46
|
+
|
47
|
+
validate!
|
48
|
+
end
|
49
|
+
|
50
|
+
def call(args)
|
51
|
+
validated_args = validate_args!(args)
|
52
|
+
handler.call(validated_args)
|
53
|
+
end
|
54
|
+
|
55
|
+
private
|
56
|
+
|
57
|
+
def validate!
|
58
|
+
raise ArgumentError, 'Name must be a string' unless name.is_a?(String)
|
59
|
+
raise ArgumentError, 'Description must be a string' unless description.is_a?(String)
|
60
|
+
raise ArgumentError, 'Parameters must be a hash' unless parameters.is_a?(Hash)
|
61
|
+
raise ArgumentError, 'Block must be provided' unless handler.respond_to?(:call)
|
62
|
+
end
|
63
|
+
|
64
|
+
def validate_args!(args)
|
65
|
+
args = args.transform_keys(&:to_sym)
|
66
|
+
required_params = parameters.select { |_, v| v[:required] }.keys
|
67
|
+
|
68
|
+
required_params.each do |param|
|
69
|
+
raise ArgumentError, "Missing required parameter: #{param}" unless args.key?(param.to_sym)
|
70
|
+
end
|
71
|
+
|
72
|
+
args
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
data/lib/ruby_llm/version.rb
CHANGED