onnx-ruby 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CLAUDE.md +334 -0
- data/Gemfile +5 -0
- data/LICENSE +21 -0
- data/README.md +301 -0
- data/Rakefile +17 -0
- data/examples/classification.rb +35 -0
- data/examples/embedding.rb +35 -0
- data/examples/real_world_demo.rb +170 -0
- data/examples/with_zvec.rb +54 -0
- data/ext/onnx_ruby/extconf.rb +75 -0
- data/ext/onnx_ruby/onnx_ruby_ext.cpp +436 -0
- data/lib/onnx_ruby/classifier.rb +107 -0
- data/lib/onnx_ruby/configuration.rb +16 -0
- data/lib/onnx_ruby/embedder.rb +147 -0
- data/lib/onnx_ruby/hub.rb +73 -0
- data/lib/onnx_ruby/lazy_session.rb +38 -0
- data/lib/onnx_ruby/model.rb +71 -0
- data/lib/onnx_ruby/reranker.rb +91 -0
- data/lib/onnx_ruby/session.rb +89 -0
- data/lib/onnx_ruby/session_pool.rb +75 -0
- data/lib/onnx_ruby/tensor.rb +92 -0
- data/lib/onnx_ruby/version.rb +5 -0
- data/lib/onnx_ruby.rb +45 -0
- data/onnx-ruby.gemspec +37 -0
- metadata +125 -0
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "fileutils"
|
|
4
|
+
require "net/http"
|
|
5
|
+
require "uri"
|
|
6
|
+
require "json"
|
|
7
|
+
|
|
8
|
+
module OnnxRuby
|
|
9
|
+
module Hub
|
|
10
|
+
DEFAULT_CACHE_DIR = File.join(Dir.home, ".cache", "onnx_ruby", "models")
|
|
11
|
+
|
|
12
|
+
# Download a model from Hugging Face Hub
|
|
13
|
+
# @param repo_id [String] e.g. "sentence-transformers/all-MiniLM-L6-v2"
|
|
14
|
+
# @param filename [String] ONNX file to download (default: "model.onnx")
|
|
15
|
+
# @param cache_dir [String] local cache directory
|
|
16
|
+
# @return [String] path to the downloaded model file
|
|
17
|
+
def self.download(repo_id, filename: "model.onnx", cache_dir: DEFAULT_CACHE_DIR, revision: "main")
|
|
18
|
+
model_dir = File.join(cache_dir, repo_id.tr("/", "--"), revision)
|
|
19
|
+
model_path = File.join(model_dir, filename)
|
|
20
|
+
|
|
21
|
+
return model_path if File.exist?(model_path)
|
|
22
|
+
|
|
23
|
+
FileUtils.mkdir_p(model_dir)
|
|
24
|
+
|
|
25
|
+
url = "https://huggingface.co/#{repo_id}/resolve/#{revision}/#{filename}"
|
|
26
|
+
download_file(url, model_path)
|
|
27
|
+
|
|
28
|
+
model_path
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
# List cached models
|
|
32
|
+
# @param cache_dir [String] cache directory to search
|
|
33
|
+
# @return [Array<String>] list of cached model paths
|
|
34
|
+
def self.cached_models(cache_dir: DEFAULT_CACHE_DIR)
|
|
35
|
+
return [] unless Dir.exist?(cache_dir)
|
|
36
|
+
|
|
37
|
+
Dir.glob(File.join(cache_dir, "**", "*.onnx"))
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# Clear the model cache
|
|
41
|
+
# @param cache_dir [String] cache directory to clear
|
|
42
|
+
def self.clear_cache(cache_dir: DEFAULT_CACHE_DIR)
|
|
43
|
+
FileUtils.rm_rf(cache_dir) if Dir.exist?(cache_dir)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
class << self
|
|
47
|
+
private
|
|
48
|
+
|
|
49
|
+
def download_file(url, dest)
|
|
50
|
+
uri = URI(url)
|
|
51
|
+
max_redirects = 5
|
|
52
|
+
|
|
53
|
+
max_redirects.times do
|
|
54
|
+
response = Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
|
55
|
+
http.request(Net::HTTP::Get.new(uri))
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
case response
|
|
59
|
+
when Net::HTTPSuccess
|
|
60
|
+
File.binwrite(dest, response.body)
|
|
61
|
+
return
|
|
62
|
+
when Net::HTTPRedirection
|
|
63
|
+
uri = URI(response["location"])
|
|
64
|
+
else
|
|
65
|
+
raise ModelError, "failed to download #{url}: #{response.code} #{response.message}"
|
|
66
|
+
end
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
raise ModelError, "too many redirects downloading #{url}"
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module OnnxRuby
|
|
4
|
+
class LazySession
|
|
5
|
+
def initialize(model_path, **opts)
|
|
6
|
+
@model_path = model_path
|
|
7
|
+
@opts = opts
|
|
8
|
+
@session = nil
|
|
9
|
+
@mutex = Mutex.new
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def inputs
|
|
13
|
+
load_session.inputs
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def outputs
|
|
17
|
+
load_session.outputs
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def run(inputs, **kwargs)
|
|
21
|
+
load_session.run(inputs, **kwargs)
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def loaded?
|
|
25
|
+
!@session.nil?
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
private
|
|
29
|
+
|
|
30
|
+
def load_session
|
|
31
|
+
return @session if @session
|
|
32
|
+
|
|
33
|
+
@mutex.synchronize do
|
|
34
|
+
@session ||= Session.new(@model_path, **@opts)
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
end
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module OnnxRuby
|
|
4
|
+
# ActiveModel-style mixin for embedding generation.
|
|
5
|
+
#
|
|
6
|
+
# Usage:
|
|
7
|
+
# class Document
|
|
8
|
+
# include OnnxRuby::Model
|
|
9
|
+
#
|
|
10
|
+
# onnx_model "embeddings.onnx"
|
|
11
|
+
# onnx_input ->(doc) { { "input_ids" => doc.token_ids, "attention_mask" => doc.mask } }
|
|
12
|
+
# onnx_output "embeddings"
|
|
13
|
+
# end
|
|
14
|
+
#
|
|
15
|
+
# doc = Document.new
|
|
16
|
+
# doc.onnx_predict # => [0.123, -0.456, ...]
|
|
17
|
+
module Model
|
|
18
|
+
def self.included(base)
|
|
19
|
+
base.extend(ClassMethods)
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
module ClassMethods
|
|
23
|
+
def onnx_model(path = nil, **opts)
|
|
24
|
+
if path
|
|
25
|
+
@onnx_model_path = path
|
|
26
|
+
@onnx_session_opts = opts
|
|
27
|
+
end
|
|
28
|
+
@onnx_model_path
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def onnx_input(callable = nil, &block)
|
|
32
|
+
@onnx_input_fn = callable || block if callable || block
|
|
33
|
+
@onnx_input_fn
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def onnx_output(name = nil)
|
|
37
|
+
@onnx_output_name = name if name
|
|
38
|
+
@onnx_output_name
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def onnx_session
|
|
42
|
+
@onnx_session ||= begin
|
|
43
|
+
path = resolve_model_path(@onnx_model_path)
|
|
44
|
+
LazySession.new(path, **(@onnx_session_opts || {}))
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
private
|
|
49
|
+
|
|
50
|
+
def resolve_model_path(path)
|
|
51
|
+
return path if File.absolute_path?(path) && File.exist?(path)
|
|
52
|
+
|
|
53
|
+
full = File.join(OnnxRuby.configuration.models_path, path)
|
|
54
|
+
return full if File.exist?(full)
|
|
55
|
+
|
|
56
|
+
path
|
|
57
|
+
end
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def onnx_predict(**run_opts)
|
|
61
|
+
input_fn = self.class.onnx_input
|
|
62
|
+
raise Error, "onnx_input not defined on #{self.class}" unless input_fn
|
|
63
|
+
|
|
64
|
+
inputs = input_fn.call(self)
|
|
65
|
+
result = self.class.onnx_session.run(inputs, **run_opts)
|
|
66
|
+
|
|
67
|
+
output_name = self.class.onnx_output
|
|
68
|
+
output_name ? result[output_name] : result.values.first
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
end
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module OnnxRuby
|
|
4
|
+
class Reranker
|
|
5
|
+
attr_reader :session
|
|
6
|
+
|
|
7
|
+
def initialize(model_path, tokenizer: nil, **session_opts)
|
|
8
|
+
@session = Session.new(model_path, **session_opts)
|
|
9
|
+
@tokenizer = resolve_tokenizer(tokenizer)
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
# Rerank documents by relevance to a query
|
|
13
|
+
# @param query [String] the query text (requires tokenizer)
|
|
14
|
+
# @param documents [Array<String>] documents to rerank
|
|
15
|
+
# @return [Array<Hash>] sorted array of { document:, score:, index: }
|
|
16
|
+
def rerank(query, documents)
|
|
17
|
+
raise Error, "tokenizer is required for reranking" unless @tokenizer
|
|
18
|
+
|
|
19
|
+
pairs = documents.map { |doc| [query, doc] }
|
|
20
|
+
scores = score_pairs(pairs)
|
|
21
|
+
|
|
22
|
+
documents.each_with_index.map do |doc, i|
|
|
23
|
+
{ document: doc, score: scores[i], index: i }
|
|
24
|
+
end.sort_by { |r| -r[:score] }
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
# Score query-document pairs with pre-tokenized inputs
|
|
28
|
+
# @param input_ids [Array<Array<Integer>>] batch of token ID sequences
|
|
29
|
+
# @param attention_mask [Array<Array<Integer>>] batch of attention masks
|
|
30
|
+
# @return [Array<Float>] relevance scores
|
|
31
|
+
def score(input_ids:, attention_mask:)
|
|
32
|
+
feed = build_feed(input_ids, attention_mask)
|
|
33
|
+
result = @session.run(feed)
|
|
34
|
+
raw_scores = find_output(result, %w[scores logits output])
|
|
35
|
+
raw_scores.map { |row| row.is_a?(Array) ? row.first : row }
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
private
|
|
39
|
+
|
|
40
|
+
def resolve_tokenizer(tokenizer)
|
|
41
|
+
return nil if tokenizer.nil?
|
|
42
|
+
|
|
43
|
+
if tokenizer.respond_to?(:encode)
|
|
44
|
+
tokenizer
|
|
45
|
+
else
|
|
46
|
+
begin
|
|
47
|
+
require "tokenizers"
|
|
48
|
+
Tokenizers::Tokenizer.from_pretrained(tokenizer.to_s)
|
|
49
|
+
rescue LoadError
|
|
50
|
+
raise Error, "tokenizer-ruby gem is required for text tokenization. " \
|
|
51
|
+
"Install with: gem install tokenizers"
|
|
52
|
+
end
|
|
53
|
+
end
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
def score_pairs(pairs)
|
|
57
|
+
if @tokenizer.respond_to?(:encode_batch)
|
|
58
|
+
encodings = @tokenizer.encode_batch(pairs)
|
|
59
|
+
ids = encodings.map(&:ids)
|
|
60
|
+
masks = encodings.map(&:attention_mask)
|
|
61
|
+
else
|
|
62
|
+
encodings = pairs.map { |pair| @tokenizer.encode(*pair) }
|
|
63
|
+
ids = encodings.map(&:ids)
|
|
64
|
+
masks = encodings.map(&:attention_mask)
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
max_len = ids.map(&:length).max
|
|
68
|
+
ids = ids.map { |row| row + Array.new(max_len - row.length, 0) }
|
|
69
|
+
masks = masks.map { |row| row + Array.new(max_len - row.length, 0) }
|
|
70
|
+
|
|
71
|
+
feed = build_feed(ids, masks)
|
|
72
|
+
result = @session.run(feed)
|
|
73
|
+
raw_scores = find_output(result, %w[scores logits output])
|
|
74
|
+
raw_scores.map { |row| row.is_a?(Array) ? row.first : row }
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def build_feed(ids, masks)
|
|
78
|
+
input_names = @session.inputs.map { |i| i[:name] }
|
|
79
|
+
feed = {}
|
|
80
|
+
feed[input_names.find { |n| n.include?("input_id") } || input_names[0]] = ids
|
|
81
|
+
mask_name = input_names.find { |n| n.include?("mask") || n.include?("attention") }
|
|
82
|
+
feed[mask_name] = masks if mask_name
|
|
83
|
+
feed
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
def find_output(result, candidate_names)
|
|
87
|
+
candidate_names.each { |name| return result[name] if result.key?(name) }
|
|
88
|
+
result.values.first
|
|
89
|
+
end
|
|
90
|
+
end
|
|
91
|
+
end
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module OnnxRuby
|
|
4
|
+
class Session
|
|
5
|
+
VALID_PROVIDERS = %i[cpu coreml cuda tensorrt].freeze
|
|
6
|
+
|
|
7
|
+
def initialize(model_path, providers: [:cpu], inter_threads: nil, intra_threads: nil,
|
|
8
|
+
log_level: :warning, optimization_level: :all, memory_pattern: true,
|
|
9
|
+
cpu_mem_arena: true, execution_mode: :sequential)
|
|
10
|
+
model_path = File.expand_path(model_path)
|
|
11
|
+
raise ModelError, "model file not found: #{model_path}" unless File.exist?(model_path)
|
|
12
|
+
|
|
13
|
+
provider_strs = Array(providers).map do |p|
|
|
14
|
+
p = p.to_sym
|
|
15
|
+
raise Error, "unknown provider: #{p}. Valid: #{VALID_PROVIDERS.join(", ")}" unless VALID_PROVIDERS.include?(p)
|
|
16
|
+
p.to_s
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
@session = Ext::SessionWrapper.new(
|
|
20
|
+
model_path,
|
|
21
|
+
log_level_to_int(log_level),
|
|
22
|
+
intra_threads || 0,
|
|
23
|
+
inter_threads || 0,
|
|
24
|
+
optimization_level.to_s,
|
|
25
|
+
memory_pattern,
|
|
26
|
+
cpu_mem_arena,
|
|
27
|
+
execution_mode.to_s,
|
|
28
|
+
provider_strs
|
|
29
|
+
)
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
def inputs
|
|
33
|
+
@session.input_info
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def outputs
|
|
37
|
+
@session.output_info
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def run(inputs, output_names: nil)
|
|
41
|
+
input_values = inputs.map do |name, data|
|
|
42
|
+
if data.is_a?(Tensor)
|
|
43
|
+
{ name: name, data: data.flat_data, shape: data.shape, dtype: data.dtype.to_s }
|
|
44
|
+
else
|
|
45
|
+
flat = data.flatten
|
|
46
|
+
shape = infer_shape(data)
|
|
47
|
+
dtype = infer_dtype(flat)
|
|
48
|
+
{ name: name, data: flat, shape: shape, dtype: dtype }
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
@session.run(input_values, output_names || [])
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
private
|
|
56
|
+
|
|
57
|
+
def log_level_to_int(level)
|
|
58
|
+
case level
|
|
59
|
+
when :verbose then 0
|
|
60
|
+
when :info then 1
|
|
61
|
+
when :warning then 2
|
|
62
|
+
when :error then 3
|
|
63
|
+
when :fatal then 4
|
|
64
|
+
else 2
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
def infer_shape(data)
|
|
69
|
+
shape = []
|
|
70
|
+
current = data
|
|
71
|
+
while current.is_a?(Array)
|
|
72
|
+
shape << current.length
|
|
73
|
+
current = current.first
|
|
74
|
+
end
|
|
75
|
+
shape
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
def infer_dtype(flat)
|
|
79
|
+
sample = flat.find { |v| !v.nil? }
|
|
80
|
+
case sample
|
|
81
|
+
when Float then "float"
|
|
82
|
+
when Integer then "int64"
|
|
83
|
+
when String then "string"
|
|
84
|
+
when true, false then "bool"
|
|
85
|
+
else "float"
|
|
86
|
+
end
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
end
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module OnnxRuby
|
|
4
|
+
class SessionPool
|
|
5
|
+
class TimeoutError < Error; end
|
|
6
|
+
|
|
7
|
+
def initialize(model_path, size: nil, timeout: nil, **session_opts)
|
|
8
|
+
@model_path = model_path
|
|
9
|
+
@session_opts = session_opts
|
|
10
|
+
@size = size || OnnxRuby.configuration.pool_size
|
|
11
|
+
@timeout = timeout || OnnxRuby.configuration.pool_timeout
|
|
12
|
+
@pool = []
|
|
13
|
+
@mutex = Mutex.new
|
|
14
|
+
@condition = ConditionVariable.new
|
|
15
|
+
@created = 0
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
# Check out a session, yield it, then check it back in
|
|
19
|
+
def with_session(&block)
|
|
20
|
+
session = checkout
|
|
21
|
+
begin
|
|
22
|
+
yield session
|
|
23
|
+
ensure
|
|
24
|
+
checkin(session)
|
|
25
|
+
end
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
# Run inference using a pooled session
|
|
29
|
+
def run(inputs, **kwargs)
|
|
30
|
+
with_session { |s| s.run(inputs, **kwargs) }
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# Current pool stats
|
|
34
|
+
def size
|
|
35
|
+
@mutex.synchronize { @created }
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
def available
|
|
39
|
+
@mutex.synchronize { @pool.size }
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
private
|
|
43
|
+
|
|
44
|
+
def checkout
|
|
45
|
+
@mutex.synchronize do
|
|
46
|
+
loop do
|
|
47
|
+
# Return an available session
|
|
48
|
+
return @pool.pop unless @pool.empty?
|
|
49
|
+
|
|
50
|
+
# Create a new one if under limit
|
|
51
|
+
if @created < @size
|
|
52
|
+
@created += 1
|
|
53
|
+
return create_session
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
# Wait for one to be returned
|
|
57
|
+
unless @condition.wait(@mutex, @timeout)
|
|
58
|
+
raise TimeoutError, "timed out waiting for session (pool size: #{@size})"
|
|
59
|
+
end
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def checkin(session)
|
|
65
|
+
@mutex.synchronize do
|
|
66
|
+
@pool.push(session)
|
|
67
|
+
@condition.signal
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
def create_session
|
|
72
|
+
Session.new(@model_path, **@session_opts)
|
|
73
|
+
end
|
|
74
|
+
end
|
|
75
|
+
end
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module OnnxRuby
|
|
4
|
+
class Tensor
|
|
5
|
+
DTYPE_MAP = {
|
|
6
|
+
float32: :float,
|
|
7
|
+
float: :float,
|
|
8
|
+
float64: :double,
|
|
9
|
+
double: :double,
|
|
10
|
+
int32: :int32,
|
|
11
|
+
int: :int32,
|
|
12
|
+
int64: :int64,
|
|
13
|
+
bool: :bool,
|
|
14
|
+
string: :string
|
|
15
|
+
}.freeze
|
|
16
|
+
|
|
17
|
+
attr_reader :shape, :dtype
|
|
18
|
+
|
|
19
|
+
def initialize(data, shape: nil, dtype: nil)
|
|
20
|
+
@data = data.flatten
|
|
21
|
+
@shape = shape || infer_shape(data)
|
|
22
|
+
@dtype = normalize_dtype(dtype || infer_dtype(@data))
|
|
23
|
+
|
|
24
|
+
validate!
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def to_a
|
|
28
|
+
reshape(@data.dup, @shape)
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def flat_data
|
|
32
|
+
@data
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def self.float(data, shape: nil)
|
|
36
|
+
new(data, shape: shape, dtype: :float)
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
def self.int64(data, shape: nil)
|
|
40
|
+
new(data, shape: shape, dtype: :int64)
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
def self.int32(data, shape: nil)
|
|
44
|
+
new(data, shape: shape, dtype: :int32)
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
def self.double(data, shape: nil)
|
|
48
|
+
new(data, shape: shape, dtype: :double)
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
private
|
|
52
|
+
|
|
53
|
+
def normalize_dtype(dtype)
|
|
54
|
+
DTYPE_MAP.fetch(dtype) { raise TensorError, "unsupported dtype: #{dtype}" }
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
def infer_shape(data)
|
|
58
|
+
shape = []
|
|
59
|
+
current = data
|
|
60
|
+
while current.is_a?(Array)
|
|
61
|
+
shape << current.length
|
|
62
|
+
current = current.first
|
|
63
|
+
end
|
|
64
|
+
shape
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def infer_dtype(flat)
|
|
68
|
+
sample = flat.find { |v| !v.nil? }
|
|
69
|
+
case sample
|
|
70
|
+
when Float then :float
|
|
71
|
+
when Integer then :int64
|
|
72
|
+
when String then :string
|
|
73
|
+
when true, false then :bool
|
|
74
|
+
else :float
|
|
75
|
+
end
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
def validate!
|
|
79
|
+
expected_size = @shape.reduce(1, :*)
|
|
80
|
+
if @data.length != expected_size
|
|
81
|
+
raise TensorError,
|
|
82
|
+
"data size #{@data.length} does not match shape #{@shape} (expected #{expected_size})"
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
def reshape(flat, dims)
|
|
87
|
+
return flat if dims.length <= 1
|
|
88
|
+
size = dims[1..].reduce(1, :*)
|
|
89
|
+
flat.each_slice(size).map { |slice| reshape(slice, dims[1..]) }
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
end
|
data/lib/onnx_ruby.rb
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "onnx_ruby/version"
|
|
4
|
+
|
|
5
|
+
module OnnxRuby
|
|
6
|
+
class Error < StandardError; end
|
|
7
|
+
class ModelError < Error; end
|
|
8
|
+
class InferenceError < Error; end
|
|
9
|
+
class TensorError < Error; end
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
require_relative "onnx_ruby/onnx_ruby_ext"
|
|
13
|
+
require_relative "onnx_ruby/tensor"
|
|
14
|
+
require_relative "onnx_ruby/session"
|
|
15
|
+
require_relative "onnx_ruby/embedder"
|
|
16
|
+
require_relative "onnx_ruby/classifier"
|
|
17
|
+
require_relative "onnx_ruby/reranker"
|
|
18
|
+
require_relative "onnx_ruby/hub"
|
|
19
|
+
require_relative "onnx_ruby/configuration"
|
|
20
|
+
require_relative "onnx_ruby/lazy_session"
|
|
21
|
+
require_relative "onnx_ruby/session_pool"
|
|
22
|
+
require_relative "onnx_ruby/model"
|
|
23
|
+
|
|
24
|
+
module OnnxRuby
|
|
25
|
+
class << self
|
|
26
|
+
def configuration
|
|
27
|
+
@configuration ||= Configuration.new
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
def configure
|
|
31
|
+
yield configuration
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def self.optimize(input_path, output_path, level: :all)
|
|
36
|
+
input_path = File.expand_path(input_path)
|
|
37
|
+
raise ModelError, "model file not found: #{input_path}" unless File.exist?(input_path)
|
|
38
|
+
|
|
39
|
+
Ext.optimize_model(input_path, output_path, level.to_s)
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
def self.available_providers
|
|
43
|
+
Ext.available_providers
|
|
44
|
+
end
|
|
45
|
+
end
|
data/onnx-ruby.gemspec
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "lib/onnx_ruby/version"
|
|
4
|
+
|
|
5
|
+
Gem::Specification.new do |spec|
|
|
6
|
+
spec.name = "onnx-ruby"
|
|
7
|
+
spec.version = OnnxRuby::VERSION
|
|
8
|
+
spec.authors = ["Johannes Dwi Cahyo"]
|
|
9
|
+
spec.email = ["johannesdwicahyo@gmail.com"]
|
|
10
|
+
|
|
11
|
+
spec.summary = "Ruby bindings for ONNX Runtime"
|
|
12
|
+
spec.description = "High-performance ONNX Runtime bindings for Ruby using Rice. " \
|
|
13
|
+
"Run ONNX models locally for embeddings, classification, NER, and more."
|
|
14
|
+
spec.homepage = "https://github.com/johannesdwicahyo/onnx-ruby"
|
|
15
|
+
spec.license = "MIT"
|
|
16
|
+
spec.required_ruby_version = ">= 3.1.0"
|
|
17
|
+
|
|
18
|
+
spec.metadata["homepage_uri"] = spec.homepage
|
|
19
|
+
spec.metadata["source_code_uri"] = spec.homepage
|
|
20
|
+
spec.metadata["changelog_uri"] = "#{spec.homepage}/blob/main/CHANGELOG.md"
|
|
21
|
+
|
|
22
|
+
spec.files = Dir.chdir(__dir__) do
|
|
23
|
+
`git ls-files -z`.split("\x0").reject do |f|
|
|
24
|
+
(File.expand_path(f) == __FILE__) ||
|
|
25
|
+
f.start_with?("test/", "spec/", "features/", ".git", ".github", "script/")
|
|
26
|
+
end
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
spec.require_paths = ["lib"]
|
|
30
|
+
spec.extensions = ["ext/onnx_ruby/extconf.rb"]
|
|
31
|
+
|
|
32
|
+
spec.add_dependency "rice", ">= 4.0"
|
|
33
|
+
|
|
34
|
+
spec.add_development_dependency "rake", "~> 13.0"
|
|
35
|
+
spec.add_development_dependency "rake-compiler", "~> 1.2"
|
|
36
|
+
spec.add_development_dependency "minitest", "~> 5.0"
|
|
37
|
+
end
|