llm.rb 8.1.0 → 10.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +196 -6
  3. data/README.md +233 -518
  4. data/data/anthropic.json +278 -258
  5. data/data/bedrock.json +1288 -1561
  6. data/data/deepseek.json +38 -38
  7. data/data/google.json +656 -579
  8. data/data/openai.json +860 -818
  9. data/data/xai.json +243 -552
  10. data/data/zai.json +168 -168
  11. data/lib/llm/active_record/acts_as_agent.rb +5 -0
  12. data/lib/llm/active_record/acts_as_llm.rb +7 -8
  13. data/lib/llm/active_record.rb +1 -6
  14. data/lib/llm/agent.rb +121 -82
  15. data/lib/llm/context.rb +79 -74
  16. data/lib/llm/contract/completion.rb +45 -0
  17. data/lib/llm/cost.rb +81 -4
  18. data/lib/llm/error.rb +1 -1
  19. data/lib/llm/function/array.rb +8 -5
  20. data/lib/llm/function/call_group.rb +39 -0
  21. data/lib/llm/function/call_task.rb +46 -0
  22. data/lib/llm/function/fork/task.rb +6 -0
  23. data/lib/llm/function/ractor/task.rb +6 -0
  24. data/lib/llm/function/task.rb +10 -0
  25. data/lib/llm/function.rb +28 -1
  26. data/lib/llm/mcp/transport/http.rb +26 -46
  27. data/lib/llm/mcp/transport/stdio.rb +0 -8
  28. data/lib/llm/mcp.rb +6 -23
  29. data/lib/llm/provider.rb +30 -20
  30. data/lib/llm/providers/anthropic/error_handler.rb +6 -7
  31. data/lib/llm/providers/anthropic/files.rb +2 -2
  32. data/lib/llm/providers/anthropic/response_adapter/completion.rb +30 -0
  33. data/lib/llm/providers/anthropic/stream_parser.rb +2 -2
  34. data/lib/llm/providers/anthropic.rb +1 -1
  35. data/lib/llm/providers/bedrock/error_handler.rb +8 -9
  36. data/lib/llm/providers/bedrock/models.rb +13 -13
  37. data/lib/llm/providers/bedrock/response_adapter/completion.rb +30 -0
  38. data/lib/llm/providers/bedrock/stream_parser.rb +2 -2
  39. data/lib/llm/providers/bedrock.rb +1 -1
  40. data/lib/llm/providers/google/error_handler.rb +6 -7
  41. data/lib/llm/providers/google/files.rb +2 -4
  42. data/lib/llm/providers/google/images.rb +1 -1
  43. data/lib/llm/providers/google/models.rb +0 -2
  44. data/lib/llm/providers/google/response_adapter/completion.rb +30 -0
  45. data/lib/llm/providers/google/stream_parser.rb +2 -2
  46. data/lib/llm/providers/google.rb +1 -1
  47. data/lib/llm/providers/ollama/error_handler.rb +6 -7
  48. data/lib/llm/providers/ollama/models.rb +0 -2
  49. data/lib/llm/providers/ollama/response_adapter/completion.rb +30 -0
  50. data/lib/llm/providers/ollama.rb +1 -1
  51. data/lib/llm/providers/openai/audio.rb +3 -3
  52. data/lib/llm/providers/openai/error_handler.rb +6 -7
  53. data/lib/llm/providers/openai/files.rb +2 -2
  54. data/lib/llm/providers/openai/images.rb +3 -3
  55. data/lib/llm/providers/openai/models.rb +1 -1
  56. data/lib/llm/providers/openai/response_adapter/completion.rb +42 -0
  57. data/lib/llm/providers/openai/response_adapter/responds.rb +39 -0
  58. data/lib/llm/providers/openai/responses/stream_parser.rb +2 -2
  59. data/lib/llm/providers/openai/responses.rb +2 -2
  60. data/lib/llm/providers/openai/stream_parser.rb +2 -2
  61. data/lib/llm/providers/openai/vector_stores.rb +1 -1
  62. data/lib/llm/providers/openai.rb +1 -1
  63. data/lib/llm/response.rb +10 -8
  64. data/lib/llm/schema.rb +11 -0
  65. data/lib/llm/sequel/agent.rb +5 -0
  66. data/lib/llm/sequel/plugin.rb +8 -14
  67. data/lib/llm/stream/queue.rb +15 -42
  68. data/lib/llm/stream.rb +15 -40
  69. data/lib/llm/tool/param.rb +1 -8
  70. data/lib/llm/transport/execution.rb +67 -0
  71. data/lib/llm/transport/http.rb +134 -0
  72. data/lib/llm/transport/persistent_http.rb +152 -0
  73. data/lib/llm/transport/response/http.rb +113 -0
  74. data/lib/llm/transport/response.rb +112 -0
  75. data/lib/llm/{provider/transport/http → transport}/stream_decoder.rb +8 -4
  76. data/lib/llm/transport.rb +139 -0
  77. data/lib/llm/usage.rb +14 -5
  78. data/lib/llm/utils.rb +24 -14
  79. data/lib/llm/version.rb +1 -1
  80. data/lib/llm.rb +3 -12
  81. data/llm.gemspec +2 -16
  82. metadata +13 -20
  83. data/lib/llm/bot.rb +0 -3
  84. data/lib/llm/provider/transport/http/execution.rb +0 -115
  85. data/lib/llm/provider/transport/http/interruptible.rb +0 -114
  86. data/lib/llm/provider/transport/http.rb +0 -145
