onnxruntime 0.3.1 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/LICENSE.txt +18 -19
- data/README.md +10 -3
- data/lib/onnxruntime.rb +1 -1
- data/lib/onnxruntime/ffi.rb +24 -11
- data/lib/onnxruntime/inference_session.rb +194 -73
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/ThirdPartyNotices.txt +1029 -3
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +7 -7
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 316527780be2781a474d0813aff47d840654423823ad83b8b52b13752caf6814
|
4
|
+
data.tar.gz: 5954ba2dc4223b8330fb52e8474be78542360c2f4c1cb5946529d062c0b1b864
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c127874dd75a10b8cb9d9d033e607f9d73069bb5709d7b9ecee04c1d59f969f61db6ec88fd569188b0a35c8276f7611619f81662f3735fbd53e61938f15c6822
|
7
|
+
data.tar.gz: 77a7c9f8c98b25fd82ee8818c63176d2005f846db489b5973330220344be8fd47d8e5279517a716f8c1222f729a248fc8ac80cb610633dbdc35c49418a42ea1c
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.5.1 (2020-11-01)
|
2
|
+
|
3
|
+
- Updated ONNX Runtime to 1.5.2
|
4
|
+
- Added support for string output
|
5
|
+
- Added `output_type` option
|
6
|
+
- Improved performance for Numo array inputs
|
7
|
+
|
8
|
+
## 0.5.0 (2020-10-01)
|
9
|
+
|
10
|
+
- Updated ONNX Runtime to 1.5.1
|
11
|
+
- OpenMP is now required on Mac
|
12
|
+
- Fixed `mul_1.onnx` example
|
13
|
+
|
14
|
+
## 0.4.0 (2020-07-20)
|
15
|
+
|
16
|
+
- Updated ONNX Runtime to 1.4.0
|
17
|
+
- Added `providers` method
|
18
|
+
- Fixed errors on Windows
|
19
|
+
|
20
|
+
## 0.3.3 (2020-06-17)
|
21
|
+
|
22
|
+
- Fixed segmentation fault on exit on Linux
|
23
|
+
|
24
|
+
## 0.3.2 (2020-06-16)
|
25
|
+
|
26
|
+
- Fixed error with FFI 1.13.0+
|
27
|
+
- Added friendly graph optimization levels
|
28
|
+
|
1
29
|
## 0.3.1 (2020-05-18)
|
2
30
|
|
3
31
|
- Updated ONNX Runtime to 1.3.0
|
data/LICENSE.txt
CHANGED
@@ -1,23 +1,22 @@
|
|
1
|
-
Copyright (c) 2019-2020 Andrew Kane
|
2
|
-
Datasets Copyright (c) Microsoft Corporation
|
3
|
-
|
4
1
|
MIT License
|
5
2
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
3
|
+
Copyright (c) 2018 Microsoft Corporation
|
4
|
+
Copyright (c) 2019-2020 Andrew Kane
|
5
|
+
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7
|
+
of this software and associated documentation files (the "Software"), to deal
|
8
|
+
in the Software without restriction, including without limitation the rights
|
9
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10
|
+
copies of the Software, and to permit persons to whom the Software is
|
11
|
+
furnished to do so, subject to the following conditions:
|
13
12
|
|
14
|
-
The above copyright notice and this permission notice shall be
|
15
|
-
|
13
|
+
The above copyright notice and this permission notice shall be included in all
|
14
|
+
copies or substantial portions of the Software.
|
16
15
|
|
17
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
OF
|
23
|
-
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22
|
+
SOFTWARE.
|
data/README.md
CHANGED
@@ -14,6 +14,12 @@ Add this line to your application’s Gemfile:
|
|
14
14
|
gem 'onnxruntime'
|
15
15
|
```
|
16
16
|
|
17
|
+
On Mac, also install OpenMP:
|
18
|
+
|
19
|
+
```sh
|
20
|
+
brew install libomp
|
21
|
+
```
|
22
|
+
|
17
23
|
## Getting Started
|
18
24
|
|
19
25
|
Load a model and make predictions
|
@@ -63,8 +69,8 @@ OnnxRuntime::Model.new(path_or_bytes, {
|
|
63
69
|
enable_cpu_mem_arena: true,
|
64
70
|
enable_mem_pattern: true,
|
65
71
|
enable_profiling: false,
|
66
|
-
execution_mode: :sequential,
|
67
|
-
graph_optimization_level: nil,
|
72
|
+
execution_mode: :sequential, # :sequential or :parallel
|
73
|
+
graph_optimization_level: nil, # :none, :basic, :extended, or :all
|
68
74
|
inter_op_num_threads: nil,
|
69
75
|
intra_op_num_threads: nil,
|
70
76
|
log_severity_level: 2,
|
@@ -81,7 +87,8 @@ model.predict(input_feed, {
|
|
81
87
|
log_severity_level: 2,
|
82
88
|
log_verbosity_level: 0,
|
83
89
|
logid: nil,
|
84
|
-
terminate: false
|
90
|
+
terminate: false,
|
91
|
+
output_type: :ruby # :ruby or :numo
|
85
92
|
})
|
86
93
|
```
|
87
94
|
|
data/lib/onnxruntime.rb
CHANGED
data/lib/onnxruntime/ffi.rb
CHANGED
@@ -3,10 +3,13 @@ module OnnxRuntime
|
|
3
3
|
extend ::FFI::Library
|
4
4
|
|
5
5
|
begin
|
6
|
-
ffi_lib
|
6
|
+
ffi_lib OnnxRuntime.ffi_lib
|
7
7
|
rescue LoadError => e
|
8
|
-
|
9
|
-
|
8
|
+
if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
|
9
|
+
raise LoadError, "OpenMP not found. Run `brew install libomp`"
|
10
|
+
else
|
11
|
+
raise e
|
12
|
+
end
|
10
13
|
end
|
11
14
|
|
12
15
|
# https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
|
@@ -20,19 +23,19 @@ module OnnxRuntime
|
|
20
23
|
layout \
|
21
24
|
:CreateStatus, callback(%i[int string], :pointer),
|
22
25
|
:GetErrorCode, callback(%i[pointer], :pointer),
|
23
|
-
:GetErrorMessage, callback(%i[pointer], :
|
26
|
+
:GetErrorMessage, callback(%i[pointer], :pointer),
|
24
27
|
:CreateEnv, callback(%i[int string pointer], :pointer),
|
25
28
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
26
29
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
27
30
|
:DisableTelemetryEvents, callback(%i[pointer], :pointer),
|
28
|
-
:CreateSession, callback(%i[pointer
|
31
|
+
:CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
|
29
32
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
30
33
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
31
34
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[pointer
|
35
|
+
:SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
|
33
36
|
:CloneSessionOptions, callback(%i[], :pointer),
|
34
37
|
:SetSessionExecutionMode, callback(%i[], :pointer),
|
35
|
-
:EnableProfiling, callback(%i[pointer
|
38
|
+
:EnableProfiling, callback(%i[pointer pointer], :pointer),
|
36
39
|
:DisableProfiling, callback(%i[pointer], :pointer),
|
37
40
|
:EnableMemPattern, callback(%i[pointer], :pointer),
|
38
41
|
:DisableMemPattern, callback(%i[pointer], :pointer),
|
@@ -71,8 +74,8 @@ module OnnxRuntime
|
|
71
74
|
:IsTensor, callback(%i[], :pointer),
|
72
75
|
:GetTensorMutableData, callback(%i[pointer pointer], :pointer),
|
73
76
|
:FillStringTensor, callback(%i[pointer pointer size_t], :pointer),
|
74
|
-
:GetStringTensorDataLength, callback(%i[], :pointer),
|
75
|
-
:GetStringTensorContent, callback(%i[], :pointer),
|
77
|
+
:GetStringTensorDataLength, callback(%i[pointer pointer], :pointer),
|
78
|
+
:GetStringTensorContent, callback(%i[pointer pointer size_t pointer size_t], :pointer),
|
76
79
|
:CastTypeInfoToTensorInfo, callback(%i[pointer pointer], :pointer),
|
77
80
|
:GetOnnxTypeFromTypeInfo, callback(%i[pointer pointer], :pointer),
|
78
81
|
:CreateTensorTypeAndShapeInfo, callback(%i[], :pointer),
|
@@ -142,7 +145,9 @@ module OnnxRuntime
|
|
142
145
|
:CreateThreadingOptions, callback(%i[], :pointer),
|
143
146
|
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
144
147
|
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
145
|
-
:AddFreeDimensionOverrideByName, callback(%i[], :pointer)
|
148
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer),
|
149
|
+
:GetAvailableProviders, callback(%i[pointer pointer], :pointer),
|
150
|
+
:ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
|
146
151
|
end
|
147
152
|
|
148
153
|
class ApiBase < ::FFI::Struct
|
@@ -150,9 +155,17 @@ module OnnxRuntime
|
|
150
155
|
# to prevent "unable to resolve type" error on Ubuntu
|
151
156
|
layout \
|
152
157
|
:GetApi, callback(%i[uint32], Api.by_ref),
|
153
|
-
:GetVersionString, callback(%i[], :
|
158
|
+
:GetVersionString, callback(%i[], :pointer)
|
154
159
|
end
|
155
160
|
|
156
161
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
162
|
+
|
163
|
+
if Gem.win_platform?
|
164
|
+
class Libc
|
165
|
+
extend ::FFI::Library
|
166
|
+
ffi_lib ::FFI::Library::LIBC
|
167
|
+
attach_function :mbstowcs, %i[pointer string size_t], :size_t
|
168
|
+
end
|
169
|
+
end
|
157
170
|
end
|
158
171
|
end
|
@@ -8,26 +8,25 @@ module OnnxRuntime
|
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
9
9
|
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
10
|
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
-
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
12
12
|
if execution_mode
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
0
|
17
|
-
when :parallel
|
18
|
-
1
|
19
|
-
else
|
20
|
-
raise ArgumentError, "Invalid execution mode"
|
21
|
-
end
|
13
|
+
execution_modes = {sequential: 0, parallel: 1}
|
14
|
+
mode = execution_modes[execution_mode]
|
15
|
+
raise ArgumentError, "Invalid execution mode" unless mode
|
22
16
|
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
|
23
17
|
end
|
24
|
-
|
18
|
+
if graph_optimization_level
|
19
|
+
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
22
|
+
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
|
+
end
|
25
24
|
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
26
25
|
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
|
27
26
|
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
28
27
|
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
29
28
|
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
30
|
-
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
31
30
|
|
32
31
|
# session
|
33
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
@@ -40,17 +39,12 @@ module OnnxRuntime
|
|
40
39
|
path_or_bytes.encoding == Encoding::BINARY
|
41
40
|
end
|
42
41
|
|
43
|
-
# fix for Windows "File doesn't exist"
|
44
|
-
if Gem.win_platform? && !from_memory
|
45
|
-
path_or_bytes = File.binread(path_or_bytes)
|
46
|
-
from_memory = true
|
47
|
-
end
|
48
|
-
|
49
42
|
if from_memory
|
50
43
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
51
44
|
else
|
52
|
-
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
53
46
|
end
|
47
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
54
48
|
|
55
49
|
# input info
|
56
50
|
allocator = ::FFI::MemoryPointer.new(:pointer)
|
@@ -63,7 +57,7 @@ module OnnxRuntime
|
|
63
57
|
# input
|
64
58
|
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
65
59
|
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
66
|
-
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
67
61
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
68
62
|
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
69
63
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -74,17 +68,19 @@ module OnnxRuntime
|
|
74
68
|
# output
|
75
69
|
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
76
70
|
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
77
|
-
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
78
72
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
79
73
|
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
80
74
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
81
75
|
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
|
82
76
|
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
83
77
|
end
|
78
|
+
ensure
|
79
|
+
# release :SessionOptions, session_options
|
84
80
|
end
|
85
81
|
|
86
82
|
# TODO support logid
|
87
|
-
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
|
83
|
+
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
|
88
84
|
input_tensor = create_input_tensor(input_feed)
|
89
85
|
|
90
86
|
output_names ||= @outputs.map { |v| v[:name] }
|
@@ -104,7 +100,14 @@ module OnnxRuntime
|
|
104
100
|
check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
|
105
101
|
|
106
102
|
output_names.size.times.map do |i|
|
107
|
-
create_from_onnx_value(output_tensor[i].read_pointer)
|
103
|
+
create_from_onnx_value(output_tensor[i].read_pointer, output_type)
|
104
|
+
end
|
105
|
+
ensure
|
106
|
+
release :RunOptions, run_options
|
107
|
+
if input_tensor
|
108
|
+
input_feed.size.times do |i|
|
109
|
+
release :Value, input_tensor[i]
|
110
|
+
end
|
108
111
|
end
|
109
112
|
end
|
110
113
|
|
@@ -121,7 +124,7 @@ module OnnxRuntime
|
|
121
124
|
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
122
125
|
|
123
126
|
custom_metadata_map = {}
|
124
|
-
check_status
|
127
|
+
check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
125
128
|
num_keys.read(:int64_t).times do |i|
|
126
129
|
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
127
130
|
value = ::FFI::MemoryPointer.new(:string)
|
@@ -134,7 +137,6 @@ module OnnxRuntime
|
|
134
137
|
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
|
135
138
|
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
|
136
139
|
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
|
137
|
-
api[:ReleaseModelMetadata].call(metadata.read_pointer)
|
138
140
|
|
139
141
|
{
|
140
142
|
custom_metadata_map: custom_metadata_map,
|
@@ -144,58 +146,90 @@ module OnnxRuntime
|
|
144
146
|
producer_name: producer_name.read_pointer.read_string,
|
145
147
|
version: version.read(:int64_t)
|
146
148
|
}
|
149
|
+
ensure
|
150
|
+
release :ModelMetadata, metadata
|
147
151
|
end
|
148
152
|
|
153
|
+
# return value has double underscore like Python
|
149
154
|
def end_profiling
|
150
155
|
out = ::FFI::MemoryPointer.new(:string)
|
151
156
|
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
152
157
|
out.read_pointer.read_string
|
153
158
|
end
|
154
159
|
|
160
|
+
# no way to set providers with C API yet
|
161
|
+
# so we can return all available providers
|
162
|
+
def providers
|
163
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
164
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
165
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
166
|
+
length = length_ptr.read_int
|
167
|
+
providers = []
|
168
|
+
length.times do |i|
|
169
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
170
|
+
end
|
171
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
172
|
+
providers
|
173
|
+
end
|
174
|
+
|
155
175
|
private
|
156
176
|
|
157
177
|
def create_input_tensor(input_feed)
|
158
178
|
allocator_info = ::FFI::MemoryPointer.new(:pointer)
|
159
|
-
check_status
|
179
|
+
check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
|
160
180
|
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
|
161
181
|
|
162
182
|
input_feed.each_with_index do |(input_name, input), idx|
|
163
|
-
|
183
|
+
if numo_array?(input)
|
184
|
+
shape = input.shape
|
185
|
+
else
|
186
|
+
input = input.to_a unless input.is_a?(Array)
|
164
187
|
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
188
|
+
shape = []
|
189
|
+
s = input
|
190
|
+
while s.is_a?(Array)
|
191
|
+
shape << s.size
|
192
|
+
s = s.first
|
193
|
+
end
|
170
194
|
end
|
171
195
|
|
172
|
-
flat_input = input.flatten
|
173
|
-
input_tensor_size = flat_input.size
|
174
|
-
|
175
196
|
# TODO support more types
|
176
197
|
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
177
|
-
raise "Unknown input: #{input_name}" unless inp
|
198
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
178
199
|
|
179
200
|
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
180
201
|
input_node_dims.write_array_of_int64(shape)
|
181
202
|
|
182
203
|
if inp[:type] == "tensor(string)"
|
183
|
-
|
184
|
-
|
204
|
+
if numo_array?(input)
|
205
|
+
input_tensor_size = input.size
|
206
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size)
|
207
|
+
input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) })
|
208
|
+
else
|
209
|
+
flat_input = input.flatten.to_a
|
210
|
+
input_tensor_size = flat_input.size
|
211
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
|
212
|
+
input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
|
213
|
+
end
|
185
214
|
type_enum = FFI::TensorElementDataType[:string]
|
186
215
|
check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
|
187
|
-
check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values,
|
216
|
+
check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size)
|
188
217
|
else
|
189
|
-
tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
|
190
218
|
tensor_type = tensor_types[inp[:type]]
|
191
219
|
|
192
220
|
if tensor_type
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
flat_input =
|
221
|
+
if numo_array?(input)
|
222
|
+
input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
|
223
|
+
else
|
224
|
+
flat_input = input.flatten.to_a
|
225
|
+
input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
|
226
|
+
if tensor_type == :bool
|
227
|
+
tensor_type = :uchar
|
228
|
+
flat_input = flat_input.map { |v| v ? 1 : 0 }
|
229
|
+
end
|
230
|
+
input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
|
197
231
|
end
|
198
|
-
|
232
|
+
|
199
233
|
type_enum = FFI::TensorElementDataType[tensor_type]
|
200
234
|
else
|
201
235
|
unsupported_type("input", inp[:type])
|
@@ -214,9 +248,9 @@ module OnnxRuntime
|
|
214
248
|
ptr
|
215
249
|
end
|
216
250
|
|
217
|
-
def create_from_onnx_value(out_ptr)
|
251
|
+
def create_from_onnx_value(out_ptr, output_type)
|
218
252
|
out_type = ::FFI::MemoryPointer.new(:int)
|
219
|
-
check_status
|
253
|
+
check_status api[:GetValueType].call(out_ptr, out_type)
|
220
254
|
type = FFI::OnnxType[out_type.read_int]
|
221
255
|
|
222
256
|
case type
|
@@ -231,29 +265,50 @@ module OnnxRuntime
|
|
231
265
|
|
232
266
|
out_size = ::FFI::MemoryPointer.new(:size_t)
|
233
267
|
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
234
|
-
output_tensor_size =
|
268
|
+
output_tensor_size = out_size.read(:size_t)
|
269
|
+
|
270
|
+
release :TensorTypeAndShapeInfo, typeinfo
|
235
271
|
|
236
272
|
# TODO support more types
|
237
273
|
type = FFI::TensorElementDataType[type]
|
238
|
-
|
274
|
+
|
275
|
+
case output_type
|
276
|
+
when :numo
|
239
277
|
case type
|
240
|
-
when :
|
241
|
-
|
242
|
-
|
243
|
-
|
278
|
+
when :string
|
279
|
+
result = Numo::RObject.new(shape)
|
280
|
+
result.allocate
|
281
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
244
282
|
else
|
245
|
-
|
283
|
+
numo_type = numo_types[type]
|
284
|
+
unsupported_type("element", type) unless numo_type
|
285
|
+
numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
|
246
286
|
end
|
287
|
+
when :ruby
|
288
|
+
arr =
|
289
|
+
case type
|
290
|
+
when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
|
291
|
+
tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
|
292
|
+
when :bool
|
293
|
+
tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
|
294
|
+
when :string
|
295
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
|
296
|
+
else
|
297
|
+
unsupported_type("element", type)
|
298
|
+
end
|
247
299
|
|
248
|
-
|
300
|
+
Utils.reshape(arr, shape)
|
301
|
+
else
|
302
|
+
raise ArgumentError, "Invalid output type: #{output_type}"
|
303
|
+
end
|
249
304
|
when :sequence
|
250
305
|
out = ::FFI::MemoryPointer.new(:size_t)
|
251
306
|
check_status api[:GetValueCount].call(out_ptr, out)
|
252
307
|
|
253
|
-
|
308
|
+
out.read(:size_t).times.map do |i|
|
254
309
|
seq = ::FFI::MemoryPointer.new(:pointer)
|
255
310
|
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
256
|
-
create_from_onnx_value(seq.read_pointer)
|
311
|
+
create_from_onnx_value(seq.read_pointer, output_type)
|
257
312
|
end
|
258
313
|
when :map
|
259
314
|
type_shape = ::FFI::MemoryPointer.new(:pointer)
|
@@ -265,14 +320,15 @@ module OnnxRuntime
|
|
265
320
|
check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
|
266
321
|
check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
|
267
322
|
check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
|
323
|
+
release :TensorTypeAndShapeInfo, type_shape
|
268
324
|
|
269
325
|
# TODO support more types
|
270
326
|
elem_type = FFI::TensorElementDataType[elem_type.read_int]
|
271
327
|
case elem_type
|
272
328
|
when :int64
|
273
329
|
ret = {}
|
274
|
-
keys = create_from_onnx_value(map_keys.read_pointer)
|
275
|
-
values = create_from_onnx_value(map_values.read_pointer)
|
330
|
+
keys = create_from_onnx_value(map_keys.read_pointer, output_type)
|
331
|
+
values = create_from_onnx_value(map_values.read_pointer, output_type)
|
276
332
|
keys.zip(values).each do |k, v|
|
277
333
|
ret[k] = v
|
278
334
|
end
|
@@ -285,15 +341,32 @@ module OnnxRuntime
|
|
285
341
|
end
|
286
342
|
end
|
287
343
|
|
344
|
+
def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
345
|
+
len = ::FFI::MemoryPointer.new(:size_t)
|
346
|
+
check_status api[:GetStringTensorDataLength].call(out_ptr, len)
|
347
|
+
|
348
|
+
s_len = len.read(:size_t)
|
349
|
+
s = ::FFI::MemoryPointer.new(:uchar, s_len)
|
350
|
+
offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
|
351
|
+
check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
|
352
|
+
|
353
|
+
offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
|
354
|
+
offsets << s_len
|
355
|
+
output_tensor_size.times do |i|
|
356
|
+
result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
|
357
|
+
end
|
358
|
+
result
|
359
|
+
end
|
360
|
+
|
288
361
|
def read_pointer
|
289
362
|
@session.read_pointer
|
290
363
|
end
|
291
364
|
|
292
365
|
def check_status(status)
|
293
366
|
unless status.null?
|
294
|
-
message = api[:GetErrorMessage].call(status)
|
367
|
+
message = api[:GetErrorMessage].call(status).read_string
|
295
368
|
api[:ReleaseStatus].call(status)
|
296
|
-
raise
|
369
|
+
raise Error, message
|
297
370
|
end
|
298
371
|
end
|
299
372
|
|
@@ -305,6 +378,7 @@ module OnnxRuntime
|
|
305
378
|
case type
|
306
379
|
when :tensor
|
307
380
|
tensor_info = ::FFI::MemoryPointer.new(:pointer)
|
381
|
+
# don't free tensor_info
|
308
382
|
check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
|
309
383
|
|
310
384
|
type, shape = tensor_type_and_shape(tensor_info)
|
@@ -345,7 +419,7 @@ module OnnxRuntime
|
|
345
419
|
unsupported_type("ONNX", type)
|
346
420
|
end
|
347
421
|
ensure
|
348
|
-
|
422
|
+
release :TypeInfo, typeinfo
|
349
423
|
end
|
350
424
|
|
351
425
|
def tensor_type_and_shape(tensor_info)
|
@@ -354,7 +428,7 @@ module OnnxRuntime
|
|
354
428
|
|
355
429
|
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
356
430
|
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
357
|
-
num_dims =
|
431
|
+
num_dims = num_dims_ptr.read(:size_t)
|
358
432
|
|
359
433
|
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
360
434
|
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
@@ -363,20 +437,67 @@ module OnnxRuntime
|
|
363
437
|
end
|
364
438
|
|
365
439
|
def unsupported_type(name, type)
|
366
|
-
raise "Unsupported #{name} type: #{type}"
|
440
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
367
441
|
end
|
368
442
|
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
443
|
+
def tensor_types
|
444
|
+
@tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
|
445
|
+
end
|
446
|
+
|
447
|
+
def numo_array?(obj)
|
448
|
+
defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
|
449
|
+
end
|
450
|
+
|
451
|
+
def numo_types
|
452
|
+
@numo_types ||= {
|
453
|
+
float: Numo::SFloat,
|
454
|
+
uint8: Numo::UInt8,
|
455
|
+
int8: Numo::Int8,
|
456
|
+
uint16: Numo::UInt16,
|
457
|
+
int16: Numo::Int16,
|
458
|
+
int32: Numo::Int32,
|
459
|
+
int64: Numo::Int64,
|
460
|
+
bool: Numo::UInt8,
|
461
|
+
double: Numo::DFloat,
|
462
|
+
uint32: Numo::UInt32,
|
463
|
+
uint64: Numo::UInt64
|
464
|
+
}
|
376
465
|
end
|
377
466
|
|
378
467
|
def api
|
379
|
-
|
468
|
+
self.class.api
|
469
|
+
end
|
470
|
+
|
471
|
+
def release(*args)
|
472
|
+
self.class.release(*args)
|
473
|
+
end
|
474
|
+
|
475
|
+
def self.api
|
476
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
477
|
+
end
|
478
|
+
|
479
|
+
def self.release(type, pointer)
|
480
|
+
api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
|
481
|
+
end
|
482
|
+
|
483
|
+
def self.finalize(session)
|
484
|
+
# must use proc instead of stabby lambda
|
485
|
+
proc { release :Session, session }
|
486
|
+
end
|
487
|
+
|
488
|
+
# wide string on Windows
|
489
|
+
# char string on Linux
|
490
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
491
|
+
def ort_string(str)
|
492
|
+
if Gem.win_platform?
|
493
|
+
max = str.size + 1 # for null byte
|
494
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
495
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
496
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
497
|
+
dest
|
498
|
+
else
|
499
|
+
str
|
500
|
+
end
|
380
501
|
end
|
381
502
|
|
382
503
|
def env
|
@@ -385,7 +506,7 @@ module OnnxRuntime
|
|
385
506
|
@@env ||= begin
|
386
507
|
env = ::FFI::MemoryPointer.new(:pointer)
|
387
508
|
check_status api[:CreateEnv].call(3, "Default", env)
|
388
|
-
at_exit {
|
509
|
+
at_exit { release :Env, env }
|
389
510
|
# disable telemetry
|
390
511
|
# https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
|
391
512
|
check_status api[:DisableTelemetryEvents].call(env)
|