onnx-ruby 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,73 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "fileutils"
4
+ require "net/http"
5
+ require "uri"
6
+ require "json"
7
+
8
+ module OnnxRuby
9
+ module Hub
10
+ DEFAULT_CACHE_DIR = File.join(Dir.home, ".cache", "onnx_ruby", "models")
11
+
12
+ # Download a model from Hugging Face Hub
13
+ # @param repo_id [String] e.g. "sentence-transformers/all-MiniLM-L6-v2"
14
+ # @param filename [String] ONNX file to download (default: "model.onnx")
15
+ # @param cache_dir [String] local cache directory
16
+ # @return [String] path to the downloaded model file
17
+ def self.download(repo_id, filename: "model.onnx", cache_dir: DEFAULT_CACHE_DIR, revision: "main")
18
+ model_dir = File.join(cache_dir, repo_id.tr("/", "--"), revision)
19
+ model_path = File.join(model_dir, filename)
20
+
21
+ return model_path if File.exist?(model_path)
22
+
23
+ FileUtils.mkdir_p(model_dir)
24
+
25
+ url = "https://huggingface.co/#{repo_id}/resolve/#{revision}/#{filename}"
26
+ download_file(url, model_path)
27
+
28
+ model_path
29
+ end
30
+
31
+ # List cached models
32
+ # @param cache_dir [String] cache directory to search
33
+ # @return [Array<String>] list of cached model paths
34
+ def self.cached_models(cache_dir: DEFAULT_CACHE_DIR)
35
+ return [] unless Dir.exist?(cache_dir)
36
+
37
+ Dir.glob(File.join(cache_dir, "**", "*.onnx"))
38
+ end
39
+
40
+ # Clear the model cache
41
+ # @param cache_dir [String] cache directory to clear
42
+ def self.clear_cache(cache_dir: DEFAULT_CACHE_DIR)
43
+ FileUtils.rm_rf(cache_dir) if Dir.exist?(cache_dir)
44
+ end
45
+
46
+ class << self
47
+ private
48
+
49
+ def download_file(url, dest)
50
+ uri = URI(url)
51
+ max_redirects = 5
52
+
53
+ max_redirects.times do
54
+ response = Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
55
+ http.request(Net::HTTP::Get.new(uri))
56
+ end
57
+
58
+ case response
59
+ when Net::HTTPSuccess
60
+ File.binwrite(dest, response.body)
61
+ return
62
+ when Net::HTTPRedirection
63
+ uri = URI(response["location"])
64
+ else
65
+ raise ModelError, "failed to download #{url}: #{response.code} #{response.message}"
66
+ end
67
+ end
68
+
69
+ raise ModelError, "too many redirects downloading #{url}"
70
+ end
71
+ end
72
+ end
73
+ end
@@ -0,0 +1,38 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OnnxRuby
4
+ class LazySession
5
+ def initialize(model_path, **opts)
6
+ @model_path = model_path
7
+ @opts = opts
8
+ @session = nil
9
+ @mutex = Mutex.new
10
+ end
11
+
12
+ def inputs
13
+ load_session.inputs
14
+ end
15
+
16
+ def outputs
17
+ load_session.outputs
18
+ end
19
+
20
+ def run(inputs, **kwargs)
21
+ load_session.run(inputs, **kwargs)
22
+ end
23
+
24
+ def loaded?
25
+ !@session.nil?
26
+ end
27
+
28
+ private
29
+
30
+ def load_session
31
+ return @session if @session
32
+
33
+ @mutex.synchronize do
34
+ @session ||= Session.new(@model_path, **@opts)
35
+ end
36
+ end
37
+ end
38
+ end
@@ -0,0 +1,71 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OnnxRuby
4
+ # ActiveModel-style mixin for embedding generation.
5
+ #
6
+ # Usage:
7
+ # class Document
8
+ # include OnnxRuby::Model
9
+ #
10
+ # onnx_model "embeddings.onnx"
11
+ # onnx_input ->(doc) { { "input_ids" => doc.token_ids, "attention_mask" => doc.mask } }
12
+ # onnx_output "embeddings"
13
+ # end
14
+ #
15
+ # doc = Document.new
16
+ # doc.onnx_predict # => [0.123, -0.456, ...]
17
+ module Model
18
+ def self.included(base)
19
+ base.extend(ClassMethods)
20
+ end
21
+
22
+ module ClassMethods
23
+ def onnx_model(path = nil, **opts)
24
+ if path
25
+ @onnx_model_path = path
26
+ @onnx_session_opts = opts
27
+ end
28
+ @onnx_model_path
29
+ end
30
+
31
+ def onnx_input(callable = nil, &block)
32
+ @onnx_input_fn = callable || block if callable || block
33
+ @onnx_input_fn
34
+ end
35
+
36
+ def onnx_output(name = nil)
37
+ @onnx_output_name = name if name
38
+ @onnx_output_name
39
+ end
40
+
41
+ def onnx_session
42
+ @onnx_session ||= begin
43
+ path = resolve_model_path(@onnx_model_path)
44
+ LazySession.new(path, **(@onnx_session_opts || {}))
45
+ end
46
+ end
47
+
48
+ private
49
+
50
+ def resolve_model_path(path)
51
+ return path if File.absolute_path?(path) && File.exist?(path)
52
+
53
+ full = File.join(OnnxRuby.configuration.models_path, path)
54
+ return full if File.exist?(full)
55
+
56
+ path
57
+ end
58
+ end
59
+
60
+ def onnx_predict(**run_opts)
61
+ input_fn = self.class.onnx_input
62
+ raise Error, "onnx_input not defined on #{self.class}" unless input_fn
63
+
64
+ inputs = input_fn.call(self)
65
+ result = self.class.onnx_session.run(inputs, **run_opts)
66
+
67
+ output_name = self.class.onnx_output
68
+ output_name ? result[output_name] : result.values.first
69
+ end
70
+ end
71
+ end
@@ -0,0 +1,91 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OnnxRuby
4
+ class Reranker
5
+ attr_reader :session
6
+
7
+ def initialize(model_path, tokenizer: nil, **session_opts)
8
+ @session = Session.new(model_path, **session_opts)
9
+ @tokenizer = resolve_tokenizer(tokenizer)
10
+ end
11
+
12
+ # Rerank documents by relevance to a query
13
+ # @param query [String] the query text (requires tokenizer)
14
+ # @param documents [Array<String>] documents to rerank
15
+ # @return [Array<Hash>] sorted array of { document:, score:, index: }
16
+ def rerank(query, documents)
17
+ raise Error, "tokenizer is required for reranking" unless @tokenizer
18
+
19
+ pairs = documents.map { |doc| [query, doc] }
20
+ scores = score_pairs(pairs)
21
+
22
+ documents.each_with_index.map do |doc, i|
23
+ { document: doc, score: scores[i], index: i }
24
+ end.sort_by { |r| -r[:score] }
25
+ end
26
+
27
+ # Score query-document pairs with pre-tokenized inputs
28
+ # @param input_ids [Array<Array<Integer>>] batch of token ID sequences
29
+ # @param attention_mask [Array<Array<Integer>>] batch of attention masks
30
+ # @return [Array<Float>] relevance scores
31
+ def score(input_ids:, attention_mask:)
32
+ feed = build_feed(input_ids, attention_mask)
33
+ result = @session.run(feed)
34
+ raw_scores = find_output(result, %w[scores logits output])
35
+ raw_scores.map { |row| row.is_a?(Array) ? row.first : row }
36
+ end
37
+
38
+ private
39
+
40
+ def resolve_tokenizer(tokenizer)
41
+ return nil if tokenizer.nil?
42
+
43
+ if tokenizer.respond_to?(:encode)
44
+ tokenizer
45
+ else
46
+ begin
47
+ require "tokenizers"
48
+ Tokenizers::Tokenizer.from_pretrained(tokenizer.to_s)
49
+ rescue LoadError
50
+ raise Error, "tokenizer-ruby gem is required for text tokenization. " \
51
+ "Install with: gem install tokenizers"
52
+ end
53
+ end
54
+ end
55
+
56
+ def score_pairs(pairs)
57
+ if @tokenizer.respond_to?(:encode_batch)
58
+ encodings = @tokenizer.encode_batch(pairs)
59
+ ids = encodings.map(&:ids)
60
+ masks = encodings.map(&:attention_mask)
61
+ else
62
+ encodings = pairs.map { |pair| @tokenizer.encode(*pair) }
63
+ ids = encodings.map(&:ids)
64
+ masks = encodings.map(&:attention_mask)
65
+ end
66
+
67
+ max_len = ids.map(&:length).max
68
+ ids = ids.map { |row| row + Array.new(max_len - row.length, 0) }
69
+ masks = masks.map { |row| row + Array.new(max_len - row.length, 0) }
70
+
71
+ feed = build_feed(ids, masks)
72
+ result = @session.run(feed)
73
+ raw_scores = find_output(result, %w[scores logits output])
74
+ raw_scores.map { |row| row.is_a?(Array) ? row.first : row }
75
+ end
76
+
77
+ def build_feed(ids, masks)
78
+ input_names = @session.inputs.map { |i| i[:name] }
79
+ feed = {}
80
+ feed[input_names.find { |n| n.include?("input_id") } || input_names[0]] = ids
81
+ mask_name = input_names.find { |n| n.include?("mask") || n.include?("attention") }
82
+ feed[mask_name] = masks if mask_name
83
+ feed
84
+ end
85
+
86
+ def find_output(result, candidate_names)
87
+ candidate_names.each { |name| return result[name] if result.key?(name) }
88
+ result.values.first
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,89 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OnnxRuby
4
+ class Session
5
+ VALID_PROVIDERS = %i[cpu coreml cuda tensorrt].freeze
6
+
7
+ def initialize(model_path, providers: [:cpu], inter_threads: nil, intra_threads: nil,
8
+ log_level: :warning, optimization_level: :all, memory_pattern: true,
9
+ cpu_mem_arena: true, execution_mode: :sequential)
10
+ model_path = File.expand_path(model_path)
11
+ raise ModelError, "model file not found: #{model_path}" unless File.exist?(model_path)
12
+
13
+ provider_strs = Array(providers).map do |p|
14
+ p = p.to_sym
15
+ raise Error, "unknown provider: #{p}. Valid: #{VALID_PROVIDERS.join(", ")}" unless VALID_PROVIDERS.include?(p)
16
+ p.to_s
17
+ end
18
+
19
+ @session = Ext::SessionWrapper.new(
20
+ model_path,
21
+ log_level_to_int(log_level),
22
+ intra_threads || 0,
23
+ inter_threads || 0,
24
+ optimization_level.to_s,
25
+ memory_pattern,
26
+ cpu_mem_arena,
27
+ execution_mode.to_s,
28
+ provider_strs
29
+ )
30
+ end
31
+
32
+ def inputs
33
+ @session.input_info
34
+ end
35
+
36
+ def outputs
37
+ @session.output_info
38
+ end
39
+
40
+ def run(inputs, output_names: nil)
41
+ input_values = inputs.map do |name, data|
42
+ if data.is_a?(Tensor)
43
+ { name: name, data: data.flat_data, shape: data.shape, dtype: data.dtype.to_s }
44
+ else
45
+ flat = data.flatten
46
+ shape = infer_shape(data)
47
+ dtype = infer_dtype(flat)
48
+ { name: name, data: flat, shape: shape, dtype: dtype }
49
+ end
50
+ end
51
+
52
+ @session.run(input_values, output_names || [])
53
+ end
54
+
55
+ private
56
+
57
+ def log_level_to_int(level)
58
+ case level
59
+ when :verbose then 0
60
+ when :info then 1
61
+ when :warning then 2
62
+ when :error then 3
63
+ when :fatal then 4
64
+ else 2
65
+ end
66
+ end
67
+
68
+ def infer_shape(data)
69
+ shape = []
70
+ current = data
71
+ while current.is_a?(Array)
72
+ shape << current.length
73
+ current = current.first
74
+ end
75
+ shape
76
+ end
77
+
78
+ def infer_dtype(flat)
79
+ sample = flat.find { |v| !v.nil? }
80
+ case sample
81
+ when Float then "float"
82
+ when Integer then "int64"
83
+ when String then "string"
84
+ when true, false then "bool"
85
+ else "float"
86
+ end
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,75 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OnnxRuby
4
+ class SessionPool
5
+ class TimeoutError < Error; end
6
+
7
+ def initialize(model_path, size: nil, timeout: nil, **session_opts)
8
+ @model_path = model_path
9
+ @session_opts = session_opts
10
+ @size = size || OnnxRuby.configuration.pool_size
11
+ @timeout = timeout || OnnxRuby.configuration.pool_timeout
12
+ @pool = []
13
+ @mutex = Mutex.new
14
+ @condition = ConditionVariable.new
15
+ @created = 0
16
+ end
17
+
18
+ # Check out a session, yield it, then check it back in
19
+ def with_session(&block)
20
+ session = checkout
21
+ begin
22
+ yield session
23
+ ensure
24
+ checkin(session)
25
+ end
26
+ end
27
+
28
+ # Run inference using a pooled session
29
+ def run(inputs, **kwargs)
30
+ with_session { |s| s.run(inputs, **kwargs) }
31
+ end
32
+
33
+ # Current pool stats
34
+ def size
35
+ @mutex.synchronize { @created }
36
+ end
37
+
38
+ def available
39
+ @mutex.synchronize { @pool.size }
40
+ end
41
+
42
+ private
43
+
44
+ def checkout
45
+ @mutex.synchronize do
46
+ loop do
47
+ # Return an available session
48
+ return @pool.pop unless @pool.empty?
49
+
50
+ # Create a new one if under limit
51
+ if @created < @size
52
+ @created += 1
53
+ return create_session
54
+ end
55
+
56
+ # Wait for one to be returned
57
+ unless @condition.wait(@mutex, @timeout)
58
+ raise TimeoutError, "timed out waiting for session (pool size: #{@size})"
59
+ end
60
+ end
61
+ end
62
+ end
63
+
64
+ def checkin(session)
65
+ @mutex.synchronize do
66
+ @pool.push(session)
67
+ @condition.signal
68
+ end
69
+ end
70
+
71
+ def create_session
72
+ Session.new(@model_path, **@session_opts)
73
+ end
74
+ end
75
+ end
@@ -0,0 +1,92 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OnnxRuby
4
+ class Tensor
5
+ DTYPE_MAP = {
6
+ float32: :float,
7
+ float: :float,
8
+ float64: :double,
9
+ double: :double,
10
+ int32: :int32,
11
+ int: :int32,
12
+ int64: :int64,
13
+ bool: :bool,
14
+ string: :string
15
+ }.freeze
16
+
17
+ attr_reader :shape, :dtype
18
+
19
+ def initialize(data, shape: nil, dtype: nil)
20
+ @data = data.flatten
21
+ @shape = shape || infer_shape(data)
22
+ @dtype = normalize_dtype(dtype || infer_dtype(@data))
23
+
24
+ validate!
25
+ end
26
+
27
+ def to_a
28
+ reshape(@data.dup, @shape)
29
+ end
30
+
31
+ def flat_data
32
+ @data
33
+ end
34
+
35
+ def self.float(data, shape: nil)
36
+ new(data, shape: shape, dtype: :float)
37
+ end
38
+
39
+ def self.int64(data, shape: nil)
40
+ new(data, shape: shape, dtype: :int64)
41
+ end
42
+
43
+ def self.int32(data, shape: nil)
44
+ new(data, shape: shape, dtype: :int32)
45
+ end
46
+
47
+ def self.double(data, shape: nil)
48
+ new(data, shape: shape, dtype: :double)
49
+ end
50
+
51
+ private
52
+
53
+ def normalize_dtype(dtype)
54
+ DTYPE_MAP.fetch(dtype) { raise TensorError, "unsupported dtype: #{dtype}" }
55
+ end
56
+
57
+ def infer_shape(data)
58
+ shape = []
59
+ current = data
60
+ while current.is_a?(Array)
61
+ shape << current.length
62
+ current = current.first
63
+ end
64
+ shape
65
+ end
66
+
67
+ def infer_dtype(flat)
68
+ sample = flat.find { |v| !v.nil? }
69
+ case sample
70
+ when Float then :float
71
+ when Integer then :int64
72
+ when String then :string
73
+ when true, false then :bool
74
+ else :float
75
+ end
76
+ end
77
+
78
+ def validate!
79
+ expected_size = @shape.reduce(1, :*)
80
+ if @data.length != expected_size
81
+ raise TensorError,
82
+ "data size #{@data.length} does not match shape #{@shape} (expected #{expected_size})"
83
+ end
84
+ end
85
+
86
+ def reshape(flat, dims)
87
+ return flat if dims.length <= 1
88
+ size = dims[1..].reduce(1, :*)
89
+ flat.each_slice(size).map { |slice| reshape(slice, dims[1..]) }
90
+ end
91
+ end
92
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OnnxRuby
4
+ VERSION = "0.1.0"
5
+ end
data/lib/onnx_ruby.rb ADDED
@@ -0,0 +1,45 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "onnx_ruby/version"
4
+
5
+ module OnnxRuby
6
+ class Error < StandardError; end
7
+ class ModelError < Error; end
8
+ class InferenceError < Error; end
9
+ class TensorError < Error; end
10
+ end
11
+
12
+ require_relative "onnx_ruby/onnx_ruby_ext"
13
+ require_relative "onnx_ruby/tensor"
14
+ require_relative "onnx_ruby/session"
15
+ require_relative "onnx_ruby/embedder"
16
+ require_relative "onnx_ruby/classifier"
17
+ require_relative "onnx_ruby/reranker"
18
+ require_relative "onnx_ruby/hub"
19
+ require_relative "onnx_ruby/configuration"
20
+ require_relative "onnx_ruby/lazy_session"
21
+ require_relative "onnx_ruby/session_pool"
22
+ require_relative "onnx_ruby/model"
23
+
24
+ module OnnxRuby
25
+ class << self
26
+ def configuration
27
+ @configuration ||= Configuration.new
28
+ end
29
+
30
+ def configure
31
+ yield configuration
32
+ end
33
+ end
34
+
35
+ def self.optimize(input_path, output_path, level: :all)
36
+ input_path = File.expand_path(input_path)
37
+ raise ModelError, "model file not found: #{input_path}" unless File.exist?(input_path)
38
+
39
+ Ext.optimize_model(input_path, output_path, level.to_s)
40
+ end
41
+
42
+ def self.available_providers
43
+ Ext.available_providers
44
+ end
45
+ end
data/onnx-ruby.gemspec ADDED
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "lib/onnx_ruby/version"
4
+
5
+ Gem::Specification.new do |spec|
6
+ spec.name = "onnx-ruby"
7
+ spec.version = OnnxRuby::VERSION
8
+ spec.authors = ["Johannes Dwi Cahyo"]
9
+ spec.email = ["johannesdwicahyo@gmail.com"]
10
+
11
+ spec.summary = "Ruby bindings for ONNX Runtime"
12
+ spec.description = "High-performance ONNX Runtime bindings for Ruby using Rice. " \
13
+ "Run ONNX models locally for embeddings, classification, NER, and more."
14
+ spec.homepage = "https://github.com/johannesdwicahyo/onnx-ruby"
15
+ spec.license = "MIT"
16
+ spec.required_ruby_version = ">= 3.1.0"
17
+
18
+ spec.metadata["homepage_uri"] = spec.homepage
19
+ spec.metadata["source_code_uri"] = spec.homepage
20
+ spec.metadata["changelog_uri"] = "#{spec.homepage}/blob/main/CHANGELOG.md"
21
+
22
+ spec.files = Dir.chdir(__dir__) do
23
+ `git ls-files -z`.split("\x0").reject do |f|
24
+ (File.expand_path(f) == __FILE__) ||
25
+ f.start_with?("test/", "spec/", "features/", ".git", ".github", "script/")
26
+ end
27
+ end
28
+
29
+ spec.require_paths = ["lib"]
30
+ spec.extensions = ["ext/onnx_ruby/extconf.rb"]
31
+
32
+ spec.add_dependency "rice", ">= 4.0"
33
+
34
+ spec.add_development_dependency "rake", "~> 13.0"
35
+ spec.add_development_dependency "rake-compiler", "~> 1.2"
36
+ spec.add_development_dependency "minitest", "~> 5.0"
37
+ end