@@ -42,6 +42,45 @@ module LLM::OpenAI::ResponseAdapter
42
42
  &.reasoning_tokens || 0
43
43
  end
44
44
 
45
+ ##
46
+ # (see LLM::Contract::Completion#input_audio_tokens)
47
+ def input_audio_tokens
48
+ body
49
+ .usage
50
+ &.input_tokens_details
51
+ &.audio_tokens || 0
52
+ end
53
+
54
+ ##
55
+ # (see LLM::Contract::Completion#output_audio_tokens)
56
+ def output_audio_tokens
57
+ body
58
+ .usage
59
+ &.output_tokens_details
60
+ &.audio_tokens || 0
61
+ end
62
+
63
+ ##
64
+ # (see LLM::Contract::Completion#input_image_tokens)
65
+ def input_image_tokens
66
+ super
67
+ end
68
+
69
+ ##
70
+ # (see LLM::Contract::Completion#cache_read_tokens)
71
+ def cache_read_tokens
72
+ body
73
+ .usage
74
+ &.input_tokens_details
75
+ &.cached_tokens || 0
76
+ end
77
+
78
+ ##
79
+ # (see LLM::Contract::Completion#cache_write_tokens)
80
+ def cache_write_tokens
81
+ 0
82
+ end
83
+
45
84
  ##
46
85
  # (see LLM::Contract::Completion#total_tokens)
47
86
  def total_tokens
@@ -269,14 +269,14 @@ class LLM::OpenAI
269
269
  # @group Resolvers
270
270
 
271
271
  def resolve_tool(tool, arguments)
272
- registered = @stream.find_tool(tool["name"])
272
+ registered = @stream.__find__(tool["name"])
273
273
  fn = (registered || LLM::Function.new(tool["name"])).dup.tap do |fn|
274
274
  fn.id = tool["call_id"]
275
275
  fn.arguments = arguments
276
276
  fn.tracer = @stream.extra[:tracer]
277
277
  fn.model = @stream.extra[:model]
278
278
  end
279
- [fn, (registered ? nil : @stream.tool_not_found(fn))]
279
+ [fn, (registered ? nil : fn.unavailable)]
280
280
  end
281
281
 
282
282
  def parse_arguments(arguments)
@@ -44,7 +44,7 @@ class LLM::OpenAI
44
44
  messages = build_complete_messages(prompt, params, role)
45
45
  @provider.tracer.set_request_metadata(user_input: extract_user_input(messages, fallback: prompt))
46
46
  body = LLM.json.dump({input: [adapt(messages, mode: :response)].flatten}.merge!(params))
47
- set_body_stream(req, StringIO.new(body))
47
+ transport.set_body_stream(req, StringIO.new(body))
48
48
  res, span, tracer = execute(request: req, stream:, stream_parser:, operation: "chat", model: params[:model])
49
49
  res = ResponseAdapter.adapt(res, type: :responds)
50
50
  .extend(Module.new { define_method(:__tools__) { tools } })
@@ -85,7 +85,7 @@ class LLM::OpenAI
85
85
 
