zuno 0.1.4 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +158 -0
- data/lib/zuno/version.rb +1 -1
- data/lib/zuno.rb +1415 -10
- metadata +30 -29
- data/lib/providers/openai.rb +0 -58
- data/lib/zuno/chat.rb +0 -39
- data/lib/zuno/configuration.rb +0 -11
data/lib/zuno.rb
CHANGED
|
@@ -1,19 +1,1424 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
|
+
require "json"
|
|
4
|
+
require "securerandom"
|
|
5
|
+
require "cgi"
|
|
6
|
+
require "typhoeus"
|
|
7
|
+
|
|
3
8
|
require_relative "zuno/version"
|
|
4
|
-
require_relative "zuno/configuration"
|
|
5
|
-
require "zuno/chat"
|
|
6
|
-
require "zuno/transcription"
|
|
7
|
-
require "zuno/translation"
|
|
8
|
-
require "faraday"
|
|
9
9
|
|
|
10
10
|
module Zuno
|
|
11
|
-
class
|
|
12
|
-
|
|
11
|
+
class Error < StandardError; end
|
|
12
|
+
class ProviderError < Error; end
|
|
13
|
+
class ToolError < Error; end
|
|
14
|
+
class MaxIterationsExceeded < Error; end
|
|
15
|
+
class StreamingError < Error; end
|
|
16
|
+
class CallbackControl
|
|
17
|
+
attr_reader :stop_reason
|
|
18
|
+
|
|
19
|
+
def initialize
|
|
20
|
+
@stopped = false
|
|
21
|
+
@stop_reason = nil
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def stop!(reason: nil)
|
|
25
|
+
@stopped = true
|
|
26
|
+
@stop_reason = reason
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
def stopped?
|
|
30
|
+
@stopped
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
ModelDescriptor = Struct.new(:id, :provider, :provider_options, keyword_init: true) do
|
|
35
|
+
def initialize(id:, provider:, provider_options: {})
|
|
36
|
+
super(
|
|
37
|
+
id: id.to_s,
|
|
38
|
+
provider: provider.to_sym,
|
|
39
|
+
provider_options: provider_options.is_a?(Hash) ? provider_options : {}
|
|
40
|
+
)
|
|
41
|
+
end
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
ToolDefinition = Struct.new(:name, :description, :input_schema, :execute_proc, keyword_init: true) do
|
|
45
|
+
def initialize(name:, description:, input_schema:, execute_proc:)
|
|
46
|
+
super(
|
|
47
|
+
name: name.to_s,
|
|
48
|
+
description: description.to_s,
|
|
49
|
+
input_schema: input_schema.is_a?(Hash) ? input_schema : {},
|
|
50
|
+
execute_proc: execute_proc
|
|
51
|
+
)
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def as_provider_tool
|
|
55
|
+
{
|
|
56
|
+
type: "function",
|
|
57
|
+
function: {
|
|
58
|
+
name: name,
|
|
59
|
+
description: description,
|
|
60
|
+
parameters: input_schema
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
def execute(arguments)
|
|
66
|
+
raise ToolError, "Tool '#{name}' is missing an execute block" unless execute_proc.respond_to?(:call)
|
|
67
|
+
|
|
68
|
+
symbolized_args = arguments.each_with_object({}) { |(key, value), acc| acc[key.to_sym] = value }
|
|
69
|
+
|
|
70
|
+
begin
|
|
71
|
+
execute_proc.call(arguments)
|
|
72
|
+
rescue ArgumentError, TypeError
|
|
73
|
+
execute_proc.call(**symbolized_args)
|
|
74
|
+
end
|
|
75
|
+
end
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
OPENROUTER_ADAPTER_CONFIG_KEYS = %i[api_key app_url title timeout].freeze
|
|
79
|
+
REPLICATE_ADAPTER_CONFIG_KEYS = %i[api_key timeout].freeze
|
|
80
|
+
DEFAULT_MAX_ITERATIONS = 1
|
|
81
|
+
REPLICATE_PREFER_WAIT_SECONDS = 60
|
|
82
|
+
REPLICATE_POLL_INTERVAL_SECONDS = 1
|
|
83
|
+
REPLICATE_WAIT_TIMEOUT_SECONDS = 600
|
|
84
|
+
REPLICATE_TERMINAL_STATUSES = %w[succeeded failed canceled aborted].freeze
|
|
85
|
+
|
|
86
|
+
module_function
|
|
87
|
+
|
|
88
|
+
def default_provider_options
|
|
89
|
+
@default_provider_options ||= {}
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
def default_provider_options=(options)
|
|
93
|
+
@default_provider_options = options.is_a?(Hash) ? options : {}
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
def model(id, provider: :openrouter, provider_options: {})
|
|
97
|
+
ModelDescriptor.new(id: id, provider: provider, provider_options: provider_options)
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def openrouter(api_key: nil, app_url: nil, title: nil, timeout: Providers::OpenRouter::DEFAULT_TIMEOUT)
|
|
101
|
+
Providers::OpenRouter.new(
|
|
102
|
+
api_key: api_key,
|
|
103
|
+
app_url: app_url,
|
|
104
|
+
title: title,
|
|
105
|
+
timeout: timeout
|
|
106
|
+
)
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def replicate(api_key: nil, timeout: Providers::Replicate::DEFAULT_TIMEOUT)
|
|
110
|
+
Providers::Replicate.new(
|
|
111
|
+
api_key: api_key,
|
|
112
|
+
timeout: timeout
|
|
113
|
+
)
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
def tool(name:, description:, input_schema:, &execute)
|
|
117
|
+
raise ToolError, "A block is required for tool '#{name}'" unless block_given?
|
|
118
|
+
|
|
119
|
+
ToolDefinition.new(
|
|
120
|
+
name: name,
|
|
121
|
+
description: description,
|
|
122
|
+
input_schema: input_schema,
|
|
123
|
+
execute_proc: execute
|
|
124
|
+
)
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
def generate(
|
|
128
|
+
model:,
|
|
129
|
+
messages: nil,
|
|
130
|
+
system: nil,
|
|
131
|
+
prompt: nil,
|
|
132
|
+
input: nil,
|
|
133
|
+
tools: {},
|
|
134
|
+
tool_choice: nil,
|
|
135
|
+
temperature: nil,
|
|
136
|
+
max_tokens: nil,
|
|
137
|
+
provider_options: {},
|
|
138
|
+
before_tool_execution: nil,
|
|
139
|
+
after_tool_execution: nil,
|
|
140
|
+
before_generation: nil,
|
|
141
|
+
after_generation: nil
|
|
142
|
+
)
|
|
143
|
+
callback_control = nil
|
|
144
|
+
after_generation_called = false
|
|
145
|
+
callback_control = CallbackControl.new
|
|
146
|
+
|
|
147
|
+
model_descriptor = normalize_model(model)
|
|
148
|
+
resolved_provider_options = merge_provider_options(
|
|
149
|
+
model_descriptor.provider_options,
|
|
150
|
+
provider_options
|
|
151
|
+
)
|
|
152
|
+
provider = model_descriptor.provider.to_sym
|
|
153
|
+
|
|
154
|
+
call_callback!(
|
|
155
|
+
before_generation,
|
|
156
|
+
{
|
|
157
|
+
model: model_descriptor,
|
|
158
|
+
mode: "single",
|
|
159
|
+
provider: provider
|
|
160
|
+
},
|
|
161
|
+
callback_control
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if callback_control.stopped?
|
|
165
|
+
result = callback_stopped_result(
|
|
166
|
+
control: callback_control,
|
|
167
|
+
iterations: [],
|
|
168
|
+
message: {},
|
|
169
|
+
usage: nil,
|
|
170
|
+
raw_response: nil
|
|
171
|
+
)
|
|
172
|
+
after_generation_called = true
|
|
173
|
+
call_callback!(after_generation, { ok: true, result: result }, callback_control)
|
|
174
|
+
return result
|
|
175
|
+
end
|
|
176
|
+
|
|
177
|
+
result =
|
|
178
|
+
case provider
|
|
179
|
+
when :openrouter
|
|
180
|
+
adapter = provider_adapter(provider, resolved_provider_options)
|
|
181
|
+
generate_openrouter_single(
|
|
182
|
+
model_descriptor: model_descriptor,
|
|
183
|
+
adapter: adapter,
|
|
184
|
+
messages: messages,
|
|
185
|
+
system: system,
|
|
186
|
+
prompt: prompt,
|
|
187
|
+
tools: tools,
|
|
188
|
+
tool_choice: tool_choice,
|
|
189
|
+
temperature: temperature,
|
|
190
|
+
max_tokens: max_tokens,
|
|
191
|
+
provider_options: resolved_provider_options,
|
|
192
|
+
before_tool_execution: before_tool_execution,
|
|
193
|
+
after_tool_execution: after_tool_execution
|
|
194
|
+
)
|
|
195
|
+
when :replicate
|
|
196
|
+
raise Error, "tools are not supported for replicate generate" unless normalize_tools(tools).empty?
|
|
197
|
+
raise Error, "tool_choice is not supported for replicate generate" unless tool_choice.nil?
|
|
198
|
+
|
|
199
|
+
validate_no_webhook_support!(resolved_provider_options)
|
|
200
|
+
adapter = provider_adapter(provider, resolved_provider_options)
|
|
201
|
+
generate_replicate_single(
|
|
202
|
+
model_descriptor: model_descriptor,
|
|
203
|
+
adapter: adapter,
|
|
204
|
+
input: input,
|
|
205
|
+
provider_options: resolved_provider_options
|
|
206
|
+
)
|
|
207
|
+
else
|
|
208
|
+
raise ProviderError, "Unsupported provider: #{provider}"
|
|
209
|
+
end
|
|
210
|
+
|
|
211
|
+
after_generation_called = true
|
|
212
|
+
call_callback!(after_generation, { ok: true, result: result }, callback_control)
|
|
213
|
+
result
|
|
214
|
+
rescue ProviderError => e
|
|
215
|
+
unless after_generation_called
|
|
216
|
+
after_generation_called = true
|
|
217
|
+
call_callback!(after_generation, { ok: false, error: e }, callback_control)
|
|
218
|
+
end
|
|
219
|
+
raise
|
|
220
|
+
rescue StandardError => e
|
|
221
|
+
unless after_generation_called
|
|
222
|
+
after_generation_called = true
|
|
223
|
+
call_callback!(after_generation, { ok: false, error: e }, callback_control)
|
|
224
|
+
end
|
|
225
|
+
raise Error, e.message
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
def loop(
|
|
229
|
+
model:,
|
|
230
|
+
messages: nil,
|
|
231
|
+
system: nil,
|
|
232
|
+
prompt: nil,
|
|
233
|
+
tools: {},
|
|
234
|
+
tool_choice: nil,
|
|
235
|
+
stop_when: nil,
|
|
236
|
+
max_iterations: DEFAULT_MAX_ITERATIONS,
|
|
237
|
+
temperature: nil,
|
|
238
|
+
max_tokens: nil,
|
|
239
|
+
provider_options: {},
|
|
240
|
+
before_tool_execution: nil,
|
|
241
|
+
after_tool_execution: nil,
|
|
242
|
+
before_iteration: nil,
|
|
243
|
+
after_iteration: nil,
|
|
244
|
+
before_generation: nil,
|
|
245
|
+
after_generation: nil
|
|
246
|
+
)
|
|
247
|
+
callback_control = nil
|
|
248
|
+
model_descriptor = normalize_model(model)
|
|
249
|
+
raise Error, "loop only supports openrouter provider" unless model_descriptor.provider.to_sym == :openrouter
|
|
250
|
+
|
|
251
|
+
resolved_provider_options = merge_provider_options(
|
|
252
|
+
model_descriptor.provider_options,
|
|
253
|
+
provider_options
|
|
254
|
+
)
|
|
255
|
+
adapter = provider_adapter(model_descriptor.provider, resolved_provider_options)
|
|
256
|
+
tool_map = normalize_tools(tools)
|
|
257
|
+
llm_messages = normalize_messages(messages: messages, system: system, prompt: prompt)
|
|
258
|
+
resolved_tool_choice = normalize_tool_choice(
|
|
259
|
+
explicit_tool_choice: tool_choice,
|
|
260
|
+
provider_options: resolved_provider_options,
|
|
261
|
+
tools: tool_map
|
|
262
|
+
)
|
|
263
|
+
resolved_stop_when = normalize_stop_when(stop_when)
|
|
264
|
+
resolved_max_iterations = normalize_max_iterations(max_iterations)
|
|
265
|
+
after_generation_called = false
|
|
266
|
+
callback_control = CallbackControl.new
|
|
267
|
+
|
|
268
|
+
call_callback!(
|
|
269
|
+
before_generation,
|
|
270
|
+
{
|
|
271
|
+
model: model_descriptor,
|
|
272
|
+
messages: llm_messages,
|
|
273
|
+
tool_names: tool_map.keys,
|
|
274
|
+
tool_choice: resolved_tool_choice,
|
|
275
|
+
max_iterations: resolved_max_iterations,
|
|
276
|
+
stop_when: resolved_stop_when
|
|
277
|
+
},
|
|
278
|
+
callback_control
|
|
279
|
+
)
|
|
280
|
+
if callback_control.stopped?
|
|
281
|
+
result = callback_stopped_result(
|
|
282
|
+
control: callback_control,
|
|
283
|
+
iterations: [],
|
|
284
|
+
message: {},
|
|
285
|
+
usage: nil,
|
|
286
|
+
raw_response: nil
|
|
287
|
+
)
|
|
288
|
+
after_generation_called = true
|
|
289
|
+
call_callback!(after_generation, { ok: true, result: result }, callback_control)
|
|
290
|
+
return result
|
|
291
|
+
end
|
|
292
|
+
|
|
293
|
+
iterations = []
|
|
294
|
+
iteration_count = 0
|
|
295
|
+
|
|
296
|
+
infinite_iterations = resolved_max_iterations == :infinite
|
|
297
|
+
|
|
298
|
+
while infinite_iterations || iteration_count < resolved_max_iterations
|
|
299
|
+
current_iteration = iteration_count + 1
|
|
300
|
+
call_callback!(
|
|
301
|
+
before_iteration,
|
|
302
|
+
{
|
|
303
|
+
iteration_index: current_iteration,
|
|
304
|
+
messages: llm_messages
|
|
305
|
+
},
|
|
306
|
+
callback_control
|
|
307
|
+
)
|
|
308
|
+
if callback_control.stopped?
|
|
309
|
+
result = callback_stopped_result(
|
|
310
|
+
control: callback_control,
|
|
311
|
+
iterations: iterations,
|
|
312
|
+
message: {},
|
|
313
|
+
usage: nil,
|
|
314
|
+
raw_response: nil
|
|
315
|
+
)
|
|
316
|
+
after_generation_called = true
|
|
317
|
+
call_callback!(after_generation, { ok: true, result: result }, callback_control)
|
|
318
|
+
return result
|
|
319
|
+
end
|
|
320
|
+
|
|
321
|
+
payload = build_payload(
|
|
322
|
+
model_id: model_descriptor.id,
|
|
323
|
+
messages: llm_messages,
|
|
324
|
+
tools: tool_map,
|
|
325
|
+
tool_choice: resolved_tool_choice,
|
|
326
|
+
temperature: temperature,
|
|
327
|
+
max_tokens: max_tokens,
|
|
328
|
+
provider_options: resolved_provider_options
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
response = adapter.chat(payload)
|
|
332
|
+
message = response.dig("choices", 0, "message") || {}
|
|
333
|
+
tool_calls = Array(message["tool_calls"])
|
|
334
|
+
|
|
335
|
+
iteration_record = {
|
|
336
|
+
index: current_iteration,
|
|
337
|
+
message: message,
|
|
338
|
+
tool_calls: tool_calls,
|
|
339
|
+
usage: response["usage"],
|
|
340
|
+
finish_reason: response.dig("choices", 0, "finish_reason"),
|
|
341
|
+
tool_results: []
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
if tool_calls.empty?
|
|
345
|
+
iterations << iteration_record
|
|
346
|
+
call_callback!(
|
|
347
|
+
after_iteration,
|
|
348
|
+
{
|
|
349
|
+
iteration_index: current_iteration,
|
|
350
|
+
iteration: iteration_record
|
|
351
|
+
},
|
|
352
|
+
callback_control
|
|
353
|
+
)
|
|
354
|
+
if callback_control.stopped?
|
|
355
|
+
result = callback_stopped_result(
|
|
356
|
+
control: callback_control,
|
|
357
|
+
iterations: iterations,
|
|
358
|
+
message: message,
|
|
359
|
+
usage: response["usage"],
|
|
360
|
+
raw_response: response
|
|
361
|
+
)
|
|
362
|
+
after_generation_called = true
|
|
363
|
+
call_callback!(after_generation, { ok: true, result: result }, callback_control)
|
|
364
|
+
return result
|
|
365
|
+
end
|
|
366
|
+
|
|
367
|
+
result = {
|
|
368
|
+
text: extract_message_text(message),
|
|
369
|
+
message: message,
|
|
370
|
+
usage: response["usage"],
|
|
371
|
+
finish_reason: response.dig("choices", 0, "finish_reason"),
|
|
372
|
+
iterations: iterations,
|
|
373
|
+
raw_response: response
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
after_generation_called = true
|
|
377
|
+
call_callback!(after_generation, { ok: true, result: result }, callback_control)
|
|
378
|
+
return result
|
|
379
|
+
end
|
|
380
|
+
|
|
381
|
+
llm_messages << build_assistant_tool_call_message(message: message, tool_calls: tool_calls)
|
|
382
|
+
stop_triggered = false
|
|
383
|
+
stop_triggered_tool_name = nil
|
|
384
|
+
|
|
385
|
+
tool_calls.each do |tool_call|
|
|
386
|
+
tool_call_id = normalize_tool_call_id(tool_call["id"])
|
|
387
|
+
arguments = parse_arguments(tool_call.dig("function", "arguments"))
|
|
388
|
+
tool_name = tool_call.dig("function", "name").to_s
|
|
389
|
+
|
|
390
|
+
call_callback!(
|
|
391
|
+
before_tool_execution,
|
|
392
|
+
{
|
|
393
|
+
iteration_index: current_iteration,
|
|
394
|
+
tool_call_id: tool_call_id,
|
|
395
|
+
tool_name: tool_name,
|
|
396
|
+
input: arguments,
|
|
397
|
+
raw_tool_call: tool_call
|
|
398
|
+
},
|
|
399
|
+
callback_control
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
tool_result = execute_tool_call(
|
|
403
|
+
tool_call: tool_call,
|
|
404
|
+
tools: tool_map,
|
|
405
|
+
tool_call_id: tool_call_id,
|
|
406
|
+
arguments: arguments
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
iteration_record[:tool_results] << tool_result
|
|
410
|
+
call_callback!(
|
|
411
|
+
after_tool_execution,
|
|
412
|
+
tool_result.merge(iteration_index: current_iteration),
|
|
413
|
+
callback_control
|
|
414
|
+
)
|
|
415
|
+
if tool_stop_condition_met?(resolved_stop_when, tool_result)
|
|
416
|
+
stop_triggered = true
|
|
417
|
+
stop_triggered_tool_name ||= tool_result[:tool_name]
|
|
418
|
+
end
|
|
419
|
+
|
|
420
|
+
llm_messages << {
|
|
421
|
+
"role" => "tool",
|
|
422
|
+
"tool_call_id" => tool_result[:tool_call_id],
|
|
423
|
+
"content" => serialize_tool_content(tool_result[:output])
|
|
424
|
+
}
|
|
425
|
+
end
|
|
426
|
+
|
|
427
|
+
iterations << iteration_record
|
|
428
|
+
call_callback!(
|
|
429
|
+
after_iteration,
|
|
430
|
+
{
|
|
431
|
+
iteration_index: current_iteration,
|
|
432
|
+
iteration: iteration_record
|
|
433
|
+
},
|
|
434
|
+
callback_control
|
|
435
|
+
)
|
|
436
|
+
if callback_control.stopped?
|
|
437
|
+
result = callback_stopped_result(
|
|
438
|
+
control: callback_control,
|
|
439
|
+
iterations: iterations,
|
|
440
|
+
message: message,
|
|
441
|
+
usage: response["usage"],
|
|
442
|
+
raw_response: response
|
|
443
|
+
)
|
|
444
|
+
after_generation_called = true
|
|
445
|
+
call_callback!(after_generation, { ok: true, result: result }, callback_control)
|
|
446
|
+
return result
|
|
447
|
+
end
|
|
448
|
+
|
|
449
|
+
if stop_triggered
|
|
450
|
+
result = {
|
|
451
|
+
text: extract_message_text(message),
|
|
452
|
+
message: message,
|
|
453
|
+
usage: response["usage"],
|
|
454
|
+
finish_reason: "stop_when_tool_called",
|
|
455
|
+
stop_reason: {
|
|
456
|
+
type: "tool_called",
|
|
457
|
+
tool_name: stop_triggered_tool_name
|
|
458
|
+
},
|
|
459
|
+
iterations: iterations,
|
|
460
|
+
raw_response: response
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
after_generation_called = true
|
|
464
|
+
call_callback!(after_generation, { ok: true, result: result }, callback_control)
|
|
465
|
+
return result
|
|
466
|
+
end
|
|
467
|
+
|
|
468
|
+
iteration_count += 1
|
|
469
|
+
end
|
|
470
|
+
|
|
471
|
+
raise MaxIterationsExceeded,
|
|
472
|
+
"Reached max_iterations=#{resolved_max_iterations} without a final assistant response" unless infinite_iterations
|
|
473
|
+
rescue ProviderError, MaxIterationsExceeded => e
|
|
474
|
+
unless after_generation_called
|
|
475
|
+
after_generation_called = true
|
|
476
|
+
call_callback!(after_generation, { ok: false, error: e }, callback_control)
|
|
477
|
+
end
|
|
478
|
+
raise
|
|
479
|
+
rescue StandardError => e
|
|
480
|
+
unless after_generation_called
|
|
481
|
+
after_generation_called = true
|
|
482
|
+
call_callback!(after_generation, { ok: false, error: e }, callback_control)
|
|
483
|
+
end
|
|
484
|
+
raise Error, e.message
|
|
485
|
+
end
|
|
486
|
+
|
|
487
|
+
def stream(
|
|
488
|
+
model:,
|
|
489
|
+
messages: nil,
|
|
490
|
+
system: nil,
|
|
491
|
+
prompt: nil,
|
|
492
|
+
temperature: nil,
|
|
493
|
+
max_tokens: nil,
|
|
494
|
+
provider_options: {},
|
|
495
|
+
&block
|
|
496
|
+
)
|
|
497
|
+
raise ArgumentError, "stream requires a block callback" unless block_given?
|
|
498
|
+
|
|
499
|
+
model_descriptor = normalize_model(model)
|
|
500
|
+
raise ProviderError, "stream only supports openrouter provider" unless model_descriptor.provider.to_sym == :openrouter
|
|
501
|
+
|
|
502
|
+
resolved_provider_options = merge_provider_options(
|
|
503
|
+
model_descriptor.provider_options,
|
|
504
|
+
provider_options
|
|
505
|
+
)
|
|
506
|
+
adapter = provider_adapter(model_descriptor.provider, resolved_provider_options)
|
|
507
|
+
llm_messages = normalize_messages(messages: messages, system: system, prompt: prompt)
|
|
508
|
+
|
|
509
|
+
payload = build_payload(
|
|
510
|
+
model_id: model_descriptor.id,
|
|
511
|
+
messages: llm_messages,
|
|
512
|
+
tools: {},
|
|
513
|
+
tool_choice: nil,
|
|
514
|
+
temperature: temperature,
|
|
515
|
+
max_tokens: max_tokens,
|
|
516
|
+
provider_options: resolved_provider_options
|
|
517
|
+
).merge("stream" => true)
|
|
518
|
+
|
|
519
|
+
block.call(type: :start, model: model_descriptor.id, provider: model_descriptor.provider)
|
|
520
|
+
|
|
521
|
+
emitted_finish = false
|
|
522
|
+
|
|
523
|
+
adapter.stream(payload) do |raw_data|
|
|
524
|
+
next if raw_data == "[DONE]"
|
|
525
|
+
|
|
526
|
+
parsed = parse_json(raw_data)
|
|
527
|
+
raise StreamingError, "Malformed SSE payload: #{raw_data}" if parsed.nil?
|
|
528
|
+
|
|
529
|
+
usage = parsed["usage"]
|
|
530
|
+
block.call(type: :usage, usage: usage, raw: parsed) if usage.is_a?(Hash)
|
|
531
|
+
|
|
532
|
+
choice = Array(parsed["choices"]).first || {}
|
|
533
|
+
delta = choice["delta"] || {}
|
|
534
|
+
|
|
535
|
+
text_delta = extract_text_delta(delta)
|
|
536
|
+
block.call(type: :text_delta, text: text_delta, raw: parsed) if text_delta && !text_delta.empty?
|
|
537
|
+
|
|
538
|
+
Array(delta["tool_calls"]).each do |tool_call_delta|
|
|
539
|
+
tool_delta = {
|
|
540
|
+
index: tool_call_delta["index"],
|
|
541
|
+
id: tool_call_delta["id"],
|
|
542
|
+
type: tool_call_delta["type"],
|
|
543
|
+
name: tool_call_delta.dig("function", "name"),
|
|
544
|
+
arguments_delta: tool_call_delta.dig("function", "arguments")
|
|
545
|
+
}
|
|
546
|
+
block.call(type: :tool_call_delta, tool_call: tool_delta, raw: parsed)
|
|
547
|
+
end
|
|
548
|
+
|
|
549
|
+
finish_reason = choice["finish_reason"]
|
|
550
|
+
next if finish_reason.nil? || finish_reason.to_s.empty?
|
|
551
|
+
|
|
552
|
+
emitted_finish = true
|
|
553
|
+
block.call(type: :finish, finish_reason: finish_reason, raw: parsed)
|
|
554
|
+
end
|
|
555
|
+
|
|
556
|
+
block.call(type: :finish, finish_reason: "stop") unless emitted_finish
|
|
557
|
+
true
|
|
558
|
+
rescue ProviderError => e
|
|
559
|
+
wrapped_error = StreamingError.new(e.message)
|
|
560
|
+
block.call(type: :error, error: wrapped_error.message) if block_given?
|
|
561
|
+
raise wrapped_error
|
|
562
|
+
rescue StreamingError => e
|
|
563
|
+
block.call(type: :error, error: e.message) if block_given?
|
|
564
|
+
raise
|
|
565
|
+
rescue StandardError => e
|
|
566
|
+
wrapped_error = StreamingError.new(e.message)
|
|
567
|
+
block.call(type: :error, error: wrapped_error.message) if block_given?
|
|
568
|
+
raise wrapped_error
|
|
569
|
+
end
|
|
570
|
+
|
|
571
|
+
def normalize_model(input)
|
|
572
|
+
return input if input.is_a?(ModelDescriptor)
|
|
573
|
+
return model(input, provider: :openrouter) if input.is_a?(String)
|
|
574
|
+
|
|
575
|
+
if input.is_a?(Hash)
|
|
576
|
+
return model(
|
|
577
|
+
input[:id] || input["id"],
|
|
578
|
+
provider: input[:provider] || input["provider"] || :openrouter,
|
|
579
|
+
provider_options: input[:provider_options] || input["provider_options"] || {}
|
|
580
|
+
)
|
|
581
|
+
end
|
|
582
|
+
|
|
583
|
+
raise Error, "Unsupported model value: #{input.inspect}"
|
|
584
|
+
end
|
|
585
|
+
private_class_method :normalize_model
|
|
586
|
+
|
|
587
|
+
def generate_openrouter_single(
|
|
588
|
+
model_descriptor:,
|
|
589
|
+
adapter:,
|
|
590
|
+
messages:,
|
|
591
|
+
system:,
|
|
592
|
+
prompt:,
|
|
593
|
+
tools:,
|
|
594
|
+
tool_choice:,
|
|
595
|
+
temperature:,
|
|
596
|
+
max_tokens:,
|
|
597
|
+
provider_options:,
|
|
598
|
+
before_tool_execution:,
|
|
599
|
+
after_tool_execution:
|
|
600
|
+
)
|
|
601
|
+
tool_map = normalize_tools(tools)
|
|
602
|
+
llm_messages = normalize_messages(messages: messages, system: system, prompt: prompt)
|
|
603
|
+
resolved_tool_choice = normalize_tool_choice(
|
|
604
|
+
explicit_tool_choice: tool_choice,
|
|
605
|
+
provider_options: provider_options,
|
|
606
|
+
tools: tool_map
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
payload = build_payload(
|
|
610
|
+
model_id: model_descriptor.id,
|
|
611
|
+
messages: llm_messages,
|
|
612
|
+
tools: tool_map,
|
|
613
|
+
tool_choice: resolved_tool_choice,
|
|
614
|
+
temperature: temperature,
|
|
615
|
+
max_tokens: max_tokens,
|
|
616
|
+
provider_options: provider_options
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
response = adapter.chat(payload)
|
|
620
|
+
message = response.dig("choices", 0, "message") || {}
|
|
621
|
+
tool_calls = Array(message["tool_calls"])
|
|
622
|
+
tool_results = []
|
|
623
|
+
|
|
624
|
+
unless tool_calls.empty? || tool_map.empty?
|
|
625
|
+
tool_calls.each do |tool_call|
|
|
626
|
+
tool_call_id = normalize_tool_call_id(tool_call["id"])
|
|
627
|
+
arguments = parse_arguments(tool_call.dig("function", "arguments"))
|
|
628
|
+
tool_name = tool_call.dig("function", "name").to_s
|
|
629
|
+
|
|
630
|
+
call_callback!(
|
|
631
|
+
before_tool_execution,
|
|
632
|
+
{
|
|
633
|
+
iteration_index: 1,
|
|
634
|
+
tool_call_id: tool_call_id,
|
|
635
|
+
tool_name: tool_name,
|
|
636
|
+
input: arguments,
|
|
637
|
+
raw_tool_call: tool_call
|
|
638
|
+
}
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
tool_result = execute_tool_call(
|
|
642
|
+
tool_call: tool_call,
|
|
643
|
+
tools: tool_map,
|
|
644
|
+
tool_call_id: tool_call_id,
|
|
645
|
+
arguments: arguments
|
|
646
|
+
)
|
|
647
|
+
tool_results << tool_result
|
|
648
|
+
|
|
649
|
+
call_callback!(
|
|
650
|
+
after_tool_execution,
|
|
651
|
+
tool_result.merge(iteration_index: 1)
|
|
652
|
+
)
|
|
653
|
+
end
|
|
654
|
+
end
|
|
655
|
+
|
|
656
|
+
result = {
|
|
657
|
+
text: extract_message_text(message),
|
|
658
|
+
message: message,
|
|
659
|
+
usage: response["usage"],
|
|
660
|
+
finish_reason: response.dig("choices", 0, "finish_reason"),
|
|
661
|
+
tool_calls: tool_calls,
|
|
662
|
+
raw_response: response
|
|
663
|
+
}
|
|
664
|
+
result[:tool_results] = tool_results unless tool_results.empty?
|
|
665
|
+
result
|
|
666
|
+
end
|
|
667
|
+
private_class_method :generate_openrouter_single
|
|
668
|
+
|
|
669
|
+
def generate_replicate_single(model_descriptor:, adapter:, input:, provider_options:)
|
|
670
|
+
raise Error, "generate with replicate requires input: Hash" unless input.is_a?(Hash)
|
|
671
|
+
|
|
672
|
+
reference = normalize_replicate_reference(
|
|
673
|
+
model_descriptor: model_descriptor,
|
|
674
|
+
provider_options: provider_options
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
prediction = adapter.create_prediction(
|
|
678
|
+
reference: reference,
|
|
679
|
+
input: input
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
deadline = Process.clock_gettime(Process::CLOCK_MONOTONIC) + REPLICATE_WAIT_TIMEOUT_SECONDS
|
|
683
|
+
|
|
684
|
+
until replicate_terminal_status?(prediction["status"])
|
|
685
|
+
if Process.clock_gettime(Process::CLOCK_MONOTONIC) >= deadline
|
|
686
|
+
raise ProviderError, "Replicate prediction did not finish within #{REPLICATE_WAIT_TIMEOUT_SECONDS} seconds"
|
|
687
|
+
end
|
|
688
|
+
|
|
689
|
+
sleep(REPLICATE_POLL_INTERVAL_SECONDS)
|
|
690
|
+
prediction = adapter.get_prediction(prediction: prediction)
|
|
691
|
+
end
|
|
692
|
+
|
|
693
|
+
{
|
|
694
|
+
id: prediction["id"],
|
|
695
|
+
status: prediction["status"],
|
|
696
|
+
output: prediction["output"],
|
|
697
|
+
error: prediction["error"],
|
|
698
|
+
logs: prediction["logs"],
|
|
699
|
+
metrics: prediction["metrics"],
|
|
700
|
+
urls: prediction["urls"],
|
|
701
|
+
raw_response: prediction
|
|
702
|
+
}
|
|
703
|
+
end
|
|
704
|
+
private_class_method :generate_replicate_single
|
|
705
|
+
|
|
706
|
+
def normalize_replicate_reference(model_descriptor:, provider_options:)
|
|
707
|
+
type = provider_options[:replicate_target] || provider_options["replicate_target"] || :model
|
|
708
|
+
normalized_type = type.to_sym
|
|
709
|
+
model_id = model_descriptor.id.to_s.strip
|
|
710
|
+
raise Error, "Replicate model id is required" if model_id.empty?
|
|
711
|
+
|
|
712
|
+
if normalized_type == :model || normalized_type == :deployment
|
|
713
|
+
owner, name, extra = model_id.split("/", 3)
|
|
714
|
+
if owner.to_s.empty? || name.to_s.empty? || !extra.nil?
|
|
715
|
+
raise Error, "Replicate #{normalized_type} id must be in 'owner/name' format"
|
|
716
|
+
end
|
|
717
|
+
elsif normalized_type != :version
|
|
718
|
+
raise Error, "Unsupported replicate_target: #{normalized_type}"
|
|
719
|
+
end
|
|
720
|
+
|
|
721
|
+
{
|
|
722
|
+
type: normalized_type,
|
|
723
|
+
id: model_id
|
|
724
|
+
}
|
|
725
|
+
end
|
|
726
|
+
private_class_method :normalize_replicate_reference
|
|
727
|
+
|
|
728
|
+
def replicate_terminal_status?(status)
|
|
729
|
+
REPLICATE_TERMINAL_STATUSES.include?(status.to_s)
|
|
730
|
+
end
|
|
731
|
+
private_class_method :replicate_terminal_status?
|
|
732
|
+
|
|
733
|
+
def validate_no_webhook_support!(provider_options)
|
|
734
|
+
return unless provider_options.is_a?(Hash)
|
|
735
|
+
|
|
736
|
+
webhook_set = provider_options.key?(:webhook) || provider_options.key?("webhook")
|
|
737
|
+
events_set =
|
|
738
|
+
provider_options.key?(:webhook_events_filter) ||
|
|
739
|
+
provider_options.key?("webhook_events_filter")
|
|
740
|
+
return unless webhook_set || events_set
|
|
741
|
+
|
|
742
|
+
raise Error, "webhook and webhook_events_filter are not supported"
|
|
743
|
+
end
|
|
744
|
+
private_class_method :validate_no_webhook_support!
|
|
745
|
+
|
|
746
|
+
def normalize_tools(tools)
|
|
747
|
+
return {} if tools.nil?
|
|
748
|
+
|
|
749
|
+
if tools.is_a?(Array)
|
|
750
|
+
return tools.each_with_object({}) do |entry, acc|
|
|
751
|
+
next unless entry.is_a?(ToolDefinition)
|
|
752
|
+
|
|
753
|
+
acc[entry.name] = entry
|
|
754
|
+
end
|
|
755
|
+
end
|
|
756
|
+
|
|
757
|
+
raise ToolError, "tools must be a Hash or Array" unless tools.is_a?(Hash)
|
|
758
|
+
|
|
759
|
+
tools.each_with_object({}) do |(name, value), acc|
|
|
760
|
+
tool_name = name.to_s
|
|
761
|
+
acc[tool_name] = normalize_tool_entry(tool_name, value)
|
|
762
|
+
end
|
|
763
|
+
end
|
|
764
|
+
private_class_method :normalize_tools
|
|
765
|
+
|
|
766
|
+
def normalize_tool_entry(name, value)
|
|
767
|
+
return value if value.is_a?(ToolDefinition)
|
|
768
|
+
|
|
769
|
+
if value.is_a?(Hash)
|
|
770
|
+
execute_proc = value[:execute] || value["execute"]
|
|
771
|
+
return ToolDefinition.new(
|
|
772
|
+
name: name,
|
|
773
|
+
description: value[:description] || value["description"] || "",
|
|
774
|
+
input_schema: value[:input_schema] || value["input_schema"] || value[:parameters] || value["parameters"] || {},
|
|
775
|
+
execute_proc: execute_proc
|
|
776
|
+
)
|
|
777
|
+
end
|
|
778
|
+
|
|
779
|
+
raise ToolError, "Tool '#{name}' has invalid definition"
|
|
780
|
+
end
|
|
781
|
+
private_class_method :normalize_tool_entry
|
|
782
|
+
|
|
783
|
+
def normalize_messages(messages:, system:, prompt:)
|
|
784
|
+
if messages.nil? || messages.empty?
|
|
785
|
+
normalized = []
|
|
786
|
+
normalized << { "role" => "system", "content" => system.to_s } if system
|
|
787
|
+
normalized << { "role" => "user", "content" => prompt.to_s } if prompt
|
|
788
|
+
return normalized
|
|
789
|
+
end
|
|
790
|
+
|
|
791
|
+
normalized = deep_stringify(messages)
|
|
792
|
+
normalized.unshift({ "role" => "system", "content" => system.to_s }) if system
|
|
793
|
+
normalized << { "role" => "user", "content" => prompt.to_s } if prompt
|
|
794
|
+
normalized
|
|
795
|
+
end
|
|
796
|
+
private_class_method :normalize_messages
|
|
797
|
+
|
|
798
|
+
def build_payload(model_id:, messages:, tools:, tool_choice:, temperature:, max_tokens:, provider_options:)
|
|
799
|
+
payload = {
|
|
800
|
+
"model" => model_id,
|
|
801
|
+
"messages" => messages
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
payload["temperature"] = temperature unless temperature.nil?
|
|
805
|
+
payload["max_tokens"] = max_tokens unless max_tokens.nil?
|
|
806
|
+
payload["tools"] = tools.values.map(&:as_provider_tool) unless tools.empty?
|
|
807
|
+
payload["tool_choice"] = deep_stringify(tool_choice) unless tool_choice.nil?
|
|
808
|
+
|
|
809
|
+
request_options = reject_keys(provider_options, OPENROUTER_ADAPTER_CONFIG_KEYS + [ :tool_choice ])
|
|
810
|
+
payload.merge!(deep_stringify(request_options)) if request_options.is_a?(Hash)
|
|
811
|
+
payload
|
|
812
|
+
end
|
|
813
|
+
private_class_method :build_payload
|
|
814
|
+
|
|
815
|
+
def merge_provider_options(model_provider_options, call_provider_options)
|
|
816
|
+
merged = {}
|
|
817
|
+
merged.merge!(default_provider_options) if default_provider_options.is_a?(Hash)
|
|
818
|
+
merged.merge!(model_provider_options) if model_provider_options.is_a?(Hash)
|
|
819
|
+
merged.merge!(call_provider_options) if call_provider_options.is_a?(Hash)
|
|
820
|
+
merged
|
|
821
|
+
end
|
|
822
|
+
private_class_method :merge_provider_options
|
|
823
|
+
|
|
824
|
+
def provider_adapter(provider, provider_options)
|
|
825
|
+
case provider.to_sym
|
|
826
|
+
when :openrouter
|
|
827
|
+
config = pick_keys(provider_options, OPENROUTER_ADAPTER_CONFIG_KEYS)
|
|
828
|
+
Providers::OpenRouter.new(**config)
|
|
829
|
+
when :replicate
|
|
830
|
+
config = pick_keys(provider_options, REPLICATE_ADAPTER_CONFIG_KEYS)
|
|
831
|
+
Providers::Replicate.new(**config)
|
|
832
|
+
else
|
|
833
|
+
raise ProviderError, "Unsupported provider: #{provider}"
|
|
834
|
+
end
|
|
835
|
+
end
|
|
836
|
+
private_class_method :provider_adapter
|
|
837
|
+
|
|
838
|
+
def build_assistant_tool_call_message(message:, tool_calls:)
|
|
839
|
+
{
|
|
840
|
+
"role" => "assistant",
|
|
841
|
+
"content" => extract_message_text(message),
|
|
842
|
+
"tool_calls" => tool_calls
|
|
843
|
+
}
|
|
844
|
+
end
|
|
845
|
+
private_class_method :build_assistant_tool_call_message
|
|
846
|
+
|
|
847
|
+
def execute_tool_call(tool_call:, tools:, tool_call_id: nil, arguments: nil)
|
|
848
|
+
tool_name = tool_call.dig("function", "name").to_s
|
|
849
|
+
tool_call_id = normalize_tool_call_id(tool_call_id || tool_call["id"])
|
|
850
|
+
arguments = parse_arguments(tool_call.dig("function", "arguments")) if arguments.nil?
|
|
851
|
+
tool = tools[tool_name]
|
|
852
|
+
|
|
853
|
+
unless tool
|
|
854
|
+
return {
|
|
855
|
+
tool_call_id: tool_call_id,
|
|
856
|
+
tool_name: tool_name,
|
|
857
|
+
input: arguments,
|
|
858
|
+
ok: false,
|
|
859
|
+
output: { "ok" => false, "error" => "Unknown tool: #{tool_name}" }
|
|
860
|
+
}
|
|
861
|
+
end
|
|
862
|
+
|
|
863
|
+
output = normalize_output_payload(tool.execute(arguments))
|
|
864
|
+
{
|
|
865
|
+
tool_call_id: tool_call_id,
|
|
866
|
+
tool_name: tool_name,
|
|
867
|
+
input: arguments,
|
|
868
|
+
ok: true,
|
|
869
|
+
output: output
|
|
870
|
+
}
|
|
871
|
+
rescue StandardError => e
|
|
872
|
+
{
|
|
873
|
+
tool_call_id: tool_call_id,
|
|
874
|
+
tool_name: tool_name,
|
|
875
|
+
input: arguments,
|
|
876
|
+
ok: false,
|
|
877
|
+
output: { "ok" => false, "error" => "Tool execution failed: #{e.message}" }
|
|
878
|
+
}
|
|
879
|
+
end
|
|
880
|
+
private_class_method :execute_tool_call
|
|
881
|
+
|
|
882
|
+
def normalize_tool_call_id(value)
|
|
883
|
+
normalized = value.to_s.strip
|
|
884
|
+
return normalized unless normalized.empty?
|
|
885
|
+
|
|
886
|
+
"call_#{SecureRandom.uuid}"
|
|
887
|
+
end
|
|
888
|
+
private_class_method :normalize_tool_call_id
|
|
889
|
+
|
|
890
|
+
def normalize_max_iterations(value)
|
|
891
|
+
return DEFAULT_MAX_ITERATIONS if value.nil?
|
|
892
|
+
return :infinite if value == :infinite || value == Float::INFINITY
|
|
893
|
+
return value if value.is_a?(Integer) && value.positive?
|
|
894
|
+
|
|
895
|
+
raise Error, "max_iterations must be a positive Integer or :infinite"
|
|
896
|
+
end
|
|
897
|
+
private_class_method :normalize_max_iterations
|
|
898
|
+
|
|
899
|
+
def normalize_tool_choice(explicit_tool_choice:, provider_options:, tools:)
|
|
900
|
+
requested_tool_choice = if explicit_tool_choice.nil? && provider_options.is_a?(Hash)
|
|
901
|
+
provider_options[:tool_choice] || provider_options["tool_choice"]
|
|
902
|
+
else
|
|
903
|
+
explicit_tool_choice
|
|
904
|
+
end
|
|
905
|
+
|
|
906
|
+
if requested_tool_choice.nil?
|
|
907
|
+
return nil if tools.empty?
|
|
908
|
+
|
|
909
|
+
return "auto"
|
|
910
|
+
end
|
|
911
|
+
|
|
912
|
+
normalized = normalize_tool_choice_value(requested_tool_choice)
|
|
913
|
+
|
|
914
|
+
if tools.empty?
|
|
915
|
+
return nil if normalized == "auto" || normalized == "none"
|
|
916
|
+
|
|
917
|
+
raise Error, "tool_choice requires at least one tool"
|
|
918
|
+
end
|
|
919
|
+
|
|
920
|
+
if normalized.is_a?(Hash)
|
|
921
|
+
tool_name = normalized.dig("function", "name").to_s
|
|
922
|
+
raise Error, "tool_choice references unknown tool '#{tool_name}'" unless tools.key?(tool_name)
|
|
923
|
+
end
|
|
924
|
+
|
|
925
|
+
normalized
|
|
926
|
+
end
|
|
927
|
+
private_class_method :normalize_tool_choice
|
|
928
|
+
|
|
929
|
+
def normalize_tool_choice_value(value)
|
|
930
|
+
case value
|
|
931
|
+
when Symbol, String
|
|
932
|
+
normalized = value.to_s.strip
|
|
933
|
+
return normalized if %w[auto required none].include?(normalized)
|
|
934
|
+
|
|
935
|
+
raise Error, "tool_choice must be one of auto, required, none, or { type: 'tool', toolName: '...' }"
|
|
936
|
+
when Hash
|
|
937
|
+
type = (value[:type] || value["type"]).to_s
|
|
938
|
+
|
|
939
|
+
if type == "tool"
|
|
940
|
+
tool_name =
|
|
941
|
+
value[:tool_name] || value["tool_name"] ||
|
|
942
|
+
value[:toolName] || value["toolName"]
|
|
943
|
+
normalized_tool_name = tool_name.to_s.strip
|
|
944
|
+
raise Error, "tool_choice[:toolName] is required when type is 'tool'" if normalized_tool_name.empty?
|
|
945
|
+
|
|
946
|
+
return {
|
|
947
|
+
"type" => "function",
|
|
948
|
+
"function" => {
|
|
949
|
+
"name" => normalized_tool_name
|
|
950
|
+
}
|
|
951
|
+
}
|
|
952
|
+
end
|
|
953
|
+
|
|
954
|
+
if type == "function"
|
|
955
|
+
tool_name = value.dig(:function, :name) || value.dig("function", "name")
|
|
956
|
+
normalized_tool_name = tool_name.to_s.strip
|
|
957
|
+
raise Error, "tool_choice function name is required when type is 'function'" if normalized_tool_name.empty?
|
|
958
|
+
|
|
959
|
+
return {
|
|
960
|
+
"type" => "function",
|
|
961
|
+
"function" => {
|
|
962
|
+
"name" => normalized_tool_name
|
|
963
|
+
}
|
|
964
|
+
}
|
|
965
|
+
end
|
|
966
|
+
|
|
967
|
+
raise Error, "tool_choice hash must use type: 'tool' (or provider-native type: 'function')"
|
|
968
|
+
else
|
|
969
|
+
raise Error, "tool_choice must be a String, Symbol, or Hash"
|
|
970
|
+
end
|
|
971
|
+
end
|
|
972
|
+
private_class_method :normalize_tool_choice_value
|
|
973
|
+
|
|
974
|
+
def normalize_stop_when(value)
|
|
975
|
+
return {} if value.nil?
|
|
976
|
+
raise Error, "stop_when must be a Hash when provided" unless value.is_a?(Hash)
|
|
977
|
+
|
|
978
|
+
unknown_keys = value.keys.map(&:to_sym) - [ :tool_called ]
|
|
979
|
+
raise Error, "stop_when only supports :tool_called" unless unknown_keys.empty?
|
|
980
|
+
|
|
981
|
+
tool_called = value[:tool_called] || value["tool_called"]
|
|
982
|
+
return {} if tool_called.nil?
|
|
983
|
+
|
|
984
|
+
tool_names =
|
|
985
|
+
case tool_called
|
|
986
|
+
when String, Symbol
|
|
987
|
+
[ tool_called.to_s ]
|
|
988
|
+
when Array
|
|
989
|
+
tool_called.map(&:to_s)
|
|
990
|
+
else
|
|
991
|
+
raise Error, "stop_when[:tool_called] must be a String, Symbol, or Array"
|
|
992
|
+
end
|
|
993
|
+
|
|
994
|
+
normalized_names = tool_names.map(&:strip).reject(&:empty?).uniq
|
|
995
|
+
raise Error, "stop_when[:tool_called] must include at least one tool name" if normalized_names.empty?
|
|
996
|
+
|
|
997
|
+
{ tool_called: normalized_names }
|
|
998
|
+
end
|
|
999
|
+
private_class_method :normalize_stop_when
|
|
1000
|
+
|
|
1001
|
+
def tool_stop_condition_met?(stop_when, tool_result)
|
|
1002
|
+
return false unless stop_when.is_a?(Hash)
|
|
1003
|
+
|
|
1004
|
+
tool_names = Array(stop_when[:tool_called])
|
|
1005
|
+
return false if tool_names.empty?
|
|
1006
|
+
return false unless tool_result[:ok]
|
|
1007
|
+
|
|
1008
|
+
tool_names.include?(tool_result[:tool_name].to_s)
|
|
1009
|
+
end
|
|
1010
|
+
private_class_method :tool_stop_condition_met?
|
|
1011
|
+
|
|
1012
|
+
def callback_stopped_result(control:, iterations:, message:, usage:, raw_response:)
|
|
1013
|
+
{
|
|
1014
|
+
text: extract_message_text(message),
|
|
1015
|
+
message: message,
|
|
1016
|
+
usage: usage,
|
|
1017
|
+
finish_reason: "stopped_by_callback",
|
|
1018
|
+
stop_reason: {
|
|
1019
|
+
type: "callback",
|
|
1020
|
+
reason: control.stop_reason
|
|
1021
|
+
},
|
|
1022
|
+
iterations: iterations,
|
|
1023
|
+
raw_response: raw_response
|
|
1024
|
+
}
|
|
13
1025
|
end
|
|
1026
|
+
private_class_method :callback_stopped_result
|
|
1027
|
+
|
|
1028
|
+
def call_callback!(callback, payload, control = nil)
|
|
1029
|
+
return if callback.nil?
|
|
1030
|
+
raise Error, "Callback must respond to #call" unless callback.respond_to?(:call)
|
|
1031
|
+
|
|
1032
|
+
if control && callback_accepts_control?(callback)
|
|
1033
|
+
callback.call(payload, control)
|
|
1034
|
+
else
|
|
1035
|
+
callback.call(payload)
|
|
1036
|
+
end
|
|
1037
|
+
end
|
|
1038
|
+
private_class_method :call_callback!
|
|
1039
|
+
|
|
1040
|
+
def callback_accepts_control?(callback)
|
|
1041
|
+
return true unless callback.lambda?
|
|
1042
|
+
|
|
1043
|
+
params = callback.parameters
|
|
1044
|
+
return true if params.any? { |param_type, _| param_type == :rest }
|
|
1045
|
+
|
|
1046
|
+
positional_count = params.count { |param_type, _| param_type == :req || param_type == :opt }
|
|
1047
|
+
positional_count >= 2
|
|
1048
|
+
end
|
|
1049
|
+
private_class_method :callback_accepts_control?
|
|
1050
|
+
|
|
1051
|
+
def normalize_output_payload(payload)
|
|
1052
|
+
case payload
|
|
1053
|
+
when Hash, Array
|
|
1054
|
+
deep_stringify(payload)
|
|
1055
|
+
else
|
|
1056
|
+
payload
|
|
1057
|
+
end
|
|
1058
|
+
end
|
|
1059
|
+
private_class_method :normalize_output_payload
|
|
1060
|
+
|
|
1061
|
+
def extract_message_text(message)
|
|
1062
|
+
content = message["content"]
|
|
1063
|
+
return content if content.is_a?(String)
|
|
1064
|
+
|
|
1065
|
+
if content.is_a?(Array)
|
|
1066
|
+
return content.filter_map { |part| part["text"] if part.is_a?(Hash) && part["text"] }.join
|
|
1067
|
+
end
|
|
1068
|
+
|
|
1069
|
+
return content["text"].to_s if content.is_a?(Hash) && content["text"]
|
|
1070
|
+
|
|
1071
|
+
content.to_s
|
|
1072
|
+
end
|
|
1073
|
+
private_class_method :extract_message_text
|
|
1074
|
+
|
|
1075
|
+
def extract_text_delta(delta)
|
|
1076
|
+
content = delta["content"]
|
|
1077
|
+
return content if content.is_a?(String)
|
|
1078
|
+
|
|
1079
|
+
if content.is_a?(Array)
|
|
1080
|
+
return content.filter_map { |part| part["text"] if part.is_a?(Hash) && part["text"] }.join
|
|
1081
|
+
end
|
|
1082
|
+
|
|
1083
|
+
nil
|
|
1084
|
+
end
|
|
1085
|
+
private_class_method :extract_text_delta
|
|
1086
|
+
|
|
1087
|
+
def parse_arguments(arguments)
|
|
1088
|
+
return arguments if arguments.is_a?(Hash)
|
|
1089
|
+
return {} if arguments.nil? || arguments.to_s.strip.empty?
|
|
1090
|
+
|
|
1091
|
+
parse_json(arguments) || {}
|
|
1092
|
+
end
|
|
1093
|
+
private_class_method :parse_arguments
|
|
1094
|
+
|
|
1095
|
+
def parse_json(value)
|
|
1096
|
+
JSON.parse(value)
|
|
1097
|
+
rescue JSON::ParserError, TypeError
|
|
1098
|
+
nil
|
|
1099
|
+
end
|
|
1100
|
+
private_class_method :parse_json
|
|
1101
|
+
|
|
1102
|
+
def serialize_tool_content(output)
|
|
1103
|
+
return output if output.is_a?(String)
|
|
1104
|
+
|
|
1105
|
+
JSON.generate(output)
|
|
1106
|
+
rescue StandardError
|
|
1107
|
+
output.to_s
|
|
1108
|
+
end
|
|
1109
|
+
private_class_method :serialize_tool_content
|
|
1110
|
+
|
|
1111
|
+
def pick_keys(hash, keys)
|
|
1112
|
+
return {} unless hash.is_a?(Hash)
|
|
1113
|
+
|
|
1114
|
+
hash.each_with_object({}) do |(key, value), acc|
|
|
1115
|
+
symbol_key = key.to_sym
|
|
1116
|
+
acc[symbol_key] = value if keys.include?(symbol_key)
|
|
1117
|
+
end
|
|
1118
|
+
end
|
|
1119
|
+
private_class_method :pick_keys
|
|
1120
|
+
|
|
1121
|
+
def reject_keys(hash, keys)
|
|
1122
|
+
return {} unless hash.is_a?(Hash)
|
|
1123
|
+
|
|
1124
|
+
hash.each_with_object({}) do |(key, value), acc|
|
|
1125
|
+
symbol_key = key.to_sym
|
|
1126
|
+
acc[key] = value unless keys.include?(symbol_key)
|
|
1127
|
+
end
|
|
1128
|
+
end
|
|
1129
|
+
private_class_method :reject_keys
|
|
1130
|
+
|
|
1131
|
+
def deep_stringify(value)
|
|
1132
|
+
case value
|
|
1133
|
+
when Hash
|
|
1134
|
+
value.each_with_object({}) do |(key, item), acc|
|
|
1135
|
+
acc[key.to_s] = deep_stringify(item)
|
|
1136
|
+
end
|
|
1137
|
+
when Array
|
|
1138
|
+
value.map { |item| deep_stringify(item) }
|
|
1139
|
+
else
|
|
1140
|
+
value
|
|
1141
|
+
end
|
|
1142
|
+
end
|
|
1143
|
+
private_class_method :deep_stringify
|
|
1144
|
+
|
|
1145
|
+
module Providers
|
|
1146
|
+
class OpenRouter
|
|
1147
|
+
CHAT_COMPLETIONS_URL = "https://openrouter.ai/api/v1/chat/completions".freeze
|
|
1148
|
+
DEFAULT_TIMEOUT = 120_000
|
|
1149
|
+
|
|
1150
|
+
def initialize(api_key: nil, app_url: nil, title: nil, timeout: DEFAULT_TIMEOUT)
|
|
1151
|
+
@api_key = api_key
|
|
1152
|
+
raise ProviderError, "OpenRouter API key not configured" if @api_key.nil? || @api_key.to_s.empty?
|
|
1153
|
+
|
|
1154
|
+
@app_url = app_url || "http://localhost"
|
|
1155
|
+
@title = title || "zuno-ruby"
|
|
1156
|
+
@timeout = timeout
|
|
1157
|
+
end
|
|
1158
|
+
|
|
1159
|
+
def model(model_id)
|
|
1160
|
+
ModelDescriptor.new(
|
|
1161
|
+
id: model_id,
|
|
1162
|
+
provider: :openrouter,
|
|
1163
|
+
provider_options: provider_options
|
|
1164
|
+
)
|
|
1165
|
+
end
|
|
1166
|
+
|
|
1167
|
+
def chat(payload)
|
|
1168
|
+
response = Typhoeus.post(
|
|
1169
|
+
CHAT_COMPLETIONS_URL,
|
|
1170
|
+
headers: headers,
|
|
1171
|
+
body: JSON.generate(payload),
|
|
1172
|
+
timeout: @timeout
|
|
1173
|
+
)
|
|
1174
|
+
|
|
1175
|
+
validate_response!(response)
|
|
1176
|
+
parsed = JSON.parse(response.body)
|
|
1177
|
+
raise ProviderError, "OpenRouter returned invalid JSON" unless parsed.is_a?(Hash)
|
|
1178
|
+
|
|
1179
|
+
parsed
|
|
1180
|
+
rescue JSON::ParserError => e
|
|
1181
|
+
raise ProviderError, "Failed to parse OpenRouter response: #{e.message}"
|
|
1182
|
+
end
|
|
1183
|
+
|
|
1184
|
+
def stream(payload)
|
|
1185
|
+
raise ArgumentError, "stream requires a block callback" unless block_given?
|
|
1186
|
+
|
|
1187
|
+
request = Typhoeus::Request.new(
|
|
1188
|
+
CHAT_COMPLETIONS_URL,
|
|
1189
|
+
method: :post,
|
|
1190
|
+
headers: headers,
|
|
1191
|
+
body: JSON.generate(payload),
|
|
1192
|
+
timeout: @timeout
|
|
1193
|
+
)
|
|
1194
|
+
|
|
1195
|
+
parser = SseParser.new { |data| yield(data) }
|
|
1196
|
+
request.on_body do |chunk|
|
|
1197
|
+
parser.push(chunk)
|
|
1198
|
+
nil
|
|
1199
|
+
end
|
|
1200
|
+
|
|
1201
|
+
request.run
|
|
1202
|
+
validate_response!(request.response)
|
|
1203
|
+
parser.flush
|
|
1204
|
+
end
|
|
1205
|
+
|
|
1206
|
+
private
|
|
1207
|
+
|
|
1208
|
+
def provider_options
|
|
1209
|
+
{
|
|
1210
|
+
api_key: @api_key,
|
|
1211
|
+
app_url: @app_url,
|
|
1212
|
+
title: @title,
|
|
1213
|
+
timeout: @timeout
|
|
1214
|
+
}
|
|
1215
|
+
end
|
|
1216
|
+
|
|
1217
|
+
def headers
|
|
1218
|
+
{
|
|
1219
|
+
"Authorization" => "Bearer #{@api_key}",
|
|
1220
|
+
"Content-Type" => "application/json",
|
|
1221
|
+
"HTTP-Referer" => @app_url,
|
|
1222
|
+
"X-Title" => @title
|
|
1223
|
+
}
|
|
1224
|
+
end
|
|
1225
|
+
|
|
1226
|
+
def validate_response!(response)
|
|
1227
|
+
raise ProviderError, "No response returned from OpenRouter" if response.nil?
|
|
1228
|
+
raise ProviderError, "OpenRouter request timed out" if response.timed_out?
|
|
1229
|
+
|
|
1230
|
+
status = response.code.to_i
|
|
1231
|
+
body = response.body.to_s
|
|
1232
|
+
message = body.length > 300 ? "#{body[0, 300]}..." : body
|
|
1233
|
+
|
|
1234
|
+
return if status >= 200 && status < 300
|
|
1235
|
+
|
|
1236
|
+
if status.positive?
|
|
1237
|
+
raise ProviderError, "OpenRouter responded with HTTP #{status}: #{message}"
|
|
1238
|
+
end
|
|
1239
|
+
|
|
1240
|
+
suffix = message.empty? ? "" : ": #{message}"
|
|
1241
|
+
raise ProviderError, "OpenRouter request failed: #{response.return_code}#{suffix}"
|
|
1242
|
+
end
|
|
1243
|
+
|
|
1244
|
+
end
|
|
1245
|
+
|
|
1246
|
+
class Replicate
|
|
1247
|
+
API_BASE_URL = "https://api.replicate.com/v1".freeze
|
|
1248
|
+
DEFAULT_TIMEOUT = 120_000
|
|
1249
|
+
|
|
1250
|
+
def initialize(api_key: nil, timeout: DEFAULT_TIMEOUT)
|
|
1251
|
+
@api_key = api_key
|
|
1252
|
+
raise ProviderError, "Replicate API key not configured" if @api_key.nil? || @api_key.to_s.empty?
|
|
1253
|
+
|
|
1254
|
+
@timeout = timeout
|
|
1255
|
+
end
|
|
1256
|
+
|
|
1257
|
+
def model(model_id)
|
|
1258
|
+
model_descriptor(model_id: model_id, target: :model)
|
|
1259
|
+
end
|
|
1260
|
+
|
|
1261
|
+
def version(version_id)
|
|
1262
|
+
model_descriptor(model_id: version_id, target: :version)
|
|
1263
|
+
end
|
|
1264
|
+
|
|
1265
|
+
def deployment(deployment_id)
|
|
1266
|
+
model_descriptor(model_id: deployment_id, target: :deployment)
|
|
1267
|
+
end
|
|
1268
|
+
|
|
1269
|
+
def create_prediction(reference:, input:)
|
|
1270
|
+
path, payload = build_create_request(reference: reference, input: input)
|
|
1271
|
+
|
|
1272
|
+
response = Typhoeus.post(
|
|
1273
|
+
"#{API_BASE_URL}#{path}",
|
|
1274
|
+
headers: headers.merge("Prefer" => "wait=#{REPLICATE_PREFER_WAIT_SECONDS}"),
|
|
1275
|
+
body: JSON.generate(payload),
|
|
1276
|
+
timeout: @timeout
|
|
1277
|
+
)
|
|
1278
|
+
parse_response(response)
|
|
1279
|
+
end
|
|
1280
|
+
|
|
1281
|
+
def get_prediction(prediction:)
|
|
1282
|
+
url = prediction.dig("urls", "get")
|
|
1283
|
+
|
|
1284
|
+
if url.nil? || url.to_s.strip.empty?
|
|
1285
|
+
prediction_id = prediction["id"].to_s
|
|
1286
|
+
raise ProviderError, "Replicate prediction id is missing" if prediction_id.empty?
|
|
1287
|
+
|
|
1288
|
+
url = "#{API_BASE_URL}/predictions/#{CGI.escape(prediction_id)}"
|
|
1289
|
+
end
|
|
1290
|
+
|
|
1291
|
+
response = Typhoeus.get(
|
|
1292
|
+
url,
|
|
1293
|
+
headers: headers,
|
|
1294
|
+
timeout: @timeout
|
|
1295
|
+
)
|
|
1296
|
+
parse_response(response)
|
|
1297
|
+
end
|
|
1298
|
+
|
|
1299
|
+
private
|
|
1300
|
+
|
|
1301
|
+
def model_descriptor(model_id:, target:)
|
|
1302
|
+
ModelDescriptor.new(
|
|
1303
|
+
id: model_id,
|
|
1304
|
+
provider: :replicate,
|
|
1305
|
+
provider_options: provider_options(target: target)
|
|
1306
|
+
)
|
|
1307
|
+
end
|
|
1308
|
+
|
|
1309
|
+
def provider_options(target:)
|
|
1310
|
+
{
|
|
1311
|
+
api_key: @api_key,
|
|
1312
|
+
timeout: @timeout,
|
|
1313
|
+
replicate_target: target
|
|
1314
|
+
}
|
|
1315
|
+
end
|
|
1316
|
+
|
|
1317
|
+
def build_create_request(reference:, input:)
|
|
1318
|
+
type = reference[:type].to_sym
|
|
1319
|
+
id = reference[:id].to_s
|
|
1320
|
+
|
|
1321
|
+
case type
|
|
1322
|
+
when :version
|
|
1323
|
+
["/predictions", { "version" => id, "input" => input }]
|
|
1324
|
+
when :model
|
|
1325
|
+
["/models/#{escape_owner_and_name(id)}/predictions", { "input" => input }]
|
|
1326
|
+
when :deployment
|
|
1327
|
+
["/deployments/#{escape_owner_and_name(id)}/predictions", { "input" => input }]
|
|
1328
|
+
else
|
|
1329
|
+
raise ProviderError, "Unsupported Replicate reference type: #{type}"
|
|
1330
|
+
end
|
|
1331
|
+
end
|
|
1332
|
+
|
|
1333
|
+
def escape_owner_and_name(value)
|
|
1334
|
+
owner, name = value.split("/", 2)
|
|
1335
|
+
"#{CGI.escape(owner.to_s)}/#{CGI.escape(name.to_s)}"
|
|
1336
|
+
end
|
|
1337
|
+
|
|
1338
|
+
def parse_response(response)
|
|
1339
|
+
validate_response!(response)
|
|
1340
|
+
parsed = JSON.parse(response.body)
|
|
1341
|
+
raise ProviderError, "Replicate returned invalid JSON" unless parsed.is_a?(Hash)
|
|
1342
|
+
|
|
1343
|
+
parsed
|
|
1344
|
+
rescue JSON::ParserError => e
|
|
1345
|
+
raise ProviderError, "Failed to parse Replicate response: #{e.message}"
|
|
1346
|
+
end
|
|
1347
|
+
|
|
1348
|
+
def headers
|
|
1349
|
+
{
|
|
1350
|
+
"Authorization" => "Bearer #{@api_key}",
|
|
1351
|
+
"Content-Type" => "application/json"
|
|
1352
|
+
}
|
|
1353
|
+
end
|
|
1354
|
+
|
|
1355
|
+
def validate_response!(response)
|
|
1356
|
+
raise ProviderError, "No response returned from Replicate" if response.nil?
|
|
1357
|
+
raise ProviderError, "Replicate request timed out" if response.timed_out?
|
|
1358
|
+
|
|
1359
|
+
status = response.code.to_i
|
|
1360
|
+
body = response.body.to_s
|
|
1361
|
+
message = body.length > 300 ? "#{body[0, 300]}..." : body
|
|
1362
|
+
|
|
1363
|
+
return if status >= 200 && status < 300
|
|
1364
|
+
|
|
1365
|
+
if status.positive?
|
|
1366
|
+
raise ProviderError, "Replicate responded with HTTP #{status}: #{message}"
|
|
1367
|
+
end
|
|
1368
|
+
|
|
1369
|
+
suffix = message.empty? ? "" : ": #{message}"
|
|
1370
|
+
raise ProviderError, "Replicate request failed: #{response.return_code}#{suffix}"
|
|
1371
|
+
end
|
|
1372
|
+
end
|
|
1373
|
+
end
|
|
1374
|
+
|
|
1375
|
+
class SseParser
|
|
1376
|
+
def initialize(&on_data)
|
|
1377
|
+
@on_data = on_data
|
|
1378
|
+
@buffer = +""
|
|
1379
|
+
end
|
|
1380
|
+
|
|
1381
|
+
def push(chunk)
|
|
1382
|
+
return if chunk.nil? || chunk.empty?
|
|
1383
|
+
|
|
1384
|
+
@buffer << chunk
|
|
1385
|
+
parse_buffer
|
|
1386
|
+
end
|
|
1387
|
+
|
|
1388
|
+
def flush
|
|
1389
|
+
parse_buffer(final: true)
|
|
1390
|
+
end
|
|
1391
|
+
|
|
1392
|
+
private
|
|
1393
|
+
|
|
1394
|
+
def parse_buffer(final: false)
|
|
1395
|
+
delimiter = "\n\n"
|
|
1396
|
+
|
|
1397
|
+
loop do
|
|
1398
|
+
index = @buffer.index(delimiter)
|
|
1399
|
+
break if index.nil?
|
|
1400
|
+
|
|
1401
|
+
raw_event = @buffer.slice!(0, index + delimiter.length)
|
|
1402
|
+
emit_event(raw_event)
|
|
1403
|
+
end
|
|
1404
|
+
|
|
1405
|
+
emit_event(@buffer.dup) if final && !@buffer.empty?
|
|
1406
|
+
@buffer.clear if final
|
|
1407
|
+
end
|
|
1408
|
+
|
|
1409
|
+
def emit_event(raw_event)
|
|
1410
|
+
lines = raw_event.split("\n")
|
|
1411
|
+
payload_lines = lines.filter_map do |line|
|
|
1412
|
+
stripped = line.strip
|
|
1413
|
+
next if stripped.empty?
|
|
1414
|
+
next unless stripped.start_with?("data:")
|
|
1415
|
+
|
|
1416
|
+
stripped.sub(/^data:\s?/, "")
|
|
1417
|
+
end
|
|
1418
|
+
|
|
1419
|
+
return if payload_lines.empty?
|
|
14
1420
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
yield(configuration)
|
|
1421
|
+
@on_data.call(payload_lines.join("\n"))
|
|
1422
|
+
end
|
|
18
1423
|
end
|
|
19
1424
|
end
|