onnxruntime 0.3.0 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +26 -0
- data/LICENSE.txt +18 -19
- data/README.md +8 -2
- data/lib/onnxruntime.rb +1 -1
- data/lib/onnxruntime/ffi.rb +28 -9
- data/lib/onnxruntime/inference_session.rb +102 -45
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/ThirdPartyNotices.txt +1029 -3
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: dc2115e3b32ec7d1d0ab2da64edfb5bf2ebf24ecc6ad705d4a53ae4793e60412
|
4
|
+
data.tar.gz: 5e53efd5de12ed5ac06a7e4472c8ee043ffbbba9c82b5f49e340f6a70a4dfcdf
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 93f83646d23298213b971ea7462ce4e5cc970c49d7572b646a6f8003b05689a5a232a3acc05b9cacfbabf65534c046d687eecc80fbf5583b2691391d30b54872
|
7
|
+
data.tar.gz: fc0349312d70944a2169302e8b9f1f70f255e29b60c523f0ba1a688ac162cc230636f9ef333d6ac032d855d9387d88234181a69edade6407bde027050972a2da
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,29 @@
|
|
1
|
+
## 0.5.0 (2020-10-01)
|
2
|
+
|
3
|
+
- Updated ONNX Runtime to 1.5.1
|
4
|
+
- OpenMP is now required on Mac
|
5
|
+
- Fixed `mul_1.onnx` example
|
6
|
+
|
7
|
+
## 0.4.0 (2020-07-20)
|
8
|
+
|
9
|
+
- Updated ONNX Runtime to 1.4.0
|
10
|
+
- Added `providers` method
|
11
|
+
- Fixed errors on Windows
|
12
|
+
|
13
|
+
## 0.3.3 (2020-06-17)
|
14
|
+
|
15
|
+
- Fixed segmentation fault on exit on Linux
|
16
|
+
|
17
|
+
## 0.3.2 (2020-06-16)
|
18
|
+
|
19
|
+
- Fixed error with FFI 1.13.0+
|
20
|
+
- Added friendly graph optimization levels
|
21
|
+
|
22
|
+
## 0.3.1 (2020-05-18)
|
23
|
+
|
24
|
+
- Updated ONNX Runtime to 1.3.0
|
25
|
+
- Added `custom_metadata_map` to model metadata
|
26
|
+
|
1
27
|
## 0.3.0 (2020-03-11)
|
2
28
|
|
3
29
|
- Updated ONNX Runtime to 1.2.0
|
data/LICENSE.txt
CHANGED
@@ -1,23 +1,22 @@
|
|
1
|
-
Copyright (c) 2019-2020 Andrew Kane
|
2
|
-
Datasets Copyright (c) Microsoft Corporation
|
3
|
-
|
4
1
|
MIT License
|
5
2
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
3
|
+
Copyright (c) 2018 Microsoft Corporation
|
4
|
+
Copyright (c) 2019-2020 Andrew Kane
|
5
|
+
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7
|
+
of this software and associated documentation files (the "Software"), to deal
|
8
|
+
in the Software without restriction, including without limitation the rights
|
9
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10
|
+
copies of the Software, and to permit persons to whom the Software is
|
11
|
+
furnished to do so, subject to the following conditions:
|
13
12
|
|
14
|
-
The above copyright notice and this permission notice shall be
|
15
|
-
|
13
|
+
The above copyright notice and this permission notice shall be included in all
|
14
|
+
copies or substantial portions of the Software.
|
16
15
|
|
17
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
OF
|
23
|
-
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22
|
+
SOFTWARE.
|
data/README.md
CHANGED
@@ -14,6 +14,12 @@ Add this line to your application’s Gemfile:
|
|
14
14
|
gem 'onnxruntime'
|
15
15
|
```
|
16
16
|
|
17
|
+
On Mac, also install OpenMP:
|
18
|
+
|
19
|
+
```sh
|
20
|
+
brew install libomp
|
21
|
+
```
|
22
|
+
|
17
23
|
## Getting Started
|
18
24
|
|
19
25
|
Load a model and make predictions
|
@@ -63,8 +69,8 @@ OnnxRuntime::Model.new(path_or_bytes, {
|
|
63
69
|
enable_cpu_mem_arena: true,
|
64
70
|
enable_mem_pattern: true,
|
65
71
|
enable_profiling: false,
|
66
|
-
execution_mode: :sequential,
|
67
|
-
graph_optimization_level: nil,
|
72
|
+
execution_mode: :sequential, # :sequential or :parallel
|
73
|
+
graph_optimization_level: nil, # :none, :basic, :extended, or :all
|
68
74
|
inter_op_num_threads: nil,
|
69
75
|
intra_op_num_threads: nil,
|
70
76
|
log_severity_level: 2,
|
data/lib/onnxruntime.rb
CHANGED
data/lib/onnxruntime/ffi.rb
CHANGED
@@ -3,10 +3,13 @@ module OnnxRuntime
|
|
3
3
|
extend ::FFI::Library
|
4
4
|
|
5
5
|
begin
|
6
|
-
ffi_lib
|
6
|
+
ffi_lib OnnxRuntime.ffi_lib
|
7
7
|
rescue LoadError => e
|
8
|
-
|
9
|
-
|
8
|
+
if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
|
9
|
+
raise LoadError, "OpenMP not found. Run `brew install libomp`"
|
10
|
+
else
|
11
|
+
raise e
|
12
|
+
end
|
10
13
|
end
|
11
14
|
|
12
15
|
# https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
|
@@ -20,19 +23,19 @@ module OnnxRuntime
|
|
20
23
|
layout \
|
21
24
|
:CreateStatus, callback(%i[int string], :pointer),
|
22
25
|
:GetErrorCode, callback(%i[pointer], :pointer),
|
23
|
-
:GetErrorMessage, callback(%i[pointer], :
|
26
|
+
:GetErrorMessage, callback(%i[pointer], :pointer),
|
24
27
|
:CreateEnv, callback(%i[int string pointer], :pointer),
|
25
28
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
26
29
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
27
30
|
:DisableTelemetryEvents, callback(%i[pointer], :pointer),
|
28
|
-
:CreateSession, callback(%i[pointer
|
31
|
+
:CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
|
29
32
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
30
33
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
31
34
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[pointer
|
35
|
+
:SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
|
33
36
|
:CloneSessionOptions, callback(%i[], :pointer),
|
34
37
|
:SetSessionExecutionMode, callback(%i[], :pointer),
|
35
|
-
:EnableProfiling, callback(%i[pointer
|
38
|
+
:EnableProfiling, callback(%i[pointer pointer], :pointer),
|
36
39
|
:DisableProfiling, callback(%i[pointer], :pointer),
|
37
40
|
:EnableMemPattern, callback(%i[pointer], :pointer),
|
38
41
|
:DisableMemPattern, callback(%i[pointer], :pointer),
|
@@ -136,7 +139,15 @@ module OnnxRuntime
|
|
136
139
|
:ModelMetadataGetDescription, callback(%i[pointer pointer pointer], :pointer),
|
137
140
|
:ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer),
|
138
141
|
:ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer),
|
139
|
-
:ReleaseModelMetadata, callback(%i[pointer], :void)
|
142
|
+
:ReleaseModelMetadata, callback(%i[pointer], :void),
|
143
|
+
:CreateEnvWithGlobalThreadPools, callback(%i[], :pointer),
|
144
|
+
:DisablePerSessionThreads, callback(%i[], :pointer),
|
145
|
+
:CreateThreadingOptions, callback(%i[], :pointer),
|
146
|
+
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
147
|
+
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
148
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer),
|
149
|
+
:GetAvailableProviders, callback(%i[pointer pointer], :pointer),
|
150
|
+
:ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
|
140
151
|
end
|
141
152
|
|
142
153
|
class ApiBase < ::FFI::Struct
|
@@ -144,9 +155,17 @@ module OnnxRuntime
|
|
144
155
|
# to prevent "unable to resolve type" error on Ubuntu
|
145
156
|
layout \
|
146
157
|
:GetApi, callback(%i[uint32], Api.by_ref),
|
147
|
-
:GetVersionString, callback(%i[], :
|
158
|
+
:GetVersionString, callback(%i[], :pointer)
|
148
159
|
end
|
149
160
|
|
150
161
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
162
|
+
|
163
|
+
if Gem.win_platform?
|
164
|
+
class Libc
|
165
|
+
extend ::FFI::Library
|
166
|
+
ffi_lib ::FFI::Library::LIBC
|
167
|
+
attach_function :mbstowcs, %i[pointer string size_t], :size_t
|
168
|
+
end
|
169
|
+
end
|
151
170
|
end
|
152
171
|
end
|
@@ -8,26 +8,25 @@ module OnnxRuntime
|
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
9
9
|
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
10
|
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
-
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
12
12
|
if execution_mode
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
0
|
17
|
-
when :parallel
|
18
|
-
1
|
19
|
-
else
|
20
|
-
raise ArgumentError, "Invalid execution mode"
|
21
|
-
end
|
13
|
+
execution_modes = {sequential: 0, parallel: 1}
|
14
|
+
mode = execution_modes[execution_mode]
|
15
|
+
raise ArgumentError, "Invalid execution mode" unless mode
|
22
16
|
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
|
23
17
|
end
|
24
|
-
|
18
|
+
if graph_optimization_level
|
19
|
+
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
22
|
+
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
|
+
end
|
25
24
|
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
26
25
|
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
|
27
26
|
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
28
27
|
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
29
28
|
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
30
|
-
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
31
30
|
|
32
31
|
# session
|
33
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
@@ -40,17 +39,12 @@ module OnnxRuntime
|
|
40
39
|
path_or_bytes.encoding == Encoding::BINARY
|
41
40
|
end
|
42
41
|
|
43
|
-
# fix for Windows "File doesn't exist"
|
44
|
-
if Gem.win_platform? && !from_memory
|
45
|
-
path_or_bytes = File.binread(path_or_bytes)
|
46
|
-
from_memory = true
|
47
|
-
end
|
48
|
-
|
49
42
|
if from_memory
|
50
43
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
51
44
|
else
|
52
|
-
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
53
46
|
end
|
47
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
54
48
|
|
55
49
|
# input info
|
56
50
|
allocator = ::FFI::MemoryPointer.new(:pointer)
|
@@ -63,7 +57,7 @@ module OnnxRuntime
|
|
63
57
|
# input
|
64
58
|
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
65
59
|
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
66
|
-
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
67
61
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
68
62
|
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
69
63
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -74,13 +68,15 @@ module OnnxRuntime
|
|
74
68
|
# output
|
75
69
|
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
76
70
|
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
77
|
-
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
78
72
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
79
73
|
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
80
74
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
81
75
|
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
|
82
76
|
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
83
77
|
end
|
78
|
+
ensure
|
79
|
+
# release :SessionOptions, session_options
|
84
80
|
end
|
85
81
|
|
86
82
|
# TODO support logid
|
@@ -106,9 +102,18 @@ module OnnxRuntime
|
|
106
102
|
output_names.size.times.map do |i|
|
107
103
|
create_from_onnx_value(output_tensor[i].read_pointer)
|
108
104
|
end
|
105
|
+
ensure
|
106
|
+
release :RunOptions, run_options
|
107
|
+
if input_tensor
|
108
|
+
input_feed.size.times do |i|
|
109
|
+
release :Value, input_tensor[i]
|
110
|
+
end
|
111
|
+
end
|
109
112
|
end
|
110
113
|
|
111
114
|
def modelmeta
|
115
|
+
keys = ::FFI::MemoryPointer.new(:pointer)
|
116
|
+
num_keys = ::FFI::MemoryPointer.new(:int64_t)
|
112
117
|
description = ::FFI::MemoryPointer.new(:string)
|
113
118
|
domain = ::FFI::MemoryPointer.new(:string)
|
114
119
|
graph_name = ::FFI::MemoryPointer.new(:string)
|
@@ -117,36 +122,61 @@ module OnnxRuntime
|
|
117
122
|
|
118
123
|
metadata = ::FFI::MemoryPointer.new(:pointer)
|
119
124
|
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
125
|
+
|
126
|
+
custom_metadata_map = {}
|
127
|
+
check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
128
|
+
num_keys.read(:int64_t).times do |i|
|
129
|
+
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
130
|
+
value = ::FFI::MemoryPointer.new(:string)
|
131
|
+
check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
|
132
|
+
custom_metadata_map[key] = value.read_pointer.read_string
|
133
|
+
end
|
134
|
+
|
120
135
|
check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
|
121
136
|
check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
|
122
137
|
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
|
123
138
|
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
|
124
139
|
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
|
125
|
-
api[:ReleaseModelMetadata].call(metadata.read_pointer)
|
126
|
-
|
127
|
-
# TODO add custom_metadata_map
|
128
|
-
# need a way to get keys
|
129
140
|
|
130
141
|
{
|
142
|
+
custom_metadata_map: custom_metadata_map,
|
131
143
|
description: description.read_pointer.read_string,
|
132
144
|
domain: domain.read_pointer.read_string,
|
133
145
|
graph_name: graph_name.read_pointer.read_string,
|
134
146
|
producer_name: producer_name.read_pointer.read_string,
|
135
147
|
version: version.read(:int64_t)
|
136
148
|
}
|
149
|
+
ensure
|
150
|
+
release :ModelMetadata, metadata
|
137
151
|
end
|
138
152
|
|
153
|
+
# return value has double underscore like Python
|
139
154
|
def end_profiling
|
140
155
|
out = ::FFI::MemoryPointer.new(:string)
|
141
156
|
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
142
157
|
out.read_pointer.read_string
|
143
158
|
end
|
144
159
|
|
160
|
+
# no way to set providers with C API yet
|
161
|
+
# so we can return all available providers
|
162
|
+
def providers
|
163
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
164
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
165
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
166
|
+
length = length_ptr.read_int
|
167
|
+
providers = []
|
168
|
+
length.times do |i|
|
169
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
170
|
+
end
|
171
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
172
|
+
providers
|
173
|
+
end
|
174
|
+
|
145
175
|
private
|
146
176
|
|
147
177
|
def create_input_tensor(input_feed)
|
148
178
|
allocator_info = ::FFI::MemoryPointer.new(:pointer)
|
149
|
-
check_status
|
179
|
+
check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
|
150
180
|
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
|
151
181
|
|
152
182
|
input_feed.each_with_index do |(input_name, input), idx|
|
@@ -164,7 +194,7 @@ module OnnxRuntime
|
|
164
194
|
|
165
195
|
# TODO support more types
|
166
196
|
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
167
|
-
raise "Unknown input: #{input_name}" unless inp
|
197
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
168
198
|
|
169
199
|
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
170
200
|
input_node_dims.write_array_of_int64(shape)
|
@@ -206,7 +236,7 @@ module OnnxRuntime
|
|
206
236
|
|
207
237
|
def create_from_onnx_value(out_ptr)
|
208
238
|
out_type = ::FFI::MemoryPointer.new(:int)
|
209
|
-
check_status
|
239
|
+
check_status api[:GetValueType].call(out_ptr, out_type)
|
210
240
|
type = FFI::OnnxType[out_type.read_int]
|
211
241
|
|
212
242
|
case type
|
@@ -221,7 +251,9 @@ module OnnxRuntime
|
|
221
251
|
|
222
252
|
out_size = ::FFI::MemoryPointer.new(:size_t)
|
223
253
|
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
224
|
-
output_tensor_size =
|
254
|
+
output_tensor_size = out_size.read(:size_t)
|
255
|
+
|
256
|
+
release :TensorTypeAndShapeInfo, typeinfo
|
225
257
|
|
226
258
|
# TODO support more types
|
227
259
|
type = FFI::TensorElementDataType[type]
|
@@ -240,7 +272,7 @@ module OnnxRuntime
|
|
240
272
|
out = ::FFI::MemoryPointer.new(:size_t)
|
241
273
|
check_status api[:GetValueCount].call(out_ptr, out)
|
242
274
|
|
243
|
-
|
275
|
+
out.read(:size_t).times.map do |i|
|
244
276
|
seq = ::FFI::MemoryPointer.new(:pointer)
|
245
277
|
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
246
278
|
create_from_onnx_value(seq.read_pointer)
|
@@ -255,6 +287,7 @@ module OnnxRuntime
|
|
255
287
|
check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
|
256
288
|
check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
|
257
289
|
check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
|
290
|
+
release :TensorTypeAndShapeInfo, type_shape
|
258
291
|
|
259
292
|
# TODO support more types
|
260
293
|
elem_type = FFI::TensorElementDataType[elem_type.read_int]
|
@@ -281,9 +314,9 @@ module OnnxRuntime
|
|
281
314
|
|
282
315
|
def check_status(status)
|
283
316
|
unless status.null?
|
284
|
-
message = api[:GetErrorMessage].call(status)
|
317
|
+
message = api[:GetErrorMessage].call(status).read_string
|
285
318
|
api[:ReleaseStatus].call(status)
|
286
|
-
raise
|
319
|
+
raise Error, message
|
287
320
|
end
|
288
321
|
end
|
289
322
|
|
@@ -295,6 +328,7 @@ module OnnxRuntime
|
|
295
328
|
case type
|
296
329
|
when :tensor
|
297
330
|
tensor_info = ::FFI::MemoryPointer.new(:pointer)
|
331
|
+
# don't free tensor_info
|
298
332
|
check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
|
299
333
|
|
300
334
|
type, shape = tensor_type_and_shape(tensor_info)
|
@@ -335,7 +369,7 @@ module OnnxRuntime
|
|
335
369
|
unsupported_type("ONNX", type)
|
336
370
|
end
|
337
371
|
ensure
|
338
|
-
|
372
|
+
release :TypeInfo, typeinfo
|
339
373
|
end
|
340
374
|
|
341
375
|
def tensor_type_and_shape(tensor_info)
|
@@ -344,7 +378,7 @@ module OnnxRuntime
|
|
344
378
|
|
345
379
|
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
346
380
|
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
347
|
-
num_dims =
|
381
|
+
num_dims = num_dims_ptr.read(:size_t)
|
348
382
|
|
349
383
|
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
350
384
|
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
@@ -353,20 +387,43 @@ module OnnxRuntime
|
|
353
387
|
end
|
354
388
|
|
355
389
|
def unsupported_type(name, type)
|
356
|
-
raise "Unsupported #{name} type: #{type}"
|
390
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
357
391
|
end
|
358
392
|
|
359
|
-
|
360
|
-
|
361
|
-
if RUBY_PLATFORM == "java"
|
362
|
-
ptr.read_long
|
363
|
-
else
|
364
|
-
ptr.read(:size_t)
|
365
|
-
end
|
393
|
+
def api
|
394
|
+
self.class.api
|
366
395
|
end
|
367
396
|
|
368
|
-
def
|
369
|
-
|
397
|
+
def release(*args)
|
398
|
+
self.class.release(*args)
|
399
|
+
end
|
400
|
+
|
401
|
+
def self.api
|
402
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
403
|
+
end
|
404
|
+
|
405
|
+
def self.release(type, pointer)
|
406
|
+
api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
|
407
|
+
end
|
408
|
+
|
409
|
+
def self.finalize(session)
|
410
|
+
# must use proc instead of stabby lambda
|
411
|
+
proc { release :Session, session }
|
412
|
+
end
|
413
|
+
|
414
|
+
# wide string on Windows
|
415
|
+
# char string on Linux
|
416
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
417
|
+
def ort_string(str)
|
418
|
+
if Gem.win_platform?
|
419
|
+
max = str.size + 1 # for null byte
|
420
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
421
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
422
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
423
|
+
dest
|
424
|
+
else
|
425
|
+
str
|
426
|
+
end
|
370
427
|
end
|
371
428
|
|
372
429
|
def env
|
@@ -375,7 +432,7 @@ module OnnxRuntime
|
|
375
432
|
@@env ||= begin
|
376
433
|
env = ::FFI::MemoryPointer.new(:pointer)
|
377
434
|
check_status api[:CreateEnv].call(3, "Default", env)
|
378
|
-
at_exit {
|
435
|
+
at_exit { release :Env, env }
|
379
436
|
# disable telemetry
|
380
437
|
# https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
|
381
438
|
check_status api[:DisableTelemetryEvents].call(env)
|