86
86
  private
87
87
 
88
- [:path, :headers, :execute, :set_body_stream, :resolve_tools].each do |m|
88
+ [:path, :headers, :execute, :transport, :resolve_tools].each do |m|
89
89
  define_method(m) { |*args, **kwargs, &b| @provider.send(m, *args, **kwargs, &b) }
90
90
  end
91
91
 
@@ -185,14 +185,14 @@ class LLM::OpenAI
185
185
  end
186
186
 
187
187
  def resolve_tool(tool, function, arguments)
188
- registered = @stream.find_tool(function["name"])
188
+ registered = @stream.__find__(function["name"])
189
189
  fn = (registered || LLM::Function.new(function["name"])).dup.tap do |fn|
190
190
  fn.id = tool["id"]
191
191
  fn.arguments = arguments
192
192
  fn.tracer = @stream.extra[:tracer]
193
193
  fn.model = @stream.extra[:model]
194
194
  end
195
- [fn, (registered ? nil : @stream.tool_not_found(fn))]
195
+ [fn, (registered ? nil : fn.unavailable)]
196
196
  end
197
197
 
198
198
  def parse_arguments(arguments)
@@ -259,7 +259,7 @@ class LLM::OpenAI
259
259
 
260
260
  private
261
261
 
262
- [:path, :headers, :execute, :set_body_stream].each do |m|
262
+ [:path, :headers, :execute].each do |m|
263
263
  define_method(m) { |*args, **kwargs, &b| @provider.send(m, *args, **kwargs, &b) }
264
264
  end
265
265
  end
@@ -223,7 +223,7 @@ module LLM
223
223
  messages = build_complete_messages(prompt, params, role)
224
224
  body = LLM.json.dump({messages: adapt(messages, mode: :complete).flatten}.merge!(params))
225
225
  req = Net::HTTP::Post.new(completions_path, headers)
226
- set_body_stream(req, StringIO.new(body))
226
+ transport.set_body_stream(req, StringIO.new(body))
227
227
  [req, messages]
228
228
  end
229
229
 
data/lib/llm/response.rb CHANGED
@@ -10,25 +10,27 @@ module LLM
10
10
  # handling can share one common surface without flattening away
11
11
  # specialized behavior.
12
12
  #
13
- # The normalized response still keeps the original
14
- # {Net::HTTPResponse Net::HTTPResponse} available through {#res}
15
- # when callers need direct access to raw HTTP details such as
16
- # headers, status codes, or unadapted bodies.
13
+ # The normalized response keeps the transport response available
14
+ # through {#res}. When the default net/http transport is in use,
15
+ # {LLM::Transport::Response::HTTP
16
+ # LLM::Transport::Response::HTTP} keeps the
17
+ # original {Net::HTTPResponse Net::HTTPResponse} available through
18
+ # its own {LLM::Transport::Response::HTTP#res #res}.
17
19
  class Response
18
20
  require "json"
19
21
 
20
22
  ##
21
23
  # Returns the HTTP response
22
- # @return [Net::HTTPResponse]
24
+ # @return [LLM::Transport::Response]
23
25
  attr_reader :res
24
26
 
25
27
  ##
26
- # @param [Net::HTTPResponse] res
28
+ # @param [LLM::Transport::Response] res
27
29
  # HTTP response
28
30
  # @return [LLM::Response]
29
31
  # Returns an instance of LLM::Response
30
32
  def initialize(res)
31
- @res = res
33
+ @res = LLM::Transport::Response.from(res)
32
34
  end
33
35
 
34
36
  ##
@@ -51,7 +53,7 @@ module LLM
51
53
  # Returns true if the response is successful
52
54
  # @return [Boolean]
53
55
  def ok?
54
- Net::HTTPSuccess === @res
56
+ @res.success?
55
57
  end
56
58
 
57
59
  ##
data/lib/llm/schema.rb CHANGED
@@ -56,6 +56,8 @@ class LLM::Schema
56
56
  def resolve(schema, type)
57
57
  if LLM::Schema::Leaf === type
58
58
  type
59
+ elsif ::Array === type
60
+ resolve_array(schema, type)
59
61
  elsif Class === type && type.respond_to?(:object)
60
62
  type.object
61
63
  else
