onnxruntime 0.3.3 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/lib/onnxruntime/ffi.rb +15 -10
- data/lib/onnxruntime/inference_session.rb +45 -29
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 54260e1a83f205da2a0a016cc5b6a8508aa8969b484c22f18db966429e6007b0
|
4
|
+
data.tar.gz: 44b4310ff48bb154057adf1ddc0622980148cfda47d7b973f01ae58f4e5e7416
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d61177039a5314b80342b627d5f7f1f8352be51d95dd71568d9b3ef2a5d970f9b9763023e7ca653f86689aa0ade0188c2d3a28a6f2cd3c890c43f7358952f77a
|
7
|
+
data.tar.gz: 7c2a6c4c52e93fc3f3dc4a297b921eb69606b89e1c05d40ac736dd731d8e5ab2a51c5173c2cde7b081ec46f19d88a333adf96c01cef485ba6d73943e9160c640
|
data/CHANGELOG.md
CHANGED
data/lib/onnxruntime/ffi.rb
CHANGED
@@ -2,12 +2,7 @@ module OnnxRuntime
|
|
2
2
|
module FFI
|
3
3
|
extend ::FFI::Library
|
4
4
|
|
5
|
-
|
6
|
-
ffi_lib Array(OnnxRuntime.ffi_lib)
|
7
|
-
rescue LoadError => e
|
8
|
-
raise e if ENV["ONNXRUNTIME_DEBUG"]
|
9
|
-
raise LoadError, "Could not find ONNX Runtime"
|
10
|
-
end
|
5
|
+
ffi_lib Array(OnnxRuntime.ffi_lib)
|
11
6
|
|
12
7
|
# https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
|
13
8
|
# keep same order
|
@@ -25,14 +20,14 @@ module OnnxRuntime
|
|
25
20
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
26
21
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
27
22
|
:DisableTelemetryEvents, callback(%i[pointer], :pointer),
|
28
|
-
:CreateSession, callback(%i[pointer
|
23
|
+
:CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
|
29
24
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
30
25
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
31
26
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[pointer
|
27
|
+
:SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
|
33
28
|
:CloneSessionOptions, callback(%i[], :pointer),
|
34
29
|
:SetSessionExecutionMode, callback(%i[], :pointer),
|
35
|
-
:EnableProfiling, callback(%i[pointer
|
30
|
+
:EnableProfiling, callback(%i[pointer pointer], :pointer),
|
36
31
|
:DisableProfiling, callback(%i[pointer], :pointer),
|
37
32
|
:EnableMemPattern, callback(%i[pointer], :pointer),
|
38
33
|
:DisableMemPattern, callback(%i[pointer], :pointer),
|
@@ -142,7 +137,9 @@ module OnnxRuntime
|
|
142
137
|
:CreateThreadingOptions, callback(%i[], :pointer),
|
143
138
|
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
144
139
|
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
145
|
-
:AddFreeDimensionOverrideByName, callback(%i[], :pointer)
|
140
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer),
|
141
|
+
:GetAvailableProviders, callback(%i[pointer pointer], :pointer),
|
142
|
+
:ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
|
146
143
|
end
|
147
144
|
|
148
145
|
class ApiBase < ::FFI::Struct
|
@@ -154,5 +151,13 @@ module OnnxRuntime
|
|
154
151
|
end
|
155
152
|
|
156
153
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
154
|
+
|
155
|
+
if Gem.win_platform?
|
156
|
+
class Libc
|
157
|
+
extend ::FFI::Library
|
158
|
+
ffi_lib ::FFI::Library::LIBC
|
159
|
+
attach_function :mbstowcs, %i[pointer string size_t], :size_t
|
160
|
+
end
|
161
|
+
end
|
157
162
|
end
|
158
163
|
end
|
@@ -8,7 +8,7 @@ module OnnxRuntime
|
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
9
9
|
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
10
|
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
-
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
12
12
|
if execution_mode
|
13
13
|
execution_modes = {sequential: 0, parallel: 1}
|
14
14
|
mode = execution_modes[execution_mode]
|
@@ -17,8 +17,8 @@ module OnnxRuntime
|
|
17
17
|
end
|
18
18
|
if graph_optimization_level
|
19
19
|
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
-
|
21
|
-
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
22
22
|
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
23
|
end
|
24
24
|
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
@@ -26,7 +26,7 @@ module OnnxRuntime
|
|
26
26
|
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
27
27
|
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
28
28
|
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
29
|
-
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
30
30
|
|
31
31
|
# session
|
32
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
@@ -39,16 +39,10 @@ module OnnxRuntime
|
|
39
39
|
path_or_bytes.encoding == Encoding::BINARY
|
40
40
|
end
|
41
41
|
|
42
|
-
# fix for Windows "File doesn't exist"
|
43
|
-
if Gem.win_platform? && !from_memory
|
44
|
-
path_or_bytes = File.binread(path_or_bytes)
|
45
|
-
from_memory = true
|
46
|
-
end
|
47
|
-
|
48
42
|
if from_memory
|
49
43
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
50
44
|
else
|
51
|
-
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
52
46
|
end
|
53
47
|
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
54
48
|
|
@@ -63,7 +57,7 @@ module OnnxRuntime
|
|
63
57
|
# input
|
64
58
|
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
65
59
|
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
66
|
-
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
67
61
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
68
62
|
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
69
63
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -74,7 +68,7 @@ module OnnxRuntime
|
|
74
68
|
# output
|
75
69
|
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
76
70
|
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
77
|
-
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
78
72
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
79
73
|
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
80
74
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -156,12 +150,28 @@ module OnnxRuntime
|
|
156
150
|
release :ModelMetadata, metadata
|
157
151
|
end
|
158
152
|
|
153
|
+
# return value has double underscore like Python
|
159
154
|
def end_profiling
|
160
155
|
out = ::FFI::MemoryPointer.new(:string)
|
161
156
|
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
162
157
|
out.read_pointer.read_string
|
163
158
|
end
|
164
159
|
|
160
|
+
# no way to set providers with C API yet
|
161
|
+
# so we can return all available providers
|
162
|
+
def providers
|
163
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
164
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
165
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
166
|
+
length = length_ptr.read_int
|
167
|
+
providers = []
|
168
|
+
length.times do |i|
|
169
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
170
|
+
end
|
171
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
172
|
+
providers
|
173
|
+
end
|
174
|
+
|
165
175
|
private
|
166
176
|
|
167
177
|
def create_input_tensor(input_feed)
|
@@ -184,7 +194,7 @@ module OnnxRuntime
|
|
184
194
|
|
185
195
|
# TODO support more types
|
186
196
|
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
187
|
-
raise "Unknown input: #{input_name}" unless inp
|
197
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
188
198
|
|
189
199
|
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
190
200
|
input_node_dims.write_array_of_int64(shape)
|
@@ -241,7 +251,7 @@ module OnnxRuntime
|
|
241
251
|
|
242
252
|
out_size = ::FFI::MemoryPointer.new(:size_t)
|
243
253
|
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
244
|
-
output_tensor_size =
|
254
|
+
output_tensor_size = out_size.read(:size_t)
|
245
255
|
|
246
256
|
release :TensorTypeAndShapeInfo, typeinfo
|
247
257
|
|
@@ -262,7 +272,7 @@ module OnnxRuntime
|
|
262
272
|
out = ::FFI::MemoryPointer.new(:size_t)
|
263
273
|
check_status api[:GetValueCount].call(out_ptr, out)
|
264
274
|
|
265
|
-
|
275
|
+
out.read(:size_t).times.map do |i|
|
266
276
|
seq = ::FFI::MemoryPointer.new(:pointer)
|
267
277
|
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
268
278
|
create_from_onnx_value(seq.read_pointer)
|
@@ -306,7 +316,7 @@ module OnnxRuntime
|
|
306
316
|
unless status.null?
|
307
317
|
message = api[:GetErrorMessage].call(status).read_string
|
308
318
|
api[:ReleaseStatus].call(status)
|
309
|
-
raise
|
319
|
+
raise Error, message
|
310
320
|
end
|
311
321
|
end
|
312
322
|
|
@@ -368,7 +378,7 @@ module OnnxRuntime
|
|
368
378
|
|
369
379
|
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
370
380
|
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
371
|
-
num_dims =
|
381
|
+
num_dims = num_dims_ptr.read(:size_t)
|
372
382
|
|
373
383
|
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
374
384
|
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
@@ -377,16 +387,7 @@ module OnnxRuntime
|
|
377
387
|
end
|
378
388
|
|
379
389
|
def unsupported_type(name, type)
|
380
|
-
raise "Unsupported #{name} type: #{type}"
|
381
|
-
end
|
382
|
-
|
383
|
-
# read(:size_t) not supported in FFI JRuby
|
384
|
-
def read_size_t(ptr)
|
385
|
-
if RUBY_PLATFORM == "java"
|
386
|
-
ptr.read_long
|
387
|
-
else
|
388
|
-
ptr.read(:size_t)
|
389
|
-
end
|
390
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
390
391
|
end
|
391
392
|
|
392
393
|
def api
|
@@ -398,7 +399,7 @@ module OnnxRuntime
|
|
398
399
|
end
|
399
400
|
|
400
401
|
def self.api
|
401
|
-
@api ||= FFI.OrtGetApiBase[:GetApi].call(
|
402
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
402
403
|
end
|
403
404
|
|
404
405
|
def self.release(type, pointer)
|
@@ -410,6 +411,21 @@ module OnnxRuntime
|
|
410
411
|
proc { release :Session, session }
|
411
412
|
end
|
412
413
|
|
414
|
+
# wide string on Windows
|
415
|
+
# char string on Linux
|
416
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
417
|
+
def ort_string(str)
|
418
|
+
if Gem.win_platform?
|
419
|
+
max = str.size + 1 # for null byte
|
420
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
421
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
422
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
423
|
+
dest
|
424
|
+
else
|
425
|
+
str
|
426
|
+
end
|
427
|
+
end
|
428
|
+
|
413
429
|
def env
|
414
430
|
# use mutex for thread-safety
|
415
431
|
Utils.mutex.synchronize do
|
data/lib/onnxruntime/version.rb
CHANGED
data/vendor/libonnxruntime.dylib
CHANGED
Binary file
|
data/vendor/libonnxruntime.so
CHANGED
Binary file
|
data/vendor/onnxruntime.dll
CHANGED
Binary file
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: onnxruntime
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.4.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-07-21 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: ffi
|