onnxruntime 0.3.0 → 0.5.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +26 -0
- data/LICENSE.txt +18 -19
- data/README.md +8 -2
- data/lib/onnxruntime.rb +1 -1
- data/lib/onnxruntime/ffi.rb +28 -9
- data/lib/onnxruntime/inference_session.rb +102 -45
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/ThirdPartyNotices.txt +1029 -3
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: dc2115e3b32ec7d1d0ab2da64edfb5bf2ebf24ecc6ad705d4a53ae4793e60412
|
4
|
+
data.tar.gz: 5e53efd5de12ed5ac06a7e4472c8ee043ffbbba9c82b5f49e340f6a70a4dfcdf
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 93f83646d23298213b971ea7462ce4e5cc970c49d7572b646a6f8003b05689a5a232a3acc05b9cacfbabf65534c046d687eecc80fbf5583b2691391d30b54872
|
7
|
+
data.tar.gz: fc0349312d70944a2169302e8b9f1f70f255e29b60c523f0ba1a688ac162cc230636f9ef333d6ac032d855d9387d88234181a69edade6407bde027050972a2da
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,29 @@
|
|
1
|
+
## 0.5.0 (2020-10-01)
|
2
|
+
|
3
|
+
- Updated ONNX Runtime to 1.5.1
|
4
|
+
- OpenMP is now required on Mac
|
5
|
+
- Fixed `mul_1.onnx` example
|
6
|
+
|
7
|
+
## 0.4.0 (2020-07-20)
|
8
|
+
|
9
|
+
- Updated ONNX Runtime to 1.4.0
|
10
|
+
- Added `providers` method
|
11
|
+
- Fixed errors on Windows
|
12
|
+
|
13
|
+
## 0.3.3 (2020-06-17)
|
14
|
+
|
15
|
+
- Fixed segmentation fault on exit on Linux
|
16
|
+
|
17
|
+
## 0.3.2 (2020-06-16)
|
18
|
+
|
19
|
+
- Fixed error with FFI 1.13.0+
|
20
|
+
- Added friendly graph optimization levels
|
21
|
+
|
22
|
+
## 0.3.1 (2020-05-18)
|
23
|
+
|
24
|
+
- Updated ONNX Runtime to 1.3.0
|
25
|
+
- Added `custom_metadata_map` to model metadata
|
26
|
+
|
1
27
|
## 0.3.0 (2020-03-11)
|
2
28
|
|
3
29
|
- Updated ONNX Runtime to 1.2.0
|
data/LICENSE.txt
CHANGED
@@ -1,23 +1,22 @@
|
|
1
|
-
Copyright (c) 2019-2020 Andrew Kane
|
2
|
-
Datasets Copyright (c) Microsoft Corporation
|
3
|
-
|
4
1
|
MIT License
|
5
2
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
3
|
+
Copyright (c) 2018 Microsoft Corporation
|
4
|
+
Copyright (c) 2019-2020 Andrew Kane
|
5
|
+
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7
|
+
of this software and associated documentation files (the "Software"), to deal
|
8
|
+
in the Software without restriction, including without limitation the rights
|
9
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10
|
+
copies of the Software, and to permit persons to whom the Software is
|
11
|
+
furnished to do so, subject to the following conditions:
|
13
12
|
|
14
|
-
The above copyright notice and this permission notice shall be
|
15
|
-
|
13
|
+
The above copyright notice and this permission notice shall be included in all
|
14
|
+
copies or substantial portions of the Software.
|
16
15
|
|
17
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
OF
|
23
|
-
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22
|
+
SOFTWARE.
|
data/README.md
CHANGED
@@ -14,6 +14,12 @@ Add this line to your application’s Gemfile:
|
|
14
14
|
gem 'onnxruntime'
|
15
15
|
```
|
16
16
|
|
17
|
+
On Mac, also install OpenMP:
|
18
|
+
|
19
|
+
```sh
|
20
|
+
brew install libomp
|
21
|
+
```
|
22
|
+
|
17
23
|
## Getting Started
|
18
24
|
|
19
25
|
Load a model and make predictions
|
@@ -63,8 +69,8 @@ OnnxRuntime::Model.new(path_or_bytes, {
|
|
63
69
|
enable_cpu_mem_arena: true,
|
64
70
|
enable_mem_pattern: true,
|
65
71
|
enable_profiling: false,
|
66
|
-
execution_mode: :sequential,
|
67
|
-
graph_optimization_level: nil,
|
72
|
+
execution_mode: :sequential, # :sequential or :parallel
|
73
|
+
graph_optimization_level: nil, # :none, :basic, :extended, or :all
|
68
74
|
inter_op_num_threads: nil,
|
69
75
|
intra_op_num_threads: nil,
|
70
76
|
log_severity_level: 2,
|
data/lib/onnxruntime.rb
CHANGED
data/lib/onnxruntime/ffi.rb
CHANGED
@@ -3,10 +3,13 @@ module OnnxRuntime
|
|
3
3
|
extend ::FFI::Library
|
4
4
|
|
5
5
|
begin
|
6
|
-
ffi_lib
|
6
|
+
ffi_lib OnnxRuntime.ffi_lib
|
7
7
|
rescue LoadError => e
|
8
|
-
|
9
|
-
|
8
|
+
if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
|
9
|
+
raise LoadError, "OpenMP not found. Run `brew install libomp`"
|
10
|
+
else
|
11
|
+
raise e
|
12
|
+
end
|
10
13
|
end
|
11
14
|
|
12
15
|
# https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
|
@@ -20,19 +23,19 @@ module OnnxRuntime
|
|
20
23
|
layout \
|
21
24
|
:CreateStatus, callback(%i[int string], :pointer),
|
22
25
|
:GetErrorCode, callback(%i[pointer], :pointer),
|
23
|
-
:GetErrorMessage, callback(%i[pointer], :
|
26
|
+
:GetErrorMessage, callback(%i[pointer], :pointer),
|
24
27
|
:CreateEnv, callback(%i[int string pointer], :pointer),
|
25
28
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
26
29
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
27
30
|
:DisableTelemetryEvents, callback(%i[pointer], :pointer),
|
28
|
-
:CreateSession, callback(%i[pointer
|
31
|
+
:CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
|
29
32
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
30
33
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
31
34
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[pointer
|
35
|
+
:SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
|
33
36
|
:CloneSessionOptions, callback(%i[], :pointer),
|
34
37
|
:SetSessionExecutionMode, callback(%i[], :pointer),
|
35
|
-
:EnableProfiling, callback(%i[pointer
|
38
|
+
:EnableProfiling, callback(%i[pointer pointer], :pointer),
|
36
39
|
:DisableProfiling, callback(%i[pointer], :pointer),
|
37
40
|
:EnableMemPattern, callback(%i[pointer], :pointer),
|
38
41
|
:DisableMemPattern, callback(%i[pointer], :pointer),
|
@@ -136,7 +139,15 @@ module OnnxRuntime
|
|
136
139
|
:ModelMetadataGetDescription, callback(%i[pointer pointer pointer], :pointer),
|
137
140
|
:ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer),
|
138
141
|
:ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer),
|
139
|
-
:ReleaseModelMetadata, callback(%i[pointer], :void)
|
142
|
+
:ReleaseModelMetadata, callback(%i[pointer], :void),
|
143
|
+
:CreateEnvWithGlobalThreadPools, callback(%i[], :pointer),
|
144
|
+
:DisablePerSessionThreads, callback(%i[], :pointer),
|
145
|
+
:CreateThreadingOptions, callback(%i[], :pointer),
|
146
|
+
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
147
|
+
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
148
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer),
|
149
|
+
:GetAvailableProviders, callback(%i[pointer pointer], :pointer),
|
150
|
+
:ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
|
140
151
|
end
|
141
152
|
|
142
153
|
class ApiBase < ::FFI::Struct
|
@@ -144,9 +155,17 @@ module OnnxRuntime
|
|
144
155
|
# to prevent "unable to resolve type" error on Ubuntu
|
145
156
|
layout \
|
146
157
|
:GetApi, callback(%i[uint32], Api.by_ref),
|
147
|
-
:GetVersionString, callback(%i[], :
|
158
|
+
:GetVersionString, callback(%i[], :pointer)
|
148
159
|
end
|
149
160
|
|
150
161
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
162
|
+
|
163
|
+
if Gem.win_platform?
|
164
|
+
class Libc
|
165
|
+
extend ::FFI::Library
|
166
|
+
ffi_lib ::FFI::Library::LIBC
|
167
|
+
attach_function :mbstowcs, %i[pointer string size_t], :size_t
|
168
|
+
end
|
169
|
+
end
|
151
170
|
end
|
152
171
|
end
|
@@ -8,26 +8,25 @@ module OnnxRuntime
|
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
9
9
|
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
10
|
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
-
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
12
12
|
if execution_mode
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
0
|
17
|
-
when :parallel
|
18
|
-
1
|
19
|
-
else
|
20
|
-
raise ArgumentError, "Invalid execution mode"
|
21
|
-
end
|
13
|
+
execution_modes = {sequential: 0, parallel: 1}
|
14
|
+
mode = execution_modes[execution_mode]
|
15
|
+
raise ArgumentError, "Invalid execution mode" unless mode
|
22
16
|
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
|
23
17
|
end
|
24
|
-
|
18
|
+
if graph_optimization_level
|
19
|
+
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
22
|
+
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
|
+
end
|
25
24
|
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
26
25
|
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
|
27
26
|
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
28
27
|
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
29
28
|
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
30
|
-
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
31
30
|
|
32
31
|
# session
|
33
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
@@ -40,17 +39,12 @@ module OnnxRuntime
|
|
40
39
|
path_or_bytes.encoding == Encoding::BINARY
|
41
40
|
end
|
42
41
|
|
43
|
-
# fix for Windows "File doesn't exist"
|
44
|
-
if Gem.win_platform? && !from_memory
|
45
|
-
path_or_bytes = File.binread(path_or_bytes)
|
46
|
-
from_memory = true
|
47
|
-
end
|
48
|
-
|
49
42
|
if from_memory
|
50
43
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
51
44
|
else
|
52
|
-
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
53
46
|
end
|
47
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
54
48
|
|
55
49
|
# input info
|
56
50
|
allocator = ::FFI::MemoryPointer.new(:pointer)
|
@@ -63,7 +57,7 @@ module OnnxRuntime
|
|
63
57
|
# input
|
64
58
|
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
65
59
|
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
66
|
-
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
67
61
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
68
62
|
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
69
63
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -74,13 +68,15 @@ module OnnxRuntime
|
|
74
68
|
# output
|
75
69
|
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
76
70
|
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
77
|
-
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
78
72
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
79
73
|
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
80
74
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
81
75
|
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
|
82
76
|
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
83
77
|
end
|
78
|
+
ensure
|
79
|
+
# release :SessionOptions, session_options
|
84
80
|
end
|
85
81
|
|
86
82
|
# TODO support logid
|
@@ -106,9 +102,18 @@ module OnnxRuntime
|
|
106
102
|
output_names.size.times.map do |i|
|
107
103
|
create_from_onnx_value(output_tensor[i].read_pointer)
|
108
104
|
end
|
105
|
+
ensure
|
106
|
+
release :RunOptions, run_options
|
107
|
+
if input_tensor
|
108
|
+
input_feed.size.times do |i|
|
109
|
+
release :Value, input_tensor[i]
|
110
|
+
end
|
111
|
+
end
|
109
112
|
end
|
110
113
|
|
111
114
|
def modelmeta
|
115
|
+
keys = ::FFI::MemoryPointer.new(:pointer)
|
116
|
+
num_keys = ::FFI::MemoryPointer.new(:int64_t)
|
112
117
|
description = ::FFI::MemoryPointer.new(:string)
|
113
118
|
domain = ::FFI::MemoryPointer.new(:string)
|
114
119
|
graph_name = ::FFI::MemoryPointer.new(:string)
|
@@ -117,36 +122,61 @@ module OnnxRuntime
|
|
117
122
|
|
118
123
|
metadata = ::FFI::MemoryPointer.new(:pointer)
|
119
124
|
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
125
|
+
|
126
|
+
custom_metadata_map = {}
|
127
|
+
check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
128
|
+
num_keys.read(:int64_t).times do |i|
|
129
|
+
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
130
|
+
value = ::FFI::MemoryPointer.new(:string)
|
131
|
+
check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
|
132
|
+
custom_metadata_map[key] = value.read_pointer.read_string
|
133
|
+
end
|
134
|
+
|
120
135
|
check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
|
121
136
|
check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
|
122
137
|
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
|
123
138
|
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
|
124
139
|
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
|
125
|
-
api[:ReleaseModelMetadata].call(metadata.read_pointer)
|
126
|
-
|
127
|
-
# TODO add custom_metadata_map
|
128
|
-
# need a way to get keys
|
129
140
|
|
130
141
|
{
|
142
|
+
custom_metadata_map: custom_metadata_map,
|
131
143
|
description: description.read_pointer.read_string,
|
132
144
|
domain: domain.read_pointer.read_string,
|
133
145
|
graph_name: graph_name.read_pointer.read_string,
|
134
146
|
producer_name: producer_name.read_pointer.read_string,
|
135
147
|
version: version.read(:int64_t)
|
136
148
|
}
|
149
|
+
ensure
|
150
|
+
release :ModelMetadata, metadata
|
137
151
|
end
|
138
152
|
|
153
|
+
# return value has double underscore like Python
|
139
154
|
def end_profiling
|
140
155
|
out = ::FFI::MemoryPointer.new(:string)
|
141
156
|
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
142
157
|
out.read_pointer.read_string
|
143
158
|
end
|
144
159
|
|
160
|
+
# no way to set providers with C API yet
|
161
|
+
# so we can return all available providers
|
162
|
+
def providers
|
163
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
164
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
165
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
166
|
+
length = length_ptr.read_int
|
167
|
+
providers = []
|
168
|
+
length.times do |i|
|
169
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
170
|
+
end
|
171
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
172
|
+
providers
|
173
|
+
end
|
174
|
+
|
145
175
|
private
|
146
176
|
|
147
177
|
def create_input_tensor(input_feed)
|
148
178
|
allocator_info = ::FFI::MemoryPointer.new(:pointer)
|
149
|
-
check_status
|
179
|
+
check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
|
150
180
|
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
|
151
181
|
|
152
182
|
input_feed.each_with_index do |(input_name, input), idx|
|
@@ -164,7 +194,7 @@ module OnnxRuntime
|
|
164
194
|
|
165
195
|
# TODO support more types
|
166
196
|
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
167
|
-
raise "Unknown input: #{input_name}" unless inp
|
197
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
168
198
|
|
169
199
|
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
170
200
|
input_node_dims.write_array_of_int64(shape)
|
@@ -206,7 +236,7 @@ module OnnxRuntime
|
|
206
236
|
|
207
237
|
def create_from_onnx_value(out_ptr)
|
208
238
|
out_type = ::FFI::MemoryPointer.new(:int)
|
209
|
-
check_status
|
239
|
+
check_status api[:GetValueType].call(out_ptr, out_type)
|
210
240
|
type = FFI::OnnxType[out_type.read_int]
|
211
241
|
|
212
242
|
case type
|
@@ -221,7 +251,9 @@ module OnnxRuntime
|
|
221
251
|
|
222
252
|
out_size = ::FFI::MemoryPointer.new(:size_t)
|
223
253
|
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
224
|
-
output_tensor_size =
|
254
|
+
output_tensor_size = out_size.read(:size_t)
|
255
|
+
|
256
|
+
release :TensorTypeAndShapeInfo, typeinfo
|
225
257
|
|
226
258
|
# TODO support more types
|
227
259
|
type = FFI::TensorElementDataType[type]
|
@@ -240,7 +272,7 @@ module OnnxRuntime
|
|
240
272
|
out = ::FFI::MemoryPointer.new(:size_t)
|
241
273
|
check_status api[:GetValueCount].call(out_ptr, out)
|
242
274
|
|
243
|
-
|
275
|
+
out.read(:size_t).times.map do |i|
|
244
276
|
seq = ::FFI::MemoryPointer.new(:pointer)
|
245
277
|
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
246
278
|
create_from_onnx_value(seq.read_pointer)
|
@@ -255,6 +287,7 @@ module OnnxRuntime
|
|
255
287
|
check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
|
256
288
|
check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
|
257
289
|
check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
|
290
|
+
release :TensorTypeAndShapeInfo, type_shape
|
258
291
|
|
259
292
|
# TODO support more types
|
260
293
|
elem_type = FFI::TensorElementDataType[elem_type.read_int]
|
@@ -281,9 +314,9 @@ module OnnxRuntime
|
|
281
314
|
|
282
315
|
def check_status(status)
|
283
316
|
unless status.null?
|
284
|
-
message = api[:GetErrorMessage].call(status)
|
317
|
+
message = api[:GetErrorMessage].call(status).read_string
|
285
318
|
api[:ReleaseStatus].call(status)
|
286
|
-
raise
|
319
|
+
raise Error, message
|
287
320
|
end
|
288
321
|
end
|
289
322
|
|
@@ -295,6 +328,7 @@ module OnnxRuntime
|
|
295
328
|
case type
|
296
329
|
when :tensor
|
297
330
|
tensor_info = ::FFI::MemoryPointer.new(:pointer)
|
331
|
+
# don't free tensor_info
|
298
332
|
check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
|
299
333
|
|
300
334
|
type, shape = tensor_type_and_shape(tensor_info)
|
@@ -335,7 +369,7 @@ module OnnxRuntime
|
|
335
369
|
unsupported_type("ONNX", type)
|
336
370
|
end
|
337
371
|
ensure
|
338
|
-
|
372
|
+
release :TypeInfo, typeinfo
|
339
373
|
end
|
340
374
|
|
341
375
|
def tensor_type_and_shape(tensor_info)
|
@@ -344,7 +378,7 @@ module OnnxRuntime
|
|
344
378
|
|
345
379
|
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
346
380
|
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
347
|
-
num_dims =
|
381
|
+
num_dims = num_dims_ptr.read(:size_t)
|
348
382
|
|
349
383
|
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
350
384
|
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
@@ -353,20 +387,43 @@ module OnnxRuntime
|
|
353
387
|
end
|
354
388
|
|
355
389
|
def unsupported_type(name, type)
|
356
|
-
raise "Unsupported #{name} type: #{type}"
|
390
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
357
391
|
end
|
358
392
|
|
359
|
-
|
360
|
-
|
361
|
-
if RUBY_PLATFORM == "java"
|
362
|
-
ptr.read_long
|
363
|
-
else
|
364
|
-
ptr.read(:size_t)
|
365
|
-
end
|
393
|
+
def api
|
394
|
+
self.class.api
|
366
395
|
end
|
367
396
|
|
368
|
-
def
|
369
|
-
|
397
|
+
def release(*args)
|
398
|
+
self.class.release(*args)
|
399
|
+
end
|
400
|
+
|
401
|
+
def self.api
|
402
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
403
|
+
end
|
404
|
+
|
405
|
+
def self.release(type, pointer)
|
406
|
+
api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
|
407
|
+
end
|
408
|
+
|
409
|
+
def self.finalize(session)
|
410
|
+
# must use proc instead of stabby lambda
|
411
|
+
proc { release :Session, session }
|
412
|
+
end
|
413
|
+
|
414
|
+
# wide string on Windows
|
415
|
+
# char string on Linux
|
416
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
417
|
+
def ort_string(str)
|
418
|
+
if Gem.win_platform?
|
419
|
+
max = str.size + 1 # for null byte
|
420
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
421
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
422
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
423
|
+
dest
|
424
|
+
else
|
425
|
+
str
|
426
|
+
end
|
370
427
|
end
|
371
428
|
|
372
429
|
def env
|
@@ -375,7 +432,7 @@ module OnnxRuntime
|
|
375
432
|
@@env ||= begin
|
376
433
|
env = ::FFI::MemoryPointer.new(:pointer)
|
377
434
|
check_status api[:CreateEnv].call(3, "Default", env)
|
378
|
-
at_exit {
|
435
|
+
at_exit { release :Env, env }
|
379
436
|
# disable telemetry
|
380
437
|
# https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
|
381
438
|
check_status api[:DisableTelemetryEvents].call(env)
|