@@ -63,6 +65,15 @@ class LLM::Schema
63
65
  schema.public_send(target)
64
66
  end
65
67
  end
68
+
69
+ def resolve_array(schema, values)
70
+ item = if values.size == 1
71
+ resolve(schema, values[0])
72
+ else
73
+ schema.any_of(*values.map { resolve(schema, _1) })
74
+ end
75
+ schema.array(item)
76
+ end
66
77
  end
67
78
 
68
79
  ##
@@ -58,6 +58,11 @@ module LLM::Sequel
58
58
  agent.concurrency(concurrency)
59
59
  end
60
60
 
61
+ def confirm(*tool_names, &block)
62
+ return agent.confirm if tool_names.empty? && !block
63
+ agent.confirm(*tool_names, &block)
64
+ end
65
+
61
66
  def tracer(tracer = nil, &block)
62
67
  return agent.tracer if tracer.nil? && !block
63
68
  agent.tracer(tracer, &block)
@@ -30,12 +30,7 @@ module LLM::Sequel
30
30
  # Resolves a single configured option against a model instance.
31
31
  # @return [Object]
32
32
  def self.resolve_option(obj, option)
33
- case option
34
- when Proc then obj.instance_exec(&option)
35
- when Symbol then obj.send(option)
36
- when Hash then option.dup
37
- else option
38
- end
33
+ LLM::Utils.resolve_option(obj, option)
39
34
  end
40
35
 
41
36
  ##
@@ -184,14 +179,6 @@ module LLM::Sequel
184
179
  ctx.wait(...)
185
180
  end
186
181
 
187
- ##
188
- # Calls into the stored context.
189
- # @see LLM::Context#call
190
- # @return [Object]
191
- def call(...)
192
- ctx.call(...)
193
- end
194
-
195
182
  ##
196
183
  # @see LLM::Context#mode
197
184
  # @return [Symbol]
@@ -222,6 +209,13 @@ module LLM::Sequel
222
209
  ctx.functions
223
210
  end
224
211
 
212
+ ##
213
+ # @see LLM::Context#functions?
214
+ # @return [Boolean]
215
+ def functions?
216
+ ctx.functions?
217
+ end
218
+
225
219
  ##
226
220
  # @see LLM::Context#returns
227
221
  # @return [Array<LLM::Function::Return>]
@@ -4,7 +4,7 @@ class LLM::Stream
4
4
  ##
5
5
  # A small queue for collecting streamed tool work. Values can be immediate
6
6
  # {LLM::Function::Return} objects or concurrent handles returned by
7
- # {LLM::Function#spawn}. Calling {#wait(strategy)} resolves queued work and
7
+ # {LLM::Function#spawn}. Calling {#wait} resolves queued work and
8
8
  # returns an array of {LLM::Function::Return} values.
9
9
  class Queue
10
10
  ##
@@ -41,56 +41,29 @@ class LLM::Stream
41
41
 
42
42
  ##
43
43
  # Waits for queued work to finish and returns function results.
44
- # @param [Symbol, Array<Symbol>] strategy
45
- # Controls concurrency strategy, or lists the possible concurrency strategies
46
- # to wait on:
47
- # - `:thread`: Use threads
48
- # - `:task`: Use async tasks (requires async gem)
49
- # - `:fiber`: Use scheduler-backed fibers (requires Fiber.scheduler)
50
- # - `:ractor`: Use Ruby ractors (class-based tools only; MCP tools are not supported)
51
- # - `[:thread, :ractor]`: Wait for any queued thread or ractor work, in the
52
- # given order. This is useful when different tools were spawned with
53
- # different concurrency strategies.
44
+ #
45
+ # Queued work is waited according to the actual task types that were
46
+ # enqueued, so callers do not need to provide a strategy here.
47
+ #
54
48
  # @return [Array<LLM::Function::Return>]
55
- def wait(strategy)
49
+ def wait
56
50
  returns, tasks = @items.shift(@items.length).partition { LLM::Function::Return === _1 }
57
- results = wait_tasks(tasks, strategy)
51
+ results = wait_tasks(tasks)
58
52
  returns.concat fire_hooks(tasks, results)
59
53
  end
60
54
  alias_method :value, :wait
61
55
 
62
56
  private
63
57
 
