rcrewai 0.2.1 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. checksums.yaml +4 -4
  2. data/.rubocop.yml +21 -0
  3. data/.rubocop_todo.yml +99 -0
  4. data/CHANGELOG.md +64 -1
  5. data/README.md +170 -2
  6. data/ROADMAP.md +84 -0
  7. data/Rakefile +53 -53
  8. data/bin/rcrewai +3 -3
  9. data/docs/mcp.md +109 -0
  10. data/docs/superpowers/plans/2026-05-11-llm-modernization.md +2753 -0
  11. data/docs/superpowers/specs/2026-05-11-llm-modernization-design.md +479 -0
  12. data/docs/upgrading-to-0.3.md +163 -0
  13. data/examples/async_execution_example.rb +82 -81
  14. data/examples/hierarchical_crew_example.rb +68 -72
  15. data/examples/human_in_the_loop_example.rb +73 -74
  16. data/examples/mcp_example.rb +48 -0
  17. data/examples/native_tools_example.rb +64 -0
  18. data/examples/streaming_example.rb +56 -0
  19. data/lib/rcrewai/agent.rb +181 -286
  20. data/lib/rcrewai/async_executor.rb +43 -43
  21. data/lib/rcrewai/cli.rb +11 -11
  22. data/lib/rcrewai/configuration.rb +34 -9
  23. data/lib/rcrewai/crew.rb +134 -39
  24. data/lib/rcrewai/events.rb +30 -0
  25. data/lib/rcrewai/flow/state.rb +47 -0
  26. data/lib/rcrewai/flow/state_store.rb +50 -0
  27. data/lib/rcrewai/flow.rb +243 -0
  28. data/lib/rcrewai/human_input.rb +104 -114
  29. data/lib/rcrewai/knowledge/base.rb +52 -0
  30. data/lib/rcrewai/knowledge/chunker.rb +31 -0
  31. data/lib/rcrewai/knowledge/embedder.rb +48 -0
  32. data/lib/rcrewai/knowledge/sources.rb +83 -0
  33. data/lib/rcrewai/knowledge/store.rb +58 -0
  34. data/lib/rcrewai/knowledge.rb +13 -0
  35. data/lib/rcrewai/legacy_react_runner.rb +172 -0
  36. data/lib/rcrewai/llm_client.rb +24 -1
  37. data/lib/rcrewai/llm_clients/anthropic.rb +174 -54
  38. data/lib/rcrewai/llm_clients/azure.rb +23 -128
  39. data/lib/rcrewai/llm_clients/base.rb +11 -7
  40. data/lib/rcrewai/llm_clients/google.rb +159 -95
  41. data/lib/rcrewai/llm_clients/ollama.rb +150 -106
  42. data/lib/rcrewai/llm_clients/openai.rb +140 -63
  43. data/lib/rcrewai/mcp/client.rb +101 -0
  44. data/lib/rcrewai/mcp/tool_adapter.rb +59 -0
  45. data/lib/rcrewai/mcp/transport/http.rb +53 -0
  46. data/lib/rcrewai/mcp/transport/stdio.rb +55 -0
  47. data/lib/rcrewai/mcp.rb +8 -0
  48. data/lib/rcrewai/memory.rb +45 -37
  49. data/lib/rcrewai/output_schema.rb +79 -0
  50. data/lib/rcrewai/planning.rb +65 -0
  51. data/lib/rcrewai/pricing.rb +34 -0
  52. data/lib/rcrewai/process.rb +86 -95
  53. data/lib/rcrewai/provider_schema.rb +38 -0
  54. data/lib/rcrewai/sse_parser.rb +55 -0
  55. data/lib/rcrewai/task.rb +145 -66
  56. data/lib/rcrewai/tool_runner.rb +132 -0
  57. data/lib/rcrewai/tool_schema.rb +97 -0
  58. data/lib/rcrewai/tools/base.rb +98 -37
  59. data/lib/rcrewai/tools/code_executor.rb +71 -74
  60. data/lib/rcrewai/tools/email_sender.rb +70 -78
  61. data/lib/rcrewai/tools/file_reader.rb +38 -30
  62. data/lib/rcrewai/tools/file_writer.rb +40 -38
  63. data/lib/rcrewai/tools/pdf_processor.rb +115 -130
  64. data/lib/rcrewai/tools/sql_database.rb +58 -55
  65. data/lib/rcrewai/tools/web_search.rb +26 -25
  66. data/lib/rcrewai/version.rb +2 -2
  67. data/lib/rcrewai.rb +20 -10
  68. data/rcrewai.gemspec +39 -39
  69. metadata +77 -47
