onnxruntime 0.3.3 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -0
- data/LICENSE.txt +18 -19
- data/README.md +3 -2
- data/lib/onnxruntime/ffi.rb +23 -10
- data/lib/onnxruntime/inference_session.rb +151 -61
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/LICENSE +1 -1
- data/vendor/ThirdPartyNotices.txt +1118 -493
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +8 -50
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 4b0bef4682d8d44fb3cfd3fa619c1e7e620943014a22f259e14bfa1692dbe2ef
|
4
|
+
data.tar.gz: bfe59759c177613806ab0ddd97c4d8a68a7e3915299d81aab3e9d369ce672918
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f0dfc6c0a91621b5637e56061b40ba26648a7e538d84cb2c135a06c1832c9f71aba2e7fb7ab849750756540d02632726b3f45732a8c50649ca4ace03befbe086
|
7
|
+
data.tar.gz: 59a01f8d955f9d31f35864a9217fcf79eb35876229e749c5c50b4572903d9dc3e312a4a7a55612e9c36e3095a797da58637442ac0bb191adcb743b62fefb0905
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,33 @@
|
|
1
|
+
## 0.6.0 (2021-03-14)
|
2
|
+
|
3
|
+
- Updated ONNX Runtime to 1.7.0
|
4
|
+
- OpenMP is no longer required
|
5
|
+
|
6
|
+
## 0.5.2 (2020-12-27)
|
7
|
+
|
8
|
+
- Updated ONNX Runtime to 1.6.0
|
9
|
+
- Fixed error with `execution_mode` option
|
10
|
+
- Fixed error with `bool` input
|
11
|
+
|
12
|
+
## 0.5.1 (2020-11-01)
|
13
|
+
|
14
|
+
- Updated ONNX Runtime to 1.5.2
|
15
|
+
- Added support for string output
|
16
|
+
- Added `output_type` option
|
17
|
+
- Improved performance for Numo array inputs
|
18
|
+
|
19
|
+
## 0.5.0 (2020-10-01)
|
20
|
+
|
21
|
+
- Updated ONNX Runtime to 1.5.1
|
22
|
+
- OpenMP is now required on Mac
|
23
|
+
- Fixed `mul_1.onnx` example
|
24
|
+
|
25
|
+
## 0.4.0 (2020-07-20)
|
26
|
+
|
27
|
+
- Updated ONNX Runtime to 1.4.0
|
28
|
+
- Added `providers` method
|
29
|
+
- Fixed errors on Windows
|
30
|
+
|
1
31
|
## 0.3.3 (2020-06-17)
|
2
32
|
|
3
33
|
- Fixed segmentation fault on exit on Linux
|
data/LICENSE.txt
CHANGED
@@ -1,23 +1,22 @@
|
|
1
|
-
Copyright (c) 2019-2020 Andrew Kane
|
2
|
-
Datasets Copyright (c) Microsoft Corporation
|
3
|
-
|
4
1
|
MIT License
|
5
2
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
3
|
+
Copyright (c) Microsoft Corporation
|
4
|
+
Copyright (c) 2019-2021 Andrew Kane
|
5
|
+
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7
|
+
of this software and associated documentation files (the "Software"), to deal
|
8
|
+
in the Software without restriction, including without limitation the rights
|
9
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10
|
+
copies of the Software, and to permit persons to whom the Software is
|
11
|
+
furnished to do so, subject to the following conditions:
|
13
12
|
|
14
|
-
The above copyright notice and this permission notice shall be
|
15
|
-
|
13
|
+
The above copyright notice and this permission notice shall be included in all
|
14
|
+
copies or substantial portions of the Software.
|
16
15
|
|
17
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
OF
|
23
|
-
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22
|
+
SOFTWARE.
|
data/README.md
CHANGED
@@ -4,7 +4,7 @@
|
|
4
4
|
|
5
5
|
Check out [an example](https://ankane.org/tensorflow-ruby)
|
6
6
|
|
7
|
-
[![Build Status](https://
|
7
|
+
[![Build Status](https://github.com/ankane/onnxruntime/workflows/build/badge.svg?branch=master)](https://github.com/ankane/onnxruntime/actions)
|
8
8
|
|
9
9
|
## Installation
|
10
10
|
|
@@ -81,7 +81,8 @@ model.predict(input_feed, {
|
|
81
81
|
log_severity_level: 2,
|
82
82
|
log_verbosity_level: 0,
|
83
83
|
logid: nil,
|
84
|
-
terminate: false
|
84
|
+
terminate: false,
|
85
|
+
output_type: :ruby # :ruby or :numo
|
85
86
|
})
|
86
87
|
```
|
87
88
|
|
data/lib/onnxruntime/ffi.rb
CHANGED
@@ -3,10 +3,13 @@ module OnnxRuntime
|
|
3
3
|
extend ::FFI::Library
|
4
4
|
|
5
5
|
begin
|
6
|
-
ffi_lib
|
6
|
+
ffi_lib OnnxRuntime.ffi_lib
|
7
7
|
rescue LoadError => e
|
8
|
-
|
9
|
-
|
8
|
+
if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
|
9
|
+
raise LoadError, "OpenMP not found. Run `brew install libomp`"
|
10
|
+
else
|
11
|
+
raise e
|
12
|
+
end
|
10
13
|
end
|
11
14
|
|
12
15
|
# https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
|
@@ -25,14 +28,14 @@ module OnnxRuntime
|
|
25
28
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
26
29
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
27
30
|
:DisableTelemetryEvents, callback(%i[pointer], :pointer),
|
28
|
-
:CreateSession, callback(%i[pointer
|
31
|
+
:CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
|
29
32
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
30
33
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
31
34
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[pointer
|
35
|
+
:SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
|
33
36
|
:CloneSessionOptions, callback(%i[], :pointer),
|
34
|
-
:SetSessionExecutionMode, callback(%i[], :pointer),
|
35
|
-
:EnableProfiling, callback(%i[pointer
|
37
|
+
:SetSessionExecutionMode, callback(%i[pointer int], :pointer),
|
38
|
+
:EnableProfiling, callback(%i[pointer pointer], :pointer),
|
36
39
|
:DisableProfiling, callback(%i[pointer], :pointer),
|
37
40
|
:EnableMemPattern, callback(%i[pointer], :pointer),
|
38
41
|
:DisableMemPattern, callback(%i[pointer], :pointer),
|
@@ -71,8 +74,8 @@ module OnnxRuntime
|
|
71
74
|
:IsTensor, callback(%i[], :pointer),
|
72
75
|
:GetTensorMutableData, callback(%i[pointer pointer], :pointer),
|
73
76
|
:FillStringTensor, callback(%i[pointer pointer size_t], :pointer),
|
74
|
-
:GetStringTensorDataLength, callback(%i[], :pointer),
|
75
|
-
:GetStringTensorContent, callback(%i[], :pointer),
|
77
|
+
:GetStringTensorDataLength, callback(%i[pointer pointer], :pointer),
|
78
|
+
:GetStringTensorContent, callback(%i[pointer pointer size_t pointer size_t], :pointer),
|
76
79
|
:CastTypeInfoToTensorInfo, callback(%i[pointer pointer], :pointer),
|
77
80
|
:GetOnnxTypeFromTypeInfo, callback(%i[pointer pointer], :pointer),
|
78
81
|
:CreateTensorTypeAndShapeInfo, callback(%i[], :pointer),
|
@@ -142,7 +145,9 @@ module OnnxRuntime
|
|
142
145
|
:CreateThreadingOptions, callback(%i[], :pointer),
|
143
146
|
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
144
147
|
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
145
|
-
:AddFreeDimensionOverrideByName, callback(%i[], :pointer)
|
148
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer),
|
149
|
+
:GetAvailableProviders, callback(%i[pointer pointer], :pointer),
|
150
|
+
:ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
|
146
151
|
end
|
147
152
|
|
148
153
|
class ApiBase < ::FFI::Struct
|
@@ -154,5 +159,13 @@ module OnnxRuntime
|
|
154
159
|
end
|
155
160
|
|
156
161
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
162
|
+
|
163
|
+
if Gem.win_platform?
|
164
|
+
class Libc
|
165
|
+
extend ::FFI::Library
|
166
|
+
ffi_lib ::FFI::Library::LIBC
|
167
|
+
attach_function :mbstowcs, %i[pointer string size_t], :size_t
|
168
|
+
end
|
169
|
+
end
|
157
170
|
end
|
158
171
|
end
|
@@ -8,7 +8,7 @@ module OnnxRuntime
|
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
9
9
|
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
10
|
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
-
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
12
12
|
if execution_mode
|
13
13
|
execution_modes = {sequential: 0, parallel: 1}
|
14
14
|
mode = execution_modes[execution_mode]
|
@@ -17,8 +17,8 @@ module OnnxRuntime
|
|
17
17
|
end
|
18
18
|
if graph_optimization_level
|
19
19
|
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
-
|
21
|
-
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
22
22
|
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
23
|
end
|
24
24
|
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
@@ -26,7 +26,7 @@ module OnnxRuntime
|
|
26
26
|
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
27
27
|
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
28
28
|
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
29
|
-
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
30
30
|
|
31
31
|
# session
|
32
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
@@ -39,16 +39,10 @@ module OnnxRuntime
|
|
39
39
|
path_or_bytes.encoding == Encoding::BINARY
|
40
40
|
end
|
41
41
|
|
42
|
-
# fix for Windows "File doesn't exist"
|
43
|
-
if Gem.win_platform? && !from_memory
|
44
|
-
path_or_bytes = File.binread(path_or_bytes)
|
45
|
-
from_memory = true
|
46
|
-
end
|
47
|
-
|
48
42
|
if from_memory
|
49
43
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
50
44
|
else
|
51
|
-
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
52
46
|
end
|
53
47
|
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
54
48
|
|
@@ -63,7 +57,7 @@ module OnnxRuntime
|
|
63
57
|
# input
|
64
58
|
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
65
59
|
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
66
|
-
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
67
61
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
68
62
|
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
69
63
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -74,7 +68,7 @@ module OnnxRuntime
|
|
74
68
|
# output
|
75
69
|
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
76
70
|
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
77
|
-
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
78
72
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
79
73
|
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
80
74
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -86,7 +80,7 @@ module OnnxRuntime
|
|
86
80
|
end
|
87
81
|
|
88
82
|
# TODO support logid
|
89
|
-
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
|
83
|
+
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
|
90
84
|
input_tensor = create_input_tensor(input_feed)
|
91
85
|
|
92
86
|
output_names ||= @outputs.map { |v| v[:name] }
|
@@ -106,7 +100,7 @@ module OnnxRuntime
|
|
106
100
|
check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
|
107
101
|
|
108
102
|
output_names.size.times.map do |i|
|
109
|
-
create_from_onnx_value(output_tensor[i].read_pointer)
|
103
|
+
create_from_onnx_value(output_tensor[i].read_pointer, output_type)
|
110
104
|
end
|
111
105
|
ensure
|
112
106
|
release :RunOptions, run_options
|
@@ -130,7 +124,7 @@ module OnnxRuntime
|
|
130
124
|
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
131
125
|
|
132
126
|
custom_metadata_map = {}
|
133
|
-
check_status
|
127
|
+
check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
134
128
|
num_keys.read(:int64_t).times do |i|
|
135
129
|
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
136
130
|
value = ::FFI::MemoryPointer.new(:string)
|
@@ -156,56 +150,86 @@ module OnnxRuntime
|
|
156
150
|
release :ModelMetadata, metadata
|
157
151
|
end
|
158
152
|
|
153
|
+
# return value has double underscore like Python
|
159
154
|
def end_profiling
|
160
155
|
out = ::FFI::MemoryPointer.new(:string)
|
161
156
|
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
162
157
|
out.read_pointer.read_string
|
163
158
|
end
|
164
159
|
|
160
|
+
# no way to set providers with C API yet
|
161
|
+
# so we can return all available providers
|
162
|
+
def providers
|
163
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
164
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
165
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
166
|
+
length = length_ptr.read_int
|
167
|
+
providers = []
|
168
|
+
length.times do |i|
|
169
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
170
|
+
end
|
171
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
172
|
+
providers
|
173
|
+
end
|
174
|
+
|
165
175
|
private
|
166
176
|
|
167
177
|
def create_input_tensor(input_feed)
|
168
178
|
allocator_info = ::FFI::MemoryPointer.new(:pointer)
|
169
|
-
check_status
|
179
|
+
check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
|
170
180
|
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
|
171
181
|
|
172
182
|
input_feed.each_with_index do |(input_name, input), idx|
|
173
|
-
|
183
|
+
if numo_array?(input)
|
184
|
+
shape = input.shape
|
185
|
+
else
|
186
|
+
input = input.to_a unless input.is_a?(Array)
|
174
187
|
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
188
|
+
shape = []
|
189
|
+
s = input
|
190
|
+
while s.is_a?(Array)
|
191
|
+
shape << s.size
|
192
|
+
s = s.first
|
193
|
+
end
|
180
194
|
end
|
181
195
|
|
182
|
-
flat_input = input.flatten
|
183
|
-
input_tensor_size = flat_input.size
|
184
|
-
|
185
196
|
# TODO support more types
|
186
197
|
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
187
|
-
raise "Unknown input: #{input_name}" unless inp
|
198
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
188
199
|
|
189
200
|
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
190
201
|
input_node_dims.write_array_of_int64(shape)
|
191
202
|
|
192
203
|
if inp[:type] == "tensor(string)"
|
193
|
-
|
194
|
-
|
204
|
+
if numo_array?(input)
|
205
|
+
input_tensor_size = input.size
|
206
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size)
|
207
|
+
input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) })
|
208
|
+
else
|
209
|
+
flat_input = input.flatten.to_a
|
210
|
+
input_tensor_size = flat_input.size
|
211
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
|
212
|
+
input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
|
213
|
+
end
|
195
214
|
type_enum = FFI::TensorElementDataType[:string]
|
196
215
|
check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
|
197
|
-
check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values,
|
216
|
+
check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size)
|
198
217
|
else
|
199
|
-
tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
|
200
218
|
tensor_type = tensor_types[inp[:type]]
|
201
219
|
|
202
220
|
if tensor_type
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
flat_input =
|
221
|
+
if numo_array?(input)
|
222
|
+
input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
|
223
|
+
else
|
224
|
+
flat_input = input.flatten.to_a
|
225
|
+
input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
|
226
|
+
if tensor_type == :bool
|
227
|
+
input_tensor_values.write_array_of_uint8(flat_input.map { |v| v ? 1 : 0 })
|
228
|
+
else
|
229
|
+
input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
|
230
|
+
end
|
207
231
|
end
|
208
|
-
|
232
|
+
|
209
233
|
type_enum = FFI::TensorElementDataType[tensor_type]
|
210
234
|
else
|
211
235
|
unsupported_type("input", inp[:type])
|
@@ -224,9 +248,9 @@ module OnnxRuntime
|
|
224
248
|
ptr
|
225
249
|
end
|
226
250
|
|
227
|
-
def create_from_onnx_value(out_ptr)
|
251
|
+
def create_from_onnx_value(out_ptr, output_type)
|
228
252
|
out_type = ::FFI::MemoryPointer.new(:int)
|
229
|
-
check_status
|
253
|
+
check_status api[:GetValueType].call(out_ptr, out_type)
|
230
254
|
type = FFI::OnnxType[out_type.read_int]
|
231
255
|
|
232
256
|
case type
|
@@ -241,31 +265,50 @@ module OnnxRuntime
|
|
241
265
|
|
242
266
|
out_size = ::FFI::MemoryPointer.new(:size_t)
|
243
267
|
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
244
|
-
output_tensor_size =
|
268
|
+
output_tensor_size = out_size.read(:size_t)
|
245
269
|
|
246
270
|
release :TensorTypeAndShapeInfo, typeinfo
|
247
271
|
|
248
272
|
# TODO support more types
|
249
273
|
type = FFI::TensorElementDataType[type]
|
250
|
-
|
274
|
+
|
275
|
+
case output_type
|
276
|
+
when :numo
|
251
277
|
case type
|
252
|
-
when :
|
253
|
-
|
254
|
-
|
255
|
-
|
278
|
+
when :string
|
279
|
+
result = Numo::RObject.new(shape)
|
280
|
+
result.allocate
|
281
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
256
282
|
else
|
257
|
-
|
283
|
+
numo_type = numo_types[type]
|
284
|
+
unsupported_type("element", type) unless numo_type
|
285
|
+
numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
|
258
286
|
end
|
287
|
+
when :ruby
|
288
|
+
arr =
|
289
|
+
case type
|
290
|
+
when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
|
291
|
+
tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
|
292
|
+
when :bool
|
293
|
+
tensor_data.read_pointer.read_array_of_uint8(output_tensor_size).map { |v| v == 1 }
|
294
|
+
when :string
|
295
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
|
296
|
+
else
|
297
|
+
unsupported_type("element", type)
|
298
|
+
end
|
259
299
|
|
260
|
-
|
300
|
+
Utils.reshape(arr, shape)
|
301
|
+
else
|
302
|
+
raise ArgumentError, "Invalid output type: #{output_type}"
|
303
|
+
end
|
261
304
|
when :sequence
|
262
305
|
out = ::FFI::MemoryPointer.new(:size_t)
|
263
306
|
check_status api[:GetValueCount].call(out_ptr, out)
|
264
307
|
|
265
|
-
|
308
|
+
out.read(:size_t).times.map do |i|
|
266
309
|
seq = ::FFI::MemoryPointer.new(:pointer)
|
267
310
|
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
268
|
-
create_from_onnx_value(seq.read_pointer)
|
311
|
+
create_from_onnx_value(seq.read_pointer, output_type)
|
269
312
|
end
|
270
313
|
when :map
|
271
314
|
type_shape = ::FFI::MemoryPointer.new(:pointer)
|
@@ -284,8 +327,8 @@ module OnnxRuntime
|
|
284
327
|
case elem_type
|
285
328
|
when :int64
|
286
329
|
ret = {}
|
287
|
-
keys = create_from_onnx_value(map_keys.read_pointer)
|
288
|
-
values = create_from_onnx_value(map_values.read_pointer)
|
330
|
+
keys = create_from_onnx_value(map_keys.read_pointer, output_type)
|
331
|
+
values = create_from_onnx_value(map_values.read_pointer, output_type)
|
289
332
|
keys.zip(values).each do |k, v|
|
290
333
|
ret[k] = v
|
291
334
|
end
|
@@ -298,6 +341,23 @@ module OnnxRuntime
|
|
298
341
|
end
|
299
342
|
end
|
300
343
|
|
344
|
+
def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
345
|
+
len = ::FFI::MemoryPointer.new(:size_t)
|
346
|
+
check_status api[:GetStringTensorDataLength].call(out_ptr, len)
|
347
|
+
|
348
|
+
s_len = len.read(:size_t)
|
349
|
+
s = ::FFI::MemoryPointer.new(:uchar, s_len)
|
350
|
+
offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
|
351
|
+
check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
|
352
|
+
|
353
|
+
offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
|
354
|
+
offsets << s_len
|
355
|
+
output_tensor_size.times do |i|
|
356
|
+
result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
|
357
|
+
end
|
358
|
+
result
|
359
|
+
end
|
360
|
+
|
301
361
|
def read_pointer
|
302
362
|
@session.read_pointer
|
303
363
|
end
|
@@ -306,7 +366,7 @@ module OnnxRuntime
|
|
306
366
|
unless status.null?
|
307
367
|
message = api[:GetErrorMessage].call(status).read_string
|
308
368
|
api[:ReleaseStatus].call(status)
|
309
|
-
raise
|
369
|
+
raise Error, message
|
310
370
|
end
|
311
371
|
end
|
312
372
|
|
@@ -368,7 +428,7 @@ module OnnxRuntime
|
|
368
428
|
|
369
429
|
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
370
430
|
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
371
|
-
num_dims =
|
431
|
+
num_dims = num_dims_ptr.read(:size_t)
|
372
432
|
|
373
433
|
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
374
434
|
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
@@ -377,16 +437,31 @@ module OnnxRuntime
|
|
377
437
|
end
|
378
438
|
|
379
439
|
def unsupported_type(name, type)
|
380
|
-
raise "Unsupported #{name} type: #{type}"
|
440
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
381
441
|
end
|
382
442
|
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
443
|
+
def tensor_types
|
444
|
+
@tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
|
445
|
+
end
|
446
|
+
|
447
|
+
def numo_array?(obj)
|
448
|
+
defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
|
449
|
+
end
|
450
|
+
|
451
|
+
def numo_types
|
452
|
+
@numo_types ||= {
|
453
|
+
float: Numo::SFloat,
|
454
|
+
uint8: Numo::UInt8,
|
455
|
+
int8: Numo::Int8,
|
456
|
+
uint16: Numo::UInt16,
|
457
|
+
int16: Numo::Int16,
|
458
|
+
int32: Numo::Int32,
|
459
|
+
int64: Numo::Int64,
|
460
|
+
bool: Numo::UInt8,
|
461
|
+
double: Numo::DFloat,
|
462
|
+
uint32: Numo::UInt32,
|
463
|
+
uint64: Numo::UInt64
|
464
|
+
}
|
390
465
|
end
|
391
466
|
|
392
467
|
def api
|
@@ -398,7 +473,7 @@ module OnnxRuntime
|
|
398
473
|
end
|
399
474
|
|
400
475
|
def self.api
|
401
|
-
@api ||= FFI.OrtGetApiBase[:GetApi].call(
|
476
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
402
477
|
end
|
403
478
|
|
404
479
|
def self.release(type, pointer)
|
@@ -410,6 +485,21 @@ module OnnxRuntime
|
|
410
485
|
proc { release :Session, session }
|
411
486
|
end
|
412
487
|
|
488
|
+
# wide string on Windows
|
489
|
+
# char string on Linux
|
490
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
491
|
+
def ort_string(str)
|
492
|
+
if Gem.win_platform?
|
493
|
+
max = str.size + 1 # for null byte
|
494
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
495
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
496
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
497
|
+
dest
|
498
|
+
else
|
499
|
+
str
|
500
|
+
end
|
501
|
+
end
|
502
|
+
|
413
503
|
def env
|
414
504
|
# use mutex for thread-safety
|
415
505
|
Utils.mutex.synchronize do
|