64
- def wait_tasks(tasks, strategy)
65
- strategies = Array(strategy)
66
- return wait_group(tasks, strategies.first) unless strategies.length > 1
67
- grouped = strategies.to_h { [_1, []] }
68
- tasks.each do |task|
69
- grouped[task_strategy(task)] << task
70
- end
71
- strategies.flat_map do |name|
72
- selected = grouped.fetch(name)
73
- selected.empty? ? [] : wait_group(selected, name)
74
- end
75
- end
76
-
77
- def wait_group(tasks, strategy)
78
- case strategy
79
- when :thread then LLM::Function::ThreadGroup.new(tasks).wait
80
- when :task then LLM::Function::TaskGroup.new(tasks).wait
81
- when :fiber then LLM::Function::FiberGroup.new(tasks).wait
82
- when :ractor then LLM::Function::Ractor::Group.new(tasks).wait
83
- else raise ArgumentError, "Unknown strategy: #{strategy.inspect}. Expected :thread, :task, :fiber, or :ractor"
84
- end
85
- end
86
-
87
- def task_strategy(task)
88
- case task.task
89
- when Thread then :thread
90
- when Fiber then :fiber
91
- when LLM::Function::Ractor::Task then :ractor
92
- else :task
58
+ def wait_tasks(tasks)
59
+ return [] if tasks.empty?
60
+ results = {}
61
+ grouped_tasks = tasks.group_by(&:group_class)
62
+ grouped_tasks.each do |group_class, group|
63
+ returns = group_class.new(group).wait
64
+ returns.each.with_index { results[group[_2]] = _1 }
93
65
  end
66
+ tasks.map { results[_1] }
94
67
  end
95
68
 
96
69
  def fire_hooks(tasks, results)
data/lib/llm/stream.rb CHANGED
@@ -9,8 +9,7 @@ module LLM
9
9
  # subclass that overrides the callbacks it needs. For basic streaming,
10
10
  # llm.rb also accepts any object that implements `#<<`. {#queue} provides
11
11
  # a small helper for collecting asynchronous tool work started from a
12
- # callback, and {#tool_not_found} returns an in-band tool error when a
13
- # streamed tool cannot be resolved.
12
+ # callback.
14
13
  #
15
14
  # @note The `on_*` callbacks run inline with the streaming parser. They
16
15
  # therefore block streaming progress and should generally return as
@@ -46,11 +45,11 @@ module LLM
46
45
 
47
46
  ##
48
47
  # Waits for queued tool work to finish and returns function results.
49
- # @param [Symbol] strategy
50
- # The concurrency strategy to use
48
+ # Any passed arguments are ignored because queued work is waited according
49
+ # to the actual task types already present in the queue.
51
50
  # @return [Array<LLM::Function::Return>]
52
- def wait(strategy)
53
- queue.wait(strategy)
51
+ def wait(*)
52
+ queue.wait
54
53
  end
55
54
 
56
55
  # @group Public callbacks
@@ -150,48 +149,24 @@ module LLM
150
149
 
151
150
  # @endgroup
152
151
 
153
- # @group Error handlers
154
-
155
- ##
156
- # Returns a function return describing a streamed tool that could not
157
- # be resolved.
158
- # @note This is mainly useful as a fallback from {#on_tool_call}. It
159
- # should be uncommon in normal use, since streamed tool callbacks only
160
- # run for tools already defined in the context.
161
- # @param [LLM::Function] tool
162
- # @return [LLM::Function::Return]
163
- def tool_not_found(tool)
164
- LLM::Function::Return.new(tool.id, tool.name, {
165
- error: true, type: LLM::NoSuchToolError.name, message: "tool not found"
166
- })
167
- end
168
-
169
- ##
170
- # Returns the tool definitions available for the current streamed request.
171
- # This prefers request-local tools attached to the stream and falls back
172
- # to the current context defaults when present.
173
- # @return [Array<LLM::Function, LLM::Tool>]
174
- def tools
175
- extra[:tools] || ctx&.params&.dig(:tools) || []
176
- end
152
+ # @group Finders
177
153
 
178
154
  ##
179
155
  # Resolves a streamed tool call against the current request tools first,
180
156
  # then falls back to the global function registry.
181
157
  # @param [String] name
182
158
  # @return [LLM::Function, nil]
