embedding_util 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,153 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+ require "net/http"
5
+ require "uri"
6
+
7
+ require_relative "../provider"
8
+ require_relative "../result"
9
+
10
+ module EmbeddingUtil
11
+ module Providers
12
+ class Endpoint < Provider
13
+ NETWORK_ERRORS = [
14
+ Errno::ECONNREFUSED,
15
+ Errno::ECONNRESET,
16
+ Errno::EHOSTUNREACH,
17
+ Errno::ENETUNREACH,
18
+ EOFError,
19
+ IOError,
20
+ Net::OpenTimeout,
21
+ Net::ReadTimeout,
22
+ SocketError,
23
+ Timeout::Error
24
+ ].freeze
25
+
26
+ def supported?
27
+ !!(config.embedding_endpoint_url || config.reranker_endpoint_url)
28
+ end
29
+
30
+ def support
31
+ {
32
+ provider: provider_name,
33
+ supported: supported?,
34
+ embedding_endpoint: config.embedding_endpoint_url,
35
+ reranker_endpoint: config.reranker_endpoint_url
36
+ }
37
+ end
38
+
39
+ def embed(texts, profile: config.resolved_profile)
40
+ endpoint = require_endpoint(config.embedding_endpoint_url, "embedding")
41
+ response = post_json(endpoint, "/v1/embeddings", {
42
+ input: texts,
43
+ model: profile.embedding.fetch(:model)
44
+ })
45
+
46
+ data = Array(response.fetch("data"))
47
+ embeddings = data.sort_by { |item| item.fetch("index", data.index(item) || 0) }.map { |item| item.fetch("embedding") }
48
+ EmbeddingResult.new(
49
+ embedding: embeddings,
50
+ model: response["model"],
51
+ profile: profile.name,
52
+ provider: provider_name,
53
+ metadata: { usage: response["usage"] }.compact
54
+ )
55
+ end
56
+
57
+ def rerank(query, documents, profile: config.resolved_profile)
58
+ endpoint = require_endpoint(config.reranker_endpoint_url, "reranker")
59
+ response = begin
60
+ post_json(endpoint, "/v1/rerank", rerank_payload(query, documents, profile))
61
+ rescue EndpointNotFoundError => e
62
+ raise unless fallback_rerank_not_found?(e)
63
+
64
+ post_json(endpoint, "/rerank", rerank_payload(query, documents, profile))
65
+ end
66
+
67
+ RerankResult.new(
68
+ results: ranked_documents(response, documents),
69
+ model: response["model"],
70
+ profile: profile.name,
71
+ provider: provider_name,
72
+ metadata: { usage: response["usage"] }.compact
73
+ )
74
+ end
75
+
76
+ private
77
+
78
+ def rerank_payload(query, documents, profile)
79
+ {
80
+ query: query,
81
+ documents: documents,
82
+ model: profile.reranker.fetch(:model)
83
+ }
84
+ end
85
+
86
+ def ranked_documents(response, documents)
87
+ Array(response.fetch("results")).map do |item|
88
+ index = item.fetch("index")
89
+ RankedDocument.new(
90
+ index: index,
91
+ document: item["document"] || fetch_document_at(documents, index),
92
+ score: item.fetch("relevance_score") { item.fetch("score") },
93
+ metadata: item.reject { |key, _value| %w[index document relevance_score score].include?(key) }
94
+ )
95
+ end
96
+ end
97
+
98
+ def fetch_document_at(documents, index)
99
+ documents.fetch(index)
100
+ rescue IndexError
101
+ raise EndpointError, "server returned out-of-range document index #{index.inspect} (#{documents.size} documents sent)"
102
+ end
103
+
104
+ def require_endpoint(endpoint, capability)
105
+ raise UnsupportedProviderError, "no #{capability} endpoint configured" unless endpoint
106
+
107
+ endpoint
108
+ end
109
+
110
+ def post_json(endpoint, path, payload)
111
+ uri = endpoint_uri(endpoint, path)
112
+ request = Net::HTTP::Post.new(uri)
113
+ request["Content-Type"] = "application/json"
114
+ request.body = JSON.generate(payload)
115
+
116
+ response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https", read_timeout: config.timeout, open_timeout: config.timeout) do |http|
117
+ http.request(request)
118
+ end
119
+
120
+ raise EndpointNotFoundError.new(uri, path: path, body: response.body) if response.code.to_i == 404 && route_missing_response?(response.body)
121
+ raise EndpointError, "#{uri} returned #{response.code}: #{response.body}" unless response.is_a?(Net::HTTPSuccess)
122
+
123
+ JSON.parse(response.body)
124
+ rescue JSON::ParserError => e
125
+ raise EndpointError, "invalid JSON response from #{uri}: #{e.message}"
126
+ rescue URI::InvalidURIError => e
127
+ raise EndpointError, "invalid endpoint URL #{endpoint.inspect}: #{e.message}"
128
+ rescue *NETWORK_ERRORS => e
129
+ raise EndpointError, "could not reach #{uri}: #{e.message}"
130
+ end
131
+
132
+ def endpoint_uri(endpoint, path)
133
+ uri = URI(endpoint)
134
+ segments = [uri.path, path].map { |part| part.to_s.gsub(%r{\A/+|/+\z}, "") }.reject(&:empty?)
135
+ uri.path = "/#{segments.join('/')}"
136
+ uri
137
+ end
138
+
139
+ def route_missing_response?(body)
140
+ return true if body.to_s.strip.empty?
141
+
142
+ JSON.parse(body)
143
+ false
144
+ rescue JSON::ParserError
145
+ true
146
+ end
147
+
148
+ def fallback_rerank_not_found?(error)
149
+ error.path == "/v1/rerank"
150
+ end
151
+ end
152
+ end
153
+ end
@@ -0,0 +1,44 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "../provider"
4
+ require_relative "../server_manager"
5
+ require_relative "endpoint"
6
+
7
+ module EmbeddingUtil
8
+ module Providers
9
+ class SelfHosted < Provider
10
+ def supported?
11
+ ServerManager.supported?(config)
12
+ end
13
+
14
+ def support
15
+ {
16
+ provider: provider_name,
17
+ supported: supported?,
18
+ runtime: RuntimeCommand.resolve(config.runtime),
19
+ shutdown_idle: config.shutdown_idle,
20
+ state_dir: config.state_dir
21
+ }
22
+ end
23
+
24
+ def embed(texts, profile: config.resolved_profile)
25
+ endpoint = ServerManager.new(config: config).ensure_server(:embedding, profile: profile)
26
+ endpoint_provider(embedding_endpoint: endpoint).embed(texts, profile: profile)
27
+ end
28
+
29
+ def rerank(query, documents, profile: config.resolved_profile)
30
+ endpoint = ServerManager.new(config: config).ensure_server(:reranker, profile: profile)
31
+ endpoint_provider(reranker_endpoint: endpoint).rerank(query, documents, profile: profile)
32
+ end
33
+
34
+ private
35
+
36
+ def endpoint_provider(embedding_endpoint: nil, reranker_endpoint: nil)
37
+ endpoint_config = config.dup
38
+ endpoint_config.embedding_endpoint = embedding_endpoint
39
+ endpoint_config.reranker_endpoint = reranker_endpoint
40
+ Endpoint.new(config: endpoint_config)
41
+ end
42
+ end
43
+ end
44
+ end
@@ -0,0 +1,7 @@
1
+ # frozen_string_literal: true
2
+
3
+ module EmbeddingUtil
4
+ EmbeddingResult = Data.define(:embedding, :model, :profile, :provider, :metadata)
5
+ RankedDocument = Data.define(:index, :document, :score, :metadata)
6
+ RerankResult = Data.define(:results, :model, :profile, :provider, :metadata)
7
+ end
@@ -0,0 +1,84 @@
1
+ # frozen_string_literal: true
2
+
3
+ module EmbeddingUtil
4
+ class RuntimeCommand
5
+ attr_reader :runtime, :server_model, :host, :port
6
+
7
+ def initialize(runtime:, server_model:, host:, port:)
8
+ @runtime = self.class.normalize_runtime(runtime)
9
+ @server_model = server_model
10
+ @host = host
11
+ @port = port
12
+ end
13
+
14
+ def self.available?(runtime)
15
+ case normalize_runtime(runtime)
16
+ when :auto
17
+ available?(:ramalama) || available?(:llama_server)
18
+ when :ramalama
19
+ !!command_path("ramalama")
20
+ when :llama_server
21
+ !!command_path("llama-server")
22
+ else
23
+ false
24
+ end
25
+ end
26
+
27
+ def self.resolve(runtime)
28
+ requested = normalize_runtime(runtime)
29
+ return requested unless requested == :auto
30
+
31
+ return :ramalama if available?(:ramalama)
32
+ return :llama_server if available?(:llama_server)
33
+
34
+ :auto
35
+ end
36
+
37
+ def self.command_path(command)
38
+ ENV.fetch("PATH", "").split(File::PATH_SEPARATOR).map { |dir| File.join(dir, command) }.find { |path| File.executable?(path) && !File.directory?(path) }
39
+ end
40
+
41
+ def self.normalize_runtime(runtime)
42
+ runtime.to_s.tr("-", "_").to_sym
43
+ end
44
+
45
+ def argv
46
+ case runtime
47
+ when :ramalama then ramalama_argv
48
+ when :llama_server then llama_server_argv
49
+ else raise UnsupportedProviderError, "no supported local runtime found; install ramalama or llama-server"
50
+ end
51
+ end
52
+
53
+ def label
54
+ runtime == :llama_server ? "llama-server" : runtime.to_s
55
+ end
56
+
57
+ private
58
+
59
+ def ramalama_argv
60
+ [
61
+ "ramalama", "--runtime=llama.cpp", "serve",
62
+ "--host", host,
63
+ "--port", port.to_s,
64
+ "--runtime-args=#{server_model.settings.fetch(:server_flags).join(' ')}",
65
+ huggingface_model
66
+ ]
67
+ end
68
+
69
+ def llama_server_argv
70
+ [
71
+ "llama-server",
72
+ "--host", host,
73
+ "--port", port.to_s,
74
+ "-hf", server_model.settings.fetch(:repo),
75
+ "-hff", server_model.settings.fetch(:file),
76
+ *server_model.settings.fetch(:server_flags)
77
+ ]
78
+ end
79
+
80
+ def huggingface_model
81
+ "hf://#{server_model.settings.fetch(:repo)}/#{server_model.settings.fetch(:file)}"
82
+ end
83
+ end
84
+ end
@@ -0,0 +1,258 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "fileutils"
4
+ require "json"
5
+ require "net/http"
6
+ require "open3"
7
+ require "rbconfig"
8
+ require "socket"
9
+ require "time"
10
+ require "uri"
11
+
12
+ module EmbeddingUtil
13
+ class ServerManager
14
+ attr_reader :config
15
+
16
+ def initialize(config: EmbeddingUtil.configuration)
17
+ @config = config
18
+ end
19
+
20
+ def self.supported?(config = EmbeddingUtil.configuration)
21
+ RuntimeCommand.available?(config.runtime)
22
+ end
23
+
24
+ def ensure_server(capability, profile: config.resolved_profile)
25
+ server_model = ServerModel.for(capability, profile)
26
+ log_path = server_log_path(server_model)
27
+
28
+ with_lock(server_model) do
29
+ state = read_state(server_model)
30
+ log_path = start_background(server_model) unless healthy_state?(state) || running_server_state?(state)
31
+ end
32
+
33
+ wait_for_healthy(server_model, log_path: log_path)
34
+ end
35
+
36
+ def serve(model:, runtime: config.runtime, shutdown_idle: config.shutdown_idle, host: config.host, port: nil)
37
+ server_model = model.is_a?(ServerModel) ? model : ServerModel.parse(model)
38
+ resolved_runtime = RuntimeCommand.resolve(runtime)
39
+ selected_port = selected_port_for(server_model, host: host, port: port)
40
+ command = RuntimeCommand.new(runtime: resolved_runtime, server_model: server_model, host: host, port: selected_port)
41
+ last_output_at = Time.now
42
+
43
+ FileUtils.mkdir_p(config.state_dir)
44
+ puts "starting #{server_model.name} with #{command.label} on http://#{host}:#{selected_port}"
45
+ puts "shutdown idle: #{shutdown_idle}s" if shutdown_idle&.positive?
46
+
47
+ Open3.popen2e(*command.argv) do |_stdin, output, wait_thread|
48
+ write_state(server_model, pid: wait_thread.pid, url: "http://#{host}:#{selected_port}", runtime: command.label, port: selected_port)
49
+ watchdog = start_watchdog(wait_thread.pid, shutdown_idle) { last_output_at }
50
+
51
+ output.each_line do |line|
52
+ last_output_at = Time.now
53
+ print line
54
+ end
55
+
56
+ watchdog&.kill
57
+ delete_state(server_model)
58
+ wait_thread.value.exitstatus
59
+ end
60
+ end
61
+
62
+ private
63
+
64
+ def start_background(server_model)
65
+ FileUtils.mkdir_p(config.state_dir)
66
+ log_path = server_log_path(server_model)
67
+ selected_port = selected_port_for(server_model, host: config.host)
68
+ argv = [
69
+ RbConfig.ruby, executable_path, "serve",
70
+ "--model", server_model.name,
71
+ "--runtime", config.runtime.to_s,
72
+ "--host", config.host,
73
+ "--port", selected_port.to_s
74
+ ]
75
+ argv.push("--shutdown-idle", config.shutdown_idle.to_s) unless config.shutdown_idle.nil?
76
+ warn "starting #{server_model.name} in background: #{argv.join(' ')}" if config.verbose
77
+ warn "#{server_model.name} log: #{log_path}" if config.verbose
78
+ pid = Process.spawn(*argv, out: [log_path, "a"], err: %i[child out], pgroup: true)
79
+ write_state(server_model, pid: pid, url: "http://#{config.host}:#{selected_port}", runtime: "starting", port: selected_port)
80
+ Process.detach(pid)
81
+ log_path
82
+ end
83
+
84
+ def server_log_path(server_model)
85
+ File.join(config.state_dir, "#{server_model.name}.log")
86
+ end
87
+
88
+ def executable_path
89
+ local_path = File.expand_path("../../exe/embedding_util", __dir__)
90
+ return local_path if File.exist?(local_path)
91
+
92
+ Gem.bin_path("embedding_util", "embedding_util")
93
+ end
94
+
95
+ def selected_port_for(server_model, host:, port: nil)
96
+ return required_port(host, port) if port
97
+
98
+ available_port(host, server_model.default_port(config))
99
+ end
100
+
101
+ def required_port(host, port)
102
+ return port if port_available?(host, port)
103
+
104
+ raise UnsupportedProviderError, "port #{host}:#{port} is already in use"
105
+ end
106
+
107
+ def available_port(host, preferred_port)
108
+ (preferred_port...(preferred_port + 100)).find { |candidate| port_available?(host, candidate) } || raise(
109
+ UnsupportedProviderError,
110
+ "no free port found for #{host} starting at #{preferred_port}"
111
+ )
112
+ end
113
+
114
+ def port_available?(host, port)
115
+ # Advisory only: the child runtime performs the real bind after this process releases the port.
116
+ # with_lock serializes callers within a single process, but a cross-process race window
117
+ # still exists between the probe socket closing here and the child process binding.
118
+ server = TCPServer.new(host, port)
119
+ true
120
+ rescue Errno::EADDRINUSE, Errno::EACCES, SocketError
121
+ false
122
+ ensure
123
+ server&.close
124
+ end
125
+
126
+ def wait_for_healthy(server_model, log_path: nil)
127
+ deadline = Time.now + config.startup_timeout
128
+ loop do
129
+ state = read_state(server_model)
130
+ return state.fetch("url") if healthy_state?(state)
131
+ raise UnsupportedProviderError, process_exited_message(server_model, log_path) if tracked_process_exited?(state)
132
+ raise UnsupportedProviderError, startup_timeout_message(server_model, log_path) if Time.now >= deadline
133
+
134
+ sleep 0.25
135
+ end
136
+ end
137
+
138
+ def start_watchdog(pid, shutdown_idle)
139
+ return unless shutdown_idle&.positive?
140
+
141
+ Thread.new do
142
+ loop do
143
+ sleep [shutdown_idle / 5.0, 1].max
144
+ next if Time.now - yield < shutdown_idle
145
+
146
+ terminate_idle_process(pid)
147
+ rescue Errno::ESRCH
148
+ break
149
+ end
150
+ end
151
+ end
152
+
153
+ def terminate_idle_process(pid)
154
+ Process.kill("TERM", pid)
155
+ sleep 5
156
+ Process.kill("KILL", pid) if process_running?(pid)
157
+ end
158
+
159
+ def startup_timeout_message(server_model, log_path)
160
+ message = "timed out after #{config.startup_timeout}s waiting for #{server_model.name} to become healthy"
161
+ return message unless log_path
162
+
163
+ lines = log_tail(log_path)
164
+ message += "\nlog: #{log_path}"
165
+ message += "\nlast log lines:\n#{lines}" unless lines.empty?
166
+ message
167
+ end
168
+
169
+ def process_exited_message(server_model, log_path)
170
+ message = "#{server_model.name} server process exited before becoming healthy"
171
+ return message unless log_path
172
+
173
+ lines = log_tail(log_path)
174
+ message += "\nlog: #{log_path}"
175
+ message += "\nlast log lines:\n#{lines}" unless lines.empty?
176
+ message
177
+ end
178
+
179
+ def log_tail(log_path)
180
+ return "" unless File.exist?(log_path)
181
+
182
+ File.readlines(log_path).last(20).join
183
+ rescue Errno::ENOENT, Errno::EACCES, IOError
184
+ ""
185
+ end
186
+
187
+ def healthy_state?(state)
188
+ return false unless state && state["url"] && state["pid"]
189
+ return false unless process_running?(state.fetch("pid"))
190
+
191
+ healthy_url?(state.fetch("url"))
192
+ end
193
+
194
+ def running_server_state?(state)
195
+ state && state["url"] && state["pid"] && process_running?(state.fetch("pid"))
196
+ end
197
+
198
+ def tracked_process_exited?(state)
199
+ state && state["url"] && state["pid"] && !process_running?(state.fetch("pid"))
200
+ end
201
+
202
+ def healthy_url?(url)
203
+ uri = URI.join(url.end_with?("/") ? url : "#{url}/", "health")
204
+ response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https", read_timeout: 2, open_timeout: 2) { |http| http.get(uri) }
205
+ response.is_a?(Net::HTTPSuccess)
206
+ rescue StandardError
207
+ false
208
+ end
209
+
210
+ def process_running?(pid)
211
+ Process.kill(0, Integer(pid))
212
+ true
213
+ rescue Errno::ESRCH, ArgumentError
214
+ false
215
+ rescue Errno::EPERM
216
+ # Process exists but belongs to a different user; treat as running.
217
+ true
218
+ end
219
+
220
+ def with_lock(server_model)
221
+ FileUtils.mkdir_p(config.state_dir)
222
+ File.open(lock_path(server_model), File::RDWR | File::CREAT, 0o644) do |file|
223
+ file.flock(File::LOCK_EX)
224
+ yield
225
+ end
226
+ end
227
+
228
+ def write_state(server_model, pid:, url:, runtime:, port:)
229
+ File.write(state_path(server_model), JSON.pretty_generate({
230
+ pid: pid,
231
+ url: url,
232
+ profile: server_model.profile.name,
233
+ capability: server_model.capability,
234
+ runtime: runtime,
235
+ port: port,
236
+ updated_at: Time.now.utc.iso8601
237
+ }))
238
+ end
239
+
240
+ def read_state(server_model)
241
+ JSON.parse(File.read(state_path(server_model)))
242
+ rescue Errno::ENOENT, JSON::ParserError
243
+ nil
244
+ end
245
+
246
+ def delete_state(server_model)
247
+ FileUtils.rm_f(state_path(server_model))
248
+ end
249
+
250
+ def state_path(server_model)
251
+ File.join(config.state_dir, "#{server_model.name}.json")
252
+ end
253
+
254
+ def lock_path(server_model)
255
+ File.join(config.state_dir, "#{server_model.name}.lock")
256
+ end
257
+ end
258
+ end
@@ -0,0 +1,46 @@
1
+ # frozen_string_literal: true
2
+
3
+ module EmbeddingUtil
4
+ SERVER_MODEL_PREFIXES = {
5
+ "embedding" => :embedding,
6
+ "reranker" => :reranker,
7
+ "rerank" => :reranker
8
+ }.freeze
9
+
10
+ ServerModel = Data.define(:capability, :profile) do
11
+ def self.parse(value)
12
+ text = value.to_s
13
+ prefix, profile_name = text.split("-", 2)
14
+ capability = SERVER_MODEL_PREFIXES[prefix]
15
+ raise ArgumentError, "unknown server model #{value.inspect}; expected embedding-PROFILE or reranker-PROFILE" unless capability && profile_name
16
+
17
+ new(capability: capability, profile: Profiles.fetch(profile_name))
18
+ end
19
+
20
+ def self.for(capability, profile)
21
+ new(capability: capability.to_sym, profile: profile)
22
+ end
23
+
24
+ def name
25
+ "#{capability_name}-#{profile.name}"
26
+ end
27
+
28
+ def settings
29
+ case capability
30
+ when :embedding then profile.embedding
31
+ when :reranker then profile.reranker
32
+ else raise ArgumentError, "unknown server capability: #{capability.inspect}"
33
+ end
34
+ end
35
+
36
+ def default_port(config)
37
+ capability == :embedding ? config.embedding_port : config.reranker_port
38
+ end
39
+
40
+ private
41
+
42
+ def capability_name
43
+ capability == :reranker ? "reranker" : capability.to_s
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module EmbeddingUtil
4
+ VERSION = "0.1.0"
5
+ end