desiru 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.env.example +34 -0
- data/.rubocop.yml +7 -4
- data/.ruby-version +1 -0
- data/CLAUDE.md +4 -0
- data/Gemfile +21 -2
- data/Gemfile.lock +87 -12
- data/README.md +295 -2
- data/Rakefile +1 -0
- data/db/migrations/001_create_initial_tables.rb +96 -0
- data/db/migrations/002_create_job_results.rb +39 -0
- data/desiru.db +0 -0
- data/desiru.gemspec +2 -5
- data/docs/background_processing_roadmap.md +87 -0
- data/docs/job_scheduling.md +167 -0
- data/dspy-analysis-swarm.yml +60 -0
- data/dspy-feature-analysis.md +121 -0
- data/examples/README.md +69 -0
- data/examples/api_with_persistence.rb +122 -0
- data/examples/assertions_example.rb +232 -0
- data/examples/async_processing.rb +2 -0
- data/examples/few_shot_learning.rb +1 -2
- data/examples/graphql_api.rb +4 -2
- data/examples/graphql_integration.rb +3 -3
- data/examples/graphql_optimization_summary.md +143 -0
- data/examples/graphql_performance_benchmark.rb +247 -0
- data/examples/persistence_example.rb +102 -0
- data/examples/react_agent.rb +203 -0
- data/examples/rest_api.rb +173 -0
- data/examples/rest_api_advanced.rb +333 -0
- data/examples/scheduled_job_example.rb +116 -0
- data/examples/simple_qa.rb +1 -2
- data/examples/sinatra_api.rb +109 -0
- data/examples/typed_signatures.rb +1 -2
- data/graphql_optimization_summary.md +53 -0
- data/lib/desiru/api/grape_integration.rb +284 -0
- data/lib/desiru/api/persistence_middleware.rb +148 -0
- data/lib/desiru/api/sinatra_integration.rb +217 -0
- data/lib/desiru/api.rb +42 -0
- data/lib/desiru/assertions.rb +74 -0
- data/lib/desiru/async_status.rb +65 -0
- data/lib/desiru/cache.rb +1 -1
- data/lib/desiru/configuration.rb +2 -1
- data/lib/desiru/errors.rb +160 -0
- data/lib/desiru/field.rb +17 -14
- data/lib/desiru/graphql/batch_loader.rb +85 -0
- data/lib/desiru/graphql/data_loader.rb +242 -75
- data/lib/desiru/graphql/enum_builder.rb +75 -0
- data/lib/desiru/graphql/executor.rb +37 -4
- data/lib/desiru/graphql/schema_generator.rb +62 -158
- data/lib/desiru/graphql/type_builder.rb +138 -0
- data/lib/desiru/graphql/type_cache_warmer.rb +91 -0
- data/lib/desiru/jobs/async_predict.rb +1 -1
- data/lib/desiru/jobs/base.rb +67 -0
- data/lib/desiru/jobs/batch_processor.rb +6 -6
- data/lib/desiru/jobs/retriable.rb +119 -0
- data/lib/desiru/jobs/retry_strategies.rb +169 -0
- data/lib/desiru/jobs/scheduler.rb +219 -0
- data/lib/desiru/jobs/webhook_notifier.rb +242 -0
- data/lib/desiru/models/anthropic.rb +164 -0
- data/lib/desiru/models/base.rb +37 -3
- data/lib/desiru/models/open_ai.rb +151 -0
- data/lib/desiru/models/open_router.rb +161 -0
- data/lib/desiru/module.rb +59 -9
- data/lib/desiru/modules/chain_of_thought.rb +3 -3
- data/lib/desiru/modules/majority.rb +51 -0
- data/lib/desiru/modules/multi_chain_comparison.rb +204 -0
- data/lib/desiru/modules/predict.rb +8 -1
- data/lib/desiru/modules/program_of_thought.rb +139 -0
- data/lib/desiru/modules/react.rb +273 -0
- data/lib/desiru/modules/retrieve.rb +4 -2
- data/lib/desiru/optimizers/base.rb +2 -4
- data/lib/desiru/optimizers/bootstrap_few_shot.rb +2 -2
- data/lib/desiru/optimizers/copro.rb +268 -0
- data/lib/desiru/optimizers/knn_few_shot.rb +185 -0
- data/lib/desiru/persistence/database.rb +71 -0
- data/lib/desiru/persistence/models/api_request.rb +38 -0
- data/lib/desiru/persistence/models/job_result.rb +138 -0
- data/lib/desiru/persistence/models/module_execution.rb +37 -0
- data/lib/desiru/persistence/models/optimization_result.rb +28 -0
- data/lib/desiru/persistence/models/training_example.rb +25 -0
- data/lib/desiru/persistence/models.rb +11 -0
- data/lib/desiru/persistence/repositories/api_request_repository.rb +98 -0
- data/lib/desiru/persistence/repositories/base_repository.rb +77 -0
- data/lib/desiru/persistence/repositories/job_result_repository.rb +116 -0
- data/lib/desiru/persistence/repositories/module_execution_repository.rb +85 -0
- data/lib/desiru/persistence/repositories/optimization_result_repository.rb +67 -0
- data/lib/desiru/persistence/repositories/training_example_repository.rb +102 -0
- data/lib/desiru/persistence/repository.rb +29 -0
- data/lib/desiru/persistence/setup.rb +77 -0
- data/lib/desiru/persistence.rb +49 -0
- data/lib/desiru/registry.rb +3 -5
- data/lib/desiru/signature.rb +91 -24
- data/lib/desiru/version.rb +1 -1
- data/lib/desiru.rb +23 -8
- data/missing-features-analysis.md +192 -0
- metadata +63 -45
- data/lib/desiru/models/raix_adapter.rb +0 -210
@@ -0,0 +1,242 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'net/http'
|
4
|
+
require 'uri'
|
5
|
+
require 'json'
|
6
|
+
|
7
|
+
module Desiru
|
8
|
+
module Jobs
|
9
|
+
# Handles webhook notifications for job events
|
10
|
+
class WebhookNotifier
|
11
|
+
attr_reader :config
|
12
|
+
|
13
|
+
def initialize(config = {})
|
14
|
+
@config = {
|
15
|
+
timeout: 30,
|
16
|
+
retry_count: 3,
|
17
|
+
retry_delay: 1,
|
18
|
+
headers: {
|
19
|
+
'Content-Type' => 'application/json',
|
20
|
+
'User-Agent' => "Desiru/#{Desiru::VERSION}"
|
21
|
+
}
|
22
|
+
}.merge(config)
|
23
|
+
end
|
24
|
+
|
25
|
+
# Send a webhook notification
|
26
|
+
# @param url [String] the webhook URL
|
27
|
+
# @param payload [Hash] the payload to send
|
28
|
+
# @param options [Hash] additional options
|
29
|
+
# @return [WebhookResult] the result of the webhook call
|
30
|
+
def notify(url, payload, options = {})
|
31
|
+
uri = URI.parse(url)
|
32
|
+
headers = config[:headers].merge(options[:headers] || {})
|
33
|
+
|
34
|
+
# Add signature if secret is provided
|
35
|
+
if options[:secret]
|
36
|
+
signature = generate_signature(payload, options[:secret])
|
37
|
+
headers['X-Desiru-Signature'] = signature
|
38
|
+
end
|
39
|
+
|
40
|
+
attempt = 0
|
41
|
+
last_error = nil
|
42
|
+
|
43
|
+
while attempt < config[:retry_count]
|
44
|
+
attempt += 1
|
45
|
+
|
46
|
+
begin
|
47
|
+
response = send_request(uri, payload, headers)
|
48
|
+
|
49
|
+
if response.code.to_i >= 200 && response.code.to_i < 300
|
50
|
+
return WebhookResult.new(
|
51
|
+
success: true,
|
52
|
+
status_code: response.code.to_i,
|
53
|
+
body: response.body,
|
54
|
+
headers: response.to_hash,
|
55
|
+
attempts: attempt
|
56
|
+
)
|
57
|
+
else
|
58
|
+
last_error = "HTTP #{response.code}: #{response.body}"
|
59
|
+
Desiru.logger.warn("Webhook failed (attempt #{attempt}/#{config[:retry_count]}): #{last_error}")
|
60
|
+
end
|
61
|
+
rescue StandardError => e
|
62
|
+
last_error = e.message
|
63
|
+
Desiru.logger.error("Webhook error (attempt #{attempt}/#{config[:retry_count]}): #{e.message}")
|
64
|
+
end
|
65
|
+
|
66
|
+
# Retry with delay if not the last attempt
|
67
|
+
if attempt < config[:retry_count]
|
68
|
+
sleep(config[:retry_delay] * attempt) # Exponential backoff
|
69
|
+
end
|
70
|
+
end
|
71
|
+
|
72
|
+
# All attempts failed
|
73
|
+
WebhookResult.new(
|
74
|
+
success: false,
|
75
|
+
error: last_error,
|
76
|
+
attempts: attempt
|
77
|
+
)
|
78
|
+
end
|
79
|
+
|
80
|
+
private
|
81
|
+
|
82
|
+
def send_request(uri, payload, headers)
|
83
|
+
http = Net::HTTP.new(uri.host, uri.port)
|
84
|
+
http.use_ssl = uri.scheme == 'https'
|
85
|
+
http.read_timeout = config[:timeout]
|
86
|
+
http.open_timeout = config[:timeout]
|
87
|
+
|
88
|
+
request = Net::HTTP::Post.new(uri.request_uri)
|
89
|
+
headers.each { |key, value| request[key] = value }
|
90
|
+
request.body = payload.to_json
|
91
|
+
|
92
|
+
http.request(request)
|
93
|
+
end
|
94
|
+
|
95
|
+
def generate_signature(payload, secret)
|
96
|
+
require 'openssl'
|
97
|
+
digest = OpenSSL::Digest.new('sha256')
|
98
|
+
OpenSSL::HMAC.hexdigest(digest, secret, payload.to_json)
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
# Result of a webhook notification
|
103
|
+
class WebhookResult
|
104
|
+
attr_reader :success, :status_code, :body, :headers, :error, :attempts
|
105
|
+
|
106
|
+
def initialize(success:, status_code: nil, body: nil, headers: nil, error: nil, attempts: 1)
|
107
|
+
@success = success
|
108
|
+
@status_code = status_code
|
109
|
+
@body = body
|
110
|
+
@headers = headers
|
111
|
+
@error = error
|
112
|
+
@attempts = attempts
|
113
|
+
end
|
114
|
+
|
115
|
+
def success?
|
116
|
+
@success
|
117
|
+
end
|
118
|
+
|
119
|
+
def failed?
|
120
|
+
!@success
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
# Configuration for webhook notifications
|
125
|
+
class WebhookConfig
|
126
|
+
attr_accessor :enabled, :url, :secret, :events, :include_payload, :custom_headers
|
127
|
+
|
128
|
+
def initialize
|
129
|
+
@enabled = false
|
130
|
+
@url = nil
|
131
|
+
@secret = nil
|
132
|
+
@events = %i[completed failed] # Which events to notify on
|
133
|
+
@include_payload = true # Include job result in webhook
|
134
|
+
@custom_headers = {}
|
135
|
+
end
|
136
|
+
|
137
|
+
def valid?
|
138
|
+
enabled && url && !url.empty?
|
139
|
+
end
|
140
|
+
end
|
141
|
+
|
142
|
+
# Mixin to add webhook support to jobs
|
143
|
+
module Webhookable
|
144
|
+
def self.included(base)
|
145
|
+
base.extend(ClassMethods)
|
146
|
+
base.instance_variable_set(:@webhook_config, WebhookConfig.new)
|
147
|
+
end
|
148
|
+
|
149
|
+
def self.prepended(base)
|
150
|
+
base.extend(ClassMethods)
|
151
|
+
base.instance_variable_set(:@webhook_config, WebhookConfig.new)
|
152
|
+
end
|
153
|
+
|
154
|
+
def perform(*args)
|
155
|
+
job_id = args.first || "job-#{Time.now.to_i}"
|
156
|
+
result = nil
|
157
|
+
error = nil
|
158
|
+
status = :completed
|
159
|
+
|
160
|
+
begin
|
161
|
+
# Call the original perform method
|
162
|
+
result = super
|
163
|
+
rescue StandardError => e
|
164
|
+
error = e
|
165
|
+
status = :failed
|
166
|
+
raise # Re-raise to maintain normal error handling
|
167
|
+
ensure
|
168
|
+
# Send webhook notification if configured
|
169
|
+
send_webhook_notification(job_id, status, result, error) if should_notify_webhook?(status)
|
170
|
+
end
|
171
|
+
|
172
|
+
result
|
173
|
+
end
|
174
|
+
|
175
|
+
module ClassMethods
|
176
|
+
def webhook_config
|
177
|
+
@webhook_config ||= WebhookConfig.new
|
178
|
+
end
|
179
|
+
|
180
|
+
def configure_webhook
|
181
|
+
yield(webhook_config) if block_given?
|
182
|
+
end
|
183
|
+
|
184
|
+
def webhook_enabled?
|
185
|
+
webhook_config.valid?
|
186
|
+
end
|
187
|
+
end
|
188
|
+
|
189
|
+
private
|
190
|
+
|
191
|
+
def should_notify_webhook?(status)
|
192
|
+
self.class.webhook_enabled? &&
|
193
|
+
self.class.webhook_config.events.include?(status)
|
194
|
+
end
|
195
|
+
|
196
|
+
def send_webhook_notification(job_id, status, result, error)
|
197
|
+
payload = build_webhook_payload(job_id, status, result, error)
|
198
|
+
|
199
|
+
notifier = WebhookNotifier.new
|
200
|
+
webhook_result = notifier.notify(
|
201
|
+
self.class.webhook_config.url,
|
202
|
+
payload,
|
203
|
+
secret: self.class.webhook_config.secret,
|
204
|
+
headers: self.class.webhook_config.custom_headers
|
205
|
+
)
|
206
|
+
|
207
|
+
if webhook_result.failed?
|
208
|
+
Desiru.logger.error("Failed to send webhook for job #{job_id}: #{webhook_result.error}")
|
209
|
+
else
|
210
|
+
Desiru.logger.info("Webhook notification sent for job #{job_id}")
|
211
|
+
end
|
212
|
+
rescue StandardError => e
|
213
|
+
# Don't let webhook failures affect job execution
|
214
|
+
Desiru.logger.error("Webhook notification error: #{e.message}")
|
215
|
+
end
|
216
|
+
|
217
|
+
def build_webhook_payload(job_id, status, result, error)
|
218
|
+
payload = {
|
219
|
+
job_id: job_id,
|
220
|
+
job_class: self.class.name,
|
221
|
+
status: status.to_s,
|
222
|
+
timestamp: Time.now.iso8601,
|
223
|
+
environment: ENV['RACK_ENV'] || ENV['RAILS_ENV'] || 'development'
|
224
|
+
}
|
225
|
+
|
226
|
+
if self.class.webhook_config.include_payload
|
227
|
+
if status == :completed && result
|
228
|
+
payload[:result] = result
|
229
|
+
elsif status == :failed && error
|
230
|
+
payload[:error] = {
|
231
|
+
class: error.class.name,
|
232
|
+
message: error.message,
|
233
|
+
backtrace: error.backtrace&.first(5) # Limit backtrace size
|
234
|
+
}
|
235
|
+
end
|
236
|
+
end
|
237
|
+
|
238
|
+
payload
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|
242
|
+
end
|
@@ -0,0 +1,164 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'anthropic'
|
4
|
+
|
5
|
+
module Desiru
|
6
|
+
module Models
|
7
|
+
# Anthropic Claude model adapter
|
8
|
+
class Anthropic < Base
|
9
|
+
DEFAULT_MODEL = 'claude-3-haiku-20240307'
|
10
|
+
|
11
|
+
def initialize(config = {})
|
12
|
+
super
|
13
|
+
@api_key = config[:api_key] || ENV.fetch('ANTHROPIC_API_KEY', nil)
|
14
|
+
raise ArgumentError, 'Anthropic API key is required' unless @api_key
|
15
|
+
|
16
|
+
@client = ::Anthropic::Client.new(access_token: @api_key)
|
17
|
+
end
|
18
|
+
|
19
|
+
def models
|
20
|
+
# Anthropic doesn't provide a models endpoint, so we maintain a list
|
21
|
+
# This list should be updated periodically as new models are released
|
22
|
+
@models ||= {
|
23
|
+
'claude-3-haiku-20240307' => {
|
24
|
+
name: 'Claude 3 Haiku',
|
25
|
+
max_tokens: 200_000,
|
26
|
+
description: 'Fast and efficient for simple tasks'
|
27
|
+
},
|
28
|
+
'claude-3-sonnet-20240229' => {
|
29
|
+
name: 'Claude 3 Sonnet',
|
30
|
+
max_tokens: 200_000,
|
31
|
+
description: 'Balanced performance and capability'
|
32
|
+
},
|
33
|
+
'claude-3-opus-20240229' => {
|
34
|
+
name: 'Claude 3 Opus',
|
35
|
+
max_tokens: 200_000,
|
36
|
+
description: 'Most capable model for complex tasks'
|
37
|
+
},
|
38
|
+
'claude-3-5-sonnet-20241022' => {
|
39
|
+
name: 'Claude 3.5 Sonnet',
|
40
|
+
max_tokens: 200_000,
|
41
|
+
description: 'Latest Sonnet with improved capabilities'
|
42
|
+
},
|
43
|
+
'claude-3-5-haiku-20241022' => {
|
44
|
+
name: 'Claude 3.5 Haiku',
|
45
|
+
max_tokens: 200_000,
|
46
|
+
description: 'Latest Haiku with enhanced speed'
|
47
|
+
}
|
48
|
+
}
|
49
|
+
end
|
50
|
+
|
51
|
+
protected
|
52
|
+
|
53
|
+
def perform_completion(messages, options)
|
54
|
+
model = options[:model] || @config[:model] || DEFAULT_MODEL
|
55
|
+
temperature = options[:temperature] || @config[:temperature] || 0.7
|
56
|
+
max_tokens = options[:max_tokens] || @config[:max_tokens] || 4096
|
57
|
+
|
58
|
+
# Convert messages to Anthropic format
|
59
|
+
system_message, user_messages = format_messages(messages)
|
60
|
+
|
61
|
+
# Prepare request parameters
|
62
|
+
params = {
|
63
|
+
model: model,
|
64
|
+
messages: user_messages,
|
65
|
+
max_tokens: max_tokens,
|
66
|
+
temperature: temperature
|
67
|
+
}
|
68
|
+
|
69
|
+
# Add system message if present
|
70
|
+
params[:system] = system_message if system_message
|
71
|
+
|
72
|
+
# Add tools if provided
|
73
|
+
if options[:tools]
|
74
|
+
params[:tools] = format_tools(options[:tools])
|
75
|
+
params[:tool_choice] = options[:tool_choice] if options[:tool_choice]
|
76
|
+
end
|
77
|
+
|
78
|
+
# Make API call
|
79
|
+
response = @client.messages(parameters: params)
|
80
|
+
|
81
|
+
# Format response
|
82
|
+
format_response(response, model)
|
83
|
+
rescue ::Faraday::Error => e
|
84
|
+
handle_api_error(e)
|
85
|
+
end
|
86
|
+
|
87
|
+
private
|
88
|
+
|
89
|
+
def format_messages(messages)
|
90
|
+
system_message = nil
|
91
|
+
user_messages = []
|
92
|
+
|
93
|
+
messages.each do |msg|
|
94
|
+
case msg[:role]
|
95
|
+
when 'system'
|
96
|
+
system_message = msg[:content]
|
97
|
+
when 'user'
|
98
|
+
user_messages << { role: 'user', content: msg[:content] }
|
99
|
+
when 'assistant'
|
100
|
+
user_messages << { role: 'assistant', content: msg[:content] }
|
101
|
+
end
|
102
|
+
end
|
103
|
+
|
104
|
+
[system_message, user_messages]
|
105
|
+
end
|
106
|
+
|
107
|
+
def format_tools(tools)
|
108
|
+
tools.map do |tool|
|
109
|
+
{
|
110
|
+
name: tool[:function][:name],
|
111
|
+
description: tool[:function][:description],
|
112
|
+
input_schema: tool[:function][:parameters]
|
113
|
+
}
|
114
|
+
end
|
115
|
+
end
|
116
|
+
|
117
|
+
def format_response(response, model)
|
118
|
+
content = extract_content(response)
|
119
|
+
|
120
|
+
{
|
121
|
+
content: content,
|
122
|
+
raw: response,
|
123
|
+
model: model,
|
124
|
+
usage: {
|
125
|
+
prompt_tokens: response.dig('usage', 'input_tokens') || 0,
|
126
|
+
completion_tokens: response.dig('usage', 'output_tokens') || 0,
|
127
|
+
total_tokens: (response.dig('usage', 'input_tokens') || 0) + (response.dig('usage', 'output_tokens') || 0)
|
128
|
+
}
|
129
|
+
}
|
130
|
+
end
|
131
|
+
|
132
|
+
def extract_content(response)
|
133
|
+
# Handle different response formats
|
134
|
+
if response.is_a?(Hash)
|
135
|
+
# Direct API response
|
136
|
+
if response['content'].is_a?(Array)
|
137
|
+
response['content'].map { |c| c['text'] }.join
|
138
|
+
else
|
139
|
+
response['content'] || response['completion'] || ''
|
140
|
+
end
|
141
|
+
else
|
142
|
+
# Client wrapper response
|
143
|
+
response.content.first.text
|
144
|
+
end
|
145
|
+
rescue StandardError => e
|
146
|
+
Desiru.logger.error("Failed to extract content from Anthropic response: #{e.message}")
|
147
|
+
''
|
148
|
+
end
|
149
|
+
|
150
|
+
def handle_api_error(error)
|
151
|
+
case error
|
152
|
+
when ::Faraday::UnauthorizedError
|
153
|
+
raise AuthenticationError, 'Invalid Anthropic API key'
|
154
|
+
when ::Faraday::BadRequestError
|
155
|
+
raise InvalidRequestError, "Invalid request: #{error.message}"
|
156
|
+
when ::Faraday::TooManyRequestsError
|
157
|
+
raise RateLimitError, 'Anthropic API rate limit exceeded'
|
158
|
+
else
|
159
|
+
raise APIError, "Anthropic API error: #{error.message}"
|
160
|
+
end
|
161
|
+
end
|
162
|
+
end
|
163
|
+
end
|
164
|
+
end
|
data/lib/desiru/models/base.rb
CHANGED
@@ -16,9 +16,15 @@ module Desiru
|
|
16
16
|
validate_config!
|
17
17
|
end
|
18
18
|
|
19
|
-
# Main interface method -
|
19
|
+
# Main interface method - calls perform_completion with proper message formatting
|
20
20
|
def complete(prompt, **options)
|
21
|
-
|
21
|
+
messages = prepare_messages(prompt, options[:messages])
|
22
|
+
|
23
|
+
with_retry do
|
24
|
+
response = perform_completion(messages, options)
|
25
|
+
increment_stats(response[:usage][:total_tokens]) if response[:usage]
|
26
|
+
response
|
27
|
+
end
|
22
28
|
end
|
23
29
|
|
24
30
|
# Stream completion - optional implementation
|
@@ -59,7 +65,7 @@ module Desiru
|
|
59
65
|
{
|
60
66
|
model: nil,
|
61
67
|
temperature: 0.7,
|
62
|
-
max_tokens:
|
68
|
+
max_tokens: 4096,
|
63
69
|
timeout: 30,
|
64
70
|
retry_on_failure: true,
|
65
71
|
max_retries: 3
|
@@ -107,6 +113,34 @@ module Desiru
|
|
107
113
|
jitter = rand(0..1.0)
|
108
114
|
base_delay + jitter
|
109
115
|
end
|
116
|
+
|
117
|
+
# Prepare messages in the expected format
|
118
|
+
def prepare_messages(prompt, additional_messages = nil)
|
119
|
+
messages = []
|
120
|
+
|
121
|
+
# Handle different prompt formats
|
122
|
+
case prompt
|
123
|
+
when String
|
124
|
+
messages << { role: 'user', content: prompt }
|
125
|
+
when Hash
|
126
|
+
messages << { role: 'system', content: prompt[:system] } if prompt[:system]
|
127
|
+
if prompt[:user]
|
128
|
+
messages << { role: 'user', content: prompt[:user] }
|
129
|
+
elsif prompt[:content]
|
130
|
+
messages << { role: 'user', content: prompt[:content] }
|
131
|
+
end
|
132
|
+
end
|
133
|
+
|
134
|
+
# Add any additional messages
|
135
|
+
messages.concat(additional_messages) if additional_messages
|
136
|
+
|
137
|
+
messages
|
138
|
+
end
|
139
|
+
|
140
|
+
# Subclasses must implement this method
|
141
|
+
def perform_completion(messages, options)
|
142
|
+
raise NotImplementedError, 'Subclasses must implement #perform_completion'
|
143
|
+
end
|
110
144
|
end
|
111
145
|
end
|
112
146
|
end
|
@@ -0,0 +1,151 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'openai'
|
4
|
+
|
5
|
+
module Desiru
|
6
|
+
module Models
|
7
|
+
# OpenAI GPT model adapter
|
8
|
+
class OpenAI < Base
|
9
|
+
DEFAULT_MODEL = 'gpt-4o-mini'
|
10
|
+
|
11
|
+
def initialize(config = {})
|
12
|
+
super
|
13
|
+
@api_key = config[:api_key] || ENV.fetch('OPENAI_API_KEY', nil)
|
14
|
+
raise ArgumentError, 'OpenAI API key is required' unless @api_key
|
15
|
+
|
16
|
+
@client = ::OpenAI::Client.new(access_token: @api_key)
|
17
|
+
@models_cache = nil
|
18
|
+
@models_fetched_at = nil
|
19
|
+
end
|
20
|
+
|
21
|
+
def models
|
22
|
+
# Cache models for 1 hour
|
23
|
+
fetch_models if @models_cache.nil? || @models_fetched_at.nil? || (Time.now - @models_fetched_at) > 3600
|
24
|
+
@models_cache
|
25
|
+
end
|
26
|
+
|
27
|
+
protected
|
28
|
+
|
29
|
+
def perform_completion(messages, options)
|
30
|
+
model = options[:model] || @config[:model] || DEFAULT_MODEL
|
31
|
+
temperature = options[:temperature] || @config[:temperature] || 0.7
|
32
|
+
max_tokens = options[:max_tokens] || @config[:max_tokens] || 4096
|
33
|
+
|
34
|
+
# Prepare request parameters
|
35
|
+
params = {
|
36
|
+
model: model,
|
37
|
+
messages: messages,
|
38
|
+
temperature: temperature,
|
39
|
+
max_tokens: max_tokens
|
40
|
+
}
|
41
|
+
|
42
|
+
# Add response format if specified
|
43
|
+
params[:response_format] = options[:response_format] if options[:response_format]
|
44
|
+
|
45
|
+
# Add tools if provided
|
46
|
+
if options[:tools]
|
47
|
+
params[:tools] = options[:tools]
|
48
|
+
params[:tool_choice] = options[:tool_choice] if options[:tool_choice]
|
49
|
+
end
|
50
|
+
|
51
|
+
# Make API call
|
52
|
+
response = @client.chat(parameters: params)
|
53
|
+
|
54
|
+
# Format response
|
55
|
+
format_response(response, model)
|
56
|
+
rescue ::Faraday::Error => e
|
57
|
+
handle_api_error(e)
|
58
|
+
end
|
59
|
+
|
60
|
+
def stream_complete(prompt, **options, &block)
|
61
|
+
messages = prepare_messages(prompt, options[:messages])
|
62
|
+
model = options[:model] || @config[:model] || DEFAULT_MODEL
|
63
|
+
temperature = options[:temperature] || @config[:temperature] || 0.7
|
64
|
+
max_tokens = options[:max_tokens] || @config[:max_tokens] || 4096
|
65
|
+
|
66
|
+
# Prepare streaming request
|
67
|
+
params = {
|
68
|
+
model: model,
|
69
|
+
messages: messages,
|
70
|
+
temperature: temperature,
|
71
|
+
max_tokens: max_tokens,
|
72
|
+
stream: proc do |chunk, _bytesize|
|
73
|
+
# Extract content from chunk
|
74
|
+
if chunk.dig('choices', 0, 'delta', 'content')
|
75
|
+
content = chunk.dig('choices', 0, 'delta', 'content')
|
76
|
+
block.call(content) if block_given?
|
77
|
+
end
|
78
|
+
end
|
79
|
+
}
|
80
|
+
|
81
|
+
# Make streaming API call
|
82
|
+
@client.chat(parameters: params)
|
83
|
+
rescue ::Faraday::Error => e
|
84
|
+
handle_api_error(e)
|
85
|
+
end
|
86
|
+
|
87
|
+
private
|
88
|
+
|
89
|
+
def fetch_models
|
90
|
+
response = @client.models.list
|
91
|
+
|
92
|
+
@models_cache = {}
|
93
|
+
response['data'].each do |model|
|
94
|
+
# Filter for chat models only
|
95
|
+
next unless model['id'].include?('gpt') || model['id'].include?('o1')
|
96
|
+
|
97
|
+
@models_cache[model['id']] = {
|
98
|
+
name: model['id'],
|
99
|
+
created: model['created'],
|
100
|
+
owned_by: model['owned_by']
|
101
|
+
}
|
102
|
+
end
|
103
|
+
|
104
|
+
@models_fetched_at = Time.now
|
105
|
+
@models_cache
|
106
|
+
rescue StandardError => e
|
107
|
+
Desiru.logger.warn("Failed to fetch OpenAI models: #{e.message}")
|
108
|
+
# Fallback to commonly used models
|
109
|
+
@models_cache = {
|
110
|
+
'gpt-4o-mini' => { name: 'GPT-4o Mini' },
|
111
|
+
'gpt-4o' => { name: 'GPT-4o' },
|
112
|
+
'gpt-4-turbo' => { name: 'GPT-4 Turbo' },
|
113
|
+
'gpt-4' => { name: 'GPT-4' },
|
114
|
+
'gpt-3.5-turbo' => { name: 'GPT-3.5 Turbo' }
|
115
|
+
}
|
116
|
+
@models_fetched_at = Time.now
|
117
|
+
@models_cache
|
118
|
+
end
|
119
|
+
|
120
|
+
def format_response(response, model)
|
121
|
+
# Extract content and usage regardless of response structure
|
122
|
+
content = response.dig('choices', 0, 'message', 'content') || ''
|
123
|
+
usage = response['usage'] || {}
|
124
|
+
|
125
|
+
{
|
126
|
+
content: content,
|
127
|
+
raw: response,
|
128
|
+
model: model,
|
129
|
+
usage: {
|
130
|
+
prompt_tokens: usage['prompt_tokens'] || 0,
|
131
|
+
completion_tokens: usage['completion_tokens'] || 0,
|
132
|
+
total_tokens: usage['total_tokens'] || 0
|
133
|
+
}
|
134
|
+
}
|
135
|
+
end
|
136
|
+
|
137
|
+
def handle_api_error(error)
|
138
|
+
case error
|
139
|
+
when ::Faraday::UnauthorizedError
|
140
|
+
raise AuthenticationError, 'Invalid OpenAI API key'
|
141
|
+
when ::Faraday::BadRequestError
|
142
|
+
raise InvalidRequestError, "Invalid request: #{error.message}"
|
143
|
+
when ::Faraday::TooManyRequestsError
|
144
|
+
raise RateLimitError, 'OpenAI API rate limit exceeded'
|
145
|
+
else
|
146
|
+
raise APIError, "OpenAI API error: #{error.message}"
|
147
|
+
end
|
148
|
+
end
|
149
|
+
end
|
150
|
+
end
|
151
|
+
end
|