onnxruntime 0.3.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -0
- data/LICENSE.txt +18 -19
- data/README.md +3 -2
- data/lib/onnxruntime/ffi.rb +23 -10
- data/lib/onnxruntime/inference_session.rb +151 -61
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/LICENSE +1 -1
- data/vendor/ThirdPartyNotices.txt +1118 -493
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +8 -50
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 4b0bef4682d8d44fb3cfd3fa619c1e7e620943014a22f259e14bfa1692dbe2ef
|
|
4
|
+
data.tar.gz: bfe59759c177613806ab0ddd97c4d8a68a7e3915299d81aab3e9d369ce672918
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: f0dfc6c0a91621b5637e56061b40ba26648a7e538d84cb2c135a06c1832c9f71aba2e7fb7ab849750756540d02632726b3f45732a8c50649ca4ace03befbe086
|
|
7
|
+
data.tar.gz: 59a01f8d955f9d31f35864a9217fcf79eb35876229e749c5c50b4572903d9dc3e312a4a7a55612e9c36e3095a797da58637442ac0bb191adcb743b62fefb0905
|
data/CHANGELOG.md
CHANGED
|
@@ -1,3 +1,33 @@
|
|
|
1
|
+
## 0.6.0 (2021-03-14)
|
|
2
|
+
|
|
3
|
+
- Updated ONNX Runtime to 1.7.0
|
|
4
|
+
- OpenMP is no longer required
|
|
5
|
+
|
|
6
|
+
## 0.5.2 (2020-12-27)
|
|
7
|
+
|
|
8
|
+
- Updated ONNX Runtime to 1.6.0
|
|
9
|
+
- Fixed error with `execution_mode` option
|
|
10
|
+
- Fixed error with `bool` input
|
|
11
|
+
|
|
12
|
+
## 0.5.1 (2020-11-01)
|
|
13
|
+
|
|
14
|
+
- Updated ONNX Runtime to 1.5.2
|
|
15
|
+
- Added support for string output
|
|
16
|
+
- Added `output_type` option
|
|
17
|
+
- Improved performance for Numo array inputs
|
|
18
|
+
|
|
19
|
+
## 0.5.0 (2020-10-01)
|
|
20
|
+
|
|
21
|
+
- Updated ONNX Runtime to 1.5.1
|
|
22
|
+
- OpenMP is now required on Mac
|
|
23
|
+
- Fixed `mul_1.onnx` example
|
|
24
|
+
|
|
25
|
+
## 0.4.0 (2020-07-20)
|
|
26
|
+
|
|
27
|
+
- Updated ONNX Runtime to 1.4.0
|
|
28
|
+
- Added `providers` method
|
|
29
|
+
- Fixed errors on Windows
|
|
30
|
+
|
|
1
31
|
## 0.3.3 (2020-06-17)
|
|
2
32
|
|
|
3
33
|
- Fixed segmentation fault on exit on Linux
|
data/LICENSE.txt
CHANGED
|
@@ -1,23 +1,22 @@
|
|
|
1
|
-
Copyright (c) 2019-2020 Andrew Kane
|
|
2
|
-
Datasets Copyright (c) Microsoft Corporation
|
|
3
|
-
|
|
4
1
|
MIT License
|
|
5
2
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
3
|
+
Copyright (c) Microsoft Corporation
|
|
4
|
+
Copyright (c) 2019-2021 Andrew Kane
|
|
5
|
+
|
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
in the Software without restriction, including without limitation the rights
|
|
9
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
furnished to do so, subject to the following conditions:
|
|
13
12
|
|
|
14
|
-
The above copyright notice and this permission notice shall be
|
|
15
|
-
|
|
13
|
+
The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
copies or substantial portions of the Software.
|
|
16
15
|
|
|
17
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
OF
|
|
23
|
-
|
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
SOFTWARE.
|
data/README.md
CHANGED
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
Check out [an example](https://ankane.org/tensorflow-ruby)
|
|
6
6
|
|
|
7
|
-
[](https://github.com/ankane/onnxruntime/actions)
|
|
8
8
|
|
|
9
9
|
## Installation
|
|
10
10
|
|
|
@@ -81,7 +81,8 @@ model.predict(input_feed, {
|
|
|
81
81
|
log_severity_level: 2,
|
|
82
82
|
log_verbosity_level: 0,
|
|
83
83
|
logid: nil,
|
|
84
|
-
terminate: false
|
|
84
|
+
terminate: false,
|
|
85
|
+
output_type: :ruby # :ruby or :numo
|
|
85
86
|
})
|
|
86
87
|
```
|
|
87
88
|
|
data/lib/onnxruntime/ffi.rb
CHANGED
|
@@ -3,10 +3,13 @@ module OnnxRuntime
|
|
|
3
3
|
extend ::FFI::Library
|
|
4
4
|
|
|
5
5
|
begin
|
|
6
|
-
ffi_lib
|
|
6
|
+
ffi_lib OnnxRuntime.ffi_lib
|
|
7
7
|
rescue LoadError => e
|
|
8
|
-
|
|
9
|
-
|
|
8
|
+
if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
|
|
9
|
+
raise LoadError, "OpenMP not found. Run `brew install libomp`"
|
|
10
|
+
else
|
|
11
|
+
raise e
|
|
12
|
+
end
|
|
10
13
|
end
|
|
11
14
|
|
|
12
15
|
# https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
|
|
@@ -25,14 +28,14 @@ module OnnxRuntime
|
|
|
25
28
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
|
26
29
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
|
27
30
|
:DisableTelemetryEvents, callback(%i[pointer], :pointer),
|
|
28
|
-
:CreateSession, callback(%i[pointer
|
|
31
|
+
:CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
|
|
29
32
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
|
30
33
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
|
31
34
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[pointer
|
|
35
|
+
:SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
|
|
33
36
|
:CloneSessionOptions, callback(%i[], :pointer),
|
|
34
|
-
:SetSessionExecutionMode, callback(%i[], :pointer),
|
|
35
|
-
:EnableProfiling, callback(%i[pointer
|
|
37
|
+
:SetSessionExecutionMode, callback(%i[pointer int], :pointer),
|
|
38
|
+
:EnableProfiling, callback(%i[pointer pointer], :pointer),
|
|
36
39
|
:DisableProfiling, callback(%i[pointer], :pointer),
|
|
37
40
|
:EnableMemPattern, callback(%i[pointer], :pointer),
|
|
38
41
|
:DisableMemPattern, callback(%i[pointer], :pointer),
|
|
@@ -71,8 +74,8 @@ module OnnxRuntime
|
|
|
71
74
|
:IsTensor, callback(%i[], :pointer),
|
|
72
75
|
:GetTensorMutableData, callback(%i[pointer pointer], :pointer),
|
|
73
76
|
:FillStringTensor, callback(%i[pointer pointer size_t], :pointer),
|
|
74
|
-
:GetStringTensorDataLength, callback(%i[], :pointer),
|
|
75
|
-
:GetStringTensorContent, callback(%i[], :pointer),
|
|
77
|
+
:GetStringTensorDataLength, callback(%i[pointer pointer], :pointer),
|
|
78
|
+
:GetStringTensorContent, callback(%i[pointer pointer size_t pointer size_t], :pointer),
|
|
76
79
|
:CastTypeInfoToTensorInfo, callback(%i[pointer pointer], :pointer),
|
|
77
80
|
:GetOnnxTypeFromTypeInfo, callback(%i[pointer pointer], :pointer),
|
|
78
81
|
:CreateTensorTypeAndShapeInfo, callback(%i[], :pointer),
|
|
@@ -142,7 +145,9 @@ module OnnxRuntime
|
|
|
142
145
|
:CreateThreadingOptions, callback(%i[], :pointer),
|
|
143
146
|
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
|
144
147
|
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
|
145
|
-
:AddFreeDimensionOverrideByName, callback(%i[], :pointer)
|
|
148
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer),
|
|
149
|
+
:GetAvailableProviders, callback(%i[pointer pointer], :pointer),
|
|
150
|
+
:ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
|
|
146
151
|
end
|
|
147
152
|
|
|
148
153
|
class ApiBase < ::FFI::Struct
|
|
@@ -154,5 +159,13 @@ module OnnxRuntime
|
|
|
154
159
|
end
|
|
155
160
|
|
|
156
161
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
|
162
|
+
|
|
163
|
+
if Gem.win_platform?
|
|
164
|
+
class Libc
|
|
165
|
+
extend ::FFI::Library
|
|
166
|
+
ffi_lib ::FFI::Library::LIBC
|
|
167
|
+
attach_function :mbstowcs, %i[pointer string size_t], :size_t
|
|
168
|
+
end
|
|
169
|
+
end
|
|
157
170
|
end
|
|
158
171
|
end
|
|
@@ -8,7 +8,7 @@ module OnnxRuntime
|
|
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
|
9
9
|
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
|
10
10
|
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
|
11
|
-
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
|
12
12
|
if execution_mode
|
|
13
13
|
execution_modes = {sequential: 0, parallel: 1}
|
|
14
14
|
mode = execution_modes[execution_mode]
|
|
@@ -17,8 +17,8 @@ module OnnxRuntime
|
|
|
17
17
|
end
|
|
18
18
|
if graph_optimization_level
|
|
19
19
|
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
|
22
22
|
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
|
23
23
|
end
|
|
24
24
|
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
|
@@ -26,7 +26,7 @@ module OnnxRuntime
|
|
|
26
26
|
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
|
27
27
|
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
|
28
28
|
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
|
29
|
-
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
|
30
30
|
|
|
31
31
|
# session
|
|
32
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
|
@@ -39,16 +39,10 @@ module OnnxRuntime
|
|
|
39
39
|
path_or_bytes.encoding == Encoding::BINARY
|
|
40
40
|
end
|
|
41
41
|
|
|
42
|
-
# fix for Windows "File doesn't exist"
|
|
43
|
-
if Gem.win_platform? && !from_memory
|
|
44
|
-
path_or_bytes = File.binread(path_or_bytes)
|
|
45
|
-
from_memory = true
|
|
46
|
-
end
|
|
47
|
-
|
|
48
42
|
if from_memory
|
|
49
43
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
|
50
44
|
else
|
|
51
|
-
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
|
52
46
|
end
|
|
53
47
|
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
|
54
48
|
|
|
@@ -63,7 +57,7 @@ module OnnxRuntime
|
|
|
63
57
|
# input
|
|
64
58
|
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
|
65
59
|
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
|
66
|
-
|
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
|
67
61
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
|
68
62
|
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
|
69
63
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
|
@@ -74,7 +68,7 @@ module OnnxRuntime
|
|
|
74
68
|
# output
|
|
75
69
|
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
|
76
70
|
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
|
77
|
-
|
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
|
78
72
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
|
79
73
|
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
|
80
74
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
|
@@ -86,7 +80,7 @@ module OnnxRuntime
|
|
|
86
80
|
end
|
|
87
81
|
|
|
88
82
|
# TODO support logid
|
|
89
|
-
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
|
|
83
|
+
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
|
|
90
84
|
input_tensor = create_input_tensor(input_feed)
|
|
91
85
|
|
|
92
86
|
output_names ||= @outputs.map { |v| v[:name] }
|
|
@@ -106,7 +100,7 @@ module OnnxRuntime
|
|
|
106
100
|
check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
|
|
107
101
|
|
|
108
102
|
output_names.size.times.map do |i|
|
|
109
|
-
create_from_onnx_value(output_tensor[i].read_pointer)
|
|
103
|
+
create_from_onnx_value(output_tensor[i].read_pointer, output_type)
|
|
110
104
|
end
|
|
111
105
|
ensure
|
|
112
106
|
release :RunOptions, run_options
|
|
@@ -130,7 +124,7 @@ module OnnxRuntime
|
|
|
130
124
|
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
|
131
125
|
|
|
132
126
|
custom_metadata_map = {}
|
|
133
|
-
check_status
|
|
127
|
+
check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
|
134
128
|
num_keys.read(:int64_t).times do |i|
|
|
135
129
|
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
|
136
130
|
value = ::FFI::MemoryPointer.new(:string)
|
|
@@ -156,56 +150,86 @@ module OnnxRuntime
|
|
|
156
150
|
release :ModelMetadata, metadata
|
|
157
151
|
end
|
|
158
152
|
|
|
153
|
+
# return value has double underscore like Python
|
|
159
154
|
def end_profiling
|
|
160
155
|
out = ::FFI::MemoryPointer.new(:string)
|
|
161
156
|
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
|
162
157
|
out.read_pointer.read_string
|
|
163
158
|
end
|
|
164
159
|
|
|
160
|
+
# no way to set providers with C API yet
|
|
161
|
+
# so we can return all available providers
|
|
162
|
+
def providers
|
|
163
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
|
164
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
|
165
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
|
166
|
+
length = length_ptr.read_int
|
|
167
|
+
providers = []
|
|
168
|
+
length.times do |i|
|
|
169
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
|
170
|
+
end
|
|
171
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
|
172
|
+
providers
|
|
173
|
+
end
|
|
174
|
+
|
|
165
175
|
private
|
|
166
176
|
|
|
167
177
|
def create_input_tensor(input_feed)
|
|
168
178
|
allocator_info = ::FFI::MemoryPointer.new(:pointer)
|
|
169
|
-
check_status
|
|
179
|
+
check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
|
|
170
180
|
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
|
|
171
181
|
|
|
172
182
|
input_feed.each_with_index do |(input_name, input), idx|
|
|
173
|
-
|
|
183
|
+
if numo_array?(input)
|
|
184
|
+
shape = input.shape
|
|
185
|
+
else
|
|
186
|
+
input = input.to_a unless input.is_a?(Array)
|
|
174
187
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
188
|
+
shape = []
|
|
189
|
+
s = input
|
|
190
|
+
while s.is_a?(Array)
|
|
191
|
+
shape << s.size
|
|
192
|
+
s = s.first
|
|
193
|
+
end
|
|
180
194
|
end
|
|
181
195
|
|
|
182
|
-
flat_input = input.flatten
|
|
183
|
-
input_tensor_size = flat_input.size
|
|
184
|
-
|
|
185
196
|
# TODO support more types
|
|
186
197
|
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
|
187
|
-
raise "Unknown input: #{input_name}" unless inp
|
|
198
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
|
188
199
|
|
|
189
200
|
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
|
190
201
|
input_node_dims.write_array_of_int64(shape)
|
|
191
202
|
|
|
192
203
|
if inp[:type] == "tensor(string)"
|
|
193
|
-
|
|
194
|
-
|
|
204
|
+
if numo_array?(input)
|
|
205
|
+
input_tensor_size = input.size
|
|
206
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size)
|
|
207
|
+
input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) })
|
|
208
|
+
else
|
|
209
|
+
flat_input = input.flatten.to_a
|
|
210
|
+
input_tensor_size = flat_input.size
|
|
211
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
|
|
212
|
+
input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
|
|
213
|
+
end
|
|
195
214
|
type_enum = FFI::TensorElementDataType[:string]
|
|
196
215
|
check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
|
|
197
|
-
check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values,
|
|
216
|
+
check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size)
|
|
198
217
|
else
|
|
199
|
-
tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
|
|
200
218
|
tensor_type = tensor_types[inp[:type]]
|
|
201
219
|
|
|
202
220
|
if tensor_type
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
flat_input =
|
|
221
|
+
if numo_array?(input)
|
|
222
|
+
input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
|
|
223
|
+
else
|
|
224
|
+
flat_input = input.flatten.to_a
|
|
225
|
+
input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
|
|
226
|
+
if tensor_type == :bool
|
|
227
|
+
input_tensor_values.write_array_of_uint8(flat_input.map { |v| v ? 1 : 0 })
|
|
228
|
+
else
|
|
229
|
+
input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
|
|
230
|
+
end
|
|
207
231
|
end
|
|
208
|
-
|
|
232
|
+
|
|
209
233
|
type_enum = FFI::TensorElementDataType[tensor_type]
|
|
210
234
|
else
|
|
211
235
|
unsupported_type("input", inp[:type])
|
|
@@ -224,9 +248,9 @@ module OnnxRuntime
|
|
|
224
248
|
ptr
|
|
225
249
|
end
|
|
226
250
|
|
|
227
|
-
def create_from_onnx_value(out_ptr)
|
|
251
|
+
def create_from_onnx_value(out_ptr, output_type)
|
|
228
252
|
out_type = ::FFI::MemoryPointer.new(:int)
|
|
229
|
-
check_status
|
|
253
|
+
check_status api[:GetValueType].call(out_ptr, out_type)
|
|
230
254
|
type = FFI::OnnxType[out_type.read_int]
|
|
231
255
|
|
|
232
256
|
case type
|
|
@@ -241,31 +265,50 @@ module OnnxRuntime
|
|
|
241
265
|
|
|
242
266
|
out_size = ::FFI::MemoryPointer.new(:size_t)
|
|
243
267
|
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
|
244
|
-
output_tensor_size =
|
|
268
|
+
output_tensor_size = out_size.read(:size_t)
|
|
245
269
|
|
|
246
270
|
release :TensorTypeAndShapeInfo, typeinfo
|
|
247
271
|
|
|
248
272
|
# TODO support more types
|
|
249
273
|
type = FFI::TensorElementDataType[type]
|
|
250
|
-
|
|
274
|
+
|
|
275
|
+
case output_type
|
|
276
|
+
when :numo
|
|
251
277
|
case type
|
|
252
|
-
when :
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
278
|
+
when :string
|
|
279
|
+
result = Numo::RObject.new(shape)
|
|
280
|
+
result.allocate
|
|
281
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
|
256
282
|
else
|
|
257
|
-
|
|
283
|
+
numo_type = numo_types[type]
|
|
284
|
+
unsupported_type("element", type) unless numo_type
|
|
285
|
+
numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
|
|
258
286
|
end
|
|
287
|
+
when :ruby
|
|
288
|
+
arr =
|
|
289
|
+
case type
|
|
290
|
+
when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
|
|
291
|
+
tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
|
|
292
|
+
when :bool
|
|
293
|
+
tensor_data.read_pointer.read_array_of_uint8(output_tensor_size).map { |v| v == 1 }
|
|
294
|
+
when :string
|
|
295
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
|
|
296
|
+
else
|
|
297
|
+
unsupported_type("element", type)
|
|
298
|
+
end
|
|
259
299
|
|
|
260
|
-
|
|
300
|
+
Utils.reshape(arr, shape)
|
|
301
|
+
else
|
|
302
|
+
raise ArgumentError, "Invalid output type: #{output_type}"
|
|
303
|
+
end
|
|
261
304
|
when :sequence
|
|
262
305
|
out = ::FFI::MemoryPointer.new(:size_t)
|
|
263
306
|
check_status api[:GetValueCount].call(out_ptr, out)
|
|
264
307
|
|
|
265
|
-
|
|
308
|
+
out.read(:size_t).times.map do |i|
|
|
266
309
|
seq = ::FFI::MemoryPointer.new(:pointer)
|
|
267
310
|
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
|
268
|
-
create_from_onnx_value(seq.read_pointer)
|
|
311
|
+
create_from_onnx_value(seq.read_pointer, output_type)
|
|
269
312
|
end
|
|
270
313
|
when :map
|
|
271
314
|
type_shape = ::FFI::MemoryPointer.new(:pointer)
|
|
@@ -284,8 +327,8 @@ module OnnxRuntime
|
|
|
284
327
|
case elem_type
|
|
285
328
|
when :int64
|
|
286
329
|
ret = {}
|
|
287
|
-
keys = create_from_onnx_value(map_keys.read_pointer)
|
|
288
|
-
values = create_from_onnx_value(map_values.read_pointer)
|
|
330
|
+
keys = create_from_onnx_value(map_keys.read_pointer, output_type)
|
|
331
|
+
values = create_from_onnx_value(map_values.read_pointer, output_type)
|
|
289
332
|
keys.zip(values).each do |k, v|
|
|
290
333
|
ret[k] = v
|
|
291
334
|
end
|
|
@@ -298,6 +341,23 @@ module OnnxRuntime
|
|
|
298
341
|
end
|
|
299
342
|
end
|
|
300
343
|
|
|
344
|
+
def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
|
345
|
+
len = ::FFI::MemoryPointer.new(:size_t)
|
|
346
|
+
check_status api[:GetStringTensorDataLength].call(out_ptr, len)
|
|
347
|
+
|
|
348
|
+
s_len = len.read(:size_t)
|
|
349
|
+
s = ::FFI::MemoryPointer.new(:uchar, s_len)
|
|
350
|
+
offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
|
|
351
|
+
check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
|
|
352
|
+
|
|
353
|
+
offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
|
|
354
|
+
offsets << s_len
|
|
355
|
+
output_tensor_size.times do |i|
|
|
356
|
+
result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
|
|
357
|
+
end
|
|
358
|
+
result
|
|
359
|
+
end
|
|
360
|
+
|
|
301
361
|
def read_pointer
|
|
302
362
|
@session.read_pointer
|
|
303
363
|
end
|
|
@@ -306,7 +366,7 @@ module OnnxRuntime
|
|
|
306
366
|
unless status.null?
|
|
307
367
|
message = api[:GetErrorMessage].call(status).read_string
|
|
308
368
|
api[:ReleaseStatus].call(status)
|
|
309
|
-
raise
|
|
369
|
+
raise Error, message
|
|
310
370
|
end
|
|
311
371
|
end
|
|
312
372
|
|
|
@@ -368,7 +428,7 @@ module OnnxRuntime
|
|
|
368
428
|
|
|
369
429
|
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
|
370
430
|
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
|
371
|
-
num_dims =
|
|
431
|
+
num_dims = num_dims_ptr.read(:size_t)
|
|
372
432
|
|
|
373
433
|
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
|
374
434
|
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
|
@@ -377,16 +437,31 @@ module OnnxRuntime
|
|
|
377
437
|
end
|
|
378
438
|
|
|
379
439
|
def unsupported_type(name, type)
|
|
380
|
-
raise "Unsupported #{name} type: #{type}"
|
|
440
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
|
381
441
|
end
|
|
382
442
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
443
|
+
def tensor_types
|
|
444
|
+
@tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
|
|
445
|
+
end
|
|
446
|
+
|
|
447
|
+
def numo_array?(obj)
|
|
448
|
+
defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
|
|
449
|
+
end
|
|
450
|
+
|
|
451
|
+
def numo_types
|
|
452
|
+
@numo_types ||= {
|
|
453
|
+
float: Numo::SFloat,
|
|
454
|
+
uint8: Numo::UInt8,
|
|
455
|
+
int8: Numo::Int8,
|
|
456
|
+
uint16: Numo::UInt16,
|
|
457
|
+
int16: Numo::Int16,
|
|
458
|
+
int32: Numo::Int32,
|
|
459
|
+
int64: Numo::Int64,
|
|
460
|
+
bool: Numo::UInt8,
|
|
461
|
+
double: Numo::DFloat,
|
|
462
|
+
uint32: Numo::UInt32,
|
|
463
|
+
uint64: Numo::UInt64
|
|
464
|
+
}
|
|
390
465
|
end
|
|
391
466
|
|
|
392
467
|
def api
|
|
@@ -398,7 +473,7 @@ module OnnxRuntime
|
|
|
398
473
|
end
|
|
399
474
|
|
|
400
475
|
def self.api
|
|
401
|
-
@api ||= FFI.OrtGetApiBase[:GetApi].call(
|
|
476
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
|
402
477
|
end
|
|
403
478
|
|
|
404
479
|
def self.release(type, pointer)
|
|
@@ -410,6 +485,21 @@ module OnnxRuntime
|
|
|
410
485
|
proc { release :Session, session }
|
|
411
486
|
end
|
|
412
487
|
|
|
488
|
+
# wide string on Windows
|
|
489
|
+
# char string on Linux
|
|
490
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
|
491
|
+
def ort_string(str)
|
|
492
|
+
if Gem.win_platform?
|
|
493
|
+
max = str.size + 1 # for null byte
|
|
494
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
|
495
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
|
496
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
|
497
|
+
dest
|
|
498
|
+
else
|
|
499
|
+
str
|
|
500
|
+
end
|
|
501
|
+
end
|
|
502
|
+
|
|
413
503
|
def env
|
|
414
504
|
# use mutex for thread-safety
|
|
415
505
|
Utils.mutex.synchronize do
|