ruby_llm 0.1.0.pre4 → 0.1.0.pre6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,52 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module Models
5
+ module_function
6
+
7
+ def provider_for(model)
8
+ Provider.for(model)
9
+ end
10
+
11
+ def all
12
+ @all ||= begin
13
+ data = JSON.parse(File.read(File.expand_path('models.json', __dir__)))
14
+ data['models'].map { |model| ModelInfo.new(model.transform_keys(&:to_sym)) }
15
+ end
16
+ rescue Errno::ENOENT
17
+ [] # Return empty array if file doesn't exist yet
18
+ end
19
+
20
+ def find(model_id)
21
+ all.find { |m| m.id == model_id } or raise Error, "Unknown model: #{model_id}"
22
+ end
23
+
24
+ def chat_models
25
+ all.select { |m| m.type == 'chat' }
26
+ end
27
+
28
+ def embedding_models
29
+ all.select { |m| m.type == 'embedding' }
30
+ end
31
+
32
+ def audio_models
33
+ all.select { |m| m.type == 'audio' }
34
+ end
35
+
36
+ def image_models
37
+ all.select { |m| m.type == 'image' }
38
+ end
39
+
40
+ def by_family(family)
41
+ all.select { |m| m.family == family }
42
+ end
43
+
44
+ def default_model
45
+ 'gpt-4o-mini'
46
+ end
47
+
48
+ def refresh!
49
+ @all = nil
50
+ end
51
+ end
52
+ end
@@ -0,0 +1,99 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module Provider
5
+ def self.included(base)
6
+ base.include(InstanceMethods)
7
+ end
8
+
9
+ module InstanceMethods
10
+ def complete(messages, tools: [], model: nil, &block)
11
+ # TODO: refactor
12
+ payload = build_payload(messages, tools, model: model, stream: block_given?)
13
+
14
+ content = String.new
15
+ model_id = nil
16
+ input_tokens = 0
17
+ output_tokens = 0
18
+ response = connection.post(completion_url, payload) do |req|
19
+ req.headers.merge! headers
20
+ if block_given?
21
+ req.options.on_data = handle_stream do |chunk|
22
+ model_id ||= chunk.model_id
23
+ content << (chunk.content || '')
24
+ input_tokens += chunk.input_tokens if chunk.input_tokens
25
+ output_tokens += chunk.output_tokens if chunk.output_tokens
26
+ block.call(chunk)
27
+ end
28
+ end
29
+ end
30
+
31
+ if block_given?
32
+ Message.new(
33
+ role: :assistant,
34
+ content: content,
35
+ model_id: model_id,
36
+ input_tokens: input_tokens.positive? ? input_tokens : nil,
37
+ output_tokens: output_tokens.positive? ? output_tokens : nil
38
+ )
39
+ else
40
+ parse_completion_response(response)
41
+ end
42
+ end
43
+
44
+ def list_models
45
+ response = connection.get(models_url) do |req|
46
+ req.headers.merge!(headers)
47
+ end
48
+
49
+ parse_list_models_response(response)
50
+ end
51
+
52
+ private
53
+
54
+ def connection
55
+ @connection ||= Faraday.new(api_base) do |f|
56
+ f.options.timeout = RubyLLM.config.request_timeout
57
+ f.request :json
58
+ f.response :json
59
+ f.adapter Faraday.default_adapter
60
+ f.use Faraday::Response::RaiseError
61
+ f.response :logger, RubyLLM.logger, { headers: false, bodies: true, errors: true, log_level: :debug }
62
+ end
63
+ end
64
+
65
+ def to_json_stream(&block)
66
+ parser = EventStreamParser::Parser.new
67
+ proc do |chunk, _bytes, _|
68
+ parser.feed(chunk) do |_type, data|
69
+ unless data == '[DONE]'
70
+ parsed_data = JSON.parse(data)
71
+ RubyLLM.logger.debug "chunk: #{parsed_data}"
72
+ block.call(parsed_data)
73
+ end
74
+ end
75
+ end
76
+ end
77
+ end
78
+
79
+ class << self
80
+ def register(name, provider_class)
81
+ providers[name.to_sym] = provider_class
82
+ end
83
+
84
+ def for(model)
85
+ model_info = Models.find(model)
86
+ provider_class = providers[model_info.provider.to_sym] or
87
+ raise Error, "No provider registered for #{model_info.provider}"
88
+
89
+ provider_class.new
90
+ end
91
+
92
+ private
93
+
94
+ def providers
95
+ @providers ||= {}
96
+ end
97
+ end
98
+ end
99
+ end
@@ -1,51 +1,82 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'time'
4
-
5
3
  module RubyLLM
