onnxruntime-torch-tensor 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 0da68dbc7b2e54cbf8a21027cdf21ed73c56f00f10136ef17298ce8155aa1c9c
4
+ data.tar.gz: 5aac4781db8fc0fc2bd11e244c986a7f1eda517cf4c34dd745cca6fea553057e
5
+ SHA512:
6
+ metadata.gz: ba9012f3615f0d4cdb14dd8f62574c3fea422bb42e1f382653a282e2456181a840f60255568b4708886a3fd74df390a0b810d5d4cd06441644654ddd6ce2f1da
7
+ data.tar.gz: 66654fea58aceb0cb64d5b78774732af56773d6985166b7c7acfee2c151345862a2a86a25ab2a7b6dc43dd494e05fb359a617061dd9573ee8fa66b2b66f585fe
data/.gitignore ADDED
@@ -0,0 +1,2 @@
1
+ Gemfile.lock
2
+ pkg/
data/Gemfile ADDED
@@ -0,0 +1,3 @@
1
+ source "https://rubygems.org"
2
+
3
+ gemspec
data/LICENSE.txt ADDED
@@ -0,0 +1,32 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2026 Kitaiti Makoto
4
+
5
+ All rights reserved.
6
+
7
+ Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ 1. Redistributions of source code must retain the above copyright
11
+ notice, this list of conditions and the following disclaimer.
12
+
13
+ 2. Redistributions in binary form must reproduce the above copyright
14
+ notice, this list of conditions and the following disclaimer in the
15
+ documentation and/or other materials provided with the distribution.
16
+
17
+ 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
18
+ and IDIAP Research Institute nor the names of its contributors may be
19
+ used to endorse or promote products derived from this software without
20
+ specific prior written permission.
21
+
22
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
25
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
26
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
27
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
28
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
29
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
30
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
31
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32
+ POSSIBILITY OF SUCH DAMAGE.
data/README.md ADDED
@@ -0,0 +1,80 @@
1
+ OnnxRuntime::Torch::Tensor
2
+ ==========================
3
+
4
+ [Torch][Torch.rb]::Tensor support for [ONNX Runtime Ruby][].
5
+
6
+ This gem provides conversion between [OnnxRuntime][ONNX Runtime Ruby]::OrtValue and [Torch][Torch.rb]::Tensor, so that you can pass Torch::Tensor inputs to OnnxRuntime and work with outputs as Torch::Tensor.
7
+
8
+ It works with zero-copy in most cases. Zero-copy is available when receiver is row-major and contiguous tensor on CPU.
9
+
10
+ SYNOPSIS
11
+ --------
12
+
13
+ ```ruby
14
+ require "torchaudio"
15
+ require "onnxruntime"
16
+ require "onnxruntime/torch/tensor"
17
+
18
+ session = OnnxRuntime::InferenceSession.new("path/to/model")
19
+
20
+ waveform, sample_rate = TorchAudio.load("path/to/file") # => [Torch::Tensor, Integer]
21
+ input_torch_tensor = pre_process(waveform) # => Torch::Tensor
22
+ input = input_torch_tensor.to_ort_value # => OnnxRuntime::OrtValue, zero-copy
23
+ outputs = session.run(
24
+ [:output_name],
25
+ {input_name: input},
26
+ output_type: :ort_value # required to get Torch::Tensor at the next step
27
+ ) # => Array[OnnxRuntime::OrtValue]
28
+ output = outputs[0] # => OnnxRuntime::OrtValue
29
+ output_torch_tensor = output.to_torch_tensor # => Torch::Tensor, zero-copy
30
+ output_waveform = post_process(output_torch_tensor)
31
+ TorchAudio.save("path/to/output", output_waveform, sample_rate)
32
+ ```
33
+
34
+ INSTALLATION
35
+ ------------
36
+
37
+ % gem install onnxruntime-torch-tensor
38
+
39
+ or, add to Gemfile:
40
+
41
+ ```ruby
42
+ gem "onnxruntime-torch-tensor"
43
+ ```
44
+
45
+ API
46
+ ---
47
+
48
+ ```ruby
49
+ OnnxRuntime::OrtValue#to_torch_tensor # => Torch::Tensor
50
+ Torch::Tensor#to_ort_value # => OnnxRuntime::OrtValue
51
+
52
+ OnnxRuntime::OrtValue.from_torch_tensor(torch_tensor) # => OnnxRuntime::OrtValue
53
+ # Not Torch::Tensor.fromm_ort_value, inspired from Torch.from_numo
54
+ Torch.from_ort_value(ort_value) # => Torch::Tensor
55
+ ```
56
+
57
+ This gem uses [NDAV][] internally, so provides `OrtValue()` and `TorchTensor()` methods based on it.
58
+
59
+ ```ruby
60
+ NDAV::Converter::OrtValue(torch_tensor) # => OnnxRuntime::OrtValue
61
+ NDAV::Converter::TorchTensor(ort_value) # => Torch::Tensor
62
+
63
+ module YourApp
64
+ include NDAV::Converter
65
+
66
+ def your_method
67
+ OrtValue(torch_tensor) # => OnnxRuntime::OrtValue
68
+ TorchTensor(ort_value) # => Torch::Tensor
69
+ end
70
+ end
71
+ ```
72
+
73
+ LICENSE
74
+ -------
75
+
76
+ BSD-3-Clause. See LICENSE.txt file.
77
+
78
+ [ONNX Runtime Ruby]: https://github.com/ankane/onnxruntime-ruby
79
+ [Torch.rb]: https://github.com/ankane/torch.rb
80
+ [NDAV]: https://gitlab.com/KitaitiMakoto/ndav
data/Rakefile ADDED
@@ -0,0 +1,7 @@
1
+ require "rake/testtask"
2
+ require "rubygems/tasks"
3
+
4
+ task default: :test
5
+
6
+ Gem::Tasks.new
7
+ Rake::TestTask.new
@@ -0,0 +1,87 @@
1
+ require "ndav/ort_value"
2
+ require "ndav/torch/tensor"
3
+
4
+ class NDAV
5
+ module OnnxRuntime
6
+ module Torch
7
+ module Tensor
8
+ module FromTorchTensor
9
+ def from_torch_tensor(torch)
10
+ from_ndav(torch.to_ndav)
11
+ end
12
+ end
13
+
14
+ module ToOrtValue
15
+ def to_ort_value
16
+ ::OnnxRuntime::OrtValue.from_ndav(to_ndav)
17
+ end
18
+ end
19
+
20
+ module FromOrtValue
21
+ def from_ort_value(ort_value)
22
+ from_ndav(ort_value.to_ndav)
23
+ end
24
+ end
25
+
26
+ module ToTorchTensor
27
+ def to_torch_tensor
28
+ ::Torch.from_ort_value(self)
29
+ end
30
+ end
31
+
32
+ module Converter
33
+ if defined? ::NDAV::Converter::OrtValue
34
+ def OrtValue(tensor)
35
+ case tensor
36
+ when ::OnnxRuntime::OrtValue
37
+ tensor
38
+ when ::Torch::Tensor
39
+ ::OnnxRuntime::OrtValue.from_torch_tensor(tensor)
40
+ else
41
+ super
42
+ end
43
+ end
44
+ else
45
+ def OrtValue(tensor)
46
+ case tensor
47
+ when ::OnnxRuntime::OrtValue
48
+ tensor
49
+ when ::Torch::Tensor
50
+ ::OnnxRuntime::OrtValue.from_torch_tensor(tensor)
51
+ end
52
+ end
53
+ end
54
+
55
+ if defined? ::NDAV::Converter::TorchTensor
56
+ def TorchTensor(tensor)
57
+ case tensor
58
+ when ::OnnxRuntime::OrtValue
59
+ ::Torch.from_ort_value(tensor)
60
+ when ::Torch::Tensor
61
+ tensor
62
+ else
63
+ super
64
+ end
65
+ end
66
+ else
67
+ def TorchTensor(tensor)
68
+ case tensor
69
+ when ::OnnxRuntime::OrtValue
70
+ ::Torch.from_ort_value(tensor)
71
+ when ::Torch::Tensor
72
+ tensor
73
+ end
74
+ end
75
+ end
76
+ end
77
+
78
+ ::OnnxRuntime::OrtValue.extend FromTorchTensor
79
+ ::OnnxRuntime::OrtValue.include ToTorchTensor
80
+ ::Torch.extend FromOrtValue
81
+ ::Torch::Tensor.include ToOrtValue
82
+ ::NDAV::Converter.singleton_class.prepend Converter
83
+ ::NDAV::Converter.prepend Converter
84
+ end
85
+ end
86
+ end
87
+ end
@@ -0,0 +1,22 @@
1
+ Gem::Specification.new do |s|
2
+ s.name = "onnxruntime-torch-tensor"
3
+ s.version = "0.0.1"
4
+ s.authors = ["Kitaiti Makoto"]
5
+ s.summary = "Torch::Tensor support for ONNX Runtime"
6
+ s.licenses = ["BSD-3-Clause"]
7
+ s.homepage = "https://gitlab.com/KitaitiMakoto/onnxruntime-torch-tensor"
8
+
9
+ s.files = Dir.chdir(__dir__) {`git ls-files -z`.split("\x0")}
10
+
11
+ s.add_runtime_dependency "ndav-ort_value"
12
+ s.add_runtime_dependency "ndav-torch-tensor"
13
+
14
+ s.add_development_dependency "rake"
15
+ s.add_development_dependency "test-unit"
16
+ s.add_development_dependency "test-unit-notify"
17
+ s.add_development_dependency "test-unit-rr"
18
+ s.add_development_dependency "terminal-notifier" if RUBY_PLATFORM.match?(/darwin/)
19
+ s.add_development_dependency "rubygems-tasks"
20
+ s.add_development_dependency "numo-narray-alt"
21
+ s.add_development_dependency "get_process_mem"
22
+ end
@@ -0,0 +1,43 @@
1
+ require "test/unit"
2
+ require "test/unit/notify"
3
+ require "onnxruntime/torch/tensor"
4
+
5
+ class TestConverter < Test::Unit::TestCase
6
+ def setup
7
+ @ort_value = ::OnnxRuntime::OrtValue.from_array([1, 2, 3], element_type: :int16)
8
+ @torch_tensor = ::Torch.tensor([1, 2, 3], dtype: :int16)
9
+ end
10
+
11
+ def test_torch_tensor_from_ort_value
12
+ assert_kind_of ::Torch::Tensor, ::Torch.from_ort_value(@ort_value)
13
+ end
14
+
15
+ def test_ort_value_from_torch_tensor
16
+ assert_kind_of ::OnnxRuntime::OrtValue, ::OnnxRuntime::OrtValue.from_torch_tensor(@torch_tensor)
17
+ end
18
+
19
+ def test_torch_tensor_to_ort_value
20
+ assert_kind_of ::OnnxRuntime::OrtValue, @torch_tensor.to_ort_value
21
+ end
22
+
23
+ def test_ort_value_to_torch_tensor
24
+ assert_kind_of ::Torch::Tensor, @ort_value.to_torch_tensor
25
+ end
26
+
27
+ def test_converter
28
+ c = Class.new {
29
+ prepend ::NDAV::Converter
30
+
31
+ def to_torch_tensor(array)
32
+ TorchTensor(array)
33
+ end
34
+
35
+ def to_ort_value(array)
36
+ OrtValue(array)
37
+ end
38
+ }
39
+
40
+ assert_kind_of ::Torch::Tensor, c.new.to_torch_tensor(@ort_value)
41
+ assert_kind_of ::OnnxRuntime::OrtValue, c.new.to_ort_value(@torch_tensor)
42
+ end
43
+ end
metadata ADDED
@@ -0,0 +1,185 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: onnxruntime-torch-tensor
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.0.1
5
+ platform: ruby
6
+ authors:
7
+ - Kitaiti Makoto
8
+ bindir: bin
9
+ cert_chain: []
10
+ date: 1980-01-02 00:00:00.000000000 Z
11
+ dependencies:
12
+ - !ruby/object:Gem::Dependency
13
+ name: ndav-ort_value
14
+ requirement: !ruby/object:Gem::Requirement
15
+ requirements:
16
+ - - ">="
17
+ - !ruby/object:Gem::Version
18
+ version: '0'
19
+ type: :runtime
20
+ prerelease: false
21
+ version_requirements: !ruby/object:Gem::Requirement
22
+ requirements:
23
+ - - ">="
24
+ - !ruby/object:Gem::Version
25
+ version: '0'
26
+ - !ruby/object:Gem::Dependency
27
+ name: ndav-torch-tensor
28
+ requirement: !ruby/object:Gem::Requirement
29
+ requirements:
30
+ - - ">="
31
+ - !ruby/object:Gem::Version
32
+ version: '0'
33
+ type: :runtime
34
+ prerelease: false
35
+ version_requirements: !ruby/object:Gem::Requirement
36
+ requirements:
37
+ - - ">="
38
+ - !ruby/object:Gem::Version
39
+ version: '0'
40
+ - !ruby/object:Gem::Dependency
41
+ name: rake
42
+ requirement: !ruby/object:Gem::Requirement
43
+ requirements:
44
+ - - ">="
45
+ - !ruby/object:Gem::Version
46
+ version: '0'
47
+ type: :development
48
+ prerelease: false
49
+ version_requirements: !ruby/object:Gem::Requirement
50
+ requirements:
51
+ - - ">="
52
+ - !ruby/object:Gem::Version
53
+ version: '0'
54
+ - !ruby/object:Gem::Dependency
55
+ name: test-unit
56
+ requirement: !ruby/object:Gem::Requirement
57
+ requirements:
58
+ - - ">="
59
+ - !ruby/object:Gem::Version
60
+ version: '0'
61
+ type: :development
62
+ prerelease: false
63
+ version_requirements: !ruby/object:Gem::Requirement
64
+ requirements:
65
+ - - ">="
66
+ - !ruby/object:Gem::Version
67
+ version: '0'
68
+ - !ruby/object:Gem::Dependency
69
+ name: test-unit-notify
70
+ requirement: !ruby/object:Gem::Requirement
71
+ requirements:
72
+ - - ">="
73
+ - !ruby/object:Gem::Version
74
+ version: '0'
75
+ type: :development
76
+ prerelease: false
77
+ version_requirements: !ruby/object:Gem::Requirement
78
+ requirements:
79
+ - - ">="
80
+ - !ruby/object:Gem::Version
81
+ version: '0'
82
+ - !ruby/object:Gem::Dependency
83
+ name: test-unit-rr
84
+ requirement: !ruby/object:Gem::Requirement
85
+ requirements:
86
+ - - ">="
87
+ - !ruby/object:Gem::Version
88
+ version: '0'
89
+ type: :development
90
+ prerelease: false
91
+ version_requirements: !ruby/object:Gem::Requirement
92
+ requirements:
93
+ - - ">="
94
+ - !ruby/object:Gem::Version
95
+ version: '0'
96
+ - !ruby/object:Gem::Dependency
97
+ name: terminal-notifier
98
+ requirement: !ruby/object:Gem::Requirement
99
+ requirements:
100
+ - - ">="
101
+ - !ruby/object:Gem::Version
102
+ version: '0'
103
+ type: :development
104
+ prerelease: false
105
+ version_requirements: !ruby/object:Gem::Requirement
106
+ requirements:
107
+ - - ">="
108
+ - !ruby/object:Gem::Version
109
+ version: '0'
110
+ - !ruby/object:Gem::Dependency
111
+ name: rubygems-tasks
112
+ requirement: !ruby/object:Gem::Requirement
113
+ requirements:
114
+ - - ">="
115
+ - !ruby/object:Gem::Version
116
+ version: '0'
117
+ type: :development
118
+ prerelease: false
119
+ version_requirements: !ruby/object:Gem::Requirement
120
+ requirements:
121
+ - - ">="
122
+ - !ruby/object:Gem::Version
123
+ version: '0'
124
+ - !ruby/object:Gem::Dependency
125
+ name: numo-narray-alt
126
+ requirement: !ruby/object:Gem::Requirement
127
+ requirements:
128
+ - - ">="
129
+ - !ruby/object:Gem::Version
130
+ version: '0'
131
+ type: :development
132
+ prerelease: false
133
+ version_requirements: !ruby/object:Gem::Requirement
134
+ requirements:
135
+ - - ">="
136
+ - !ruby/object:Gem::Version
137
+ version: '0'
138
+ - !ruby/object:Gem::Dependency
139
+ name: get_process_mem
140
+ requirement: !ruby/object:Gem::Requirement
141
+ requirements:
142
+ - - ">="
143
+ - !ruby/object:Gem::Version
144
+ version: '0'
145
+ type: :development
146
+ prerelease: false
147
+ version_requirements: !ruby/object:Gem::Requirement
148
+ requirements:
149
+ - - ">="
150
+ - !ruby/object:Gem::Version
151
+ version: '0'
152
+ executables: []
153
+ extensions: []
154
+ extra_rdoc_files: []
155
+ files:
156
+ - ".gitignore"
157
+ - Gemfile
158
+ - LICENSE.txt
159
+ - README.md
160
+ - Rakefile
161
+ - lib/onnxruntime/torch/tensor.rb
162
+ - onnxruntime-torch-tensor.gemspec
163
+ - test/test_converter.rb
164
+ homepage: https://gitlab.com/KitaitiMakoto/onnxruntime-torch-tensor
165
+ licenses:
166
+ - BSD-3-Clause
167
+ metadata: {}
168
+ rdoc_options: []
169
+ require_paths:
170
+ - lib
171
+ required_ruby_version: !ruby/object:Gem::Requirement
172
+ requirements:
173
+ - - ">="
174
+ - !ruby/object:Gem::Version
175
+ version: '0'
176
+ required_rubygems_version: !ruby/object:Gem::Requirement
177
+ requirements:
178
+ - - ">="
179
+ - !ruby/object:Gem::Version
180
+ version: '0'
181
+ requirements: []
182
+ rubygems_version: 4.0.6
183
+ specification_version: 4
184
+ summary: Torch::Tensor support for ONNX Runtime
185
+ test_files: []