rcrewai 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +108 -0
- data/LICENSE +21 -0
- data/README.md +328 -0
- data/Rakefile +130 -0
- data/bin/rcrewai +7 -0
- data/docs/_config.yml +59 -0
- data/docs/_layouts/api.html +16 -0
- data/docs/_layouts/default.html +78 -0
- data/docs/_layouts/example.html +24 -0
- data/docs/_layouts/tutorial.html +33 -0
- data/docs/api/configuration.md +327 -0
- data/docs/api/crew.md +345 -0
- data/docs/api/index.md +41 -0
- data/docs/api/tools.md +412 -0
- data/docs/assets/css/style.css +416 -0
- data/docs/examples/human-in-the-loop.md +382 -0
- data/docs/examples/index.md +78 -0
- data/docs/examples/production-ready-crew.md +485 -0
- data/docs/examples/simple-research-crew.md +297 -0
- data/docs/index.md +353 -0
- data/docs/tutorials/getting-started.md +341 -0
- data/examples/async_execution_example.rb +294 -0
- data/examples/hierarchical_crew_example.rb +193 -0
- data/examples/human_in_the_loop_example.rb +233 -0
- data/lib/rcrewai/agent.rb +636 -0
- data/lib/rcrewai/async_executor.rb +248 -0
- data/lib/rcrewai/cli.rb +39 -0
- data/lib/rcrewai/configuration.rb +100 -0
- data/lib/rcrewai/crew.rb +292 -0
- data/lib/rcrewai/human_input.rb +520 -0
- data/lib/rcrewai/llm_client.rb +41 -0
- data/lib/rcrewai/llm_clients/anthropic.rb +127 -0
- data/lib/rcrewai/llm_clients/azure.rb +158 -0
- data/lib/rcrewai/llm_clients/base.rb +82 -0
- data/lib/rcrewai/llm_clients/google.rb +158 -0
- data/lib/rcrewai/llm_clients/ollama.rb +199 -0
- data/lib/rcrewai/llm_clients/openai.rb +124 -0
- data/lib/rcrewai/memory.rb +194 -0
- data/lib/rcrewai/process.rb +421 -0
- data/lib/rcrewai/task.rb +376 -0
- data/lib/rcrewai/tools/base.rb +82 -0
- data/lib/rcrewai/tools/code_executor.rb +333 -0
- data/lib/rcrewai/tools/email_sender.rb +210 -0
- data/lib/rcrewai/tools/file_reader.rb +111 -0
- data/lib/rcrewai/tools/file_writer.rb +115 -0
- data/lib/rcrewai/tools/pdf_processor.rb +342 -0
- data/lib/rcrewai/tools/sql_database.rb +226 -0
- data/lib/rcrewai/tools/web_search.rb +131 -0
- data/lib/rcrewai/version.rb +5 -0
- data/lib/rcrewai.rb +36 -0
- data/rcrewai.gemspec +54 -0
- metadata +365 -0
@@ -0,0 +1,158 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'base'
|
4
|
+
|
5
|
+
module RCrewAI
|
6
|
+
module LLMClients
|
7
|
+
class Azure < Base
|
8
|
+
def initialize(config = RCrewAI.configuration)
|
9
|
+
super
|
10
|
+
@base_url = config.base_url || build_azure_url
|
11
|
+
@api_version = config.api_version || '2024-02-01'
|
12
|
+
@deployment_name = config.deployment_name || config.model
|
13
|
+
end
|
14
|
+
|
15
|
+
def chat(messages:, **options)
|
16
|
+
payload = {
|
17
|
+
messages: format_messages(messages),
|
18
|
+
temperature: options[:temperature] || config.temperature,
|
19
|
+
max_tokens: options[:max_tokens] || config.max_tokens
|
20
|
+
}
|
21
|
+
|
22
|
+
# Add additional OpenAI-compatible options
|
23
|
+
payload[:top_p] = options[:top_p] if options[:top_p]
|
24
|
+
payload[:frequency_penalty] = options[:frequency_penalty] if options[:frequency_penalty]
|
25
|
+
payload[:presence_penalty] = options[:presence_penalty] if options[:presence_penalty]
|
26
|
+
payload[:stop] = options[:stop] if options[:stop]
|
27
|
+
|
28
|
+
url = "#{@base_url}/openai/deployments/#{@deployment_name}/chat/completions?api-version=#{@api_version}"
|
29
|
+
log_request(:post, url, payload)
|
30
|
+
|
31
|
+
response = http_client.post(url, payload, build_headers.merge(authorization_header))
|
32
|
+
log_response(response)
|
33
|
+
|
34
|
+
result = handle_response(response)
|
35
|
+
format_response(result)
|
36
|
+
end
|
37
|
+
|
38
|
+
def complete(prompt:, **options)
|
39
|
+
# For older models that use completions endpoint
|
40
|
+
payload = {
|
41
|
+
prompt: prompt,
|
42
|
+
temperature: options[:temperature] || config.temperature,
|
43
|
+
max_tokens: options[:max_tokens] || config.max_tokens
|
44
|
+
}
|
45
|
+
|
46
|
+
url = "#{@base_url}/openai/deployments/#{@deployment_name}/completions?api-version=#{@api_version}"
|
47
|
+
log_request(:post, url, payload)
|
48
|
+
|
49
|
+
response = http_client.post(url, payload, build_headers.merge(authorization_header))
|
50
|
+
log_response(response)
|
51
|
+
|
52
|
+
result = handle_response(response)
|
53
|
+
format_completion_response(result)
|
54
|
+
end
|
55
|
+
|
56
|
+
def models
|
57
|
+
# Azure OpenAI uses deployments instead of models
|
58
|
+
url = "#{@base_url}/openai/deployments?api-version=#{@api_version}"
|
59
|
+
response = http_client.get(url, {}, build_headers.merge(authorization_header))
|
60
|
+
result = handle_response(response)
|
61
|
+
|
62
|
+
if result['data']
|
63
|
+
result['data'].map { |deployment| deployment['id'] }
|
64
|
+
else
|
65
|
+
[@deployment_name].compact
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
private
|
70
|
+
|
71
|
+
def authorization_header
|
72
|
+
{ 'api-key' => config.api_key }
|
73
|
+
end
|
74
|
+
|
75
|
+
def format_messages(messages)
|
76
|
+
messages.map do |msg|
|
77
|
+
if msg.is_a?(Hash)
|
78
|
+
msg
|
79
|
+
else
|
80
|
+
{ role: 'user', content: msg.to_s }
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
|
85
|
+
def format_response(response)
|
86
|
+
choice = response.dig('choices', 0)
|
87
|
+
return nil unless choice
|
88
|
+
|
89
|
+
{
|
90
|
+
content: choice.dig('message', 'content'),
|
91
|
+
role: choice.dig('message', 'role'),
|
92
|
+
finish_reason: choice['finish_reason'],
|
93
|
+
usage: response['usage'],
|
94
|
+
model: @deployment_name,
|
95
|
+
provider: :azure
|
96
|
+
}
|
97
|
+
end
|
98
|
+
|
99
|
+
def format_completion_response(response)
|
100
|
+
choice = response.dig('choices', 0)
|
101
|
+
return nil unless choice
|
102
|
+
|
103
|
+
{
|
104
|
+
content: choice['text'],
|
105
|
+
finish_reason: choice['finish_reason'],
|
106
|
+
usage: response['usage'],
|
107
|
+
model: @deployment_name,
|
108
|
+
provider: :azure
|
109
|
+
}
|
110
|
+
end
|
111
|
+
|
112
|
+
def validate_config!
|
113
|
+
super
|
114
|
+
raise ConfigurationError, "Azure API key is required" unless config.azure_api_key || config.api_key
|
115
|
+
raise ConfigurationError, "Azure base URL or endpoint is required" unless config.base_url || azure_endpoint
|
116
|
+
raise ConfigurationError, "Azure deployment name is required" unless config.deployment_name || config.model
|
117
|
+
end
|
118
|
+
|
119
|
+
def build_azure_url
|
120
|
+
endpoint = azure_endpoint
|
121
|
+
return nil unless endpoint
|
122
|
+
|
123
|
+
# Remove trailing slash and add proper path
|
124
|
+
endpoint = endpoint.chomp('/')
|
125
|
+
"#{endpoint}"
|
126
|
+
end
|
127
|
+
|
128
|
+
def azure_endpoint
|
129
|
+
# Try multiple environment variable names
|
130
|
+
ENV['AZURE_OPENAI_ENDPOINT'] ||
|
131
|
+
ENV['AZURE_ENDPOINT'] ||
|
132
|
+
config.instance_variable_get(:@azure_endpoint)
|
133
|
+
end
|
134
|
+
|
135
|
+
def handle_response(response)
|
136
|
+
case response.status
|
137
|
+
when 200..299
|
138
|
+
response.body
|
139
|
+
when 400
|
140
|
+
error_details = response.body.dig('error', 'message') || response.body
|
141
|
+
raise APIError, "Bad request: #{error_details}"
|
142
|
+
when 401
|
143
|
+
raise AuthenticationError, "Invalid API key or authentication failed"
|
144
|
+
when 403
|
145
|
+
raise AuthenticationError, "Access denied - check your API key and permissions"
|
146
|
+
when 404
|
147
|
+
raise ModelNotFoundError, "Deployment '#{@deployment_name}' not found"
|
148
|
+
when 429
|
149
|
+
raise RateLimitError, "Rate limit exceeded or quota exhausted"
|
150
|
+
when 500..599
|
151
|
+
raise APIError, "Azure OpenAI service error: #{response.status}"
|
152
|
+
else
|
153
|
+
raise APIError, "Unexpected response: #{response.status}"
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
157
|
+
end
|
158
|
+
end
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'faraday'
|
4
|
+
require 'json'
|
5
|
+
require 'logger'
|
6
|
+
|
7
|
+
module RCrewAI
|
8
|
+
module LLMClients
|
9
|
+
class Base
|
10
|
+
attr_reader :config, :logger
|
11
|
+
|
12
|
+
def initialize(config = RCrewAI.configuration)
|
13
|
+
@config = config
|
14
|
+
@logger = Logger.new($stdout)
|
15
|
+
@logger.level = Logger::INFO
|
16
|
+
validate_config!
|
17
|
+
end
|
18
|
+
|
19
|
+
def chat(messages:, **options)
|
20
|
+
raise NotImplementedError, "Subclasses must implement #chat method"
|
21
|
+
end
|
22
|
+
|
23
|
+
def complete(prompt:, **options)
|
24
|
+
chat(messages: [{ role: 'user', content: prompt }], **options)
|
25
|
+
end
|
26
|
+
|
27
|
+
protected
|
28
|
+
|
29
|
+
def validate_config!
|
30
|
+
raise ConfigurationError, "API key is required" unless config.api_key
|
31
|
+
raise ConfigurationError, "Model is required" unless config.model
|
32
|
+
end
|
33
|
+
|
34
|
+
def build_headers
|
35
|
+
{
|
36
|
+
'Content-Type' => 'application/json',
|
37
|
+
'User-Agent' => "rcrewai/#{RCrewAI::VERSION}"
|
38
|
+
}
|
39
|
+
end
|
40
|
+
|
41
|
+
def http_client
|
42
|
+
@http_client ||= Faraday.new do |f|
|
43
|
+
f.request :json
|
44
|
+
f.response :json
|
45
|
+
f.adapter Faraday.default_adapter
|
46
|
+
f.options.timeout = config.timeout
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
def handle_response(response)
|
51
|
+
case response.status
|
52
|
+
when 200..299
|
53
|
+
response.body
|
54
|
+
when 400
|
55
|
+
raise APIError, "Bad request: #{response.body}"
|
56
|
+
when 401
|
57
|
+
raise AuthenticationError, "Invalid API key"
|
58
|
+
when 429
|
59
|
+
raise RateLimitError, "Rate limit exceeded"
|
60
|
+
when 500..599
|
61
|
+
raise APIError, "Server error: #{response.status}"
|
62
|
+
else
|
63
|
+
raise APIError, "Unexpected response: #{response.status}"
|
64
|
+
end
|
65
|
+
end
|
66
|
+
|
67
|
+
def log_request(method, url, payload = nil)
|
68
|
+
logger.info "#{method.upcase} #{url}"
|
69
|
+
logger.debug "Payload: #{payload}" if payload
|
70
|
+
end
|
71
|
+
|
72
|
+
def log_response(response)
|
73
|
+
logger.debug "Response: #{response.status} - #{response.body}"
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
class APIError < RCrewAI::Error; end
|
78
|
+
class AuthenticationError < APIError; end
|
79
|
+
class RateLimitError < APIError; end
|
80
|
+
class ModelNotFoundError < APIError; end
|
81
|
+
end
|
82
|
+
end
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'base'
|
4
|
+
|
5
|
+
module RCrewAI
|
6
|
+
module LLMClients
|
7
|
+
class Google < Base
|
8
|
+
BASE_URL = 'https://generativelanguage.googleapis.com/v1beta'
|
9
|
+
|
10
|
+
def initialize(config = RCrewAI.configuration)
|
11
|
+
super
|
12
|
+
@base_url = BASE_URL
|
13
|
+
end
|
14
|
+
|
15
|
+
def chat(messages:, **options)
|
16
|
+
# Convert messages to Gemini format
|
17
|
+
formatted_contents = format_messages(messages)
|
18
|
+
|
19
|
+
payload = {
|
20
|
+
contents: formatted_contents,
|
21
|
+
generationConfig: {
|
22
|
+
temperature: options[:temperature] || config.temperature,
|
23
|
+
maxOutputTokens: options[:max_tokens] || config.max_tokens || 2048,
|
24
|
+
topP: options[:top_p] || 0.8,
|
25
|
+
topK: options[:top_k] || 10
|
26
|
+
}
|
27
|
+
}
|
28
|
+
|
29
|
+
# Add safety settings if provided
|
30
|
+
if options[:safety_settings]
|
31
|
+
payload[:safetySettings] = options[:safety_settings]
|
32
|
+
end
|
33
|
+
|
34
|
+
# Add stop sequences if provided
|
35
|
+
if options[:stop_sequences]
|
36
|
+
payload[:generationConfig][:stopSequences] = options[:stop_sequences]
|
37
|
+
end
|
38
|
+
|
39
|
+
url = "#{@base_url}/models/#{config.model}:generateContent?key=#{config.api_key}"
|
40
|
+
log_request(:post, url, payload)
|
41
|
+
|
42
|
+
response = http_client.post(url, payload, build_headers)
|
43
|
+
log_response(response)
|
44
|
+
|
45
|
+
result = handle_response(response)
|
46
|
+
format_response(result)
|
47
|
+
end
|
48
|
+
|
49
|
+
def models
|
50
|
+
# Google AI Studio doesn't provide a models list endpoint with API key auth
|
51
|
+
# Return known Gemini models
|
52
|
+
[
|
53
|
+
'gemini-pro',
|
54
|
+
'gemini-pro-vision',
|
55
|
+
'gemini-1.5-pro',
|
56
|
+
'gemini-1.5-flash',
|
57
|
+
'text-bison-001',
|
58
|
+
'chat-bison-001'
|
59
|
+
]
|
60
|
+
end
|
61
|
+
|
62
|
+
private
|
63
|
+
|
64
|
+
def format_messages(messages)
|
65
|
+
contents = []
|
66
|
+
|
67
|
+
messages.each do |msg|
|
68
|
+
role = case msg[:role]
|
69
|
+
when 'user'
|
70
|
+
'user'
|
71
|
+
when 'assistant'
|
72
|
+
'model'
|
73
|
+
when 'system'
|
74
|
+
# Gemini doesn't have system role, prepend to first user message
|
75
|
+
next
|
76
|
+
else
|
77
|
+
'user'
|
78
|
+
end
|
79
|
+
|
80
|
+
content = if msg.is_a?(Hash)
|
81
|
+
msg[:content]
|
82
|
+
else
|
83
|
+
msg.to_s
|
84
|
+
end
|
85
|
+
|
86
|
+
contents << {
|
87
|
+
role: role,
|
88
|
+
parts: [{ text: content }]
|
89
|
+
}
|
90
|
+
end
|
91
|
+
|
92
|
+
# Handle system message by prepending to first user message
|
93
|
+
system_msg = messages.find { |m| m[:role] == 'system' }
|
94
|
+
if system_msg && contents.any?
|
95
|
+
first_user_content = contents.find { |c| c[:role] == 'user' }
|
96
|
+
if first_user_content
|
97
|
+
original_text = first_user_content[:parts].first[:text]
|
98
|
+
first_user_content[:parts].first[:text] = "#{system_msg[:content]}\n\n#{original_text}"
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
contents
|
103
|
+
end
|
104
|
+
|
105
|
+
def format_response(response)
|
106
|
+
candidate = response.dig('candidates', 0)
|
107
|
+
return nil unless candidate
|
108
|
+
|
109
|
+
content = candidate.dig('content', 'parts', 0, 'text')
|
110
|
+
finish_reason = candidate['finishReason']
|
111
|
+
|
112
|
+
# Extract usage information if available
|
113
|
+
usage_metadata = response['usageMetadata']
|
114
|
+
usage = if usage_metadata
|
115
|
+
{
|
116
|
+
'prompt_tokens' => usage_metadata['promptTokenCount'],
|
117
|
+
'completion_tokens' => usage_metadata['candidatesTokenCount'],
|
118
|
+
'total_tokens' => usage_metadata['totalTokenCount']
|
119
|
+
}
|
120
|
+
end
|
121
|
+
|
122
|
+
{
|
123
|
+
content: content,
|
124
|
+
role: 'assistant',
|
125
|
+
finish_reason: finish_reason,
|
126
|
+
usage: usage,
|
127
|
+
model: config.model,
|
128
|
+
provider: :google
|
129
|
+
}
|
130
|
+
end
|
131
|
+
|
132
|
+
def validate_config!
|
133
|
+
super
|
134
|
+
raise ConfigurationError, "Google API key is required" unless config.google_api_key || config.api_key
|
135
|
+
end
|
136
|
+
|
137
|
+
def handle_response(response)
|
138
|
+
case response.status
|
139
|
+
when 200..299
|
140
|
+
response.body
|
141
|
+
when 400
|
142
|
+
error_details = response.body.dig('error', 'message') || response.body
|
143
|
+
raise APIError, "Bad request: #{error_details}"
|
144
|
+
when 401
|
145
|
+
raise AuthenticationError, "Invalid API key"
|
146
|
+
when 403
|
147
|
+
raise AuthenticationError, "API key does not have permission"
|
148
|
+
when 429
|
149
|
+
raise RateLimitError, "Rate limit exceeded or quota exhausted"
|
150
|
+
when 500..599
|
151
|
+
raise APIError, "Server error: #{response.status}"
|
152
|
+
else
|
153
|
+
raise APIError, "Unexpected response: #{response.status}"
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
157
|
+
end
|
158
|
+
end
|
@@ -0,0 +1,199 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'base'
|
4
|
+
|
5
|
+
module RCrewAI
|
6
|
+
module LLMClients
|
7
|
+
class Ollama < Base
|
8
|
+
DEFAULT_URL = 'http://localhost:11434'
|
9
|
+
|
10
|
+
def initialize(config = RCrewAI.configuration)
|
11
|
+
super
|
12
|
+
@base_url = config.base_url || ollama_url || DEFAULT_URL
|
13
|
+
end
|
14
|
+
|
15
|
+
def chat(messages:, **options)
|
16
|
+
payload = {
|
17
|
+
model: config.model,
|
18
|
+
messages: format_messages(messages),
|
19
|
+
options: {
|
20
|
+
temperature: options[:temperature] || config.temperature,
|
21
|
+
num_predict: options[:max_tokens] || config.max_tokens,
|
22
|
+
top_p: options[:top_p],
|
23
|
+
top_k: options[:top_k],
|
24
|
+
repeat_penalty: options[:repeat_penalty]
|
25
|
+
}.compact
|
26
|
+
}
|
27
|
+
|
28
|
+
# Add stop sequences if provided
|
29
|
+
payload[:options][:stop] = options[:stop] if options[:stop]
|
30
|
+
|
31
|
+
url = "#{@base_url}/api/chat"
|
32
|
+
log_request(:post, url, payload)
|
33
|
+
|
34
|
+
response = http_client.post(url, payload, build_headers)
|
35
|
+
log_response(response)
|
36
|
+
|
37
|
+
result = handle_response(response)
|
38
|
+
format_response(result)
|
39
|
+
end
|
40
|
+
|
41
|
+
def complete(prompt:, **options)
|
42
|
+
payload = {
|
43
|
+
model: config.model,
|
44
|
+
prompt: prompt,
|
45
|
+
options: {
|
46
|
+
temperature: options[:temperature] || config.temperature,
|
47
|
+
num_predict: options[:max_tokens] || config.max_tokens,
|
48
|
+
top_p: options[:top_p],
|
49
|
+
top_k: options[:top_k],
|
50
|
+
repeat_penalty: options[:repeat_penalty]
|
51
|
+
}.compact
|
52
|
+
}
|
53
|
+
|
54
|
+
payload[:options][:stop] = options[:stop] if options[:stop]
|
55
|
+
|
56
|
+
url = "#{@base_url}/api/generate"
|
57
|
+
log_request(:post, url, payload)
|
58
|
+
|
59
|
+
response = http_client.post(url, payload, build_headers)
|
60
|
+
log_response(response)
|
61
|
+
|
62
|
+
result = handle_response(response)
|
63
|
+
format_completion_response(result)
|
64
|
+
end
|
65
|
+
|
66
|
+
def models
|
67
|
+
url = "#{@base_url}/api/tags"
|
68
|
+
response = http_client.get(url, {}, build_headers)
|
69
|
+
result = handle_response(response)
|
70
|
+
|
71
|
+
if result['models']
|
72
|
+
result['models'].map { |model| model['name'] }
|
73
|
+
else
|
74
|
+
[]
|
75
|
+
end
|
76
|
+
rescue => e
|
77
|
+
logger.warn "Failed to fetch Ollama models: #{e.message}"
|
78
|
+
[]
|
79
|
+
end
|
80
|
+
|
81
|
+
def pull_model(model_name)
|
82
|
+
payload = { name: model_name }
|
83
|
+
url = "#{@base_url}/api/pull"
|
84
|
+
|
85
|
+
response = http_client.post(url, payload, build_headers)
|
86
|
+
handle_response(response)
|
87
|
+
end
|
88
|
+
|
89
|
+
def model_info(model_name = nil)
|
90
|
+
model_name ||= config.model
|
91
|
+
payload = { name: model_name }
|
92
|
+
url = "#{@base_url}/api/show"
|
93
|
+
|
94
|
+
response = http_client.post(url, payload, build_headers)
|
95
|
+
handle_response(response)
|
96
|
+
rescue => e
|
97
|
+
logger.warn "Failed to get model info for #{model_name}: #{e.message}"
|
98
|
+
nil
|
99
|
+
end
|
100
|
+
|
101
|
+
private
|
102
|
+
|
103
|
+
def format_messages(messages)
|
104
|
+
messages.map do |msg|
|
105
|
+
if msg.is_a?(Hash)
|
106
|
+
{
|
107
|
+
role: msg[:role],
|
108
|
+
content: msg[:content]
|
109
|
+
}
|
110
|
+
else
|
111
|
+
{ role: 'user', content: msg.to_s }
|
112
|
+
end
|
113
|
+
end
|
114
|
+
end
|
115
|
+
|
116
|
+
def format_response(response)
|
117
|
+
message = response['message']
|
118
|
+
return nil unless message
|
119
|
+
|
120
|
+
# Ollama doesn't provide detailed usage stats by default
|
121
|
+
usage = {
|
122
|
+
'prompt_tokens' => response['prompt_eval_count'],
|
123
|
+
'completion_tokens' => response['eval_count'],
|
124
|
+
'total_tokens' => (response['prompt_eval_count'] || 0) + (response['eval_count'] || 0)
|
125
|
+
}.compact
|
126
|
+
|
127
|
+
{
|
128
|
+
content: message['content'],
|
129
|
+
role: message['role'] || 'assistant',
|
130
|
+
finish_reason: response['done'] ? 'stop' : nil,
|
131
|
+
usage: usage,
|
132
|
+
model: response['model'] || config.model,
|
133
|
+
provider: :ollama
|
134
|
+
}
|
135
|
+
end
|
136
|
+
|
137
|
+
def format_completion_response(response)
|
138
|
+
{
|
139
|
+
content: response['response'],
|
140
|
+
finish_reason: response['done'] ? 'stop' : nil,
|
141
|
+
usage: {
|
142
|
+
'prompt_tokens' => response['prompt_eval_count'],
|
143
|
+
'completion_tokens' => response['eval_count'],
|
144
|
+
'total_tokens' => (response['prompt_eval_count'] || 0) + (response['eval_count'] || 0)
|
145
|
+
}.compact,
|
146
|
+
model: response['model'] || config.model,
|
147
|
+
provider: :ollama
|
148
|
+
}
|
149
|
+
end
|
150
|
+
|
151
|
+
def validate_config!
|
152
|
+
# Ollama doesn't require an API key
|
153
|
+
raise ConfigurationError, "Model is required" unless config.model
|
154
|
+
|
155
|
+
# Test connection to Ollama server
|
156
|
+
test_connection
|
157
|
+
end
|
158
|
+
|
159
|
+
def test_connection
|
160
|
+
url = "#{@base_url}/api/tags"
|
161
|
+
response = http_client.get(url, {}, build_headers)
|
162
|
+
|
163
|
+
unless (200..299).include?(response.status)
|
164
|
+
raise ConfigurationError, "Cannot connect to Ollama server at #{@base_url}"
|
165
|
+
end
|
166
|
+
rescue Faraday::ConnectionFailed
|
167
|
+
raise ConfigurationError, "Cannot connect to Ollama server at #{@base_url}. Is Ollama running?"
|
168
|
+
end
|
169
|
+
|
170
|
+
def ollama_url
|
171
|
+
ENV['OLLAMA_HOST'] || ENV['OLLAMA_URL']
|
172
|
+
end
|
173
|
+
|
174
|
+
def build_headers
|
175
|
+
# Ollama doesn't require special headers
|
176
|
+
{
|
177
|
+
'Content-Type' => 'application/json',
|
178
|
+
'User-Agent' => "rcrewai/#{RCrewAI::VERSION}"
|
179
|
+
}
|
180
|
+
end
|
181
|
+
|
182
|
+
def handle_response(response)
|
183
|
+
case response.status
|
184
|
+
when 200..299
|
185
|
+
response.body
|
186
|
+
when 400
|
187
|
+
error_details = response.body['error'] || response.body
|
188
|
+
raise APIError, "Bad request: #{error_details}"
|
189
|
+
when 404
|
190
|
+
raise ModelNotFoundError, "Model '#{config.model}' not found. Try running: ollama pull #{config.model}"
|
191
|
+
when 500..599
|
192
|
+
raise APIError, "Ollama server error: #{response.status}"
|
193
|
+
else
|
194
|
+
raise APIError, "Unexpected response: #{response.status}"
|
195
|
+
end
|
196
|
+
end
|
197
|
+
end
|
198
|
+
end
|
199
|
+
end
|