@@ -1,18 +1,32 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'faraday'
4
+ require 'json'
3
5
  require_relative 'base'
6
+ require_relative '../events'
7
+ require_relative '../provider_schema'
8
+ require_relative '../pricing'
4
9
 
5
10
  module RCrewAI
6
11
  module LLMClients
7
12
  class Ollama < Base
8
13
  DEFAULT_URL = 'http://localhost:11434'
9
14
 
15
+ NATIVE_TOOL_MODELS = %w[
16
+ llama3.1 llama3.1:8b llama3.1:70b llama3.1:405b
17
+ llama3.2 llama3.2:1b llama3.2:3b
18
+ qwen2.5 qwen2.5:7b qwen2.5:14b qwen2.5:32b qwen2.5:72b
19
+ mistral-nemo mistral-large
20
+ command-r command-r-plus
21
+ firefunction-v2
22
+ ].freeze
23
+
10
24
  def initialize(config = RCrewAI.configuration)
11
25
  super
12
26
  @base_url = config.base_url || ollama_url || DEFAULT_URL
13
27
  end
14
28
 
15
- def chat(messages:, **options)
29
+ def chat(messages:, tools: nil, tool_choice: :auto, stream: nil, **options) # rubocop:disable Lint/UnusedMethodArgument
16
30
  payload = {
17
31
  model: config.model,
18
32
  messages: format_messages(messages),
@@ -24,147 +38,174 @@ module RCrewAI
24
38
  repeat_penalty: options[:repeat_penalty]
25
39
  }.compact
26
40
  }
27
-
28
- # Add stop sequences if provided
29
41
  payload[:options][:stop] = options[:stop] if options[:stop]
30
42
 
31
- url = "#{@base_url}/api/chat"
32
- log_request(:post, url, payload)
33
-
34
- response = http_client.post(url, payload, build_headers)
35
- log_response(response)
43
+ if tools && !tools.empty?
44
+ payload[:tools] = ProviderSchema.for_many(:ollama, tools)
45
+ end
36
46
 
37
- result = handle_response(response)
38
- format_response(result)
47
+ url = "#{@base_url}/api/chat"
48
+ if stream
49
+ payload[:stream] = true
50
+ stream_chat(url, payload, stream)
51
+ else
52
+ payload[:stream] = false
53
+ plain_chat(url, payload)
54
+ end
39
55
  end
40
56
 
41
- def complete(prompt:, **options)
42
- payload = {
43
- model: config.model,
44
- prompt: prompt,
45
- options: {
46
- temperature: options[:temperature] || config.temperature,
47
- num_predict: options[:max_tokens] || config.max_tokens,
48
- top_p: options[:top_p],
49
- top_k: options[:top_k],
50
- repeat_penalty: options[:repeat_penalty]
51
- }.compact
52
- }
53
-
54
- payload[:options][:stop] = options[:stop] if options[:stop]
55
-
56
- url = "#{@base_url}/api/generate"
57
- log_request(:post, url, payload)
58
-
59
- response = http_client.post(url, payload, build_headers)
60
- log_response(response)
57
+ def supports_native_tools?(model: config.model)
58
+ override = RCrewAI.configuration.respond_to?(:ollama_native_tools) ? RCrewAI.configuration.ollama_native_tools : nil
59
+ return override unless override.nil?
61
60
 
