desiru 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +1 -0
- data/.rubocop.yml +55 -0
- data/CLAUDE.md +22 -0
- data/Gemfile +36 -0
- data/Gemfile.lock +255 -0
- data/LICENSE +21 -0
- data/README.md +343 -0
- data/Rakefile +18 -0
- data/desiru.gemspec +44 -0
- data/examples/README.md +55 -0
- data/examples/async_processing.rb +135 -0
- data/examples/few_shot_learning.rb +66 -0
- data/examples/graphql_api.rb +190 -0
- data/examples/graphql_integration.rb +114 -0
- data/examples/rag_retrieval.rb +80 -0
- data/examples/simple_qa.rb +31 -0
- data/examples/typed_signatures.rb +45 -0
- data/lib/desiru/async_capable.rb +170 -0
- data/lib/desiru/cache.rb +116 -0
- data/lib/desiru/configuration.rb +40 -0
- data/lib/desiru/field.rb +171 -0
- data/lib/desiru/graphql/data_loader.rb +210 -0
- data/lib/desiru/graphql/executor.rb +115 -0
- data/lib/desiru/graphql/schema_generator.rb +301 -0
- data/lib/desiru/jobs/async_predict.rb +52 -0
- data/lib/desiru/jobs/base.rb +53 -0
- data/lib/desiru/jobs/batch_processor.rb +71 -0
- data/lib/desiru/jobs/optimizer_job.rb +45 -0
- data/lib/desiru/models/base.rb +112 -0
- data/lib/desiru/models/raix_adapter.rb +210 -0
- data/lib/desiru/module.rb +204 -0
- data/lib/desiru/modules/chain_of_thought.rb +106 -0
- data/lib/desiru/modules/predict.rb +142 -0
- data/lib/desiru/modules/retrieve.rb +199 -0
- data/lib/desiru/optimizers/base.rb +130 -0
- data/lib/desiru/optimizers/bootstrap_few_shot.rb +212 -0
- data/lib/desiru/program.rb +106 -0
- data/lib/desiru/registry.rb +74 -0
- data/lib/desiru/signature.rb +322 -0
- data/lib/desiru/version.rb +5 -0
- data/lib/desiru.rb +67 -0
- metadata +184 -0
@@ -0,0 +1,301 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'graphql'
|
4
|
+
require_relative 'data_loader'
|
5
|
+
|
6
|
+
module Desiru
|
7
|
+
module GraphQL
|
8
|
+
# Generates GraphQL schemas from Desiru signatures
|
9
|
+
class SchemaGenerator
|
10
|
+
attr_reader :signatures, :modules, :data_loader
|
11
|
+
|
12
|
+
def initialize
|
13
|
+
@signatures = {}
|
14
|
+
@modules = {}
|
15
|
+
@type_cache = {}
|
16
|
+
@schema_class = nil
|
17
|
+
@data_loader = DataLoader.new
|
18
|
+
end
|
19
|
+
|
20
|
+
# Register a signature with a name for GraphQL query/mutation
|
21
|
+
def register_signature(name, signature)
|
22
|
+
@signatures[name] = signature
|
23
|
+
end
|
24
|
+
|
25
|
+
# Register a module instance to handle a specific operation
|
26
|
+
def register_module(name, module_instance)
|
27
|
+
@modules[name] = module_instance
|
28
|
+
# Auto-register signature if module has one
|
29
|
+
@signatures[name] ||= module_instance.signature if module_instance.respond_to?(:signature)
|
30
|
+
end
|
31
|
+
|
32
|
+
# Register multiple modules at once
|
33
|
+
def register_modules(modules_hash)
|
34
|
+
modules_hash.each { |name, mod| register_module(name, mod) }
|
35
|
+
end
|
36
|
+
|
37
|
+
# Generate a GraphQL schema from registered signatures
|
38
|
+
def generate_schema
|
39
|
+
return @schema_class if @schema_class && @signatures.empty?
|
40
|
+
|
41
|
+
query_class = build_query_type
|
42
|
+
|
43
|
+
@schema_class = Class.new(::GraphQL::Schema) do
|
44
|
+
query(query_class) if query_class
|
45
|
+
end
|
46
|
+
|
47
|
+
@schema_class
|
48
|
+
end
|
49
|
+
|
50
|
+
private
|
51
|
+
|
52
|
+
def build_query_type
|
53
|
+
return nil if @signatures.empty?
|
54
|
+
|
55
|
+
query_fields = build_query_fields
|
56
|
+
|
57
|
+
Class.new(::GraphQL::Schema::Object) do
|
58
|
+
graphql_name 'Query'
|
59
|
+
description 'Desiru query operations'
|
60
|
+
|
61
|
+
query_fields.each do |field_name, field_def|
|
62
|
+
# Create a resolver class for each field
|
63
|
+
resolver_class = Class.new(::GraphQL::Schema::Resolver) do
|
64
|
+
# Set the return type
|
65
|
+
type field_def[:type], null: false
|
66
|
+
|
67
|
+
# Add arguments
|
68
|
+
field_def[:arguments].each do |arg_name, arg_def|
|
69
|
+
argument arg_name, arg_def[:type], required: arg_def[:required]
|
70
|
+
end
|
71
|
+
|
72
|
+
# Define resolve method
|
73
|
+
define_method :resolve do |**args|
|
74
|
+
field_def[:resolver].call(args)
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
# Add field with resolver
|
79
|
+
field field_name, resolver: resolver_class, description: field_def[:description]
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
84
|
+
def build_query_fields
|
85
|
+
fields = {}
|
86
|
+
|
87
|
+
@signatures.each do |operation_name, signature|
|
88
|
+
output_type = build_output_type(signature)
|
89
|
+
|
90
|
+
arguments = {}
|
91
|
+
signature.input_fields.each do |field_name, field|
|
92
|
+
arguments[camelcase_field_name(field_name)] = {
|
93
|
+
type: graphql_type_for_field(field),
|
94
|
+
required: !field.optional
|
95
|
+
}
|
96
|
+
end
|
97
|
+
|
98
|
+
fields[operation_name.to_sym] = {
|
99
|
+
type: output_type,
|
100
|
+
description: "Generated from signature: #{signature.raw_signature}",
|
101
|
+
arguments: arguments,
|
102
|
+
resolver: ->(args) { execute_signature(operation_name, signature, args) }
|
103
|
+
}
|
104
|
+
end
|
105
|
+
|
106
|
+
fields
|
107
|
+
end
|
108
|
+
|
109
|
+
def build_mutation_type
|
110
|
+
# Mutations could be added for signatures that modify state
|
111
|
+
nil
|
112
|
+
end
|
113
|
+
|
114
|
+
def build_output_type(signature)
|
115
|
+
type_name = "Output#{signature.object_id}"
|
116
|
+
return @type_cache[type_name] if @type_cache[type_name]
|
117
|
+
|
118
|
+
output_field_defs = {}
|
119
|
+
signature.output_fields.each do |field_name, field|
|
120
|
+
output_field_defs[camelcase_field_name(field_name)] = {
|
121
|
+
type: graphql_type_for_field(field),
|
122
|
+
null: field.optional,
|
123
|
+
description: field.description
|
124
|
+
}
|
125
|
+
end
|
126
|
+
|
127
|
+
output_type = Class.new(::GraphQL::Schema::Object) do
|
128
|
+
graphql_name type_name
|
129
|
+
description 'Generated output type'
|
130
|
+
|
131
|
+
output_field_defs.each do |field_name, field_def|
|
132
|
+
field field_name, field_def[:type],
|
133
|
+
null: field_def[:null],
|
134
|
+
description: field_def[:description]
|
135
|
+
end
|
136
|
+
end
|
137
|
+
|
138
|
+
@type_cache[type_name] = output_type
|
139
|
+
end
|
140
|
+
|
141
|
+
def graphql_type_for_field(field)
|
142
|
+
base_type = case field.type
|
143
|
+
when :string
|
144
|
+
::GraphQL::Types::String
|
145
|
+
when :int, :integer
|
146
|
+
::GraphQL::Types::Int
|
147
|
+
when :float
|
148
|
+
::GraphQL::Types::Float
|
149
|
+
when :bool, :boolean
|
150
|
+
::GraphQL::Types::Boolean
|
151
|
+
when :list
|
152
|
+
# Handle list types
|
153
|
+
element_type = graphql_type_for_element(field.element_type)
|
154
|
+
[element_type]
|
155
|
+
when :literal
|
156
|
+
# Create enum type for literal values
|
157
|
+
create_enum_type(field)
|
158
|
+
else
|
159
|
+
::GraphQL::Types::String
|
160
|
+
end
|
161
|
+
|
162
|
+
if field.optional
|
163
|
+
base_type
|
164
|
+
else
|
165
|
+
# Arrays are already wrapped, so handle them differently
|
166
|
+
base_type.is_a?(Array) ? [base_type.first, { null: false }] : base_type.to_non_null_type
|
167
|
+
end
|
168
|
+
end
|
169
|
+
|
170
|
+
def graphql_type_for_element(element_type)
|
171
|
+
case element_type
|
172
|
+
when Hash
|
173
|
+
# Handle typed arrays like List[Literal['yes', 'no']]
|
174
|
+
if element_type[:type] == :literal
|
175
|
+
create_enum_type_from_values(element_type[:values])
|
176
|
+
else
|
177
|
+
::GraphQL::Types::String
|
178
|
+
end
|
179
|
+
else
|
180
|
+
# Simple types
|
181
|
+
case element_type
|
182
|
+
when :string then ::GraphQL::Types::String
|
183
|
+
when :int, :integer then ::GraphQL::Types::Int
|
184
|
+
when :float then ::GraphQL::Types::Float
|
185
|
+
when :bool, :boolean then ::GraphQL::Types::Boolean
|
186
|
+
else ::GraphQL::Types::String
|
187
|
+
end
|
188
|
+
end
|
189
|
+
end
|
190
|
+
|
191
|
+
def create_enum_type(field)
|
192
|
+
enum_name = "#{field.name.to_s.capitalize}Enum"
|
193
|
+
return @type_cache[enum_name] if @type_cache[enum_name]
|
194
|
+
|
195
|
+
# Extract literal values from the field's validator
|
196
|
+
values = extract_literal_values(field)
|
197
|
+
|
198
|
+
enum_type = Class.new(::GraphQL::Schema::Enum) do
|
199
|
+
graphql_name enum_name
|
200
|
+
description "Enum for #{field.name}"
|
201
|
+
|
202
|
+
values.each do |val|
|
203
|
+
value val.upcase.gsub(/[^A-Z0-9_]/, '_'), value: val
|
204
|
+
end
|
205
|
+
end
|
206
|
+
|
207
|
+
@type_cache[enum_name] = enum_type
|
208
|
+
end
|
209
|
+
|
210
|
+
def create_enum_type_from_values(values)
|
211
|
+
enum_name = "Literal#{values.map(&:capitalize).join}Enum"
|
212
|
+
return @type_cache[enum_name] if @type_cache[enum_name]
|
213
|
+
|
214
|
+
enum_type = Class.new(::GraphQL::Schema::Enum) do
|
215
|
+
graphql_name enum_name
|
216
|
+
|
217
|
+
values.each do |val|
|
218
|
+
value val.upcase.gsub(/[^A-Z0-9_]/, '_'), value: val
|
219
|
+
end
|
220
|
+
end
|
221
|
+
|
222
|
+
@type_cache[enum_name] = enum_type
|
223
|
+
end
|
224
|
+
|
225
|
+
def extract_literal_values(field)
|
226
|
+
# Try to extract values from the field's validator
|
227
|
+
if field.respond_to?(:validator) && field.validator.respond_to?(:instance_variable_get)
|
228
|
+
field.validator.instance_variable_get(:@values) || []
|
229
|
+
elsif field.respond_to?(:element_type) && field.element_type.is_a?(Hash)
|
230
|
+
field.element_type[:values] || []
|
231
|
+
else
|
232
|
+
[]
|
233
|
+
end
|
234
|
+
end
|
235
|
+
|
236
|
+
def execute_signature(operation_name, signature, args)
|
237
|
+
# Convert GraphQL arguments from camelCase to snake_case
|
238
|
+
inputs = transform_graphql_args(args)
|
239
|
+
|
240
|
+
# Check if we have a registered module for this operation
|
241
|
+
if @modules[operation_name]
|
242
|
+
# Use DataLoader for batch optimization
|
243
|
+
loader = @data_loader.for(@modules[operation_name].class)
|
244
|
+
promise = loader.load(inputs)
|
245
|
+
|
246
|
+
# In a real GraphQL implementation, this would be handled by the executor
|
247
|
+
# For now, we'll resolve immediately
|
248
|
+
result = promise.value
|
249
|
+
|
250
|
+
# Transform module result to GraphQL response format
|
251
|
+
transform_module_result(result)
|
252
|
+
else
|
253
|
+
# Fallback: create a module instance on the fly
|
254
|
+
module_class = infer_module_class(signature)
|
255
|
+
module_instance = module_class.new(signature)
|
256
|
+
result = module_instance.call(inputs)
|
257
|
+
transform_module_result(result)
|
258
|
+
end
|
259
|
+
end
|
260
|
+
|
261
|
+
def transform_graphql_args(args)
|
262
|
+
# Convert camelCase keys to snake_case, but handle single-word keys correctly
|
263
|
+
args.transform_keys do |key|
|
264
|
+
key_str = key.to_s
|
265
|
+
# Only convert if there's actually a capital letter after the first character
|
266
|
+
if key_str =~ /[a-z][A-Z]/
|
267
|
+
key_str.gsub(/([A-Z])/, '_\1').downcase.to_sym
|
268
|
+
else
|
269
|
+
key_str.downcase.to_sym
|
270
|
+
end
|
271
|
+
end
|
272
|
+
end
|
273
|
+
|
274
|
+
def transform_module_result(result)
|
275
|
+
# Convert ModuleResult to hash with camelCase keys
|
276
|
+
if result.respond_to?(:to_h)
|
277
|
+
result.to_h.transform_keys { |key| camelcase_field_name(key) }
|
278
|
+
else
|
279
|
+
result
|
280
|
+
end
|
281
|
+
end
|
282
|
+
|
283
|
+
def infer_module_class(signature)
|
284
|
+
# Infer the appropriate module class based on signature characteristics
|
285
|
+
if signature.raw_signature.include?('reasoning')
|
286
|
+
Desiru::Modules::ChainOfThought
|
287
|
+
else
|
288
|
+
Desiru::Modules::Predict
|
289
|
+
end
|
290
|
+
end
|
291
|
+
|
292
|
+
def camelcase_field_name(field_name)
|
293
|
+
# Convert snake_case to camelCase for GraphQL conventions
|
294
|
+
# Remove trailing '?' for optional fields
|
295
|
+
clean_name = field_name.to_s.gsub('?', '')
|
296
|
+
parts = clean_name.split('_')
|
297
|
+
parts[0] + parts[1..-1].map(&:capitalize).join
|
298
|
+
end
|
299
|
+
end
|
300
|
+
end
|
301
|
+
end
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'base'
|
4
|
+
|
5
|
+
module Desiru
|
6
|
+
module Jobs
|
7
|
+
class AsyncPredict < Base
|
8
|
+
sidekiq_options queue: 'critical'
|
9
|
+
|
10
|
+
def perform(job_id, module_class_name, signature_str, inputs, options = {})
|
11
|
+
update_status(job_id, 'running', message: 'Initializing module')
|
12
|
+
|
13
|
+
module_class = Object.const_get(module_class_name)
|
14
|
+
|
15
|
+
# Extract module initialization parameters
|
16
|
+
model_class = options.delete('model_class')
|
17
|
+
model_config = options.delete('model_config') || {}
|
18
|
+
config = options.delete('config') || {}
|
19
|
+
demos = options.delete('demos') || []
|
20
|
+
|
21
|
+
# Initialize model if provided
|
22
|
+
model = (Object.const_get(model_class).new(**model_config) if model_class && model_config)
|
23
|
+
|
24
|
+
module_instance = module_class.new(
|
25
|
+
signature_str,
|
26
|
+
model: model,
|
27
|
+
config: config,
|
28
|
+
demos: demos
|
29
|
+
)
|
30
|
+
|
31
|
+
update_status(job_id, 'running', progress: 50, message: 'Processing request')
|
32
|
+
result = module_instance.call(**inputs)
|
33
|
+
|
34
|
+
update_status(job_id, 'completed', progress: 100, message: 'Request completed successfully')
|
35
|
+
store_result(job_id, {
|
36
|
+
success: true,
|
37
|
+
result: result.to_h,
|
38
|
+
completed_at: Time.now.iso8601
|
39
|
+
})
|
40
|
+
rescue StandardError => e
|
41
|
+
update_status(job_id, 'failed', message: "Error: #{e.message}")
|
42
|
+
store_result(job_id, {
|
43
|
+
success: false,
|
44
|
+
error: e.message,
|
45
|
+
error_class: e.class.name,
|
46
|
+
completed_at: Time.now.iso8601
|
47
|
+
})
|
48
|
+
raise
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'sidekiq'
|
4
|
+
require 'redis'
|
5
|
+
require 'json'
|
6
|
+
|
7
|
+
module Desiru
|
8
|
+
module Jobs
|
9
|
+
class Base
|
10
|
+
include Sidekiq::Job
|
11
|
+
|
12
|
+
sidekiq_options retry: 3, dead: true
|
13
|
+
|
14
|
+
def perform(*)
|
15
|
+
raise NotImplementedError, "#{self.class} must implement #perform"
|
16
|
+
end
|
17
|
+
|
18
|
+
protected
|
19
|
+
|
20
|
+
def store_result(job_id, result, ttl: 3600)
|
21
|
+
redis.setex(result_key(job_id), ttl, result.to_json)
|
22
|
+
end
|
23
|
+
|
24
|
+
def fetch_result(job_id)
|
25
|
+
result = redis.get(result_key(job_id))
|
26
|
+
result ? JSON.parse(result, symbolize_names: true) : nil
|
27
|
+
end
|
28
|
+
|
29
|
+
def result_key(job_id)
|
30
|
+
"desiru:results:#{job_id}"
|
31
|
+
end
|
32
|
+
|
33
|
+
def redis
|
34
|
+
@redis ||= Redis.new(url: Desiru.configuration.redis_url || ENV.fetch('REDIS_URL', nil))
|
35
|
+
end
|
36
|
+
|
37
|
+
def update_status(job_id, status, progress: nil, message: nil)
|
38
|
+
status_data = {
|
39
|
+
status: status,
|
40
|
+
updated_at: Time.now.iso8601
|
41
|
+
}
|
42
|
+
status_data[:progress] = progress if progress
|
43
|
+
status_data[:message] = message if message
|
44
|
+
|
45
|
+
redis.setex(status_key(job_id), 86_400, status_data.to_json)
|
46
|
+
end
|
47
|
+
|
48
|
+
def status_key(job_id)
|
49
|
+
"desiru:status:#{job_id}"
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'base'
|
4
|
+
|
5
|
+
module Desiru
|
6
|
+
module Jobs
|
7
|
+
class BatchProcessor < Base
|
8
|
+
sidekiq_options queue: 'default'
|
9
|
+
|
10
|
+
def perform(batch_id, module_class_name, signature_str, inputs_array, options = {})
|
11
|
+
total_items = inputs_array.size
|
12
|
+
update_status(batch_id, 'running', progress: 0, message: "Processing #{total_items} items")
|
13
|
+
|
14
|
+
module_class = Object.const_get(module_class_name)
|
15
|
+
|
16
|
+
# Extract module initialization parameters
|
17
|
+
model_class = options.delete('model_class')
|
18
|
+
model_config = options.delete('model_config') || {}
|
19
|
+
config = options.delete('config') || {}
|
20
|
+
demos = options.delete('demos') || []
|
21
|
+
|
22
|
+
# Initialize model if provided
|
23
|
+
model = (Object.const_get(model_class).new(**model_config) if model_class && model_config)
|
24
|
+
|
25
|
+
module_instance = module_class.new(
|
26
|
+
signature_str,
|
27
|
+
model: model,
|
28
|
+
config: config,
|
29
|
+
demos: demos
|
30
|
+
)
|
31
|
+
|
32
|
+
results = []
|
33
|
+
errors = []
|
34
|
+
|
35
|
+
inputs_array.each_with_index do |inputs, index|
|
36
|
+
progress = ((index + 1).to_f / total_items * 100).round
|
37
|
+
update_status(batch_id, 'running', progress: progress,
|
38
|
+
message: "Processing item #{index + 1} of #{total_items}")
|
39
|
+
|
40
|
+
result = module_instance.call(**inputs)
|
41
|
+
results << {
|
42
|
+
index: index,
|
43
|
+
success: true,
|
44
|
+
result: result.to_h
|
45
|
+
}
|
46
|
+
rescue StandardError => e
|
47
|
+
errors << {
|
48
|
+
index: index,
|
49
|
+
success: false,
|
50
|
+
error: e.message,
|
51
|
+
error_class: e.class.name
|
52
|
+
}
|
53
|
+
end
|
54
|
+
|
55
|
+
final_status = errors.empty? ? 'completed' : 'completed_with_errors'
|
56
|
+
update_status(batch_id, final_status, progress: 100,
|
57
|
+
message: "Processed #{results.size} successfully, #{errors.size} failed")
|
58
|
+
|
59
|
+
store_result(batch_id, {
|
60
|
+
success: errors.empty?,
|
61
|
+
total: inputs_array.size,
|
62
|
+
successful: results.size,
|
63
|
+
failed: errors.size,
|
64
|
+
results: results,
|
65
|
+
errors: errors,
|
66
|
+
completed_at: Time.now.iso8601
|
67
|
+
}, ttl: 7200) # 2 hours TTL for batch results
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'base'
|
4
|
+
|
5
|
+
module Desiru
|
6
|
+
module Jobs
|
7
|
+
class OptimizerJob < Base
|
8
|
+
sidekiq_options queue: 'low', retry: 1
|
9
|
+
|
10
|
+
def perform(job_id, optimizer_class_name, program_class_name, trainset, optimizer_options = {})
|
11
|
+
optimizer_class = Object.const_get(optimizer_class_name)
|
12
|
+
program_class = Object.const_get(program_class_name)
|
13
|
+
optimizer = optimizer_class.new(**optimizer_options)
|
14
|
+
program = program_class.new
|
15
|
+
|
16
|
+
# Store initial status
|
17
|
+
update_status(job_id, 'running', progress: 0, message: 'Starting optimization')
|
18
|
+
|
19
|
+
# Compile the program with progress tracking
|
20
|
+
optimized_program = optimizer.compile(program, trainset: trainset) do |progress|
|
21
|
+
update_status(job_id, 'running', progress: progress, message: "Optimizing... #{progress}% complete")
|
22
|
+
end
|
23
|
+
|
24
|
+
# Store the optimized program configuration
|
25
|
+
store_result(job_id, {
|
26
|
+
success: true,
|
27
|
+
optimized_config: optimized_program.to_config,
|
28
|
+
metrics: optimizer.final_metrics,
|
29
|
+
completed_at: Time.now.iso8601
|
30
|
+
}, ttl: 86_400) # 24 hours TTL
|
31
|
+
|
32
|
+
update_status(job_id, 'completed', progress: 100, message: 'Optimization completed successfully')
|
33
|
+
rescue StandardError => e
|
34
|
+
store_result(job_id, {
|
35
|
+
success: false,
|
36
|
+
error: e.message,
|
37
|
+
error_class: e.class.name,
|
38
|
+
completed_at: Time.now.iso8601
|
39
|
+
})
|
40
|
+
update_status(job_id, 'failed', message: "Optimization failed: #{e.message}")
|
41
|
+
raise
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Desiru
|
4
|
+
module Models
|
5
|
+
# Base adapter class for language model integrations
|
6
|
+
# Defines the interface all model adapters must implement
|
7
|
+
class Base
|
8
|
+
attr_reader :config, :client
|
9
|
+
|
10
|
+
def initialize(config = {})
|
11
|
+
@config = default_config.merge(config)
|
12
|
+
@client = build_client
|
13
|
+
@request_count = 0
|
14
|
+
@token_count = 0
|
15
|
+
|
16
|
+
validate_config!
|
17
|
+
end
|
18
|
+
|
19
|
+
# Main interface method - must be implemented by subclasses
|
20
|
+
def complete(prompt, **options)
|
21
|
+
raise NotImplementedError, 'Subclasses must implement #complete'
|
22
|
+
end
|
23
|
+
|
24
|
+
# Stream completion - optional implementation
|
25
|
+
def stream_complete(prompt, **options, &)
|
26
|
+
raise NotImplementedError, "Streaming not supported by #{self.class.name}"
|
27
|
+
end
|
28
|
+
|
29
|
+
# Get available models
|
30
|
+
def models
|
31
|
+
raise NotImplementedError, 'Subclasses must implement #models'
|
32
|
+
end
|
33
|
+
|
34
|
+
# Health check
|
35
|
+
def healthy?
|
36
|
+
models
|
37
|
+
true
|
38
|
+
rescue StandardError
|
39
|
+
false
|
40
|
+
end
|
41
|
+
|
42
|
+
# Usage statistics
|
43
|
+
def stats
|
44
|
+
{
|
45
|
+
request_count: @request_count,
|
46
|
+
token_count: @token_count,
|
47
|
+
model: config[:model]
|
48
|
+
}
|
49
|
+
end
|
50
|
+
|
51
|
+
def reset_stats
|
52
|
+
@request_count = 0
|
53
|
+
@token_count = 0
|
54
|
+
end
|
55
|
+
|
56
|
+
protected
|
57
|
+
|
58
|
+
def default_config
|
59
|
+
{
|
60
|
+
model: nil,
|
61
|
+
temperature: 0.7,
|
62
|
+
max_tokens: 1000,
|
63
|
+
timeout: 30,
|
64
|
+
retry_on_failure: true,
|
65
|
+
max_retries: 3
|
66
|
+
}
|
67
|
+
end
|
68
|
+
|
69
|
+
def build_client
|
70
|
+
# Override in subclasses to build the actual client
|
71
|
+
nil
|
72
|
+
end
|
73
|
+
|
74
|
+
def validate_config!
|
75
|
+
# Override in subclasses for specific validation
|
76
|
+
end
|
77
|
+
|
78
|
+
def increment_stats(tokens_used = 0)
|
79
|
+
@request_count += 1
|
80
|
+
@token_count += tokens_used
|
81
|
+
end
|
82
|
+
|
83
|
+
# Common error handling
|
84
|
+
def with_retry(max_attempts = nil)
|
85
|
+
max_attempts ||= config[:max_retries]
|
86
|
+
attempts = 0
|
87
|
+
|
88
|
+
begin
|
89
|
+
attempts += 1
|
90
|
+
yield
|
91
|
+
rescue StandardError => e
|
92
|
+
raise unless attempts < max_attempts && retryable_error?(e)
|
93
|
+
|
94
|
+
sleep(retry_delay(attempts))
|
95
|
+
retry
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
99
|
+
def retryable_error?(error)
|
100
|
+
# Override in subclasses for specific error types
|
101
|
+
error.message.include?('timeout') || error.message.include?('rate limit')
|
102
|
+
end
|
103
|
+
|
104
|
+
def retry_delay(attempt)
|
105
|
+
# Exponential backoff with jitter
|
106
|
+
base_delay = 2**attempt
|
107
|
+
jitter = rand(0..1.0)
|
108
|
+
base_delay + jitter
|
109
|
+
end
|
110
|
+
end
|
111
|
+
end
|
112
|
+
end
|