6
4
  module Providers
7
- class Anthropic < Base
8
- def chat(messages, model: nil, temperature: 0.7, stream: false, tools: nil, &block)
9
- payload = {
10
- model: model || 'claude-3-5-sonnet-20241022',
5
+ class Anthropic
6
+ include Provider
7
+
8
+ private
9
+
10
+ def api_base
11
+ 'https://api.anthropic.com'
12
+ end
13
+
14
+ def headers
15
+ {
16
+ 'x-api-key' => RubyLLM.config.anthropic_api_key,
17
+ 'anthropic-version' => '2023-06-01'
18
+ }
19
+ end
20
+
21
+ def completion_url
22
+ '/v1/messages'
23
+ end
24
+
25
+ def models_url
26
+ '/v1/models'
27
+ end
28
+
29
+ def build_payload(messages, tools, model:, temperature: 0.7, stream: false)
30
+ {
31
+ model: model,
11
32
  messages: format_messages(messages),
12
33
  temperature: temperature,
13
34
  stream: stream,
14
- max_tokens: 4096
15
- }
35
+ max_tokens: RubyLLM.models.find(model).max_tokens
36
+ }.tap do |payload|
37
+ payload[:tools] = tools.map { |t| function_for(t) } if tools.any?
38
+ end
39
+ end
16
40
 
17
- payload[:tools] = tools.map { |tool| tool_to_anthropic(tool) } if tools&.any?
41
+ def parse_completion_response(response)
42
+ data = response.body
43
+ content_blocks = data['content'] || []
18
44
 
19
- puts 'Sending payload to Anthropic:' if ENV['RUBY_LLM_DEBUG']
20
- puts JSON.pretty_generate(payload) if ENV['RUBY_LLM_DEBUG']
45
+ text_content = content_blocks.find { |c| c['type'] == 'text' }&.fetch('text', '')
46
+ tool_use = content_blocks.find { |c| c['type'] == 'tool_use' }
21
47
 
22
- if stream && block_given?
23
- stream_chat_completion(payload, tools, &block)
48
+ if tool_use
49
+ Message.new(
50
+ role: :assistant,
51
+ content: text_content,
52
+ tool_calls: [
53
+ {
54
+ name: tool_use['name'],
55
+ arguments: JSON.generate(tool_use['input'] || {})
56
+ }
57
+ ]
58
+ )
24
59
  else
25
- create_chat_completion(payload, tools)
60
+ Message.new(
61
+ role: :assistant,
62
+ content: text_content,
63
+ input_tokens: data['usage']['input_tokens'],
64
+ output_tokens: data['usage']['output_tokens'],
65
+ model_id: data['model']
66
+ )
26
67
  end
27
68
  end
28
69
 
29
- def list_models
30
- response = @connection.get('/v1/models') do |req|
31
- req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
32
- req.headers['anthropic-version'] = '2023-06-01'
33
- end
34
-
35
- raise RubyLLM::Error, "API error: #{parse_error_message(response)}" if response.status >= 400
36
-
37
- capabilities = RubyLLM::ModelCapabilities::Anthropic.new
38
- models_data = response.body['data'] || []
70
+ def parse_models_response(response)
71
+ capabilities = ModelCapabilities::Anthropic.new
39
72
 
40
- models_data.map do |model|
73
+ (response.body['data'] || []).map do |model|
41
74
  ModelInfo.new(
42
75
  id: model['id'],
43
76
  created_at: Time.parse(model['created_at']),
44
77
  display_name: model['display_name'],
45
78
  provider: 'anthropic',
46
- metadata: {
47
- type: model['type']
48
- },
79
+ metadata: { type: model['type'] },
49
80
  context_window: capabilities.determine_context_window(model['id']),
50
81
  max_tokens: capabilities.determine_max_tokens(model['id']),
51
82
  supports_vision: capabilities.supports_vision?(model['id']),
@@ -55,254 +86,72 @@ module RubyLLM
55
86
  output_price_per_million: capabilities.get_output_price(model['id'])
56
87
  )
57
88
  end
58
- rescue Faraday::Error => e
59
- handle_error(e)
60
89
  end
61
90
 