62
- result = handle_response(response)
63
- format_completion_response(result)
61
+ base = model.to_s.split(':').first
62
+ NATIVE_TOOL_MODELS.any? { |m| m == model || m.split(':').first == base }
64
63
  end
65
64
 
66
65
  def models
67
66
  url = "#{@base_url}/api/tags"
68
67
  response = http_client.get(url, {}, build_headers)
69
68
  result = handle_response(response)
70
-
71
- if result['models']
72
- result['models'].map { |model| model['name'] }
73
- else
74
- []
75
- end
76
- rescue => e
69
+ Array(result['models']).map { |m| m['name'] }
70
+ rescue StandardError => e
77
71
  logger.warn "Failed to fetch Ollama models: #{e.message}"
78
72
  []
79
73
  end
80
74
 
81
75
  def pull_model(model_name)
82
- payload = { name: model_name }
83
76
  url = "#{@base_url}/api/pull"
84
-
85
- response = http_client.post(url, payload, build_headers)
77
+ response = http_client.post(url, { name: model_name }, build_headers)
86
78
  handle_response(response)
87
79
  end
88
80
 
89
- def model_info(model_name = nil)
90
- model_name ||= config.model
91
- payload = { name: model_name }
92
- url = "#{@base_url}/api/show"
93
-
81
+ private
82
+
83
+ def plain_chat(url, payload)
84
+ log_request(:post, url, payload)
94
85
  response = http_client.post(url, payload, build_headers)
95
- handle_response(response)
96
- rescue => e
97
- logger.warn "Failed to get model info for #{model_name}: #{e.message}"
98
- nil
86
+ log_response(response)
87
+ body = handle_response(response)
88
+ normalize_non_streaming(body)
99
89
  end
100
90
 
101
- private
91
+ def stream_chat(url, payload, sink)
92
+ log_request(:post, url, payload)
102
93
 
103
- def format_messages(messages)
104
- messages.map do |msg|
105
- if msg.is_a?(Hash)
106
- {
107
- role: msg[:role],
108
- content: msg[:content]
109
- }
110
- else
111
- { role: 'user', content: msg.to_s }
94
+ assembled_text = +''
95
+ tool_calls = []
96
+ finish_reason = nil
97
+ prompt_tokens = nil
98
+ completion_tokens = nil
99
+ buffer = String.new(encoding: Encoding::UTF_8)
100
+
101
+ process_line = lambda do |line|
102
+ line = line.strip
103
+ return if line.empty?
104
+
105
+ data = JSON.parse(line)
106
+ if (msg = data['message'])
107
+ if msg['content']
108
+ assembled_text << msg['content']
109
+ sink.call(Events::TextDelta.new(type: :text_delta, timestamp: Time.now,
110
+ agent: nil, iteration: nil,
111
+ text: msg['content']))
112
+ end
113
+ Array(msg['tool_calls']).each do |tc|
114
+ fn = tc['function'] || {}
115
+ tool_calls << {
116
+ id: tc['id'],
117
+ name: fn['name'],
118
+ arguments: fn['arguments'].is_a?(String) ? JSON.parse(fn['arguments']) : (fn['arguments'] || {})
119
+ }
120
+ end
121
+ end
122
+ if data['done']
123
+ finish_reason = tool_calls.any? ? :tool_calls : :stop
124
+ prompt_tokens = data['prompt_eval_count']
125
+ completion_tokens = data['eval_count']
112
126
  end
113
127
  end
114
- end
115
-
116
- def format_response(response)
117
- message = response['message']
118
- return nil unless message
119
128
 