183
- def find_tool(name)
184
- tool = tools.find do |candidate|
185
- candidate_name =
186
- if candidate.respond_to?(:function)
187
- candidate.function.name
188
- else
189
- candidate.name
190
- end
191
- candidate_name.to_s == name.to_s
159
+ def __find__(name)
160
+ tools = extra[:tools] || ctx&.params&.dig(:tools) || []
161
+ tool = tools.find do
162
+ candidate = _1.respond_to?(:function) ? _1.function.name : _1.name
163
+ candidate.to_s == name.to_s
192
164
  end
193
- tool&.then { _1.respond_to?(:function) ? _1.function : _1 } ||
165
+ if tool
166
+ tool.respond_to?(:function) ? tool.function : tool
167
+ else
194
168
  LLM::Function.find_by_name(name)
169
+ end
195
170
  end
196
171
 
197
172
  # @endgroup
@@ -62,14 +62,7 @@ class LLM::Tool
62
62
  extend self
63
63
 
64
64
  def resolve(schema, type)
65
- if LLM::Schema::Leaf === type
66
- type
67
- elsif Class === type && type.respond_to?(:object)
68
- type.object
69
- else
70
- target = type.name.split("::").last.downcase
71
- schema.public_send(target)
72
- end
65
+ LLM::Schema::Utils.resolve(schema, type)
73
66
  end
74
67
 
75
68
  def setup(leaf, description, options)
