onnxruntime 0.3.3 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/lib/onnxruntime/ffi.rb +15 -10
- data/lib/onnxruntime/inference_session.rb +45 -29
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 54260e1a83f205da2a0a016cc5b6a8508aa8969b484c22f18db966429e6007b0
|
4
|
+
data.tar.gz: 44b4310ff48bb154057adf1ddc0622980148cfda47d7b973f01ae58f4e5e7416
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d61177039a5314b80342b627d5f7f1f8352be51d95dd71568d9b3ef2a5d970f9b9763023e7ca653f86689aa0ade0188c2d3a28a6f2cd3c890c43f7358952f77a
|
7
|
+
data.tar.gz: 7c2a6c4c52e93fc3f3dc4a297b921eb69606b89e1c05d40ac736dd731d8e5ab2a51c5173c2cde7b081ec46f19d88a333adf96c01cef485ba6d73943e9160c640
|
data/CHANGELOG.md
CHANGED
data/lib/onnxruntime/ffi.rb
CHANGED
@@ -2,12 +2,7 @@ module OnnxRuntime
|
|
2
2
|
module FFI
|
3
3
|
extend ::FFI::Library
|
4
4
|
|
5
|
-
|
6
|
-
ffi_lib Array(OnnxRuntime.ffi_lib)
|
7
|
-
rescue LoadError => e
|
8
|
-
raise e if ENV["ONNXRUNTIME_DEBUG"]
|
9
|
-
raise LoadError, "Could not find ONNX Runtime"
|
10
|
-
end
|
5
|
+
ffi_lib Array(OnnxRuntime.ffi_lib)
|
11
6
|
|
12
7
|
# https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
|
13
8
|
# keep same order
|
@@ -25,14 +20,14 @@ module OnnxRuntime
|
|
25
20
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
26
21
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
27
22
|
:DisableTelemetryEvents, callback(%i[pointer], :pointer),
|
28
|
-
:CreateSession, callback(%i[pointer
|
23
|
+
:CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
|
29
24
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
30
25
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
31
26
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[pointer
|
27
|
+
:SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
|
33
28
|
:CloneSessionOptions, callback(%i[], :pointer),
|
34
29
|
:SetSessionExecutionMode, callback(%i[], :pointer),
|
35
|
-
:EnableProfiling, callback(%i[pointer
|
30
|
+
:EnableProfiling, callback(%i[pointer pointer], :pointer),
|
36
31
|
:DisableProfiling, callback(%i[pointer], :pointer),
|
37
32
|
:EnableMemPattern, callback(%i[pointer], :pointer),
|
38
33
|
:DisableMemPattern, callback(%i[pointer], :pointer),
|
@@ -142,7 +137,9 @@ module OnnxRuntime
|
|
142
137
|
:CreateThreadingOptions, callback(%i[], :pointer),
|
143
138
|
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
144
139
|
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
145
|
-
:AddFreeDimensionOverrideByName, callback(%i[], :pointer)
|
140
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer),
|
141
|
+
:GetAvailableProviders, callback(%i[pointer pointer], :pointer),
|
142
|
+
:ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
|
146
143
|
end
|
147
144
|
|
148
145
|
class ApiBase < ::FFI::Struct
|
@@ -154,5 +151,13 @@ module OnnxRuntime
|
|
154
151
|
end
|
155
152
|
|
156
153
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
154
|
+
|
155
|
+
if Gem.win_platform?
|
156
|
+
class Libc
|
157
|
+
extend ::FFI::Library
|
158
|
+
ffi_lib ::FFI::Library::LIBC
|
159
|
+
attach_function :mbstowcs, %i[pointer string size_t], :size_t
|
160
|
+
end
|
161
|
+
end
|
157
162
|
end
|
158
163
|
end
|
@@ -8,7 +8,7 @@ module OnnxRuntime
|
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
9
9
|
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
10
|
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
-
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
12
12
|
if execution_mode
|
13
13
|
execution_modes = {sequential: 0, parallel: 1}
|
14
14
|
mode = execution_modes[execution_mode]
|
@@ -17,8 +17,8 @@ module OnnxRuntime
|
|
17
17
|
end
|
18
18
|
if graph_optimization_level
|
19
19
|
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
-
|
21
|
-
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
22
22
|
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
23
|
end
|
24
24
|
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
@@ -26,7 +26,7 @@ module OnnxRuntime
|
|
26
26
|
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
27
27
|
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
28
28
|
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
29
|
-
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
30
30
|
|
31
31
|
# session
|
32
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
@@ -39,16 +39,10 @@ module OnnxRuntime
|
|
39
39
|
path_or_bytes.encoding == Encoding::BINARY
|
40
40
|
end
|
41
41
|
|
42
|
-
# fix for Windows "File doesn't exist"
|
43
|
-
if Gem.win_platform? && !from_memory
|
44
|
-
path_or_bytes = File.binread(path_or_bytes)
|
45
|
-
from_memory = true
|
46
|
-
end
|
47
|
-
|
48
42
|
if from_memory
|
49
43
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
50
44
|
else
|
51
|
-
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
52
46
|
end
|
53
47
|
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
54
48
|
|
@@ -63,7 +57,7 @@ module OnnxRuntime
|
|
63
57
|
# input
|
64
58
|
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
65
59
|
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
66
|
-
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
67
61
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
68
62
|
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
69
63
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -74,7 +68,7 @@ module OnnxRuntime
|
|
74
68
|
# output
|
75
69
|
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
76
70
|
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
77
|
-
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
78
72
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
79
73
|
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
80
74
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -156,12 +150,28 @@ module OnnxRuntime
|
|
156
150
|
release :ModelMetadata, metadata
|
157
151
|
end
|
158
152
|
|
153
|
+
# return value has double underscore like Python
|
159
154
|
def end_profiling
|
160
155
|
out = ::FFI::MemoryPointer.new(:string)
|
161
156
|
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
162
157
|
out.read_pointer.read_string
|
163
158
|
end
|
164
159
|
|
160
|
+
# no way to set providers with C API yet
|
161
|
+
# so we can return all available providers
|
162
|
+
def providers
|
163
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
164
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
165
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
166
|
+
length = length_ptr.read_int
|
167
|
+
providers = []
|
168
|
+
length.times do |i|
|
169
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
170
|
+
end
|
171
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
172
|
+
providers
|
173
|
+
end
|
174
|
+
|
165
175
|
private
|
166
176
|
|
167
177
|
def create_input_tensor(input_feed)
|
@@ -184,7 +194,7 @@ module OnnxRuntime
|
|
184
194
|
|
185
195
|
# TODO support more types
|
186
196
|
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
187
|
-
raise "Unknown input: #{input_name}" unless inp
|
197
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
188
198
|
|
189
199
|
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
190
200
|
input_node_dims.write_array_of_int64(shape)
|
@@ -241,7 +251,7 @@ module OnnxRuntime
|
|
241
251
|
|
242
252
|
out_size = ::FFI::MemoryPointer.new(:size_t)
|
243
253
|
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
244
|
-
output_tensor_size =
|
254
|
+
output_tensor_size = out_size.read(:size_t)
|
245
255
|
|
246
256
|
release :TensorTypeAndShapeInfo, typeinfo
|
247
257
|
|
@@ -262,7 +272,7 @@ module OnnxRuntime
|
|
262
272
|
out = ::FFI::MemoryPointer.new(:size_t)
|
263
273
|
check_status api[:GetValueCount].call(out_ptr, out)
|
264
274
|
|
265
|
-
|
275
|
+
out.read(:size_t).times.map do |i|
|
266
276
|
seq = ::FFI::MemoryPointer.new(:pointer)
|
267
277
|
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
268
278
|
create_from_onnx_value(seq.read_pointer)
|
@@ -306,7 +316,7 @@ module OnnxRuntime
|
|
306
316
|
unless status.null?
|
307
317
|
message = api[:GetErrorMessage].call(status).read_string
|
308
318
|
api[:ReleaseStatus].call(status)
|
309
|
-
raise
|
319
|
+
raise Error, message
|
310
320
|
end
|
311
321
|
end
|
312
322
|
|
@@ -368,7 +378,7 @@ module OnnxRuntime
|
|
368
378
|
|
369
379
|
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
370
380
|
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
371
|
-
num_dims =
|
381
|
+
num_dims = num_dims_ptr.read(:size_t)
|
372
382
|
|
373
383
|
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
374
384
|
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
@@ -377,16 +387,7 @@ module OnnxRuntime
|
|
377
387
|
end
|
378
388
|
|
379
389
|
def unsupported_type(name, type)
|
380
|
-
raise "Unsupported #{name} type: #{type}"
|
381
|
-
end
|
382
|
-
|
383
|
-
# read(:size_t) not supported in FFI JRuby
|
384
|
-
def read_size_t(ptr)
|
385
|
-
if RUBY_PLATFORM == "java"
|
386
|
-
ptr.read_long
|
387
|
-
else
|
388
|
-
ptr.read(:size_t)
|
389
|
-
end
|
390
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
390
391
|
end
|
391
392
|
|
392
393
|
def api
|
@@ -398,7 +399,7 @@ module OnnxRuntime
|
|
398
399
|
end
|
399
400
|
|
400
401
|
def self.api
|
401
|
-
@api ||= FFI.OrtGetApiBase[:GetApi].call(
|
402
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
402
403
|
end
|
403
404
|
|
404
405
|
def self.release(type, pointer)
|
@@ -410,6 +411,21 @@ module OnnxRuntime
|
|
410
411
|
proc { release :Session, session }
|
411
412
|
end
|
412
413
|
|
414
|
+
# wide string on Windows
|
415
|
+
# char string on Linux
|
416
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
417
|
+
def ort_string(str)
|
418
|
+
if Gem.win_platform?
|
419
|
+
max = str.size + 1 # for null byte
|
420
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
421
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
422
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
423
|
+
dest
|
424
|
+
else
|
425
|
+
str
|
426
|
+
end
|
427
|
+
end
|
428
|
+
|
413
429
|
def env
|
414
430
|
# use mutex for thread-safety
|
415
431
|
Utils.mutex.synchronize do
|
data/lib/onnxruntime/version.rb
CHANGED
data/vendor/libonnxruntime.dylib
CHANGED
Binary file
|
data/vendor/libonnxruntime.so
CHANGED
Binary file
|
data/vendor/onnxruntime.dll
CHANGED
Binary file
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: onnxruntime
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.4.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-07-21 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: ffi
|