120
- # Ollama doesn't provide detailed usage stats by default
121
- usage = {
122
- 'prompt_tokens' => response['prompt_eval_count'],
123
- 'completion_tokens' => response['eval_count'],
124
- 'total_tokens' => (response['prompt_eval_count'] || 0) + (response['eval_count'] || 0)
125
- }.compact
129
+ streaming_post(url, payload) do |chunk|
130
+ chunk = chunk.dup.force_encoding(Encoding::UTF_8) unless chunk.encoding == Encoding::UTF_8
131
+ buffer << chunk
132
+ while (idx = buffer.index("\n"))
133
+ line = buffer.slice!(0, idx + 1)
134
+ process_line.call(line)
135
+ end
136
+ end
137
+ process_line.call(buffer) unless buffer.empty?
138
+
139
+ if prompt_tokens || completion_tokens
140
+ sink.call(Events::Usage.new(
141
+ type: :usage, timestamp: Time.now, agent: nil, iteration: nil,
142
+ prompt_tokens: prompt_tokens, completion_tokens: completion_tokens,
143
+ total_tokens: (prompt_tokens || 0) + (completion_tokens || 0),
144
+ cost_usd: nil
145
+ ))
146
+ end
126
147
 
127
148
  {
128
- content: message['content'],
129
- role: message['role'] || 'assistant',
130
- finish_reason: response['done'] ? 'stop' : nil,
131
- usage: usage,
132
- model: response['model'] || config.model,
149
+ content: assembled_text.empty? ? nil : assembled_text,
150
+ tool_calls: tool_calls,
151
+ usage: {
152
+ prompt_tokens: prompt_tokens,
153
+ completion_tokens: completion_tokens,
154
+ total_tokens: (prompt_tokens || 0) + (completion_tokens || 0)
155
+ },
156
+ finish_reason: finish_reason || :stop,
157
+ model: config.model,
133
158
  provider: :ollama
134
159
  }
135
160
  end
136
161
 
137
- def format_completion_response(response)
162
+ def streaming_post(url, payload, &on_chunk)
163
+ conn = Faraday.new do |f|
164
+ f.request :json
165
+ f.options.timeout = config.timeout
166
+ f.adapter Faraday.default_adapter
167
+ end
168
+ conn.post(url) do |req|
169
+ req.headers = build_headers
170
+ req.body = payload.to_json
171
+ req.options.on_data = proc { |chunk, _| on_chunk.call(chunk) }
172
+ end
173
+ end
174
+
175
+ def normalize_non_streaming(body)
176
+ msg = body['message'] || {}
177
+ text = msg['content']
178
+ tool_calls = Array(msg['tool_calls']).map do |tc|
179
+ fn = tc['function'] || {}
180
+ args = fn['arguments']
181
+ args = JSON.parse(args) if args.is_a?(String)
182
+ { id: tc['id'], name: fn['name'], arguments: args || {} }
183
+ end
184
+ prompt_tokens = body['prompt_eval_count']
185
+ completion_tokens = body['eval_count']
186
+
138
187
  {
139
- content: response['response'],
140
- finish_reason: response['done'] ? 'stop' : nil,
188
+ content: text && !text.empty? ? text : nil,
189
+ tool_calls: tool_calls,
141
190
  usage: {
142
- 'prompt_tokens' => response['prompt_eval_count'],
143
- 'completion_tokens' => response['eval_count'],
144
- 'total_tokens' => (response['prompt_eval_count'] || 0) + (response['eval_count'] || 0)
145
- }.compact,
146
- model: response['model'] || config.model,
191
+ prompt_tokens: prompt_tokens,
192
+ completion_tokens: completion_tokens,
193
+ total_tokens: (prompt_tokens || 0) + (completion_tokens || 0)
194
+ },
195
+ finish_reason: tool_calls.any? ? :tool_calls : :stop,
196
+ model: body['model'] || config.model,
147
197
  provider: :ollama
148
198
  }
149
199
  end
150
200
 
151
- def validate_config!
152
- # Ollama doesn't require an API key
153
- raise ConfigurationError, "Model is required" unless config.model
154
-
155
- # Test connection to Ollama server
156
- test_connection
157
- end
158
-
159
- def test_connection
160
- url = "#{@base_url}/api/tags"
161
- response = http_client.get(url, {}, build_headers)
162
-
163
- unless (200..299).include?(response.status)
164
- raise ConfigurationError, "Cannot connect to Ollama server at #{@base_url}"
201
+ def format_messages(messages)
202
+ messages.map do |msg|
203
+ if msg.is_a?(Hash)
204
+ { role: msg[:role], content: msg[:content] }
205
+ else
206
+ { role: 'user', content: msg.to_s }
207
+ end
165
208
  end
