ruby_llm 0.1.0.pre3 → 0.1.0.pre5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,52 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module Models
5
+ module_function
6
+
7
+ def provider_for(model)
8
+ Provider.for(model)
9
+ end
10
+
11
+ def all
12
+ @all ||= begin
13
+ data = JSON.parse(File.read(File.expand_path('models.json', __dir__)))
14
+ data['models'].map { |model| ModelInfo.new(model.transform_keys(&:to_sym)) }
15
+ end
16
+ rescue Errno::ENOENT
17
+ [] # Return empty array if file doesn't exist yet
18
+ end
19
+
20
+ def find(model_id)
21
+ all.find { |m| m.id == model_id } or raise Error, "Unknown model: #{model_id}"
22
+ end
23
+
24
+ def chat_models
25
+ all.select { |m| m.type == 'chat' }
26
+ end
27
+
28
+ def embedding_models
29
+ all.select { |m| m.type == 'embedding' }
30
+ end
31
+
32
+ def audio_models
33
+ all.select { |m| m.type == 'audio' }
34
+ end
35
+
36
+ def image_models
37
+ all.select { |m| m.type == 'image' }
38
+ end
39
+
40
+ def by_family(family)
41
+ all.select { |m| m.family == family }
42
+ end
43
+
44
+ def default_model
45
+ 'gpt-4o-mini'
46
+ end
47
+
48
+ def refresh!
49
+ @all = nil
50
+ end
51
+ end
52
+ end
@@ -0,0 +1,99 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module Provider
5
+ def self.included(base)
6
+ base.include(InstanceMethods)
7
+ end
8
+
9
+ module InstanceMethods
10
+ def complete(messages, tools: [], model: nil, &block)
11
+ # TODO: refactor
12
+ payload = build_payload(messages, tools, model: model, stream: block_given?)
13
+
14
+ content = String.new
15
+ model_id = nil
16
+ input_tokens = 0
17
+ output_tokens = 0
18
+ response = connection.post(completion_url, payload) do |req|
19
+ req.headers.merge! headers
20
+ if block_given?
21
+ req.options.on_data = handle_stream do |chunk|
22
+ model_id ||= chunk.model_id
23
+ content << (chunk.content || '')
24
+ input_tokens += chunk.input_tokens if chunk.input_tokens
25
+ output_tokens += chunk.output_tokens if chunk.output_tokens
26
+ block.call(chunk)
27
+ end
28
+ end
29
+ end
30
+
31
+ if block_given?
32
+ Message.new(
33
+ role: :assistant,
34
+ content: content,
35
+ model_id: model_id,
36
+ input_tokens: input_tokens.positive? ? input_tokens : nil,
37
+ output_tokens: output_tokens.positive? ? output_tokens : nil
38
+ )
39
+ else
40
+ parse_completion_response(response)
41
+ end
42
+ end
43
+
44
+ def list_models
45
+ response = connection.get(models_url) do |req|
46
+ req.headers.merge!(headers)
47
+ end
48
+
49
+ parse_list_models_response(response)
50
+ end
51
+
52
+ private
53
+
54
+ def connection
55
+ @connection ||= Faraday.new(api_base) do |f|
56
+ f.options.timeout = RubyLLM.config.request_timeout
57
+ f.request :json
58
+ f.response :json
59
+ f.adapter Faraday.default_adapter
60
+ f.use Faraday::Response::RaiseError
61
+ f.response :logger, RubyLLM.logger, { headers: false, bodies: true, errors: true, log_level: :debug }
62
+ end
63
+ end
64
+
65
+ def to_json_stream(&block)
66
+ parser = EventStreamParser::Parser.new
67
+ proc do |chunk, _bytes, _|
68
+ parser.feed(chunk) do |_type, data|
69
+ unless data == '[DONE]'
70
+ parsed_data = JSON.parse(data)
71
+ RubyLLM.logger.debug "chunk: #{parsed_data}"
72
+ block.call(parsed_data)
73
+ end
74
+ end
75
+ end
76
+ end
77
+ end
78
+
79
+ class << self
80
+ def register(name, provider_class)
81
+ providers[name.to_sym] = provider_class
82
+ end
83
+
84
+ def for(model)
85
+ model_info = Models.find(model)
86
+ provider_class = providers[model_info.provider.to_sym] or
87
+ raise Error, "No provider registered for #{model_info.provider}"
88
+
89
+ provider_class.new
90
+ end
91
+
92
+ private
93
+
94
+ def providers
95
+ @providers ||= {}
96
+ end
97
+ end
98
+ end
99
+ end
@@ -1,51 +1,82 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'time'
4
-
5
3
  module RubyLLM
