onnxruntime 0.2.3 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +9 -3
- data/lib/onnxruntime.rb +1 -1
- data/lib/onnxruntime/ffi.rb +40 -12
- data/lib/onnxruntime/inference_session.rb +160 -44
- data/lib/onnxruntime/model.rb +4 -0
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 54260e1a83f205da2a0a016cc5b6a8508aa8969b484c22f18db966429e6007b0
|
4
|
+
data.tar.gz: 44b4310ff48bb154057adf1ddc0622980148cfda47d7b973f01ae58f4e5e7416
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d61177039a5314b80342b627d5f7f1f8352be51d95dd71568d9b3ef2a5d970f9b9763023e7ca653f86689aa0ade0188c2d3a28a6f2cd3c890c43f7358952f77a
|
7
|
+
data.tar.gz: 7c2a6c4c52e93fc3f3dc4a297b921eb69606b89e1c05d40ac736dd731d8e5ab2a51c5173c2cde7b081ec46f19d88a333adf96c01cef485ba6d73943e9160c640
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.4.0 (2020-07-20)
|
2
|
+
|
3
|
+
- Updated ONNX Runtime to 1.4.0
|
4
|
+
- Added `providers` method
|
5
|
+
- Fixed errors on Windows
|
6
|
+
|
7
|
+
## 0.3.3 (2020-06-17)
|
8
|
+
|
9
|
+
- Fixed segmentation fault on exit on Linux
|
10
|
+
|
11
|
+
## 0.3.2 (2020-06-16)
|
12
|
+
|
13
|
+
- Fixed error with FFI 1.13.0+
|
14
|
+
- Added friendly graph optimization levels
|
15
|
+
|
16
|
+
## 0.3.1 (2020-05-18)
|
17
|
+
|
18
|
+
- Updated ONNX Runtime to 1.3.0
|
19
|
+
- Added `custom_metadata_map` to model metadata
|
20
|
+
|
21
|
+
## 0.3.0 (2020-03-11)
|
22
|
+
|
23
|
+
- Updated ONNX Runtime to 1.2.0
|
24
|
+
- Added model metadata
|
25
|
+
- Added `end_profiling` method
|
26
|
+
- Added support for loading from IO objects
|
27
|
+
- Improved `input` and `output` for `seq` and `map` types
|
28
|
+
|
1
29
|
## 0.2.3 (2020-01-23)
|
2
30
|
|
3
31
|
- Updated ONNX Runtime to 1.1.1
|
data/README.md
CHANGED
@@ -37,10 +37,16 @@ Get outputs
|
|
37
37
|
model.outputs
|
38
38
|
```
|
39
39
|
|
40
|
+
Get metadata
|
41
|
+
|
42
|
+
```ruby
|
43
|
+
model.metadata
|
44
|
+
```
|
45
|
+
|
40
46
|
Load a model from a string
|
41
47
|
|
42
48
|
```ruby
|
43
|
-
byte_str =
|
49
|
+
byte_str = StringIO.new("...")
|
44
50
|
model = OnnxRuntime::Model.new(byte_str)
|
45
51
|
```
|
46
52
|
|
@@ -57,8 +63,8 @@ OnnxRuntime::Model.new(path_or_bytes, {
|
|
57
63
|
enable_cpu_mem_arena: true,
|
58
64
|
enable_mem_pattern: true,
|
59
65
|
enable_profiling: false,
|
60
|
-
execution_mode: :sequential,
|
61
|
-
graph_optimization_level: nil,
|
66
|
+
execution_mode: :sequential, # :sequential or :parallel
|
67
|
+
graph_optimization_level: nil, # :none, :basic, :extended, or :all
|
62
68
|
inter_op_num_threads: nil,
|
63
69
|
intra_op_num_threads: nil,
|
64
70
|
log_severity_level: 2,
|
data/lib/onnxruntime.rb
CHANGED
data/lib/onnxruntime/ffi.rb
CHANGED
@@ -2,12 +2,7 @@ module OnnxRuntime
|
|
2
2
|
module FFI
|
3
3
|
extend ::FFI::Library
|
4
4
|
|
5
|
-
|
6
|
-
ffi_lib Array(OnnxRuntime.ffi_lib)
|
7
|
-
rescue LoadError => e
|
8
|
-
raise e if ENV["ONNXRUNTIME_DEBUG"]
|
9
|
-
raise LoadError, "Could not find ONNX Runtime"
|
10
|
-
end
|
5
|
+
ffi_lib Array(OnnxRuntime.ffi_lib)
|
11
6
|
|
12
7
|
# https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
|
13
8
|
# keep same order
|
@@ -20,19 +15,19 @@ module OnnxRuntime
|
|
20
15
|
layout \
|
21
16
|
:CreateStatus, callback(%i[int string], :pointer),
|
22
17
|
:GetErrorCode, callback(%i[pointer], :pointer),
|
23
|
-
:GetErrorMessage, callback(%i[pointer], :
|
18
|
+
:GetErrorMessage, callback(%i[pointer], :pointer),
|
24
19
|
:CreateEnv, callback(%i[int string pointer], :pointer),
|
25
20
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
26
21
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
27
22
|
:DisableTelemetryEvents, callback(%i[pointer], :pointer),
|
28
|
-
:CreateSession, callback(%i[pointer
|
23
|
+
:CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
|
29
24
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
30
25
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
31
26
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[pointer
|
27
|
+
:SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
|
33
28
|
:CloneSessionOptions, callback(%i[], :pointer),
|
34
29
|
:SetSessionExecutionMode, callback(%i[], :pointer),
|
35
|
-
:EnableProfiling, callback(%i[pointer
|
30
|
+
:EnableProfiling, callback(%i[pointer pointer], :pointer),
|
36
31
|
:DisableProfiling, callback(%i[pointer], :pointer),
|
37
32
|
:EnableMemPattern, callback(%i[pointer], :pointer),
|
38
33
|
:DisableMemPattern, callback(%i[pointer], :pointer),
|
@@ -119,7 +114,32 @@ module OnnxRuntime
|
|
119
114
|
:ReleaseTypeInfo, callback(%i[pointer], :void),
|
120
115
|
:ReleaseTensorTypeAndShapeInfo, callback(%i[pointer], :void),
|
121
116
|
:ReleaseSessionOptions, callback(%i[pointer], :void),
|
122
|
-
:ReleaseCustomOpDomain, callback(%i[pointer], :void)
|
117
|
+
:ReleaseCustomOpDomain, callback(%i[pointer], :void),
|
118
|
+
:GetDenotationFromTypeInfo, callback(%i[], :pointer),
|
119
|
+
:CastTypeInfoToMapTypeInfo, callback(%i[pointer pointer], :pointer),
|
120
|
+
:CastTypeInfoToSequenceTypeInfo, callback(%i[pointer pointer], :pointer),
|
121
|
+
:GetMapKeyType, callback(%i[pointer pointer], :pointer),
|
122
|
+
:GetMapValueType, callback(%i[pointer pointer], :pointer),
|
123
|
+
:GetSequenceElementType, callback(%i[pointer pointer], :pointer),
|
124
|
+
:ReleaseMapTypeInfo, callback(%i[pointer], :void),
|
125
|
+
:ReleaseSequenceTypeInfo, callback(%i[pointer], :void),
|
126
|
+
:SessionEndProfiling, callback(%i[pointer pointer pointer], :pointer),
|
127
|
+
:SessionGetModelMetadata, callback(%i[pointer pointer], :pointer),
|
128
|
+
:ModelMetadataGetProducerName, callback(%i[pointer pointer pointer], :pointer),
|
129
|
+
:ModelMetadataGetGraphName, callback(%i[pointer pointer pointer], :pointer),
|
130
|
+
:ModelMetadataGetDomain, callback(%i[pointer pointer pointer], :pointer),
|
131
|
+
:ModelMetadataGetDescription, callback(%i[pointer pointer pointer], :pointer),
|
132
|
+
:ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer),
|
133
|
+
:ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer),
|
134
|
+
:ReleaseModelMetadata, callback(%i[pointer], :void),
|
135
|
+
:CreateEnvWithGlobalThreadPools, callback(%i[], :pointer),
|
136
|
+
:DisablePerSessionThreads, callback(%i[], :pointer),
|
137
|
+
:CreateThreadingOptions, callback(%i[], :pointer),
|
138
|
+
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
139
|
+
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
140
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer),
|
141
|
+
:GetAvailableProviders, callback(%i[pointer pointer], :pointer),
|
142
|
+
:ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
|
123
143
|
end
|
124
144
|
|
125
145
|
class ApiBase < ::FFI::Struct
|
@@ -127,9 +147,17 @@ module OnnxRuntime
|
|
127
147
|
# to prevent "unable to resolve type" error on Ubuntu
|
128
148
|
layout \
|
129
149
|
:GetApi, callback(%i[uint32], Api.by_ref),
|
130
|
-
:GetVersionString, callback(%i[], :
|
150
|
+
:GetVersionString, callback(%i[], :pointer)
|
131
151
|
end
|
132
152
|
|
133
153
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
154
|
+
|
155
|
+
if Gem.win_platform?
|
156
|
+
class Libc
|
157
|
+
extend ::FFI::Library
|
158
|
+
ffi_lib ::FFI::Library::LIBC
|
159
|
+
attach_function :mbstowcs, %i[pointer string size_t], :size_t
|
160
|
+
end
|
161
|
+
end
|
134
162
|
end
|
135
163
|
end
|
@@ -8,41 +8,43 @@ module OnnxRuntime
|
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
9
9
|
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
10
|
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
-
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
12
12
|
if execution_mode
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
0
|
17
|
-
when :parallel
|
18
|
-
1
|
19
|
-
else
|
20
|
-
raise ArgumentError, "Invalid execution mode"
|
21
|
-
end
|
13
|
+
execution_modes = {sequential: 0, parallel: 1}
|
14
|
+
mode = execution_modes[execution_mode]
|
15
|
+
raise ArgumentError, "Invalid execution mode" unless mode
|
22
16
|
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
|
23
17
|
end
|
24
|
-
|
18
|
+
if graph_optimization_level
|
19
|
+
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
22
|
+
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
|
+
end
|
25
24
|
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
26
25
|
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
|
27
26
|
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
28
27
|
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
29
28
|
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
30
|
-
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
31
30
|
|
32
31
|
# session
|
33
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
33
|
+
from_memory =
|
34
|
+
if path_or_bytes.respond_to?(:read)
|
35
|
+
path_or_bytes = path_or_bytes.read
|
36
|
+
true
|
37
|
+
else
|
38
|
+
path_or_bytes = path_or_bytes.to_str
|
39
|
+
path_or_bytes.encoding == Encoding::BINARY
|
40
|
+
end
|
40
41
|
|
41
|
-
if
|
42
|
+
if from_memory
|
42
43
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
43
44
|
else
|
44
|
-
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
45
46
|
end
|
47
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
46
48
|
|
47
49
|
# input info
|
48
50
|
allocator = ::FFI::MemoryPointer.new(:pointer)
|
@@ -55,7 +57,7 @@ module OnnxRuntime
|
|
55
57
|
# input
|
56
58
|
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
57
59
|
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
58
|
-
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
59
61
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
60
62
|
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
61
63
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -66,13 +68,15 @@ module OnnxRuntime
|
|
66
68
|
# output
|
67
69
|
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
68
70
|
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
69
|
-
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
70
72
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
71
73
|
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
72
74
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
73
75
|
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
|
74
76
|
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
75
77
|
end
|
78
|
+
ensure
|
79
|
+
# release :SessionOptions, session_options
|
76
80
|
end
|
77
81
|
|
78
82
|
# TODO support logid
|
@@ -98,6 +102,74 @@ module OnnxRuntime
|
|
98
102
|
output_names.size.times.map do |i|
|
99
103
|
create_from_onnx_value(output_tensor[i].read_pointer)
|
100
104
|
end
|
105
|
+
ensure
|
106
|
+
release :RunOptions, run_options
|
107
|
+
if input_tensor
|
108
|
+
input_feed.size.times do |i|
|
109
|
+
release :Value, input_tensor[i]
|
110
|
+
end
|
111
|
+
end
|
112
|
+
end
|
113
|
+
|
114
|
+
def modelmeta
|
115
|
+
keys = ::FFI::MemoryPointer.new(:pointer)
|
116
|
+
num_keys = ::FFI::MemoryPointer.new(:int64_t)
|
117
|
+
description = ::FFI::MemoryPointer.new(:string)
|
118
|
+
domain = ::FFI::MemoryPointer.new(:string)
|
119
|
+
graph_name = ::FFI::MemoryPointer.new(:string)
|
120
|
+
producer_name = ::FFI::MemoryPointer.new(:string)
|
121
|
+
version = ::FFI::MemoryPointer.new(:int64_t)
|
122
|
+
|
123
|
+
metadata = ::FFI::MemoryPointer.new(:pointer)
|
124
|
+
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
125
|
+
|
126
|
+
custom_metadata_map = {}
|
127
|
+
check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
128
|
+
num_keys.read(:int64_t).times do |i|
|
129
|
+
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
130
|
+
value = ::FFI::MemoryPointer.new(:string)
|
131
|
+
check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
|
132
|
+
custom_metadata_map[key] = value.read_pointer.read_string
|
133
|
+
end
|
134
|
+
|
135
|
+
check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
|
136
|
+
check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
|
137
|
+
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
|
138
|
+
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
|
139
|
+
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
|
140
|
+
|
141
|
+
{
|
142
|
+
custom_metadata_map: custom_metadata_map,
|
143
|
+
description: description.read_pointer.read_string,
|
144
|
+
domain: domain.read_pointer.read_string,
|
145
|
+
graph_name: graph_name.read_pointer.read_string,
|
146
|
+
producer_name: producer_name.read_pointer.read_string,
|
147
|
+
version: version.read(:int64_t)
|
148
|
+
}
|
149
|
+
ensure
|
150
|
+
release :ModelMetadata, metadata
|
151
|
+
end
|
152
|
+
|
153
|
+
# return value has double underscore like Python
|
154
|
+
def end_profiling
|
155
|
+
out = ::FFI::MemoryPointer.new(:string)
|
156
|
+
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
157
|
+
out.read_pointer.read_string
|
158
|
+
end
|
159
|
+
|
160
|
+
# no way to set providers with C API yet
|
161
|
+
# so we can return all available providers
|
162
|
+
def providers
|
163
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
164
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
165
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
166
|
+
length = length_ptr.read_int
|
167
|
+
providers = []
|
168
|
+
length.times do |i|
|
169
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
170
|
+
end
|
171
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
172
|
+
providers
|
101
173
|
end
|
102
174
|
|
103
175
|
private
|
@@ -122,7 +194,7 @@ module OnnxRuntime
|
|
122
194
|
|
123
195
|
# TODO support more types
|
124
196
|
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
125
|
-
raise "Unknown input: #{input_name}" unless inp
|
197
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
126
198
|
|
127
199
|
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
128
200
|
input_node_dims.write_array_of_int64(shape)
|
@@ -179,7 +251,9 @@ module OnnxRuntime
|
|
179
251
|
|
180
252
|
out_size = ::FFI::MemoryPointer.new(:size_t)
|
181
253
|
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
182
|
-
output_tensor_size =
|
254
|
+
output_tensor_size = out_size.read(:size_t)
|
255
|
+
|
256
|
+
release :TensorTypeAndShapeInfo, typeinfo
|
183
257
|
|
184
258
|
# TODO support more types
|
185
259
|
type = FFI::TensorElementDataType[type]
|
@@ -198,7 +272,7 @@ module OnnxRuntime
|
|
198
272
|
out = ::FFI::MemoryPointer.new(:size_t)
|
199
273
|
check_status api[:GetValueCount].call(out_ptr, out)
|
200
274
|
|
201
|
-
|
275
|
+
out.read(:size_t).times.map do |i|
|
202
276
|
seq = ::FFI::MemoryPointer.new(:pointer)
|
203
277
|
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
204
278
|
create_from_onnx_value(seq.read_pointer)
|
@@ -213,6 +287,7 @@ module OnnxRuntime
|
|
213
287
|
check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
|
214
288
|
check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
|
215
289
|
check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
|
290
|
+
release :TensorTypeAndShapeInfo, type_shape
|
216
291
|
|
217
292
|
# TODO support more types
|
218
293
|
elem_type = FFI::TensorElementDataType[elem_type.read_int]
|
@@ -239,9 +314,9 @@ module OnnxRuntime
|
|
239
314
|
|
240
315
|
def check_status(status)
|
241
316
|
unless status.null?
|
242
|
-
message = api[:GetErrorMessage].call(status)
|
317
|
+
message = api[:GetErrorMessage].call(status).read_string
|
243
318
|
api[:ReleaseStatus].call(status)
|
244
|
-
raise
|
319
|
+
raise Error, message
|
245
320
|
end
|
246
321
|
end
|
247
322
|
|
@@ -253,6 +328,7 @@ module OnnxRuntime
|
|
253
328
|
case type
|
254
329
|
when :tensor
|
255
330
|
tensor_info = ::FFI::MemoryPointer.new(:pointer)
|
331
|
+
# don't free tensor_info
|
256
332
|
check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
|
257
333
|
|
258
334
|
type, shape = tensor_type_and_shape(tensor_info)
|
@@ -261,22 +337,39 @@ module OnnxRuntime
|
|
261
337
|
shape: shape
|
262
338
|
}
|
263
339
|
when :sequence
|
264
|
-
|
340
|
+
sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
|
341
|
+
check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
|
342
|
+
nested_type_info = ::FFI::MemoryPointer.new(:pointer)
|
343
|
+
check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
|
344
|
+
v = node_info(nested_type_info)[:type]
|
345
|
+
|
265
346
|
{
|
266
|
-
type: "seq",
|
347
|
+
type: "seq(#{v})",
|
267
348
|
shape: []
|
268
349
|
}
|
269
350
|
when :map
|
270
|
-
|
351
|
+
map_type_info = ::FFI::MemoryPointer.new(:pointer)
|
352
|
+
check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
|
353
|
+
|
354
|
+
# key
|
355
|
+
key_type = ::FFI::MemoryPointer.new(:int)
|
356
|
+
check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
|
357
|
+
k = FFI::TensorElementDataType[key_type.read_int]
|
358
|
+
|
359
|
+
# value
|
360
|
+
value_type_info = ::FFI::MemoryPointer.new(:pointer)
|
361
|
+
check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
|
362
|
+
v = node_info(value_type_info)[:type]
|
363
|
+
|
271
364
|
{
|
272
|
-
type: "map",
|
365
|
+
type: "map(#{k},#{v})",
|
273
366
|
shape: []
|
274
367
|
}
|
275
368
|
else
|
276
369
|
unsupported_type("ONNX", type)
|
277
370
|
end
|
278
371
|
ensure
|
279
|
-
|
372
|
+
release :TypeInfo, typeinfo
|
280
373
|
end
|
281
374
|
|
282
375
|
def tensor_type_and_shape(tensor_info)
|
@@ -285,7 +378,7 @@ module OnnxRuntime
|
|
285
378
|
|
286
379
|
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
287
380
|
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
288
|
-
num_dims =
|
381
|
+
num_dims = num_dims_ptr.read(:size_t)
|
289
382
|
|
290
383
|
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
291
384
|
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
@@ -294,20 +387,43 @@ module OnnxRuntime
|
|
294
387
|
end
|
295
388
|
|
296
389
|
def unsupported_type(name, type)
|
297
|
-
raise "Unsupported #{name} type: #{type}"
|
390
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
298
391
|
end
|
299
392
|
|
300
|
-
|
301
|
-
|
302
|
-
if RUBY_PLATFORM == "java"
|
303
|
-
ptr.read_long
|
304
|
-
else
|
305
|
-
ptr.read(:size_t)
|
306
|
-
end
|
393
|
+
def api
|
394
|
+
self.class.api
|
307
395
|
end
|
308
396
|
|
309
|
-
def
|
310
|
-
|
397
|
+
def release(*args)
|
398
|
+
self.class.release(*args)
|
399
|
+
end
|
400
|
+
|
401
|
+
def self.api
|
402
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
403
|
+
end
|
404
|
+
|
405
|
+
def self.release(type, pointer)
|
406
|
+
api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
|
407
|
+
end
|
408
|
+
|
409
|
+
def self.finalize(session)
|
410
|
+
# must use proc instead of stabby lambda
|
411
|
+
proc { release :Session, session }
|
412
|
+
end
|
413
|
+
|
414
|
+
# wide string on Windows
|
415
|
+
# char string on Linux
|
416
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
417
|
+
def ort_string(str)
|
418
|
+
if Gem.win_platform?
|
419
|
+
max = str.size + 1 # for null byte
|
420
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
421
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
422
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
423
|
+
dest
|
424
|
+
else
|
425
|
+
str
|
426
|
+
end
|
311
427
|
end
|
312
428
|
|
313
429
|
def env
|
@@ -316,7 +432,7 @@ module OnnxRuntime
|
|
316
432
|
@@env ||= begin
|
317
433
|
env = ::FFI::MemoryPointer.new(:pointer)
|
318
434
|
check_status api[:CreateEnv].call(3, "Default", env)
|
319
|
-
at_exit {
|
435
|
+
at_exit { release :Env, env }
|
320
436
|
# disable telemetry
|
321
437
|
# https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
|
322
438
|
check_status api[:DisableTelemetryEvents].call(env)
|
data/lib/onnxruntime/model.rb
CHANGED
data/lib/onnxruntime/version.rb
CHANGED
data/vendor/libonnxruntime.dylib
CHANGED
Binary file
|
data/vendor/libonnxruntime.so
CHANGED
Binary file
|
data/vendor/onnxruntime.dll
CHANGED
Binary file
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: onnxruntime
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.4.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-07-21 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: ffi
|