166
- rescue Faraday::ConnectionFailed
167
- raise ConfigurationError, "Cannot connect to Ollama server at #{@base_url}. Is Ollama running?"
168
209
  end
169
210
 
170
211
  def ollama_url
@@ -172,22 +213,25 @@ module RCrewAI
172
213
  end
173
214
 
174
215
  def build_headers
175
- # Ollama doesn't require special headers
176
216
  {
177
217
  'Content-Type' => 'application/json',
178
218
  'User-Agent' => "rcrewai/#{RCrewAI::VERSION}"
179
219
  }
180
220
  end
181
221
 
222
+ def validate_config!
223
+ raise ConfigurationError, 'Model is required' unless config.model
224
+ end
225
+
182
226
  def handle_response(response)
183
227
  case response.status
184
228
  when 200..299
185
229
  response.body
186
230
  when 400
187
- error_details = response.body['error'] || response.body
188
- raise APIError, "Bad request: #{error_details}"
231
+ details = response.body.is_a?(Hash) ? response.body['error'] : response.body
232
+ raise APIError, "Bad request: #{details}"
189
233
  when 404
190
- raise ModelNotFoundError, "Model '#{config.model}' not found. Try running: ollama pull #{config.model}"
234
+ raise ModelNotFoundError, "Model '#{config.model}' not found. Try: ollama pull #{config.model}"
191
235
  when 500..599
192
236
  raise APIError, "Ollama server error: #{response.status}"
193
237
  else
@@ -196,4 +240,4 @@ module RCrewAI
196
240
  end
197
241
  end
198
242
  end
199
- end
243
+ end
@@ -1,6 +1,12 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'faraday'
4
+ require 'json'
3
5
  require_relative 'base'
6
+ require_relative '../events'
7
+ require_relative '../sse_parser'
8
+ require_relative '../provider_schema'
9
+ require_relative '../pricing'
4
10
 
5
11
  module RCrewAI
6
12
  module LLMClients
@@ -12,113 +18,184 @@ module RCrewAI
12
18
  @base_url = BASE_URL
13
19
  end
14
20
 
15
- def chat(messages:, **options)
21
+ def chat(messages:, tools: nil, tool_choice: :auto, stream: nil, **options)
16
22
  payload = {
17
23
  model: config.model,
18
- messages: format_messages(messages),
24
+ messages: messages,
19
25
  temperature: options[:temperature] || config.temperature,
20
26
  max_tokens: options[:max_tokens] || config.max_tokens
21
- }
27
+ }.compact
22
28
 
23
- # Add additional OpenAI-specific options
24
29
  payload[:top_p] = options[:top_p] if options[:top_p]
25
30
  payload[:frequency_penalty] = options[:frequency_penalty] if options[:frequency_penalty]
26
31
  payload[:presence_penalty] = options[:presence_penalty] if options[:presence_penalty]
27
32
  payload[:stop] = options[:stop] if options[:stop]
28
33
 
29
- url = "#{@base_url}/chat/completions"
30
- log_request(:post, url, payload)
31
-
32
- response = http_client.post(url, payload, build_headers.merge(authorization_header))
33
- log_response(response)
34
-
35
- result = handle_response(response)
36
- format_response(result)
37
- end
34
+ if tools && !tools.empty?
35
+ payload[:tools] = ProviderSchema.for_many(:openai, tools)
36
+ payload[:tool_choice] = tool_choice if tool_choice != :auto
37
+ end
38
38
 
39
- def complete(prompt:, **options)
40
- # For older models that use completions endpoint
41
- if config.model.include?('davinci') || config.model.include?('curie') ||
42
- config.model.include?('babbage') || config.model.include?('ada')
43
- completion_request(prompt, **options)
39
+ if stream
40
+ payload[:stream] = true
41
+ payload[:stream_options] = { include_usage: true }
42
+ stream_chat(payload, stream)
44
43
  else