6
4
  module Providers
7
- class Anthropic < Base
8
- def chat(messages, model: nil, temperature: 0.7, stream: false, tools: nil, &block)
9
- payload = {
10
- model: model || 'claude-3-5-sonnet-20241022',
11
- messages: format_messages(messages),
12
- temperature: temperature,
13
- stream: stream,
14
- max_tokens: 4096
5
+ class Anthropic
6
+ include Provider
7
+
8
+ private
9
+
10
+ def api_base
11
+ 'https://api.anthropic.com'
12
+ end
13
+
14
+ def headers
15
+ {
16
+ 'x-api-key' => RubyLLM.config.anthropic_api_key,
17
+ 'anthropic-version' => '2023-06-01'
15
18
  }
19
+ end
16
20
 
17
- payload[:tools] = tools.map { |tool| tool_to_anthropic(tool) } if tools&.any?
21
+ def completion_url
22
+ '/v1/messages'
23
+ end
18
24
 
19
- puts 'Sending payload to Anthropic:' if ENV['RUBY_LLM_DEBUG']
20
- puts JSON.pretty_generate(payload) if ENV['RUBY_LLM_DEBUG']
25
+ def models_url
26
+ '/v1/models'
27
+ end
21
28
 
22
- if stream && block_given?
23
- stream_chat_completion(payload, tools, &block)
24
- else
25
- create_chat_completion(payload, tools)
29
+ def build_payload(messages, tools, model:, temperature: 0.7, stream: false)
30
+ {
31
+ model: model,
32
+ messages: format_messages(messages),
33
+ temperature: temperature,
34
+ stream: stream,
35
+ max_tokens: RubyLLM.models.find(model).max_tokens
36
+ }.tap do |payload|
37
+ payload[:tools] = tools.map { |t| function_for(t) } if tools.any?
26
38
  end
27
39
  end
28
40
 
29
- def list_models
30
- response = @connection.get('/v1/models') do |req|
31
- req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
32
- req.headers['anthropic-version'] = '2023-06-01'
41
+ def parse_completion_response(response)
42
+ data = response.body
43
+ content_blocks = data['content'] || []
44
+
45
+ text_content = content_blocks.find { |c| c['type'] == 'text' }&.fetch('text', '')
46
+ tool_use = content_blocks.find { |c| c['type'] == 'tool_use' }
47
+
48
+ if tool_use
49
+ Message.new(
50
+ role: :assistant,
51
+ content: text_content,
52
+ tool_calls: [
53
+ {
54
+ name: tool_use['name'],
55
+ arguments: JSON.generate(tool_use['input'] || {})
56
+ }
57
+ ]
58
+ )
59
+ else
60
+ Message.new(
61
+ role: :assistant,
62
+ content: text_content,
63
+ input_tokens: data['usage']['input_tokens'],
64
+ output_tokens: data['usage']['output_tokens'],
65
+ model_id: data['model']
66
+ )
33
67
  end
68
+ end
34
69
 
35
- raise RubyLLM::Error, "API error: #{parse_error_message(response)}" if response.status >= 400
36
-
37
- capabilities = RubyLLM::ModelCapabilities::Anthropic.new
38
- models_data = response.body['data'] || []
70
+ def parse_models_response(response)
71
+ capabilities = ModelCapabilities::Anthropic.new
39
72
 
40
- models_data.map do |model|
73
+ (response.body['data'] || []).map do |model|
41
74
  ModelInfo.new(
42
75
  id: model['id'],
43
76
  created_at: Time.parse(model['created_at']),
44
77
  display_name: model['display_name'],
45
78
  provider: 'anthropic',
46
- metadata: {
47
- type: model['type']
48
- },
79
+ metadata: { type: model['type'] },
49
80
  context_window: capabilities.determine_context_window(model['id']),
50
81
  max_tokens: capabilities.determine_max_tokens(model['id']),
51
82
  supports_vision: capabilities.supports_vision?(model['id']),
@@ -55,199 +86,72 @@ module RubyLLM
55
86
  output_price_per_million: capabilities.get_output_price(model['id'])
56
87
  )
57
88
  end
58
- rescue Faraday::Error => e
59
- handle_error(e)
60
89
  end
61
90
 
62
- private
91
+ def handle_stream(&block)
92
+ to_json_stream do |data|
93
+ block.call(
94
+ Chunk.new(
95
+ role: :assistant,
96
+ model_id: data.dig('message', 'model'),
97
+ content: data.dig('delta', 'text'),
98
+ input_tokens: data.dig('message', 'usage', 'input_tokens'),
99
+ output_tokens: data.dig('message', 'usage', 'output_tokens') || data.dig('usage', 'output_tokens')
100
+ )
101
+ )
102
+ end
103
+ end
63
104
 
64
- def tool_to_anthropic(tool)
105
+ def function_for(tool)
65
106
  {
66
107
  name: tool.name,
67
108
  description: tool.description,
68
109
  input_schema: {
69
110
  type: 'object',
70
- properties: tool.parameters,
71
- required: tool.parameters.select { |_, v| v[:required] }.keys
111
+ properties: clean_parameters(tool.parameters),
112
+ required: required_parameters(tool.parameters)
72
113
  }
73
114
  }
74
115
  end
75
116
 
76
117
  def format_messages(messages)
77
118
  messages.map do |msg|
78
- message = { role: msg.role == :user ? 'user' : 'assistant' }
79
-
80
- message[:content] = if msg.tool_results
81
- [
82
- {
83
- type: 'tool_result',
84
- tool_use_id: msg.tool_results[:tool_use_id],
85
- content: msg.tool_results[:content],
86
- is_error: msg.tool_results[:is_error]
87
- }.compact
88
- ]
89
- else
90
- msg.content
91
- end
92
-
93
- message
94
- end
95
- end
96
-
97
- def create_chat_completion(payload, tools = nil)
98
- response = @connection.post('/v1/messages') do |req|
99
- req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
100
- req.headers['anthropic-version'] = '2023-06-01'
101
- req.headers['Content-Type'] = 'application/json'
102
- req.body = payload
103
- end
104
-
105
- puts 'Response from Anthropic:' if ENV['RUBY_LLM_DEBUG']
106
- puts JSON.pretty_generate(response.body) if ENV['RUBY_LLM_DEBUG']
107
-
108
- handle_response(response, tools, payload)
109
- rescue Faraday::Error => e
110
- handle_error(e)
111
- end
112
-
113
- def stream_chat_completion(payload, tools = nil)
114
- response = @connection.post('/v1/messages') do |req|
115
- req.headers['x-api-key'] = RubyLLM.configuration.anthropic_api_key
116
- req.headers['anthropic-version'] = '2023-06-01'
117
- req.body = payload
118
- end
119
-
120
- response.body.each_line do |line|
121
- next if line.strip.empty?
122
- next if line == 'data: [DONE]'
123
-
124
- begin
125
- data = JSON.parse(line.sub(/^data: /, ''))
126
-
127
- if data['type'] == 'content_block_delta'
128
- content = data['delta']['text']
129
- yield Message.new(role: :assistant, content: content) if content
130
- elsif data['type'] == 'tool_call'
131
- handle_tool_calls(data['tool_calls'], tools) do |result|
132
- yield Message.new(role: :assistant, content: result)
133
- end
134
- end
135
- rescue JSON::ParserError
136
- next
137
- end
138
- end
139
- rescue Faraday::Error => e
140
- handle_error(e)
141
- end
142
-
143
- def handle_response(response, tools, payload)
144
- data = response.body
145
- return Message.new(role: :assistant, content: '') if data['type'] == 'error'
146
-
147
- # Extract text content and tool use from response
148
- content_parts = data['content'] || []
149
- text_content = content_parts.find { |c| c['type'] == 'text' }&.fetch('text', '')
150
- tool_use = content_parts.find { |c| c['type'] == 'tool_use' }
151
-
152
- if tool_use && tools
153
- tool = tools.find { |t| t.name == tool_use['name'] }
154
- result = if tool
155
- begin
156
- tool_result = tool.call(tool_use['input'] || {})
157
- {
158
- tool_use_id: tool_use['id'],
159
- content: tool_result.to_s
160
- }
161
- rescue StandardError => e
162
- {
163
- tool_use_id: tool_use['id'],
164
- content: "Error executing tool #{tool.name}: #{e.message}",
165
- is_error: true
166
- }
167
- end
168
- end
169
-
170
- # Create a new message with the tool result
171
- new_messages = payload[:messages] + [
172
- { role: 'assistant', content: data['content'] },
119
+ if msg.tool_results
173
120
  {
174
- role: 'user',
121
+ role: convert_role(msg.role),
175
122
  content: [
176
123
  {
177
124
  type: 'tool_result',
178
- tool_use_id: result[:tool_use_id],
179
- content: result[:content],
180
- is_error: result[:is_error]
125
+ tool_use_id: msg.tool_results[:tool_use_id],
126
+ content: msg.tool_results[:content],
127
+ is_error: msg.tool_results[:is_error]
181
128
  }.compact
182
129
  ]
183
130
  }
184
- ]
185
-
186
- return create_chat_completion(payload.merge(messages: new_messages), tools)
187
- end
188
-
189
- # Extract token usage from response
190
- token_usage = if data['usage']
191
- {
192
- input_tokens: data['usage']['input_tokens'],
193
- output_tokens: data['usage']['output_tokens'],
194
- total_tokens: data['usage']['input_tokens'] + data['usage']['output_tokens']
195
- }
196
- end
197
-
198
- Message.new(
199
- role: :assistant,
200
- content: text_content,
201
- token_usage: token_usage,
202
- model_id: data['model']
203
- )
204
- end
205
-
206
- def handle_tool_calls(tool_calls, tools)
207
- return [] unless tool_calls && tools
208
-
209
- tool_calls.map do |tool_call|
210
- tool = tools.find { |t| t.name == tool_call['name'] }
211
- next unless tool
212
-
213
- begin
214
- args = JSON.parse(tool_call['arguments'])
215
- result = tool.call(args)
216
- puts "Tool result: #{result}" if ENV['RUBY_LLM_DEBUG']
217
- {
218
- tool_use_id: tool_call['id'],
219
- content: result.to_s
220
- }
221
- rescue JSON::ParserError, ArgumentError => e
222
- puts "Error executing tool: #{e.message}" if ENV['RUBY_LLM_DEBUG']
131
+ else
223
132
  {
224
- tool_use_id: tool_call['id'],
225
- content: "Error executing tool #{tool.name}: #{e.message}",
226
- is_error: true
133
+ role: convert_role(msg.role),
134
+ content: msg.content
227
135
  }
228
136
  end
229
- end.compact
137
+ end
230
138
  end
231
139
 
232
- def handle_api_error(error)
233
- response_body = error.response[:body]
234
- if response_body.is_a?(String)
235
- begin
236
- error_data = JSON.parse(response_body)
237
- message = error_data.dig('error', 'message')
238
- raise RubyLLM::Error, "API error: #{message}" if message
239
- rescue JSON::ParserError
240
- raise RubyLLM::Error, "API error: #{error.response[:status]}"
241
- end
242
- elsif response_body.dig('error', 'type') == 'invalid_request_error'
243
- raise RubyLLM::Error, "API error: #{response_body['error']['message']}"
244
- else
245
- raise RubyLLM::Error, "API error: #{error.response[:status]}"
140
+ def convert_role(role)
141
+ case role
142
+ when :user then 'user'
143
+ else 'assistant'
246
144
  end
247
145
  end
248
146
 
249
- def api_base
250
- 'https://api.anthropic.com'
147
+ def clean_parameters(parameters)
148
+ parameters.transform_values do |props|
149
+ props.except(:required)
150
+ end
151
+ end
152
+
153
+ def required_parameters(parameters)
154
+ parameters.select { |_, props| props[:required] }.keys
251
155
  end
252
156
  end
253
157
  end