ruby_llm 0.1.0.pre2 → 0.1.0.pre4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/ruby_llm/client.rb +4 -1
- data/lib/ruby_llm/message.rb +9 -5
- data/lib/ruby_llm/model_capabilities/openai.rb +1 -1
- data/lib/ruby_llm/providers/anthropic.rb +104 -21
- data/lib/ruby_llm/providers/base.rb +7 -0
- data/lib/ruby_llm/providers/openai.rb +68 -13
- data/lib/ruby_llm/version.rb +1 -1
- metadata +1 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 33195e0b579713ebc168acdf2d1b6e67e1ee4a0d5db68cacaf1648cd4d1e744e
|
4
|
+
data.tar.gz: ee20d7adaa4f8ef22853894885dc6dfab2e19aab1526fd5022cd7413653e4736
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: bef854e5f925aa246b40ef1f12e4a9cde1c085c97a996ecd4c266810b97f9ac324f6cd19dede1526215e92249cbb9aee7017a512fd309baf834aee06570be6c0
|
7
|
+
data.tar.gz: 3dc20b4b9093cca54bb1d5e033e2270a8e33146c3b3c8dcc85997ca6b01eb9e2b07e95d833938df5ceca561d4a2da75a45473dbc878ce744c93d7816fde82009
|
data/lib/ruby_llm/client.rb
CHANGED
@@ -14,7 +14,7 @@ module RubyLLM
|
|
14
14
|
end
|
15
15
|
|
16
16
|
provider = provider_for(model)
|
17
|
-
provider.chat(
|
17
|
+
response_messages = provider.chat(
|
18
18
|
formatted_messages,
|
19
19
|
model: model,
|
20
20
|
temperature: temperature,
|
@@ -22,6 +22,9 @@ module RubyLLM
|
|
22
22
|
tools: tools,
|
23
23
|
&block
|
24
24
|
)
|
25
|
+
|
26
|
+
# Always return an array of messages, even for single responses
|
27
|
+
response_messages.is_a?(Array) ? response_messages : [response_messages]
|
25
28
|
end
|
26
29
|
|
27
30
|
def list_models(provider = nil)
|
data/lib/ruby_llm/message.rb
CHANGED
@@ -1,17 +1,18 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
3
|
module RubyLLM
|
4
|
-
# Represents a message in an LLM conversation
|
5
4
|
class Message
|
6
5
|
VALID_ROLES = %i[system user assistant tool].freeze
|
7
6
|
|
8
|
-
attr_reader :role, :content, :tool_calls, :tool_results
|
7
|
+
attr_reader :role, :content, :tool_calls, :tool_results, :token_usage, :model_id
|
9
8
|
|
10
|
-
def initialize(role:, content: nil, tool_calls: nil, tool_results: nil)
|
9
|
+
def initialize(role:, content: nil, tool_calls: nil, tool_results: nil, token_usage: nil, model_id: nil)
|
11
10
|
@role = role.to_sym
|
12
11
|
@content = content
|
13
12
|
@tool_calls = tool_calls
|
14
13
|
@tool_results = tool_results
|
14
|
+
@token_usage = token_usage
|
15
|
+
@model_id = model_id
|
15
16
|
validate!
|
16
17
|
end
|
17
18
|
|
@@ -20,7 +21,9 @@ module RubyLLM
|
|
20
21
|
role: role,
|
21
22
|
content: content,
|
22
23
|
tool_calls: tool_calls,
|
23
|
-
tool_results: tool_results
|
24
|
+
tool_results: tool_results,
|
25
|
+
token_usage: token_usage,
|
26
|
+
model_id: model_id
|
24
27
|
}.compact
|
25
28
|
end
|
26
29
|
|
@@ -29,7 +32,8 @@ module RubyLLM
|
|
29
32
|
def validate!
|
30
33
|
return if VALID_ROLES.include?(role)
|
31
34
|
|
32
|
-
raise ArgumentError,
|
35
|
+
raise ArgumentError,
|
36
|
+
"Invalid role: #{role}. Must be one of: #{VALID_ROLES.join(', ')}"
|
33
37
|
end
|
34
38
|
end
|
35
39
|
end
|
@@ -102,7 +102,7 @@ module RubyLLM
|
|
102
102
|
name = model_id.tr('-', ' ')
|
103
103
|
|
104
104
|
# Capitalize each word
|
105
|
-
name = name.split(' ').map
|
105
|
+
name = name.split(' ').map(&:capitalize).join(' ')
|
106
106
|
|
107
107
|
# Apply specific formatting rules
|
108
108
|
name.gsub(/(\d{4}) (\d{2}) (\d{2})/, '\1\2\3') # Convert dates to YYYYMMDD
|
@@ -62,13 +62,24 @@ module RubyLLM
|
|
62
62
|
private
|
63
63
|
|
64
64
|
def tool_to_anthropic(tool)
|
65
|
+
# Get required fields and clean properties
|
66
|
+
required_fields = []
|
67
|
+
cleaned_properties = {}
|
68
|
+
|
69
|
+
tool.parameters.each do |name, props|
|
70
|
+
required_fields << name.to_s if props[:required]
|
71
|
+
cleaned_props = props.dup
|
72
|
+
cleaned_props.delete(:required)
|
73
|
+
cleaned_properties[name] = cleaned_props
|
74
|
+
end
|
75
|
+
|
65
76
|
{
|
66
77
|
name: tool.name,
|
67
78
|
description: tool.description,
|
68
79
|
input_schema: {
|
69
80
|
type: 'object',
|
70
|
-
properties:
|
71
|
-
required:
|
81
|
+
properties: cleaned_properties,
|
82
|
+
required: required_fields
|
72
83
|
}
|
73
84
|
}
|
74
85
|
end
|
@@ -94,7 +105,7 @@ module RubyLLM
|
|
94
105
|
end
|
95
106
|
end
|
96
107
|
|
97
|
-
def create_chat_completion(payload, tools = nil)
|
108
|
+
def create_chat_completion(payload, tools = nil, &block)
|
98
109
|
response = @connection.post('/v1/messages') do |req|
|
99
110
|
req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
|
100
111
|
req.headers['anthropic-version'] = '2023-06-01'
|
@@ -105,7 +116,10 @@ module RubyLLM
|
|
105
116
|
puts 'Response from Anthropic:' if ENV['RUBY_LLM_DEBUG']
|
106
117
|
puts JSON.pretty_generate(response.body) if ENV['RUBY_LLM_DEBUG']
|
107
118
|
|
108
|
-
|
119
|
+
# Check for API errors first
|
120
|
+
check_for_api_error(response)
|
121
|
+
|
122
|
+
handle_response(response, tools, payload, &block)
|
109
123
|
rescue Faraday::Error => e
|
110
124
|
handle_error(e)
|
111
125
|
end
|
@@ -117,6 +131,7 @@ module RubyLLM
|
|
117
131
|
req.body = payload
|
118
132
|
end
|
119
133
|
|
134
|
+
messages = []
|
120
135
|
response.body.each_line do |line|
|
121
136
|
next if line.strip.empty?
|
122
137
|
next if line == 'data: [DONE]'
|
@@ -124,32 +139,48 @@ module RubyLLM
|
|
124
139
|
begin
|
125
140
|
data = JSON.parse(line.sub(/^data: /, ''))
|
126
141
|
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
142
|
+
message = case data['type']
|
143
|
+
when 'content_block_delta'
|
144
|
+
Message.new(role: :assistant, content: data['delta']['text']) if data['delta']['text']
|
145
|
+
when 'tool_call'
|
146
|
+
handle_tool_calls(data['tool_calls'], tools) do |result|
|
147
|
+
Message.new(role: :assistant, content: result)
|
148
|
+
end
|
149
|
+
end
|
150
|
+
|
151
|
+
if message
|
152
|
+
messages << message
|
153
|
+
yield message if block_given?
|
134
154
|
end
|
135
155
|
rescue JSON::ParserError
|
136
156
|
next
|
137
157
|
end
|
138
158
|
end
|
159
|
+
|
160
|
+
messages
|
139
161
|
rescue Faraday::Error => e
|
140
162
|
handle_error(e)
|
141
163
|
end
|
142
164
|
|
143
|
-
def handle_response(response, tools, payload)
|
165
|
+
def handle_response(response, tools, payload, &block)
|
144
166
|
data = response.body
|
145
|
-
return Message.new(role: :assistant, content: '') if data['type'] == 'error'
|
146
167
|
|
147
|
-
# Extract text content and tool use from response
|
148
168
|
content_parts = data['content'] || []
|
149
169
|
text_content = content_parts.find { |c| c['type'] == 'text' }&.fetch('text', '')
|
150
170
|
tool_use = content_parts.find { |c| c['type'] == 'tool_use' }
|
151
171
|
|
152
172
|
if tool_use && tools
|
173
|
+
# Tool call handling code...
|
174
|
+
tool_message = Message.new(
|
175
|
+
role: :assistant,
|
176
|
+
content: text_content,
|
177
|
+
tool_calls: [{
|
178
|
+
name: tool_use['name'],
|
179
|
+
arguments: JSON.generate(tool_use['input'] || {})
|
180
|
+
}]
|
181
|
+
)
|
182
|
+
yield tool_message if block_given?
|
183
|
+
|
153
184
|
tool = tools.find { |t| t.name == tool_use['name'] }
|
154
185
|
result = if tool
|
155
186
|
begin
|
@@ -167,7 +198,13 @@ module RubyLLM
|
|
167
198
|
end
|
168
199
|
end
|
169
200
|
|
170
|
-
|
201
|
+
result_message = Message.new(
|
202
|
+
role: :tool,
|
203
|
+
content: result[:content],
|
204
|
+
tool_results: result
|
205
|
+
)
|
206
|
+
yield result_message if block_given?
|
207
|
+
|
171
208
|
new_messages = payload[:messages] + [
|
172
209
|
{ role: 'assistant', content: data['content'] },
|
173
210
|
{
|
@@ -183,13 +220,29 @@ module RubyLLM
|
|
183
220
|
}
|
184
221
|
]
|
185
222
|
|
186
|
-
|
187
|
-
|
223
|
+
final_response = create_chat_completion(
|
224
|
+
payload.merge(messages: new_messages),
|
225
|
+
tools,
|
226
|
+
&block
|
227
|
+
)
|
228
|
+
|
229
|
+
[tool_message, result_message] + final_response
|
230
|
+
else
|
231
|
+
token_usage = if data['usage']
|
232
|
+
{
|
233
|
+
input_tokens: data['usage']['input_tokens'],
|
234
|
+
output_tokens: data['usage']['output_tokens'],
|
235
|
+
total_tokens: data['usage']['input_tokens'] + data['usage']['output_tokens']
|
236
|
+
}
|
237
|
+
end
|
188
238
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
239
|
+
[Message.new(
|
240
|
+
role: :assistant,
|
241
|
+
content: text_content,
|
242
|
+
token_usage: token_usage,
|
243
|
+
model_id: data['model']
|
244
|
+
)]
|
245
|
+
end
|
193
246
|
end
|
194
247
|
|
195
248
|
def handle_tool_calls(tool_calls, tools)
|
@@ -218,6 +271,36 @@ module RubyLLM
|
|
218
271
|
end.compact
|
219
272
|
end
|
220
273
|
|
274
|
+
def handle_api_error(error)
|
275
|
+
response_body = error.response[:body]
|
276
|
+
if response_body.is_a?(String)
|
277
|
+
begin
|
278
|
+
error_data = JSON.parse(response_body)
|
279
|
+
message = error_data.dig('error', 'message')
|
280
|
+
raise RubyLLM::Error, "API error: #{message}" if message
|
281
|
+
rescue JSON::ParserError
|
282
|
+
raise RubyLLM::Error, "API error: #{error.response[:status]}"
|
283
|
+
end
|
284
|
+
elsif response_body['error']
|
285
|
+
raise RubyLLM::Error, "API error: #{response_body['error']['message']}"
|
286
|
+
else
|
287
|
+
raise RubyLLM::Error, "API error: #{error.response[:status]}"
|
288
|
+
end
|
289
|
+
end
|
290
|
+
|
291
|
+
def handle_error(error)
|
292
|
+
case error
|
293
|
+
when Faraday::TimeoutError
|
294
|
+
raise RubyLLM::Error, 'Request timed out'
|
295
|
+
when Faraday::ConnectionFailed
|
296
|
+
raise RubyLLM::Error, 'Connection failed'
|
297
|
+
when Faraday::ClientError
|
298
|
+
handle_api_error(error)
|
299
|
+
else
|
300
|
+
raise error
|
301
|
+
end
|
302
|
+
end
|
303
|
+
|
221
304
|
def api_base
|
222
305
|
'https://api.anthropic.com'
|
223
306
|
end
|
@@ -16,6 +16,13 @@ module RubyLLM
|
|
16
16
|
|
17
17
|
protected
|
18
18
|
|
19
|
+
def check_for_api_error(response)
|
20
|
+
return unless response.body.is_a?(Hash) && response.body['type'] == 'error'
|
21
|
+
|
22
|
+
error_msg = response.body.dig('error', 'message') || 'Unknown API error'
|
23
|
+
raise RubyLLM::Error, "API error: #{error_msg}"
|
24
|
+
end
|
25
|
+
|
19
26
|
def build_connection
|
20
27
|
Faraday.new(url: api_base) do |f|
|
21
28
|
f.options.timeout = RubyLLM.configuration.request_timeout
|
@@ -22,7 +22,7 @@ module RubyLLM
|
|
22
22
|
if stream && block_given?
|
23
23
|
stream_chat_completion(payload, tools, &block)
|
24
24
|
else
|
25
|
-
create_chat_completion(payload, tools)
|
25
|
+
create_chat_completion(payload, tools, &block)
|
26
26
|
end
|
27
27
|
rescue Faraday::TimeoutError
|
28
28
|
raise RubyLLM::Error, 'Request timed out'
|
@@ -80,7 +80,7 @@ module RubyLLM
|
|
80
80
|
}
|
81
81
|
end
|
82
82
|
|
83
|
-
def create_chat_completion(payload, tools = nil)
|
83
|
+
def create_chat_completion(payload, tools = nil, &block)
|
84
84
|
response = connection.post('/v1/chat/completions') do |req|
|
85
85
|
req.headers['Authorization'] = "Bearer #{RubyLLM.configuration.openai_api_key}"
|
86
86
|
req.headers['Content-Type'] = 'application/json'
|
@@ -90,36 +90,74 @@ module RubyLLM
|
|
90
90
|
puts 'Response from OpenAI:' if ENV['RUBY_LLM_DEBUG']
|
91
91
|
puts JSON.pretty_generate(response.body) if ENV['RUBY_LLM_DEBUG']
|
92
92
|
|
93
|
+
# Check for API errors
|
94
|
+
check_for_api_error(response)
|
95
|
+
|
96
|
+
# Check for HTTP errors
|
93
97
|
if response.status >= 400
|
94
98
|
error_msg = response.body['error']&.fetch('message', nil) || "HTTP #{response.status}"
|
95
99
|
raise RubyLLM::Error, "API error: #{error_msg}"
|
96
100
|
end
|
97
101
|
|
98
|
-
handle_response(response, tools, payload)
|
102
|
+
handle_response(response, tools, payload, &block)
|
99
103
|
end
|
100
104
|
|
101
|
-
def handle_response(response, tools, payload)
|
105
|
+
def handle_response(response, tools, payload, &block)
|
102
106
|
data = response.body
|
103
107
|
message_data = data.dig('choices', 0, 'message')
|
104
|
-
return
|
108
|
+
return [] unless message_data
|
105
109
|
|
106
110
|
if message_data['function_call'] && tools
|
111
|
+
# Create function call message
|
112
|
+
function_message = Message.new(
|
113
|
+
role: :assistant,
|
114
|
+
content: message_data['content'],
|
115
|
+
tool_calls: [message_data['function_call']]
|
116
|
+
)
|
117
|
+
yield function_message if block_given?
|
118
|
+
|
119
|
+
# Execute function and create result message
|
107
120
|
result = handle_function_call(message_data['function_call'], tools)
|
108
|
-
|
121
|
+
result_message = Message.new(
|
122
|
+
role: :tool,
|
123
|
+
content: result,
|
124
|
+
tool_results: {
|
125
|
+
name: message_data['function_call']['name'],
|
126
|
+
content: result
|
127
|
+
}
|
128
|
+
)
|
129
|
+
yield result_message if block_given?
|
109
130
|
|
110
|
-
#
|
131
|
+
# Get final response with function results
|
111
132
|
new_messages = payload[:messages] + [
|
112
133
|
{ role: 'assistant', content: message_data['content'], function_call: message_data['function_call'] },
|
113
134
|
{ role: 'function', name: message_data['function_call']['name'], content: result }
|
114
135
|
]
|
115
136
|
|
116
|
-
|
117
|
-
|
137
|
+
final_response = create_chat_completion(
|
138
|
+
payload.merge(messages: new_messages),
|
139
|
+
tools,
|
140
|
+
&block
|
141
|
+
)
|
118
142
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
143
|
+
# Return all messages in sequence
|
144
|
+
[function_message, result_message] + final_response
|
145
|
+
else
|
146
|
+
token_usage = if data['usage']
|
147
|
+
{
|
148
|
+
input_tokens: data['usage']['prompt_tokens'],
|
149
|
+
output_tokens: data['usage']['completion_tokens'],
|
150
|
+
total_tokens: data['usage']['total_tokens']
|
151
|
+
}
|
152
|
+
end
|
153
|
+
|
154
|
+
[Message.new(
|
155
|
+
role: :assistant,
|
156
|
+
content: message_data['content'],
|
157
|
+
token_usage: token_usage,
|
158
|
+
model_id: data['model']
|
159
|
+
)]
|
160
|
+
end
|
123
161
|
end
|
124
162
|
|
125
163
|
def handle_function_call(function_call, tools)
|
@@ -153,6 +191,23 @@ module RubyLLM
|
|
153
191
|
end
|
154
192
|
end
|
155
193
|
|
194
|
+
def handle_api_error(error)
|
195
|
+
response_body = error.response[:body]
|
196
|
+
if response_body.is_a?(String)
|
197
|
+
begin
|
198
|
+
error_data = JSON.parse(response_body)
|
199
|
+
message = error_data.dig('error', 'message')
|
200
|
+
raise RubyLLM::Error, "API error: #{message}" if message
|
201
|
+
rescue JSON::ParserError
|
202
|
+
raise RubyLLM::Error, "API error: #{error.response[:status]}"
|
203
|
+
end
|
204
|
+
elsif response_body['error']
|
205
|
+
raise RubyLLM::Error, "API error: #{response_body['error']['message']}"
|
206
|
+
else
|
207
|
+
raise RubyLLM::Error, "API error: #{error.response[:status]}"
|
208
|
+
end
|
209
|
+
end
|
210
|
+
|
156
211
|
def api_base
|
157
212
|
'https://api.openai.com'
|
158
213
|
end
|
data/lib/ruby_llm/version.rb
CHANGED