45
- # Use chat endpoint for newer models
46
- super
44
+ plain_chat(payload)
47
45
  end
48
46
  end
49
47
 
48
+ def supports_native_tools?(model: config.model) # rubocop:disable Lint/UnusedMethodArgument
49
+ true
50
+ end
51
+
50
52
  def models
51
53
  url = "#{@base_url}/models"
52
- response = http_client.get(url, {}, build_headers.merge(authorization_header))
54
+ response = http_client.get(url, {}, build_headers.merge(auth_header))
53
55
  result = handle_response(response)
54
56
  result['data'].map { |model| model['id'] }
55
57
  end
56
58
 
57
59
  private
58
60
 
59
- def authorization_header
60
- { 'Authorization' => "Bearer #{config.api_key}" }
61
+ def chat_url
62
+ "#{@base_url}/chat/completions"
61
63
  end
62
64
 
63
- def completion_request(prompt, **options)
64
- payload = {
65
- model: config.model,
66
- prompt: prompt,
67
- temperature: options[:temperature] || config.temperature,
68
- max_tokens: options[:max_tokens] || config.max_tokens
69
- }
65
+ def plain_chat(payload)
66
+ url = chat_url
67
+ log_request(:post, url, payload)
68
+ response = http_client.post(url, payload, build_headers.merge(auth_header))
69
+ log_response(response)
70
+ body = handle_response(response)
71
+ normalize_non_streaming(body)
72
+ end
70
73
 
71
- url = "#{@base_url}/completions"
74
+ def stream_chat(payload, sink) # rubocop:disable Metrics/AbcSize
75
+ url = chat_url
72
76
  log_request(:post, url, payload)
73
77
 
74
- response = http_client.post(url, payload, build_headers.merge(authorization_header))
75
- log_response(response)
78
+ assembled_text = +''
79
+ tool_calls_by_index = {}
80
+ final_usage = nil
81
+ finish_reason = nil
82
+
83
+ parser = SSEParser.new do |sse|
84
+ data_str = sse[:data]
85
+ next if data_str == '[DONE]'
86
+
87
+ data = JSON.parse(data_str)
88
+ choice = data.dig('choices', 0) || {}
89
+ delta = choice['delta'] || {}
90
+
91
+ if delta['content']
92
+ assembled_text << delta['content']
93
+ sink.call(Events::TextDelta.new(
94
+ type: :text_delta, timestamp: Time.now, agent: nil, iteration: nil,
95
+ text: delta['content']
96
+ ))
97
+ end
76
98
 
77
- result = handle_response(response)
78
- format_completion_response(result)
79
- end
99
+ Array(delta['tool_calls']).each do |tc|
100
+ idx = tc['index']
101
+ tool_calls_by_index[idx] ||= { id: nil, name: nil, arguments: +'' }
102
+ tool_calls_by_index[idx][:id] ||= tc['id']
103
+ tool_calls_by_index[idx][:name] ||= tc.dig('function', 'name')
104
+ tool_calls_by_index[idx][:arguments] << (tc.dig('function', 'arguments') || '')
105
+ end
106
+
107
+ finish_reason ||= choice['finish_reason']&.to_sym
80
108
 
81
- def format_messages(messages)
82
- messages.map do |msg|
83
- if msg.is_a?(Hash)
84
- msg
85
- else
86
- { role: 'user', content: msg.to_s }
109
+ if data['usage']
110
+ final_usage = {
111
+ prompt_tokens: data['usage']['prompt_tokens'],
112
+ completion_tokens: data['usage']['completion_tokens'],
113
+ total_tokens: data['usage']['total_tokens']
114
+ }
87
115
  end
88
116
  end
89
- end
90
117
 