62
- private
63
-
64
- def tool_to_anthropic(tool)
65
- # Get required fields and clean properties
66
- required_fields = []
67
- cleaned_properties = {}
68
-
69
- tool.parameters.each do |name, props|
70
- required_fields << name.to_s if props[:required]
71
- cleaned_props = props.dup
72
- cleaned_props.delete(:required)
73
- cleaned_properties[name] = cleaned_props
91
+ def handle_stream(&block)
92
+ to_json_stream do |data|
93
+ block.call(
94
+ Chunk.new(
95
+ role: :assistant,
96
+ model_id: data.dig('message', 'model'),
97
+ content: data.dig('delta', 'text'),
98
+ input_tokens: data.dig('message', 'usage', 'input_tokens'),
99
+ output_tokens: data.dig('message', 'usage', 'output_tokens') || data.dig('usage', 'output_tokens')
100
+ )
101
+ )
74
102
  end
103
+ end
75
104
 
105
+ def function_for(tool)
76
106
  {
77
107
  name: tool.name,
78
108
  description: tool.description,
79
109
  input_schema: {
80
110
  type: 'object',
81
- properties: cleaned_properties,
82
- required: required_fields
111
+ properties: clean_parameters(tool.parameters),
112
+ required: required_parameters(tool.parameters)
83
113
  }
84
114
  }
85
115
  end
86
116
 
87
117
  def format_messages(messages)
88
118
  messages.map do |msg|
89
- message = { role: msg.role == :user ? 'user' : 'assistant' }
90
-
91
- message[:content] = if msg.tool_results
92
- [
93
- {
94
- type: 'tool_result',
95
- tool_use_id: msg.tool_results[:tool_use_id],
96
- content: msg.tool_results[:content],
97
- is_error: msg.tool_results[:is_error]
98
- }.compact
99
- ]
100
- else
101
- msg.content
102
- end
103
-
104
- message
105
- end
106
- end
107
-
108
- def create_chat_completion(payload, tools = nil, &block)
109
- response = @connection.post('/v1/messages') do |req|
110
- req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
111
- req.headers['anthropic-version'] = '2023-06-01'
112
- req.headers['Content-Type'] = 'application/json'
113
- req.body = payload
114
- end
115
-
116
- puts 'Response from Anthropic:' if ENV['RUBY_LLM_DEBUG']
117
- puts JSON.pretty_generate(response.body) if ENV['RUBY_LLM_DEBUG']
118
-
119
- # Check for API errors first
120
- check_for_api_error(response)
121
-
122
- handle_response(response, tools, payload, &block)
123
- rescue Faraday::Error => e
124
- handle_error(e)
125
- end
126
-
127
- def stream_chat_completion(payload, tools = nil)
128
- response = @connection.post('/v1/messages') do |req|
129
- req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
130
- req.headers['anthropic-version'] = '2023-06-01'
131
- req.body = payload
132
- end
133
-
134
- messages = []
135
- response.body.each_line do |line|
136
- next if line.strip.empty?
137
- next if line == 'data: [DONE]'
138
-
139
- begin
140
- data = JSON.parse(line.sub(/^data: /, ''))
141
-
142
- message = case data['type']
143
- when 'content_block_delta'
144
- Message.new(role: :assistant, content: data['delta']['text']) if data['delta']['text']
145
- when 'tool_call'
146
- handle_tool_calls(data['tool_calls'], tools) do |result|
147
- Message.new(role: :assistant, content: result)
148
- end
149
- end
150
-
151
- if message
152
- messages << message
153
- yield message if block_given?
154
- end
155
- rescue JSON::ParserError
156
- next
157
- end
158
- end
159
-
160
- messages
161
- rescue Faraday::Error => e
162
- handle_error(e)
163
- end
164
-
165
- def handle_response(response, tools, payload, &block)
166
- data = response.body
167
-
168
- content_parts = data['content'] || []
169
- text_content = content_parts.find { |c| c['type'] == 'text' }&.fetch('text', '')
170
- tool_use = content_parts.find { |c| c['type'] == 'tool_use' }
171
-
172
- if tool_use && tools
173
- # Tool call handling code...
174
- tool_message = Message.new(
175
- role: :assistant,
176
- content: text_content,
177
- tool_calls: [{
178
- name: tool_use['name'],
179
- arguments: JSON.generate(tool_use['input'] || {})
180
- }]
181
- )
182
- yield tool_message if block_given?
183
-
184
- tool = tools.find { |t| t.name == tool_use['name'] }
185
- result = if tool
186
- begin
187
- tool_result = tool.call(tool_use['input'] || {})
188
- {
189
- tool_use_id: tool_use['id'],
190
- content: tool_result.to_s
191
- }
192
- rescue StandardError => e
193
- {
194
- tool_use_id: tool_use['id'],
195
- content: "Error executing tool #{tool.name}: #{e.message}",
196
- is_error: true
197
- }
198
- end
199
- end
200
-
201
- result_message = Message.new(
202
- role: :tool,
203
- content: result[:content],
204
- tool_results: result
205
- )
206
- yield result_message if block_given?
207
-
208
- new_messages = payload[:messages] + [
209
- { role: 'assistant', content: data['content'] },
119
+ if msg.tool_results
210
120
  {
211
- role: 'user',
121
+ role: convert_role(msg.role),
212
122
  content: [
213
123
  {
214
124
  type: 'tool_result',
215
- tool_use_id: result[:tool_use_id],
216
- content: result[:content],
217
- is_error: result[:is_error]
125
+ tool_use_id: msg.tool_results[:tool_use_id],
126
+ content: msg.tool_results[:content],
127
+ is_error: msg.tool_results[:is_error]
218
128
  }.compact
219
129
  ]
220
130
  }
221
- ]
222
-
223
- final_response = create_chat_completion(
224
- payload.merge(messages: new_messages),
225
- tools,
226
- &block
227
- )
228
-
229
- [tool_message, result_message] + final_response
230
- else
231
- token_usage = if data['usage']
232
- {
233
- input_tokens: data['usage']['input_tokens'],
234
- output_tokens: data['usage']['output_tokens'],
235
- total_tokens: data['usage']['input_tokens'] + data['usage']['output_tokens']
236
- }
237
- end
238
-
239
- [Message.new(
240
- role: :assistant,
241
- content: text_content,
242
- token_usage: token_usage,
243
- model_id: data['model']
244
- )]
245
- end
246
- end
247
-
248
- def handle_tool_calls(tool_calls, tools)
249
- return [] unless tool_calls && tools
250
-
251
- tool_calls.map do |tool_call|
252
- tool = tools.find { |t| t.name == tool_call['name'] }
253
- next unless tool
254
-
255
- begin
256
- args = JSON.parse(tool_call['arguments'])
257
- result = tool.call(args)
258
- puts "Tool result: #{result}" if ENV['RUBY_LLM_DEBUG']
259
- {
260
- tool_use_id: tool_call['id'],
261
- content: result.to_s
262
- }
263
- rescue JSON::ParserError, ArgumentError => e
264
- puts "Error executing tool: #{e.message}" if ENV['RUBY_LLM_DEBUG']
131
+ else
265
132
  {
266
- tool_use_id: tool_call['id'],
267
- content: "Error executing tool #{tool.name}: #{e.message}",
268
- is_error: true
133
+ role: convert_role(msg.role),
134
+ content: msg.content
269
135
  }
