onnxruntime 0.3.1 → 0.5.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/LICENSE.txt +18 -19
- data/README.md +10 -3
- data/lib/onnxruntime.rb +1 -1
- data/lib/onnxruntime/ffi.rb +24 -11
- data/lib/onnxruntime/inference_session.rb +194 -73
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/ThirdPartyNotices.txt +1029 -3
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +7 -7
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 316527780be2781a474d0813aff47d840654423823ad83b8b52b13752caf6814
|
4
|
+
data.tar.gz: 5954ba2dc4223b8330fb52e8474be78542360c2f4c1cb5946529d062c0b1b864
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c127874dd75a10b8cb9d9d033e607f9d73069bb5709d7b9ecee04c1d59f969f61db6ec88fd569188b0a35c8276f7611619f81662f3735fbd53e61938f15c6822
|
7
|
+
data.tar.gz: 77a7c9f8c98b25fd82ee8818c63176d2005f846db489b5973330220344be8fd47d8e5279517a716f8c1222f729a248fc8ac80cb610633dbdc35c49418a42ea1c
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.5.1 (2020-11-01)
|
2
|
+
|
3
|
+
- Updated ONNX Runtime to 1.5.2
|
4
|
+
- Added support for string output
|
5
|
+
- Added `output_type` option
|
6
|
+
- Improved performance for Numo array inputs
|
7
|
+
|
8
|
+
## 0.5.0 (2020-10-01)
|
9
|
+
|
10
|
+
- Updated ONNX Runtime to 1.5.1
|
11
|
+
- OpenMP is now required on Mac
|
12
|
+
- Fixed `mul_1.onnx` example
|
13
|
+
|
14
|
+
## 0.4.0 (2020-07-20)
|
15
|
+
|
16
|
+
- Updated ONNX Runtime to 1.4.0
|
17
|
+
- Added `providers` method
|
18
|
+
- Fixed errors on Windows
|
19
|
+
|
20
|
+
## 0.3.3 (2020-06-17)
|
21
|
+
|
22
|
+
- Fixed segmentation fault on exit on Linux
|
23
|
+
|
24
|
+
## 0.3.2 (2020-06-16)
|
25
|
+
|
26
|
+
- Fixed error with FFI 1.13.0+
|
27
|
+
- Added friendly graph optimization levels
|
28
|
+
|
1
29
|
## 0.3.1 (2020-05-18)
|
2
30
|
|
3
31
|
- Updated ONNX Runtime to 1.3.0
|
data/LICENSE.txt
CHANGED
@@ -1,23 +1,22 @@
|
|
1
|
-
Copyright (c) 2019-2020 Andrew Kane
|
2
|
-
Datasets Copyright (c) Microsoft Corporation
|
3
|
-
|
4
1
|
MIT License
|
5
2
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
3
|
+
Copyright (c) 2018 Microsoft Corporation
|
4
|
+
Copyright (c) 2019-2020 Andrew Kane
|
5
|
+
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7
|
+
of this software and associated documentation files (the "Software"), to deal
|
8
|
+
in the Software without restriction, including without limitation the rights
|
9
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10
|
+
copies of the Software, and to permit persons to whom the Software is
|
11
|
+
furnished to do so, subject to the following conditions:
|
13
12
|
|
14
|
-
The above copyright notice and this permission notice shall be
|
15
|
-
|
13
|
+
The above copyright notice and this permission notice shall be included in all
|
14
|
+
copies or substantial portions of the Software.
|
16
15
|
|
17
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
OF
|
23
|
-
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22
|
+
SOFTWARE.
|
data/README.md
CHANGED
@@ -14,6 +14,12 @@ Add this line to your application’s Gemfile:
|
|
14
14
|
gem 'onnxruntime'
|
15
15
|
```
|
16
16
|
|
17
|
+
On Mac, also install OpenMP:
|
18
|
+
|
19
|
+
```sh
|
20
|
+
brew install libomp
|
21
|
+
```
|
22
|
+
|
17
23
|
## Getting Started
|
18
24
|
|
19
25
|
Load a model and make predictions
|
@@ -63,8 +69,8 @@ OnnxRuntime::Model.new(path_or_bytes, {
|
|
63
69
|
enable_cpu_mem_arena: true,
|
64
70
|
enable_mem_pattern: true,
|
65
71
|
enable_profiling: false,
|
66
|
-
execution_mode: :sequential,
|
67
|
-
graph_optimization_level: nil,
|
72
|
+
execution_mode: :sequential, # :sequential or :parallel
|
73
|
+
graph_optimization_level: nil, # :none, :basic, :extended, or :all
|
68
74
|
inter_op_num_threads: nil,
|
69
75
|
intra_op_num_threads: nil,
|
70
76
|
log_severity_level: 2,
|
@@ -81,7 +87,8 @@ model.predict(input_feed, {
|
|
81
87
|
log_severity_level: 2,
|
82
88
|
log_verbosity_level: 0,
|
83
89
|
logid: nil,
|
84
|
-
terminate: false
|
90
|
+
terminate: false,
|
91
|
+
output_type: :ruby # :ruby or :numo
|
85
92
|
})
|
86
93
|
```
|
87
94
|
|
data/lib/onnxruntime.rb
CHANGED
data/lib/onnxruntime/ffi.rb
CHANGED
@@ -3,10 +3,13 @@ module OnnxRuntime
|
|
3
3
|
extend ::FFI::Library
|
4
4
|
|
5
5
|
begin
|
6
|
-
ffi_lib
|
6
|
+
ffi_lib OnnxRuntime.ffi_lib
|
7
7
|
rescue LoadError => e
|
8
|
-
|
9
|
-
|
8
|
+
if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
|
9
|
+
raise LoadError, "OpenMP not found. Run `brew install libomp`"
|
10
|
+
else
|
11
|
+
raise e
|
12
|
+
end
|
10
13
|
end
|
11
14
|
|
12
15
|
# https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
|
@@ -20,19 +23,19 @@ module OnnxRuntime
|
|
20
23
|
layout \
|
21
24
|
:CreateStatus, callback(%i[int string], :pointer),
|
22
25
|
:GetErrorCode, callback(%i[pointer], :pointer),
|
23
|
-
:GetErrorMessage, callback(%i[pointer], :
|
26
|
+
:GetErrorMessage, callback(%i[pointer], :pointer),
|
24
27
|
:CreateEnv, callback(%i[int string pointer], :pointer),
|
25
28
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
26
29
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
27
30
|
:DisableTelemetryEvents, callback(%i[pointer], :pointer),
|
28
|
-
:CreateSession, callback(%i[pointer
|
31
|
+
:CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
|
29
32
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
30
33
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
31
34
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[pointer
|
35
|
+
:SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
|
33
36
|
:CloneSessionOptions, callback(%i[], :pointer),
|
34
37
|
:SetSessionExecutionMode, callback(%i[], :pointer),
|
35
|
-
:EnableProfiling, callback(%i[pointer
|
38
|
+
:EnableProfiling, callback(%i[pointer pointer], :pointer),
|
36
39
|
:DisableProfiling, callback(%i[pointer], :pointer),
|
37
40
|
:EnableMemPattern, callback(%i[pointer], :pointer),
|
38
41
|
:DisableMemPattern, callback(%i[pointer], :pointer),
|
@@ -71,8 +74,8 @@ module OnnxRuntime
|
|
71
74
|
:IsTensor, callback(%i[], :pointer),
|
72
75
|
:GetTensorMutableData, callback(%i[pointer pointer], :pointer),
|
73
76
|
:FillStringTensor, callback(%i[pointer pointer size_t], :pointer),
|
74
|
-
:GetStringTensorDataLength, callback(%i[], :pointer),
|
75
|
-
:GetStringTensorContent, callback(%i[], :pointer),
|
77
|
+
:GetStringTensorDataLength, callback(%i[pointer pointer], :pointer),
|
78
|
+
:GetStringTensorContent, callback(%i[pointer pointer size_t pointer size_t], :pointer),
|
76
79
|
:CastTypeInfoToTensorInfo, callback(%i[pointer pointer], :pointer),
|
77
80
|
:GetOnnxTypeFromTypeInfo, callback(%i[pointer pointer], :pointer),
|
78
81
|
:CreateTensorTypeAndShapeInfo, callback(%i[], :pointer),
|
@@ -142,7 +145,9 @@ module OnnxRuntime
|
|
142
145
|
:CreateThreadingOptions, callback(%i[], :pointer),
|
143
146
|
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
144
147
|
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
145
|
-
:AddFreeDimensionOverrideByName, callback(%i[], :pointer)
|
148
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer),
|
149
|
+
:GetAvailableProviders, callback(%i[pointer pointer], :pointer),
|
150
|
+
:ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
|
146
151
|
end
|
147
152
|
|
148
153
|
class ApiBase < ::FFI::Struct
|
@@ -150,9 +155,17 @@ module OnnxRuntime
|
|
150
155
|
# to prevent "unable to resolve type" error on Ubuntu
|
151
156
|
layout \
|
152
157
|
:GetApi, callback(%i[uint32], Api.by_ref),
|
153
|
-
:GetVersionString, callback(%i[], :
|
158
|
+
:GetVersionString, callback(%i[], :pointer)
|
154
159
|
end
|
155
160
|
|
156
161
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
162
|
+
|
163
|
+
if Gem.win_platform?
|
164
|
+
class Libc
|
165
|
+
extend ::FFI::Library
|
166
|
+
ffi_lib ::FFI::Library::LIBC
|
167
|
+
attach_function :mbstowcs, %i[pointer string size_t], :size_t
|
168
|
+
end
|
169
|
+
end
|
157
170
|
end
|
158
171
|
end
|
@@ -8,26 +8,25 @@ module OnnxRuntime
|
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
9
9
|
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
10
|
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
-
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
12
12
|
if execution_mode
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
0
|
17
|
-
when :parallel
|
18
|
-
1
|
19
|
-
else
|
20
|
-
raise ArgumentError, "Invalid execution mode"
|
21
|
-
end
|
13
|
+
execution_modes = {sequential: 0, parallel: 1}
|
14
|
+
mode = execution_modes[execution_mode]
|
15
|
+
raise ArgumentError, "Invalid execution mode" unless mode
|
22
16
|
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
|
23
17
|
end
|
24
|
-
|
18
|
+
if graph_optimization_level
|
19
|
+
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
22
|
+
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
|
+
end
|
25
24
|
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
26
25
|
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
|
27
26
|
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
28
27
|
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
29
28
|
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
30
|
-
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
31
30
|
|
32
31
|
# session
|
33
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
@@ -40,17 +39,12 @@ module OnnxRuntime
|
|
40
39
|
path_or_bytes.encoding == Encoding::BINARY
|
41
40
|
end
|
42
41
|
|
43
|
-
# fix for Windows "File doesn't exist"
|
44
|
-
if Gem.win_platform? && !from_memory
|
45
|
-
path_or_bytes = File.binread(path_or_bytes)
|
46
|
-
from_memory = true
|
47
|
-
end
|
48
|
-
|
49
42
|
if from_memory
|
50
43
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
51
44
|
else
|
52
|
-
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
53
46
|
end
|
47
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
54
48
|
|
55
49
|
# input info
|
56
50
|
allocator = ::FFI::MemoryPointer.new(:pointer)
|
@@ -63,7 +57,7 @@ module OnnxRuntime
|
|
63
57
|
# input
|
64
58
|
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
65
59
|
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
66
|
-
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
67
61
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
68
62
|
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
69
63
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -74,17 +68,19 @@ module OnnxRuntime
|
|
74
68
|
# output
|
75
69
|
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
76
70
|
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
77
|
-
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
78
72
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
79
73
|
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
80
74
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
81
75
|
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
|
82
76
|
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
83
77
|
end
|
78
|
+
ensure
|
79
|
+
# release :SessionOptions, session_options
|
84
80
|
end
|
85
81
|
|
86
82
|
# TODO support logid
|
87
|
-
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
|
83
|
+
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
|
88
84
|
input_tensor = create_input_tensor(input_feed)
|
89
85
|
|
90
86
|
output_names ||= @outputs.map { |v| v[:name] }
|
@@ -104,7 +100,14 @@ module OnnxRuntime
|
|
104
100
|
check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
|
105
101
|
|
106
102
|
output_names.size.times.map do |i|
|
107
|
-
create_from_onnx_value(output_tensor[i].read_pointer)
|
103
|
+
create_from_onnx_value(output_tensor[i].read_pointer, output_type)
|
104
|
+
end
|
105
|
+
ensure
|
106
|
+
release :RunOptions, run_options
|
107
|
+
if input_tensor
|
108
|
+
input_feed.size.times do |i|
|
109
|
+
release :Value, input_tensor[i]
|
110
|
+
end
|
108
111
|
end
|
109
112
|
end
|
110
113
|
|
@@ -121,7 +124,7 @@ module OnnxRuntime
|
|
121
124
|
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
122
125
|
|
123
126
|
custom_metadata_map = {}
|
124
|
-
check_status
|
127
|
+
check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
125
128
|
num_keys.read(:int64_t).times do |i|
|
126
129
|
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
127
130
|
value = ::FFI::MemoryPointer.new(:string)
|
@@ -134,7 +137,6 @@ module OnnxRuntime
|
|
134
137
|
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
|
135
138
|
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
|
136
139
|
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
|
137
|
-
api[:ReleaseModelMetadata].call(metadata.read_pointer)
|
138
140
|
|
139
141
|
{
|
140
142
|
custom_metadata_map: custom_metadata_map,
|
@@ -144,58 +146,90 @@ module OnnxRuntime
|
|
144
146
|
producer_name: producer_name.read_pointer.read_string,
|
145
147
|
version: version.read(:int64_t)
|
146
148
|
}
|
149
|
+
ensure
|
150
|
+
release :ModelMetadata, metadata
|
147
151
|
end
|
148
152
|
|
153
|
+
# return value has double underscore like Python
|
149
154
|
def end_profiling
|
150
155
|
out = ::FFI::MemoryPointer.new(:string)
|
151
156
|
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
152
157
|
out.read_pointer.read_string
|
153
158
|
end
|
154
159
|
|
160
|
+
# no way to set providers with C API yet
|
161
|
+
# so we can return all available providers
|
162
|
+
def providers
|
163
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
164
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
165
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
166
|
+
length = length_ptr.read_int
|
167
|
+
providers = []
|
168
|
+
length.times do |i|
|
169
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
170
|
+
end
|
171
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
172
|
+
providers
|
173
|
+
end
|
174
|
+
|
155
175
|
private
|
156
176
|
|
157
177
|
def create_input_tensor(input_feed)
|
158
178
|
allocator_info = ::FFI::MemoryPointer.new(:pointer)
|
159
|
-
check_status
|
179
|
+
check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
|
160
180
|
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
|
161
181
|
|
162
182
|
input_feed.each_with_index do |(input_name, input), idx|
|
163
|
-
|
183
|
+
if numo_array?(input)
|
184
|
+
shape = input.shape
|
185
|
+
else
|
186
|
+
input = input.to_a unless input.is_a?(Array)
|
164
187
|
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
188
|
+
shape = []
|
189
|
+
s = input
|
190
|
+
while s.is_a?(Array)
|
191
|
+
shape << s.size
|
192
|
+
s = s.first
|
193
|
+
end
|
170
194
|
end
|
171
195
|
|
172
|
-
flat_input = input.flatten
|
173
|
-
input_tensor_size = flat_input.size
|
174
|
-
|
175
196
|
# TODO support more types
|
176
197
|
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
177
|
-
raise "Unknown input: #{input_name}" unless inp
|
198
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
178
199
|
|
179
200
|
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
180
201
|
input_node_dims.write_array_of_int64(shape)
|
181
202
|
|
182
203
|
if inp[:type] == "tensor(string)"
|
183
|
-
|
184
|
-
|
204
|
+
if numo_array?(input)
|
205
|
+
input_tensor_size = input.size
|
206
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size)
|
207
|
+
input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) })
|
208
|
+
else
|
209
|
+
flat_input = input.flatten.to_a
|
210
|
+
input_tensor_size = flat_input.size
|
211
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
|
212
|
+
input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
|
213
|
+
end
|
185
214
|
type_enum = FFI::TensorElementDataType[:string]
|
186
215
|
check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
|
187
|
-
check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values,
|
216
|
+
check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size)
|
188
217
|
else
|
189
|
-
tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
|
190
218
|
tensor_type = tensor_types[inp[:type]]
|
191
219
|
|
192
220
|
if tensor_type
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
flat_input =
|
221
|
+
if numo_array?(input)
|
222
|
+
input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
|
223
|
+
else
|
224
|
+
flat_input = input.flatten.to_a
|
225
|
+
input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
|
226
|
+
if tensor_type == :bool
|
227
|
+
tensor_type = :uchar
|
228
|
+
flat_input = flat_input.map { |v| v ? 1 : 0 }
|
229
|
+
end
|
230
|
+
input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
|
197
231
|
end
|
198
|
-
|
232
|
+
|
199
233
|
type_enum = FFI::TensorElementDataType[tensor_type]
|
200
234
|
else
|
201
235
|
unsupported_type("input", inp[:type])
|
@@ -214,9 +248,9 @@ module OnnxRuntime
|
|
214
248
|
ptr
|
215
249
|
end
|
216
250
|
|
217
|
-
def create_from_onnx_value(out_ptr)
|
251
|
+
def create_from_onnx_value(out_ptr, output_type)
|
218
252
|
out_type = ::FFI::MemoryPointer.new(:int)
|
219
|
-
check_status
|
253
|
+
check_status api[:GetValueType].call(out_ptr, out_type)
|
220
254
|
type = FFI::OnnxType[out_type.read_int]
|
221
255
|
|
222
256
|
case type
|
@@ -231,29 +265,50 @@ module OnnxRuntime
|
|
231
265
|
|
232
266
|
out_size = ::FFI::MemoryPointer.new(:size_t)
|
233
267
|
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
234
|
-
output_tensor_size =
|
268
|
+
output_tensor_size = out_size.read(:size_t)
|
269
|
+
|
270
|
+
release :TensorTypeAndShapeInfo, typeinfo
|
235
271
|
|
236
272
|
# TODO support more types
|
237
273
|
type = FFI::TensorElementDataType[type]
|
238
|
-
|
274
|
+
|
275
|
+
case output_type
|
276
|
+
when :numo
|
239
277
|
case type
|
240
|
-
when :
|
241
|
-
|
242
|
-
|
243
|
-
|
278
|
+
when :string
|
279
|
+
result = Numo::RObject.new(shape)
|
280
|
+
result.allocate
|
281
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
244
282
|
else
|
245
|
-
|
283
|
+
numo_type = numo_types[type]
|
284
|
+
unsupported_type("element", type) unless numo_type
|
285
|
+
numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
|
246
286
|
end
|
287
|
+
when :ruby
|
288
|
+
arr =
|
289
|
+
case type
|
290
|
+
when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
|
291
|
+
tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
|
292
|
+
when :bool
|
293
|
+
tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
|
294
|
+
when :string
|
295
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
|
296
|
+
else
|
297
|
+
unsupported_type("element", type)
|
298
|
+
end
|
247
299
|
|
248
|
-
|
300
|
+
Utils.reshape(arr, shape)
|
301
|
+
else
|
302
|
+
raise ArgumentError, "Invalid output type: #{output_type}"
|
303
|
+
end
|
249
304
|
when :sequence
|
250
305
|
out = ::FFI::MemoryPointer.new(:size_t)
|
251
306
|
check_status api[:GetValueCount].call(out_ptr, out)
|
252
307
|
|
253
|
-
|
308
|
+
out.read(:size_t).times.map do |i|
|
254
309
|
seq = ::FFI::MemoryPointer.new(:pointer)
|
255
310
|
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
256
|
-
create_from_onnx_value(seq.read_pointer)
|
311
|
+
create_from_onnx_value(seq.read_pointer, output_type)
|
257
312
|
end
|
258
313
|
when :map
|
259
314
|
type_shape = ::FFI::MemoryPointer.new(:pointer)
|
@@ -265,14 +320,15 @@ module OnnxRuntime
|
|
265
320
|
check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
|
266
321
|
check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
|
267
322
|
check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
|
323
|
+
release :TensorTypeAndShapeInfo, type_shape
|
268
324
|
|
269
325
|
# TODO support more types
|
270
326
|
elem_type = FFI::TensorElementDataType[elem_type.read_int]
|
271
327
|
case elem_type
|
272
328
|
when :int64
|
273
329
|
ret = {}
|
274
|
-
keys = create_from_onnx_value(map_keys.read_pointer)
|
275
|
-
values = create_from_onnx_value(map_values.read_pointer)
|
330
|
+
keys = create_from_onnx_value(map_keys.read_pointer, output_type)
|
331
|
+
values = create_from_onnx_value(map_values.read_pointer, output_type)
|
276
332
|
keys.zip(values).each do |k, v|
|
277
333
|
ret[k] = v
|
278
334
|
end
|
@@ -285,15 +341,32 @@ module OnnxRuntime
|
|
285
341
|
end
|
286
342
|
end
|
287
343
|
|
344
|
+
def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
345
|
+
len = ::FFI::MemoryPointer.new(:size_t)
|
346
|
+
check_status api[:GetStringTensorDataLength].call(out_ptr, len)
|
347
|
+
|
348
|
+
s_len = len.read(:size_t)
|
349
|
+
s = ::FFI::MemoryPointer.new(:uchar, s_len)
|
350
|
+
offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
|
351
|
+
check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
|
352
|
+
|
353
|
+
offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
|
354
|
+
offsets << s_len
|
355
|
+
output_tensor_size.times do |i|
|
356
|
+
result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
|
357
|
+
end
|
358
|
+
result
|
359
|
+
end
|
360
|
+
|
288
361
|
def read_pointer
|
289
362
|
@session.read_pointer
|
290
363
|
end
|
291
364
|
|
292
365
|
def check_status(status)
|
293
366
|
unless status.null?
|
294
|
-
message = api[:GetErrorMessage].call(status)
|
367
|
+
message = api[:GetErrorMessage].call(status).read_string
|
295
368
|
api[:ReleaseStatus].call(status)
|
296
|
-
raise
|
369
|
+
raise Error, message
|
297
370
|
end
|
298
371
|
end
|
299
372
|
|
@@ -305,6 +378,7 @@ module OnnxRuntime
|
|
305
378
|
case type
|
306
379
|
when :tensor
|
307
380
|
tensor_info = ::FFI::MemoryPointer.new(:pointer)
|
381
|
+
# don't free tensor_info
|
308
382
|
check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
|
309
383
|
|
310
384
|
type, shape = tensor_type_and_shape(tensor_info)
|
@@ -345,7 +419,7 @@ module OnnxRuntime
|
|
345
419
|
unsupported_type("ONNX", type)
|
346
420
|
end
|
347
421
|
ensure
|
348
|
-
|
422
|
+
release :TypeInfo, typeinfo
|
349
423
|
end
|
350
424
|
|
351
425
|
def tensor_type_and_shape(tensor_info)
|
@@ -354,7 +428,7 @@ module OnnxRuntime
|
|
354
428
|
|
355
429
|
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
356
430
|
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
357
|
-
num_dims =
|
431
|
+
num_dims = num_dims_ptr.read(:size_t)
|
358
432
|
|
359
433
|
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
360
434
|
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
@@ -363,20 +437,67 @@ module OnnxRuntime
|
|
363
437
|
end
|
364
438
|
|
365
439
|
def unsupported_type(name, type)
|
366
|
-
raise "Unsupported #{name} type: #{type}"
|
440
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
367
441
|
end
|
368
442
|
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
443
|
+
def tensor_types
|
444
|
+
@tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
|
445
|
+
end
|
446
|
+
|
447
|
+
def numo_array?(obj)
|
448
|
+
defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
|
449
|
+
end
|
450
|
+
|
451
|
+
def numo_types
|
452
|
+
@numo_types ||= {
|
453
|
+
float: Numo::SFloat,
|
454
|
+
uint8: Numo::UInt8,
|
455
|
+
int8: Numo::Int8,
|
456
|
+
uint16: Numo::UInt16,
|
457
|
+
int16: Numo::Int16,
|
458
|
+
int32: Numo::Int32,
|
459
|
+
int64: Numo::Int64,
|
460
|
+
bool: Numo::UInt8,
|
461
|
+
double: Numo::DFloat,
|
462
|
+
uint32: Numo::UInt32,
|
463
|
+
uint64: Numo::UInt64
|
464
|
+
}
|
376
465
|
end
|
377
466
|
|
378
467
|
def api
|
379
|
-
|
468
|
+
self.class.api
|
469
|
+
end
|
470
|
+
|
471
|
+
def release(*args)
|
472
|
+
self.class.release(*args)
|
473
|
+
end
|
474
|
+
|
475
|
+
def self.api
|
476
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
477
|
+
end
|
478
|
+
|
479
|
+
def self.release(type, pointer)
|
480
|
+
api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
|
481
|
+
end
|
482
|
+
|
483
|
+
def self.finalize(session)
|
484
|
+
# must use proc instead of stabby lambda
|
485
|
+
proc { release :Session, session }
|
486
|
+
end
|
487
|
+
|
488
|
+
# wide string on Windows
|
489
|
+
# char string on Linux
|
490
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
491
|
+
def ort_string(str)
|
492
|
+
if Gem.win_platform?
|
493
|
+
max = str.size + 1 # for null byte
|
494
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
495
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
496
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
497
|
+
dest
|
498
|
+
else
|
499
|
+
str
|
500
|
+
end
|
380
501
|
end
|
381
502
|
|
382
503
|
def env
|
@@ -385,7 +506,7 @@ module OnnxRuntime
|
|
385
506
|
@@env ||= begin
|
386
507
|
env = ::FFI::MemoryPointer.new(:pointer)
|
387
508
|
check_status api[:CreateEnv].call(3, "Default", env)
|
388
|
-
at_exit {
|
509
|
+
at_exit { release :Env, env }
|
389
510
|
# disable telemetry
|
390
511
|
# https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
|
391
512
|
check_status api[:DisableTelemetryEvents].call(env)
|