onnxruntime 0.2.1 → 0.3.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/LICENSE.txt +2 -1
- data/README.md +51 -2
- data/lib/onnxruntime.rb +2 -1
- data/lib/onnxruntime/datasets.rb +10 -0
- data/lib/onnxruntime/ffi.rb +46 -23
- data/lib/onnxruntime/inference_session.rb +111 -12
- data/lib/onnxruntime/model.rb +8 -4
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +4 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 0b017fa8896c64bbeeda0ba6582ca9bfa626de3f13023ec9f3f91000031e88b8
|
4
|
+
data.tar.gz: 16344db9151ca3f388539e05772a7172129613f949dbff77d89bfb9c307d0de2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 21011ae8d875230858ff148a27ff29caf005210aae6f6e6a347ed9a02dfb85bceac6714957718d3f17872f07f0ace21c3ba7d7ba8884c1dff4b79fd08b43e40b
|
7
|
+
data.tar.gz: 000f4575890f92c2b1de308e052977f159ead243aef3984cd4190a5a873ecbe2868fe586fdbe14f2963f0518a75d8c4600d47ba2fbdd6a77c84cf3da1fc84292
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.3.2 (2020-06-16)
|
2
|
+
|
3
|
+
- Fixed error with FFI 1.13.0+
|
4
|
+
- Added friendly graph optimization levels
|
5
|
+
|
6
|
+
## 0.3.1 (2020-05-18)
|
7
|
+
|
8
|
+
- Updated ONNX Runtime to 1.3.0
|
9
|
+
- Added `custom_metadata_map` to model metadata
|
10
|
+
|
11
|
+
## 0.3.0 (2020-03-11)
|
12
|
+
|
13
|
+
- Updated ONNX Runtime to 1.2.0
|
14
|
+
- Added model metadata
|
15
|
+
- Added `end_profiling` method
|
16
|
+
- Added support for loading from IO objects
|
17
|
+
- Improved `input` and `output` for `seq` and `map` types
|
18
|
+
|
19
|
+
## 0.2.3 (2020-01-23)
|
20
|
+
|
21
|
+
- Updated ONNX Runtime to 1.1.1
|
22
|
+
|
23
|
+
## 0.2.2 (2019-12-24)
|
24
|
+
|
25
|
+
- Added support for session options
|
26
|
+
- Added support for run options
|
27
|
+
- Added `Datasets` module
|
28
|
+
|
1
29
|
## 0.2.1 (2019-12-19)
|
2
30
|
|
3
31
|
- Updated ONNX Runtime to 1.1.0
|
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -20,7 +20,7 @@ Load a model and make predictions
|
|
20
20
|
|
21
21
|
```ruby
|
22
22
|
model = OnnxRuntime::Model.new("model.onnx")
|
23
|
-
model.predict(x: [1, 2, 3])
|
23
|
+
model.predict({x: [1, 2, 3]})
|
24
24
|
```
|
25
25
|
|
26
26
|
> Download pre-trained models from the [ONNX Model Zoo](https://github.com/onnx/models)
|
@@ -37,10 +37,16 @@ Get outputs
|
|
37
37
|
model.outputs
|
38
38
|
```
|
39
39
|
|
40
|
+
Get metadata
|
41
|
+
|
42
|
+
```ruby
|
43
|
+
model.metadata
|
44
|
+
```
|
45
|
+
|
40
46
|
Load a model from a string
|
41
47
|
|
42
48
|
```ruby
|
43
|
-
byte_str =
|
49
|
+
byte_str = StringIO.new("...")
|
44
50
|
model = OnnxRuntime::Model.new(byte_str)
|
45
51
|
```
|
46
52
|
|
@@ -50,6 +56,35 @@ Get specific outputs
|
|
50
56
|
model.predict({x: [1, 2, 3]}, output_names: ["label"])
|
51
57
|
```
|
52
58
|
|
59
|
+
## Session Options
|
60
|
+
|
61
|
+
```ruby
|
62
|
+
OnnxRuntime::Model.new(path_or_bytes, {
|
63
|
+
enable_cpu_mem_arena: true,
|
64
|
+
enable_mem_pattern: true,
|
65
|
+
enable_profiling: false,
|
66
|
+
execution_mode: :sequential, # :sequential or :parallel
|
67
|
+
graph_optimization_level: nil, # :none, :basic, :extended, or :all
|
68
|
+
inter_op_num_threads: nil,
|
69
|
+
intra_op_num_threads: nil,
|
70
|
+
log_severity_level: 2,
|
71
|
+
log_verbosity_level: 0,
|
72
|
+
logid: nil,
|
73
|
+
optimized_model_filepath: nil
|
74
|
+
})
|
75
|
+
```
|
76
|
+
|
77
|
+
## Run Options
|
78
|
+
|
79
|
+
```ruby
|
80
|
+
model.predict(input_feed, {
|
81
|
+
log_severity_level: 2,
|
82
|
+
log_verbosity_level: 0,
|
83
|
+
logid: nil,
|
84
|
+
terminate: false
|
85
|
+
})
|
86
|
+
```
|
87
|
+
|
53
88
|
## Inference Session API
|
54
89
|
|
55
90
|
You can also use the Inference Session API, which follows the [Python API](https://microsoft.github.io/onnxruntime/python/api_summary.html).
|
@@ -59,6 +94,20 @@ session = OnnxRuntime::InferenceSession.new("model.onnx")
|
|
59
94
|
session.run(nil, {x: [1, 2, 3]})
|
60
95
|
```
|
61
96
|
|
97
|
+
The Python example models are included as well.
|
98
|
+
|
99
|
+
```ruby
|
100
|
+
OnnxRuntime::Datasets.example("sigmoid.onnx")
|
101
|
+
```
|
102
|
+
|
103
|
+
## GPU Support
|
104
|
+
|
105
|
+
To enable GPU support on Linux and Windows, download the appropriate [GPU release](https://github.com/microsoft/onnxruntime/releases) and set:
|
106
|
+
|
107
|
+
```ruby
|
108
|
+
OnnxRuntime.ffi_lib = "path/to/lib/libonnxruntime.so" # onnxruntime.dll for Windows
|
109
|
+
```
|
110
|
+
|
62
111
|
## History
|
63
112
|
|
64
113
|
View the [changelog](https://github.com/ankane/onnxruntime/blob/master/CHANGELOG.md)
|
data/lib/onnxruntime.rb
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
require "ffi"
|
3
3
|
|
4
4
|
# modules
|
5
|
+
require "onnxruntime/datasets"
|
5
6
|
require "onnxruntime/inference_session"
|
6
7
|
require "onnxruntime/model"
|
7
8
|
require "onnxruntime/utils"
|
@@ -18,7 +19,7 @@ module OnnxRuntime
|
|
18
19
|
self.ffi_lib = [vendor_lib]
|
19
20
|
|
20
21
|
def self.lib_version
|
21
|
-
FFI.OrtGetApiBase[:GetVersionString].call
|
22
|
+
FFI.OrtGetApiBase[:GetVersionString].call.read_string
|
22
23
|
end
|
23
24
|
|
24
25
|
# friendlier error message
|
@@ -0,0 +1,10 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
module Datasets
|
3
|
+
def self.example(name)
|
4
|
+
unless %w(logreg_iris.onnx mul_1.onnx sigmoid.onnx).include?(name)
|
5
|
+
raise ArgumentError, "Unable to find example '#{name}'"
|
6
|
+
end
|
7
|
+
File.expand_path("../../datasets/#{name}", __dir__)
|
8
|
+
end
|
9
|
+
end
|
10
|
+
end
|
data/lib/onnxruntime/ffi.rb
CHANGED
@@ -3,7 +3,7 @@ module OnnxRuntime
|
|
3
3
|
extend ::FFI::Library
|
4
4
|
|
5
5
|
begin
|
6
|
-
ffi_lib OnnxRuntime.ffi_lib
|
6
|
+
ffi_lib Array(OnnxRuntime.ffi_lib)
|
7
7
|
rescue LoadError => e
|
8
8
|
raise e if ENV["ONNXRUNTIME_DEBUG"]
|
9
9
|
raise LoadError, "Could not find ONNX Runtime"
|
@@ -20,7 +20,7 @@ module OnnxRuntime
|
|
20
20
|
layout \
|
21
21
|
:CreateStatus, callback(%i[int string], :pointer),
|
22
22
|
:GetErrorCode, callback(%i[pointer], :pointer),
|
23
|
-
:GetErrorMessage, callback(%i[pointer], :
|
23
|
+
:GetErrorMessage, callback(%i[pointer], :pointer),
|
24
24
|
:CreateEnv, callback(%i[int string pointer], :pointer),
|
25
25
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
26
26
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
@@ -29,21 +29,21 @@ module OnnxRuntime
|
|
29
29
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
30
30
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
31
31
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[], :pointer),
|
32
|
+
:SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
|
33
33
|
:CloneSessionOptions, callback(%i[], :pointer),
|
34
34
|
:SetSessionExecutionMode, callback(%i[], :pointer),
|
35
|
-
:EnableProfiling, callback(%i[], :pointer),
|
36
|
-
:DisableProfiling, callback(%i[], :pointer),
|
37
|
-
:EnableMemPattern, callback(%i[], :pointer),
|
38
|
-
:DisableMemPattern, callback(%i[], :pointer),
|
39
|
-
:EnableCpuMemArena, callback(%i[], :pointer),
|
40
|
-
:DisableCpuMemArena, callback(%i[], :pointer),
|
41
|
-
:SetSessionLogId, callback(%i[], :pointer),
|
42
|
-
:SetSessionLogVerbosityLevel, callback(%i[], :pointer),
|
43
|
-
:SetSessionLogSeverityLevel, callback(%i[], :pointer),
|
44
|
-
:SetSessionGraphOptimizationLevel, callback(%i[], :pointer),
|
45
|
-
:SetIntraOpNumThreads, callback(%i[], :pointer),
|
46
|
-
:SetInterOpNumThreads, callback(%i[], :pointer),
|
35
|
+
:EnableProfiling, callback(%i[pointer string], :pointer),
|
36
|
+
:DisableProfiling, callback(%i[pointer], :pointer),
|
37
|
+
:EnableMemPattern, callback(%i[pointer], :pointer),
|
38
|
+
:DisableMemPattern, callback(%i[pointer], :pointer),
|
39
|
+
:EnableCpuMemArena, callback(%i[pointer], :pointer),
|
40
|
+
:DisableCpuMemArena, callback(%i[pointer], :pointer),
|
41
|
+
:SetSessionLogId, callback(%i[pointer string], :pointer),
|
42
|
+
:SetSessionLogVerbosityLevel, callback(%i[pointer int], :pointer),
|
43
|
+
:SetSessionLogSeverityLevel, callback(%i[pointer int], :pointer),
|
44
|
+
:SetSessionGraphOptimizationLevel, callback(%i[pointer int], :pointer),
|
45
|
+
:SetIntraOpNumThreads, callback(%i[pointer int], :pointer),
|
46
|
+
:SetInterOpNumThreads, callback(%i[pointer int], :pointer),
|
47
47
|
:CreateCustomOpDomain, callback(%i[], :pointer),
|
48
48
|
:CustomOpDomain_Add, callback(%i[], :pointer),
|
49
49
|
:AddCustomOpDomain, callback(%i[], :pointer),
|
@@ -57,15 +57,15 @@ module OnnxRuntime
|
|
57
57
|
:SessionGetInputName, callback(%i[pointer size_t pointer pointer], :pointer),
|
58
58
|
:SessionGetOutputName, callback(%i[pointer size_t pointer pointer], :pointer),
|
59
59
|
:SessionGetOverridableInitializerName, callback(%i[], :pointer),
|
60
|
-
:CreateRunOptions, callback(%i[], :pointer),
|
61
|
-
:RunOptionsSetRunLogVerbosityLevel, callback(%i[], :pointer),
|
62
|
-
:RunOptionsSetRunLogSeverityLevel, callback(%i[], :pointer),
|
63
|
-
:RunOptionsSetRunTag, callback(%i[], :pointer),
|
60
|
+
:CreateRunOptions, callback(%i[pointer], :pointer),
|
61
|
+
:RunOptionsSetRunLogVerbosityLevel, callback(%i[pointer int], :pointer),
|
62
|
+
:RunOptionsSetRunLogSeverityLevel, callback(%i[pointer int], :pointer),
|
63
|
+
:RunOptionsSetRunTag, callback(%i[pointer string], :pointer),
|
64
64
|
:RunOptionsGetRunLogVerbosityLevel, callback(%i[], :pointer),
|
65
65
|
:RunOptionsGetRunLogSeverityLevel, callback(%i[], :pointer),
|
66
66
|
:RunOptionsGetRunTag, callback(%i[], :pointer),
|
67
|
-
:RunOptionsSetTerminate, callback(%i[], :pointer),
|
68
|
-
:RunOptionsUnsetTerminate, callback(%i[], :pointer),
|
67
|
+
:RunOptionsSetTerminate, callback(%i[pointer], :pointer),
|
68
|
+
:RunOptionsUnsetTerminate, callback(%i[pointer], :pointer),
|
69
69
|
:CreateTensorAsOrtValue, callback(%i[pointer pointer size_t int pointer], :pointer),
|
70
70
|
:CreateTensorWithDataAsOrtValue, callback(%i[pointer pointer size_t pointer size_t int pointer], :pointer),
|
71
71
|
:IsTensor, callback(%i[], :pointer),
|
@@ -119,7 +119,30 @@ module OnnxRuntime
|
|
119
119
|
:ReleaseTypeInfo, callback(%i[pointer], :void),
|
120
120
|
:ReleaseTensorTypeAndShapeInfo, callback(%i[pointer], :void),
|
121
121
|
:ReleaseSessionOptions, callback(%i[pointer], :void),
|
122
|
-
:ReleaseCustomOpDomain, callback(%i[pointer], :void)
|
122
|
+
:ReleaseCustomOpDomain, callback(%i[pointer], :void),
|
123
|
+
:GetDenotationFromTypeInfo, callback(%i[], :pointer),
|
124
|
+
:CastTypeInfoToMapTypeInfo, callback(%i[pointer pointer], :pointer),
|
125
|
+
:CastTypeInfoToSequenceTypeInfo, callback(%i[pointer pointer], :pointer),
|
126
|
+
:GetMapKeyType, callback(%i[pointer pointer], :pointer),
|
127
|
+
:GetMapValueType, callback(%i[pointer pointer], :pointer),
|
128
|
+
:GetSequenceElementType, callback(%i[pointer pointer], :pointer),
|
129
|
+
:ReleaseMapTypeInfo, callback(%i[pointer], :void),
|
130
|
+
:ReleaseSequenceTypeInfo, callback(%i[pointer], :void),
|
131
|
+
:SessionEndProfiling, callback(%i[pointer pointer pointer], :pointer),
|
132
|
+
:SessionGetModelMetadata, callback(%i[pointer pointer], :pointer),
|
133
|
+
:ModelMetadataGetProducerName, callback(%i[pointer pointer pointer], :pointer),
|
134
|
+
:ModelMetadataGetGraphName, callback(%i[pointer pointer pointer], :pointer),
|
135
|
+
:ModelMetadataGetDomain, callback(%i[pointer pointer pointer], :pointer),
|
136
|
+
:ModelMetadataGetDescription, callback(%i[pointer pointer pointer], :pointer),
|
137
|
+
:ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer),
|
138
|
+
:ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer),
|
139
|
+
:ReleaseModelMetadata, callback(%i[pointer], :void),
|
140
|
+
:CreateEnvWithGlobalThreadPools, callback(%i[], :pointer),
|
141
|
+
:DisablePerSessionThreads, callback(%i[], :pointer),
|
142
|
+
:CreateThreadingOptions, callback(%i[], :pointer),
|
143
|
+
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
144
|
+
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
145
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer)
|
123
146
|
end
|
124
147
|
|
125
148
|
class ApiBase < ::FFI::Struct
|
@@ -127,7 +150,7 @@ module OnnxRuntime
|
|
127
150
|
# to prevent "unable to resolve type" error on Ubuntu
|
128
151
|
layout \
|
129
152
|
:GetApi, callback(%i[uint32], Api.by_ref),
|
130
|
-
:GetVersionString, callback(%i[], :
|
153
|
+
:GetVersionString, callback(%i[], :pointer)
|
131
154
|
end
|
132
155
|
|
133
156
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
@@ -2,21 +2,50 @@ module OnnxRuntime
|
|
2
2
|
class InferenceSession
|
3
3
|
attr_reader :inputs, :outputs
|
4
4
|
|
5
|
-
def initialize(path_or_bytes)
|
5
|
+
def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil)
|
6
6
|
# session options
|
7
7
|
session_options = ::FFI::MemoryPointer.new(:pointer)
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
9
|
+
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
|
+
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
12
|
+
if execution_mode
|
13
|
+
execution_modes = {sequential: 0, parallel: 1}
|
14
|
+
mode = execution_modes[execution_mode]
|
15
|
+
raise ArgumentError, "Invalid execution mode" unless mode
|
16
|
+
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
|
17
|
+
end
|
18
|
+
if graph_optimization_level
|
19
|
+
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
+
# TODO raise error in 0.4.0
|
21
|
+
level = optimization_levels[graph_optimization_level] || graph_optimization_level
|
22
|
+
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
|
+
end
|
24
|
+
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
25
|
+
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
|
26
|
+
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
27
|
+
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
28
|
+
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
9
30
|
|
10
31
|
# session
|
11
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
12
|
-
|
33
|
+
from_memory =
|
34
|
+
if path_or_bytes.respond_to?(:read)
|
35
|
+
path_or_bytes = path_or_bytes.read
|
36
|
+
true
|
37
|
+
else
|
38
|
+
path_or_bytes = path_or_bytes.to_str
|
39
|
+
path_or_bytes.encoding == Encoding::BINARY
|
40
|
+
end
|
13
41
|
|
14
42
|
# fix for Windows "File doesn't exist"
|
15
|
-
if Gem.win_platform? &&
|
43
|
+
if Gem.win_platform? && !from_memory
|
16
44
|
path_or_bytes = File.binread(path_or_bytes)
|
45
|
+
from_memory = true
|
17
46
|
end
|
18
47
|
|
19
|
-
if
|
48
|
+
if from_memory
|
20
49
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
21
50
|
else
|
22
51
|
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
@@ -53,7 +82,8 @@ module OnnxRuntime
|
|
53
82
|
end
|
54
83
|
end
|
55
84
|
|
56
|
-
|
85
|
+
# TODO support logid
|
86
|
+
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
|
57
87
|
input_tensor = create_input_tensor(input_feed)
|
58
88
|
|
59
89
|
output_names ||= @outputs.map { |v| v[:name] }
|
@@ -61,14 +91,66 @@ module OnnxRuntime
|
|
61
91
|
output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
|
62
92
|
input_node_names = create_node_names(input_feed.keys.map(&:to_s))
|
63
93
|
output_node_names = create_node_names(output_names.map(&:to_s))
|
64
|
-
|
65
|
-
|
94
|
+
|
95
|
+
# run options
|
96
|
+
run_options = ::FFI::MemoryPointer.new(:pointer)
|
97
|
+
check_status api[:CreateRunOptions].call(run_options)
|
98
|
+
check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level
|
99
|
+
check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
100
|
+
check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid
|
101
|
+
check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate
|
102
|
+
|
103
|
+
check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
|
66
104
|
|
67
105
|
output_names.size.times.map do |i|
|
68
106
|
create_from_onnx_value(output_tensor[i].read_pointer)
|
69
107
|
end
|
70
108
|
end
|
71
109
|
|
110
|
+
def modelmeta
|
111
|
+
keys = ::FFI::MemoryPointer.new(:pointer)
|
112
|
+
num_keys = ::FFI::MemoryPointer.new(:int64_t)
|
113
|
+
description = ::FFI::MemoryPointer.new(:string)
|
114
|
+
domain = ::FFI::MemoryPointer.new(:string)
|
115
|
+
graph_name = ::FFI::MemoryPointer.new(:string)
|
116
|
+
producer_name = ::FFI::MemoryPointer.new(:string)
|
117
|
+
version = ::FFI::MemoryPointer.new(:int64_t)
|
118
|
+
|
119
|
+
metadata = ::FFI::MemoryPointer.new(:pointer)
|
120
|
+
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
121
|
+
|
122
|
+
custom_metadata_map = {}
|
123
|
+
check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
124
|
+
num_keys.read(:int64_t).times do |i|
|
125
|
+
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
126
|
+
value = ::FFI::MemoryPointer.new(:string)
|
127
|
+
check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
|
128
|
+
custom_metadata_map[key] = value.read_pointer.read_string
|
129
|
+
end
|
130
|
+
|
131
|
+
check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
|
132
|
+
check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
|
133
|
+
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
|
134
|
+
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
|
135
|
+
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
|
136
|
+
api[:ReleaseModelMetadata].call(metadata.read_pointer)
|
137
|
+
|
138
|
+
{
|
139
|
+
custom_metadata_map: custom_metadata_map,
|
140
|
+
description: description.read_pointer.read_string,
|
141
|
+
domain: domain.read_pointer.read_string,
|
142
|
+
graph_name: graph_name.read_pointer.read_string,
|
143
|
+
producer_name: producer_name.read_pointer.read_string,
|
144
|
+
version: version.read(:int64_t)
|
145
|
+
}
|
146
|
+
end
|
147
|
+
|
148
|
+
def end_profiling
|
149
|
+
out = ::FFI::MemoryPointer.new(:string)
|
150
|
+
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
151
|
+
out.read_pointer.read_string
|
152
|
+
end
|
153
|
+
|
72
154
|
private
|
73
155
|
|
74
156
|
def create_input_tensor(input_feed)
|
@@ -208,7 +290,7 @@ module OnnxRuntime
|
|
208
290
|
|
209
291
|
def check_status(status)
|
210
292
|
unless status.null?
|
211
|
-
message = api[:GetErrorMessage].call(status)
|
293
|
+
message = api[:GetErrorMessage].call(status).read_string
|
212
294
|
api[:ReleaseStatus].call(status)
|
213
295
|
raise OnnxRuntime::Error, message
|
214
296
|
end
|
@@ -230,15 +312,32 @@ module OnnxRuntime
|
|
230
312
|
shape: shape
|
231
313
|
}
|
232
314
|
when :sequence
|
233
|
-
|
315
|
+
sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
|
316
|
+
check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
|
317
|
+
nested_type_info = ::FFI::MemoryPointer.new(:pointer)
|
318
|
+
check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
|
319
|
+
v = node_info(nested_type_info)[:type]
|
320
|
+
|
234
321
|
{
|
235
|
-
type: "seq",
|
322
|
+
type: "seq(#{v})",
|
236
323
|
shape: []
|
237
324
|
}
|
238
325
|
when :map
|
239
|
-
|
326
|
+
map_type_info = ::FFI::MemoryPointer.new(:pointer)
|
327
|
+
check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
|
328
|
+
|
329
|
+
# key
|
330
|
+
key_type = ::FFI::MemoryPointer.new(:int)
|
331
|
+
check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
|
332
|
+
k = FFI::TensorElementDataType[key_type.read_int]
|
333
|
+
|
334
|
+
# value
|
335
|
+
value_type_info = ::FFI::MemoryPointer.new(:pointer)
|
336
|
+
check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
|
337
|
+
v = node_info(value_type_info)[:type]
|
338
|
+
|
240
339
|
{
|
241
|
-
type: "map",
|
340
|
+
type: "map(#{k},#{v})",
|
242
341
|
shape: []
|
243
342
|
}
|
244
343
|
else
|
data/lib/onnxruntime/model.rb
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
module OnnxRuntime
|
2
2
|
class Model
|
3
|
-
def initialize(path_or_bytes)
|
4
|
-
@session = InferenceSession.new(path_or_bytes)
|
3
|
+
def initialize(path_or_bytes, **session_options)
|
4
|
+
@session = InferenceSession.new(path_or_bytes, **session_options)
|
5
5
|
end
|
6
6
|
|
7
|
-
def predict(input_feed, output_names: nil)
|
8
|
-
predictions = @session.run(output_names, input_feed)
|
7
|
+
def predict(input_feed, output_names: nil, **run_options)
|
8
|
+
predictions = @session.run(output_names, input_feed, **run_options)
|
9
9
|
output_names ||= outputs.map { |o| o[:name] }
|
10
10
|
|
11
11
|
result = {}
|
@@ -22,5 +22,9 @@ module OnnxRuntime
|
|
22
22
|
def outputs
|
23
23
|
@session.outputs
|
24
24
|
end
|
25
|
+
|
26
|
+
def metadata
|
27
|
+
@session.modelmeta
|
28
|
+
end
|
25
29
|
end
|
26
30
|
end
|
data/lib/onnxruntime/version.rb
CHANGED
data/vendor/libonnxruntime.dylib
CHANGED
Binary file
|
data/vendor/libonnxruntime.so
CHANGED
Binary file
|
data/vendor/onnxruntime.dll
CHANGED
Binary file
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: onnxruntime
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2
|
4
|
+
version: 0.3.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2020-06-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: ffi
|
@@ -76,6 +76,7 @@ files:
|
|
76
76
|
- LICENSE.txt
|
77
77
|
- README.md
|
78
78
|
- lib/onnxruntime.rb
|
79
|
+
- lib/onnxruntime/datasets.rb
|
79
80
|
- lib/onnxruntime/ffi.rb
|
80
81
|
- lib/onnxruntime/inference_session.rb
|
81
82
|
- lib/onnxruntime/model.rb
|
@@ -105,7 +106,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
105
106
|
- !ruby/object:Gem::Version
|
106
107
|
version: '0'
|
107
108
|
requirements: []
|
108
|
-
rubygems_version: 3.
|
109
|
+
rubygems_version: 3.1.2
|
109
110
|
signing_key:
|
110
111
|
specification_version: 4
|
111
112
|
summary: High performance scoring engine for ML models
|