270
136
  end
271
- end.compact
137
+ end
272
138
  end
273
139
 
274
- def handle_api_error(error)
275
- response_body = error.response[:body]
276
- if response_body.is_a?(String)
277
- begin
278
- error_data = JSON.parse(response_body)
279
- message = error_data.dig('error', 'message')
280
- raise RubyLLM::Error, "API error: #{message}" if message
281
- rescue JSON::ParserError
282
- raise RubyLLM::Error, "API error: #{error.response[:status]}"
283
- end
284
- elsif response_body['error']
285
- raise RubyLLM::Error, "API error: #{response_body['error']['message']}"
286
- else
287
- raise RubyLLM::Error, "API error: #{error.response[:status]}"
140
+ def convert_role(role)
141
+ case role
142
+ when :user then 'user'
143
+ else 'assistant'
288
144
  end
289
145
  end
290
146
 
291
- def handle_error(error)
292
- case error
293
- when Faraday::TimeoutError
294
- raise RubyLLM::Error, 'Request timed out'
295
- when Faraday::ConnectionFailed
296
- raise RubyLLM::Error, 'Connection failed'
297
- when Faraday::ClientError
298
- handle_api_error(error)
299
- else
300
- raise error
147
+ def clean_parameters(parameters)
148
+ parameters.transform_values do |props|
149
+ props.except(:required)
301
150
  end
302
151
  end
303
152
 
304
- def api_base
305
- 'https://api.anthropic.com'
153
+ def required_parameters(parameters)
154
+ parameters.select { |_, props| props[:required] }.keys
306
155
  end
307
156
  end
308
157
  end