91
- def format_response(response)
92
- choice = response.dig('choices', 0)
93
- return nil unless choice
118
+ streaming_post(url, payload) { |chunk| parser.feed(chunk) }
119
+
120
+ tool_calls = tool_calls_by_index.values.map do |tc|
121
+ {
122
+ id: tc[:id],
123
+ name: tc[:name],
124
+ arguments: tc[:arguments].empty? ? {} : JSON.parse(tc[:arguments])
125
+ }
126
+ end
127
+
128
+ if final_usage
129
+ sink.call(Events::Usage.new(
130
+ type: :usage, timestamp: Time.now, agent: nil, iteration: nil,
131
+ prompt_tokens: final_usage[:prompt_tokens],
132
+ completion_tokens: final_usage[:completion_tokens],
133
+ total_tokens: final_usage[:total_tokens],
134
+ cost_usd: Pricing.cost_for(config.model,
135
+ prompt_tokens: final_usage[:prompt_tokens],
136
+ completion_tokens: final_usage[:completion_tokens])
137
+ ))
138
+ end
94
139
 
95
140
  {
96
- content: choice.dig('message', 'content'),
97
- role: choice.dig('message', 'role'),
98
- finish_reason: choice['finish_reason'],
99
- usage: response['usage'],
100
- model: response['model'],
101
- provider: :openai
141
+ content: assembled_text.empty? ? nil : assembled_text,
142
+ tool_calls: tool_calls,
143
+ usage: final_usage || {},
144
+ finish_reason: finish_reason || :stop,
145
+ model: config.model,
146
+ provider: provider_name
102
147
  }
103
148
  end
104
149
 
105
- def format_completion_response(response)
106
- choice = response.dig('choices', 0)
107
- return nil unless choice
150
+ def provider_name
151
+ :openai
152
+ end
108
153
 
154
+ def streaming_post(url, payload, &on_chunk)
155
+ conn = Faraday.new do |f|
156
+ f.request :json
157
+ f.options.timeout = config.timeout
158
+ f.adapter Faraday.default_adapter
159
+ end
160
+ conn.post(url) do |req|
161
+ req.headers = build_headers.merge(auth_header)
162
+ req.body = payload.to_json
163
+ req.options.on_data = proc { |chunk, _| on_chunk.call(chunk) }
164
+ end
165
+ end
166
+
167
+ def normalize_non_streaming(body)
168
+ choice = body.dig('choices', 0) || {}
169
+ msg = choice['message'] || {}
170
+ tool_calls = Array(msg['tool_calls']).map do |tc|
171
+ {
172
+ id: tc['id'],
173
+ name: tc.dig('function', 'name'),
174
+ arguments: JSON.parse(tc.dig('function', 'arguments') || '{}')
175
+ }
176
+ end
109
177
  {
110
- content: choice['text'],
111
- finish_reason: choice['finish_reason'],
112
- usage: response['usage'],
113
- model: response['model'],
114
- provider: :openai
178
+ content: msg['content'],
179
+ tool_calls: tool_calls,
180
+ usage: {
181
+ prompt_tokens: body.dig('usage', 'prompt_tokens'),
182
+ completion_tokens: body.dig('usage', 'completion_tokens'),
183
+ total_tokens: body.dig('usage', 'total_tokens')
184
+ },
185
+ finish_reason: (choice['finish_reason'] || 'stop').to_sym,
186
+ model: body['model'] || config.model,
187
+ provider: provider_name
115
188
  }
116
189
  end
117
190
 
191
+ def auth_header
192
+ { 'Authorization' => "Bearer #{config.openai_api_key || config.api_key}" }
193
+ end
194
+
118
195
  def validate_config!
119
- raise ConfigurationError, "OpenAI API key is required" unless config.openai_api_key || config.api_key
120
- raise ConfigurationError, "Model is required" unless config.model
196
+ raise ConfigurationError, 'OpenAI API key is required' unless config.openai_api_key || config.api_key
197
+ raise ConfigurationError, 'Model is required' unless config.model
121
198
  end
122
199
  end
123
200
  end
124
- end
201
+ end