@@ -0,0 +1,67 @@
1
+ # frozen_string_literal: true
2
+
3
+ class LLM::Transport
4
+ ##
5
+ # Internal request execution methods for {LLM::Provider}.
6
+ #
7
+ # This module handles provider-side transport execution, response
8
+ # parsing, streaming, and request body setup.
9
+ #
10
+ # @api private
11
+ module Execution
12
+ private
13
+
14
+ ##
15
+ # Executes a HTTP request
16
+ # @param [Net::HTTPRequest] request
17
+ # The request to send
18
+ # @param [Proc] b
19
+ # A block to yield the response to (optional)
20
+ # @return [LLM::Transport::Response]
21
+ # The response from the server
22
+ # @raise [LLM::Error::Unauthorized]
23
+ # When authentication fails
24
+ # @raise [LLM::Error::RateLimit]
25
+ # When the rate limit is exceeded
26
+ # @raise [LLM::Error]
27
+ # When any other unsuccessful status code is returned
28
+ # @raise [SystemCallError]
29
+ # When there is a network error at the operating system level
30
+ # @return [LLM::Transport::Response]
31
+ def execute(request:, operation:, stream: nil, stream_parser: self.stream_parser, model: nil, inputs: nil, &b)
32
+ stream &&= LLM::Object.from(streamer: stream, parser: stream_parser, decoder: stream_decoder)
33
+ owner = transport.request_owner
34
+ tracer = self.tracer
35
+ span = tracer.on_request_start(operation:, model:, inputs:)
36
+ res = transport.request(request, owner:, stream:, &b)
37
+ res = LLM::Transport::Response.from(res)
38
+ [handle_response(res, tracer, span), span, tracer]
39
+ rescue *transport.interrupt_errors
40
+ raise LLM::Interrupt, "request interrupted" if transport.interrupted?(owner)
41
+ raise
42
+ end
43
+
44
+ ##
45
+ # Handles the response from a request
46
+ # @param [LLM::Transport::Response] res
47
+ # The response to handle
48
+ # @param [Object, nil] span
49
+ # The span
50
+ # @return [LLM::Transport::Response]
51
+ def handle_response(res, tracer, span)
52
+ res.ok? ? res.body = parse_response(res) : error_handler.new(tracer, span, res).raise_error!
53
+ res
54
+ end
55
+
56
+ ##
57
+ # Parse a HTTP response
58
+ # @param [LLM::Transport::Response] res
59
+ # @return [LLM::Object, String]
60
+ def parse_response(res)
61
+ case res["content-type"]
62
+ when %r{\Aapplication/json\s*} then LLM::Object.from(LLM.json.load(res.body))
63
+ else res.body
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,134 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "net/http"
4
+
5
+ class LLM::Transport
6
+ ##
7
+ # The {LLM::Transport::HTTP LLM::Transport::HTTP} transport is the
8
+ # built-in adapter for Ruby's {Net::HTTP Net::HTTP}. It manages
9
+ # transient HTTP connections, tracks active requests by owner, and
10
+ # interrupts in-flight requests when needed.
11
+ #
12
+ # @api private
13
+ class HTTP < self
14
+ INTERRUPT_ERRORS = [::IOError, ::EOFError, Errno::EBADF].freeze
15
+ Request = Struct.new(:client, keyword_init: true)
16
+
17
+ ##
18
+ # @param [String] host
19
+ # @param [Integer] port
20
+ # @param [Integer] timeout
21
+ # @param [Boolean] ssl
22
+ # @return [LLM::Transport::HTTP]
23
+ def initialize(host:, port:, timeout:, ssl:)
24
+ @host = host
25
+ @port = port
26
+ @timeout = timeout
27
+ @ssl = ssl
28
+ @base_uri = URI("#{ssl ? "https" : "http"}://#{host}:#{port}/")
29
+ @monitor = Monitor.new
30
+ end
31
+
32
+ ##
33
+ # Returns the current request owner.
34
+ # @return [Object]
35
+ def request_owner
36
+ return Fiber.current unless defined?(::Async)
37
+ Async::Task.current? ? Async::Task.current : Fiber.current
38
+ end
39
+
40
+ ##
41
+ # @return [Array<Class<Exception>>]
42
+ def interrupt_errors
43
+ [*INTERRUPT_ERRORS, *optional_interrupt_errors]
44
+ end
45
+
46
+ ##
47
+ # Interrupt an active request, if any.
48
+ # @param [Fiber] owner
49
+ # @return [nil]
50
+ def interrupt!(owner)
51
+ req = request_for(owner) or return
52
+ lock { (@interrupts ||= {})[owner] = true }
53
+ close_socket(req.client)
54
+ req.client.finish if req.client.active?
55
+ owner.stop if owner.respond_to?(:stop)
56
+ rescue *interrupt_errors
57
+ nil
58
+ end
59
+
60
+ ##
61
+ # Returns whether an execution owner was interrupted.
62
+ # @param [Fiber] owner
63
+ # @return [Boolean, nil]
64
+ def interrupted?(owner)
65
+ lock { @interrupts&.delete(owner) }
66
+ end
67
+
68
+ ##
69
+ # Performs a request on the current HTTP transport.
70
+ # @param [Net::HTTPRequest] request
71
+ # @param [Fiber] owner
72
+ # @param [LLM::Object, nil] stream
73
+ # @yieldparam [LLM::Transport::Response] response
74
+ # @return [Object]
75
+ def request(request, owner:, stream: nil, &b)
76
+ client = client()
77
+ set_request(Request.new(client:), owner)
78
+ perform_request(client, request, stream, &b)
79
+ ensure
80
+ clear_request(owner)
81
+ end
82
+
83
+ ##
84
+ # @return [String]
85
+ def inspect
86
+ "#<#{self.class.name}:0x#{object_id.to_s(16)}>"
87
+ end
88
+
89
+ private
90
+
91
+ attr_reader :host, :port, :timeout, :ssl, :base_uri
92
+
93
+ def client
94
+ client = Net::HTTP.new(host, port)
95
+ client.read_timeout = timeout
96
+ client.use_ssl = ssl
97
+ client
98
+ end
99
+
100
+ def close_socket(http)
101
+ socket = http&.instance_variable_get(:@socket) or return
102
+ socket = socket.io if socket.respond_to?(:io)
103
+ socket.close
104
+ rescue *interrupt_errors
105
+ nil
106
+ end
107
+
108
+ def request_for(owner)
109
+ lock do
110
+ @requests ||= {}
111
+ @requests[owner]
112
+ end
113
+ end
114
+
115
+ def set_request(req, owner)
116
+ lock do
117
+ @requests ||= {}
118
+ @requests[owner] = req
119
+ end
120
+ end
121
+
122
+ def clear_request(owner)
123
+ lock { @requests&.delete(owner) }
124
+ end
125
+
126
+ def lock(&)
127
+ @monitor.synchronize(&)
128
+ end
129
+
130
+ def optional_interrupt_errors
131
+ defined?(::Async::Stop) ? [Async::Stop] : []
132
+ end
133
+ end
134
+ end