onnxruntime 0.2.1 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/LICENSE.txt +2 -1
- data/README.md +51 -2
- data/lib/onnxruntime.rb +2 -1
- data/lib/onnxruntime/datasets.rb +10 -0
- data/lib/onnxruntime/ffi.rb +46 -23
- data/lib/onnxruntime/inference_session.rb +111 -12
- data/lib/onnxruntime/model.rb +8 -4
- data/lib/onnxruntime/version.rb +1 -1
- data/vendor/libonnxruntime.dylib +0 -0
- data/vendor/libonnxruntime.so +0 -0
- data/vendor/onnxruntime.dll +0 -0
- metadata +4 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 0b017fa8896c64bbeeda0ba6582ca9bfa626de3f13023ec9f3f91000031e88b8
|
4
|
+
data.tar.gz: 16344db9151ca3f388539e05772a7172129613f949dbff77d89bfb9c307d0de2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 21011ae8d875230858ff148a27ff29caf005210aae6f6e6a347ed9a02dfb85bceac6714957718d3f17872f07f0ace21c3ba7d7ba8884c1dff4b79fd08b43e40b
|
7
|
+
data.tar.gz: 000f4575890f92c2b1de308e052977f159ead243aef3984cd4190a5a873ecbe2868fe586fdbe14f2963f0518a75d8c4600d47ba2fbdd6a77c84cf3da1fc84292
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.3.2 (2020-06-16)
|
2
|
+
|
3
|
+
- Fixed error with FFI 1.13.0+
|
4
|
+
- Added friendly graph optimization levels
|
5
|
+
|
6
|
+
## 0.3.1 (2020-05-18)
|
7
|
+
|
8
|
+
- Updated ONNX Runtime to 1.3.0
|
9
|
+
- Added `custom_metadata_map` to model metadata
|
10
|
+
|
11
|
+
## 0.3.0 (2020-03-11)
|
12
|
+
|
13
|
+
- Updated ONNX Runtime to 1.2.0
|
14
|
+
- Added model metadata
|
15
|
+
- Added `end_profiling` method
|
16
|
+
- Added support for loading from IO objects
|
17
|
+
- Improved `input` and `output` for `seq` and `map` types
|
18
|
+
|
19
|
+
## 0.2.3 (2020-01-23)
|
20
|
+
|
21
|
+
- Updated ONNX Runtime to 1.1.1
|
22
|
+
|
23
|
+
## 0.2.2 (2019-12-24)
|
24
|
+
|
25
|
+
- Added support for session options
|
26
|
+
- Added support for run options
|
27
|
+
- Added `Datasets` module
|
28
|
+
|
1
29
|
## 0.2.1 (2019-12-19)
|
2
30
|
|
3
31
|
- Updated ONNX Runtime to 1.1.0
|
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -20,7 +20,7 @@ Load a model and make predictions
|
|
20
20
|
|
21
21
|
```ruby
|
22
22
|
model = OnnxRuntime::Model.new("model.onnx")
|
23
|
-
model.predict(x: [1, 2, 3])
|
23
|
+
model.predict({x: [1, 2, 3]})
|
24
24
|
```
|
25
25
|
|
26
26
|
> Download pre-trained models from the [ONNX Model Zoo](https://github.com/onnx/models)
|
@@ -37,10 +37,16 @@ Get outputs
|
|
37
37
|
model.outputs
|
38
38
|
```
|
39
39
|
|
40
|
+
Get metadata
|
41
|
+
|
42
|
+
```ruby
|
43
|
+
model.metadata
|
44
|
+
```
|
45
|
+
|
40
46
|
Load a model from a string
|
41
47
|
|
42
48
|
```ruby
|
43
|
-
byte_str =
|
49
|
+
byte_str = StringIO.new("...")
|
44
50
|
model = OnnxRuntime::Model.new(byte_str)
|
45
51
|
```
|
46
52
|
|
@@ -50,6 +56,35 @@ Get specific outputs
|
|
50
56
|
model.predict({x: [1, 2, 3]}, output_names: ["label"])
|
51
57
|
```
|
52
58
|
|
59
|
+
## Session Options
|
60
|
+
|
61
|
+
```ruby
|
62
|
+
OnnxRuntime::Model.new(path_or_bytes, {
|
63
|
+
enable_cpu_mem_arena: true,
|
64
|
+
enable_mem_pattern: true,
|
65
|
+
enable_profiling: false,
|
66
|
+
execution_mode: :sequential, # :sequential or :parallel
|
67
|
+
graph_optimization_level: nil, # :none, :basic, :extended, or :all
|
68
|
+
inter_op_num_threads: nil,
|
69
|
+
intra_op_num_threads: nil,
|
70
|
+
log_severity_level: 2,
|
71
|
+
log_verbosity_level: 0,
|
72
|
+
logid: nil,
|
73
|
+
optimized_model_filepath: nil
|
74
|
+
})
|
75
|
+
```
|
76
|
+
|
77
|
+
## Run Options
|
78
|
+
|
79
|
+
```ruby
|
80
|
+
model.predict(input_feed, {
|
81
|
+
log_severity_level: 2,
|
82
|
+
log_verbosity_level: 0,
|
83
|
+
logid: nil,
|
84
|
+
terminate: false
|
85
|
+
})
|
86
|
+
```
|
87
|
+
|
53
88
|
## Inference Session API
|
54
89
|
|
55
90
|
You can also use the Inference Session API, which follows the [Python API](https://microsoft.github.io/onnxruntime/python/api_summary.html).
|
@@ -59,6 +94,20 @@ session = OnnxRuntime::InferenceSession.new("model.onnx")
|
|
59
94
|
session.run(nil, {x: [1, 2, 3]})
|
60
95
|
```
|
61
96
|
|
97
|
+
The Python example models are included as well.
|
98
|
+
|
99
|
+
```ruby
|
100
|
+
OnnxRuntime::Datasets.example("sigmoid.onnx")
|
101
|
+
```
|
102
|
+
|
103
|
+
## GPU Support
|
104
|
+
|
105
|
+
To enable GPU support on Linux and Windows, download the appropriate [GPU release](https://github.com/microsoft/onnxruntime/releases) and set:
|
106
|
+
|
107
|
+
```ruby
|
108
|
+
OnnxRuntime.ffi_lib = "path/to/lib/libonnxruntime.so" # onnxruntime.dll for Windows
|
109
|
+
```
|
110
|
+
|
62
111
|
## History
|
63
112
|
|
64
113
|
View the [changelog](https://github.com/ankane/onnxruntime/blob/master/CHANGELOG.md)
|
data/lib/onnxruntime.rb
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
require "ffi"
|
3
3
|
|
4
4
|
# modules
|
5
|
+
require "onnxruntime/datasets"
|
5
6
|
require "onnxruntime/inference_session"
|
6
7
|
require "onnxruntime/model"
|
7
8
|
require "onnxruntime/utils"
|
@@ -18,7 +19,7 @@ module OnnxRuntime
|
|
18
19
|
self.ffi_lib = [vendor_lib]
|
19
20
|
|
20
21
|
def self.lib_version
|
21
|
-
FFI.OrtGetApiBase[:GetVersionString].call
|
22
|
+
FFI.OrtGetApiBase[:GetVersionString].call.read_string
|
22
23
|
end
|
23
24
|
|
24
25
|
# friendlier error message
|
@@ -0,0 +1,10 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
module Datasets
|
3
|
+
def self.example(name)
|
4
|
+
unless %w(logreg_iris.onnx mul_1.onnx sigmoid.onnx).include?(name)
|
5
|
+
raise ArgumentError, "Unable to find example '#{name}'"
|
6
|
+
end
|
7
|
+
File.expand_path("../../datasets/#{name}", __dir__)
|
8
|
+
end
|
9
|
+
end
|
10
|
+
end
|
data/lib/onnxruntime/ffi.rb
CHANGED
@@ -3,7 +3,7 @@ module OnnxRuntime
|
|
3
3
|
extend ::FFI::Library
|
4
4
|
|
5
5
|
begin
|
6
|
-
ffi_lib OnnxRuntime.ffi_lib
|
6
|
+
ffi_lib Array(OnnxRuntime.ffi_lib)
|
7
7
|
rescue LoadError => e
|
8
8
|
raise e if ENV["ONNXRUNTIME_DEBUG"]
|
9
9
|
raise LoadError, "Could not find ONNX Runtime"
|
@@ -20,7 +20,7 @@ module OnnxRuntime
|
|
20
20
|
layout \
|
21
21
|
:CreateStatus, callback(%i[int string], :pointer),
|
22
22
|
:GetErrorCode, callback(%i[pointer], :pointer),
|
23
|
-
:GetErrorMessage, callback(%i[pointer], :
|
23
|
+
:GetErrorMessage, callback(%i[pointer], :pointer),
|
24
24
|
:CreateEnv, callback(%i[int string pointer], :pointer),
|
25
25
|
:CreateEnvWithCustomLogger, callback(%i[], :pointer),
|
26
26
|
:EnableTelemetryEvents, callback(%i[pointer], :pointer),
|
@@ -29,21 +29,21 @@ module OnnxRuntime
|
|
29
29
|
:CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
|
30
30
|
:Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
|
31
31
|
:CreateSessionOptions, callback(%i[pointer], :pointer),
|
32
|
-
:SetOptimizedModelFilePath, callback(%i[], :pointer),
|
32
|
+
:SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
|
33
33
|
:CloneSessionOptions, callback(%i[], :pointer),
|
34
34
|
:SetSessionExecutionMode, callback(%i[], :pointer),
|
35
|
-
:EnableProfiling, callback(%i[], :pointer),
|
36
|
-
:DisableProfiling, callback(%i[], :pointer),
|
37
|
-
:EnableMemPattern, callback(%i[], :pointer),
|
38
|
-
:DisableMemPattern, callback(%i[], :pointer),
|
39
|
-
:EnableCpuMemArena, callback(%i[], :pointer),
|
40
|
-
:DisableCpuMemArena, callback(%i[], :pointer),
|
41
|
-
:SetSessionLogId, callback(%i[], :pointer),
|
42
|
-
:SetSessionLogVerbosityLevel, callback(%i[], :pointer),
|
43
|
-
:SetSessionLogSeverityLevel, callback(%i[], :pointer),
|
44
|
-
:SetSessionGraphOptimizationLevel, callback(%i[], :pointer),
|
45
|
-
:SetIntraOpNumThreads, callback(%i[], :pointer),
|
46
|
-
:SetInterOpNumThreads, callback(%i[], :pointer),
|
35
|
+
:EnableProfiling, callback(%i[pointer string], :pointer),
|
36
|
+
:DisableProfiling, callback(%i[pointer], :pointer),
|
37
|
+
:EnableMemPattern, callback(%i[pointer], :pointer),
|
38
|
+
:DisableMemPattern, callback(%i[pointer], :pointer),
|
39
|
+
:EnableCpuMemArena, callback(%i[pointer], :pointer),
|
40
|
+
:DisableCpuMemArena, callback(%i[pointer], :pointer),
|
41
|
+
:SetSessionLogId, callback(%i[pointer string], :pointer),
|
42
|
+
:SetSessionLogVerbosityLevel, callback(%i[pointer int], :pointer),
|
43
|
+
:SetSessionLogSeverityLevel, callback(%i[pointer int], :pointer),
|
44
|
+
:SetSessionGraphOptimizationLevel, callback(%i[pointer int], :pointer),
|
45
|
+
:SetIntraOpNumThreads, callback(%i[pointer int], :pointer),
|
46
|
+
:SetInterOpNumThreads, callback(%i[pointer int], :pointer),
|
47
47
|
:CreateCustomOpDomain, callback(%i[], :pointer),
|
48
48
|
:CustomOpDomain_Add, callback(%i[], :pointer),
|
49
49
|
:AddCustomOpDomain, callback(%i[], :pointer),
|
@@ -57,15 +57,15 @@ module OnnxRuntime
|
|
57
57
|
:SessionGetInputName, callback(%i[pointer size_t pointer pointer], :pointer),
|
58
58
|
:SessionGetOutputName, callback(%i[pointer size_t pointer pointer], :pointer),
|
59
59
|
:SessionGetOverridableInitializerName, callback(%i[], :pointer),
|
60
|
-
:CreateRunOptions, callback(%i[], :pointer),
|
61
|
-
:RunOptionsSetRunLogVerbosityLevel, callback(%i[], :pointer),
|
62
|
-
:RunOptionsSetRunLogSeverityLevel, callback(%i[], :pointer),
|
63
|
-
:RunOptionsSetRunTag, callback(%i[], :pointer),
|
60
|
+
:CreateRunOptions, callback(%i[pointer], :pointer),
|
61
|
+
:RunOptionsSetRunLogVerbosityLevel, callback(%i[pointer int], :pointer),
|
62
|
+
:RunOptionsSetRunLogSeverityLevel, callback(%i[pointer int], :pointer),
|
63
|
+
:RunOptionsSetRunTag, callback(%i[pointer string], :pointer),
|
64
64
|
:RunOptionsGetRunLogVerbosityLevel, callback(%i[], :pointer),
|
65
65
|
:RunOptionsGetRunLogSeverityLevel, callback(%i[], :pointer),
|
66
66
|
:RunOptionsGetRunTag, callback(%i[], :pointer),
|
67
|
-
:RunOptionsSetTerminate, callback(%i[], :pointer),
|
68
|
-
:RunOptionsUnsetTerminate, callback(%i[], :pointer),
|
67
|
+
:RunOptionsSetTerminate, callback(%i[pointer], :pointer),
|
68
|
+
:RunOptionsUnsetTerminate, callback(%i[pointer], :pointer),
|
69
69
|
:CreateTensorAsOrtValue, callback(%i[pointer pointer size_t int pointer], :pointer),
|
70
70
|
:CreateTensorWithDataAsOrtValue, callback(%i[pointer pointer size_t pointer size_t int pointer], :pointer),
|
71
71
|
:IsTensor, callback(%i[], :pointer),
|
@@ -119,7 +119,30 @@ module OnnxRuntime
|
|
119
119
|
:ReleaseTypeInfo, callback(%i[pointer], :void),
|
120
120
|
:ReleaseTensorTypeAndShapeInfo, callback(%i[pointer], :void),
|
121
121
|
:ReleaseSessionOptions, callback(%i[pointer], :void),
|
122
|
-
:ReleaseCustomOpDomain, callback(%i[pointer], :void)
|
122
|
+
:ReleaseCustomOpDomain, callback(%i[pointer], :void),
|
123
|
+
:GetDenotationFromTypeInfo, callback(%i[], :pointer),
|
124
|
+
:CastTypeInfoToMapTypeInfo, callback(%i[pointer pointer], :pointer),
|
125
|
+
:CastTypeInfoToSequenceTypeInfo, callback(%i[pointer pointer], :pointer),
|
126
|
+
:GetMapKeyType, callback(%i[pointer pointer], :pointer),
|
127
|
+
:GetMapValueType, callback(%i[pointer pointer], :pointer),
|
128
|
+
:GetSequenceElementType, callback(%i[pointer pointer], :pointer),
|
129
|
+
:ReleaseMapTypeInfo, callback(%i[pointer], :void),
|
130
|
+
:ReleaseSequenceTypeInfo, callback(%i[pointer], :void),
|
131
|
+
:SessionEndProfiling, callback(%i[pointer pointer pointer], :pointer),
|
132
|
+
:SessionGetModelMetadata, callback(%i[pointer pointer], :pointer),
|
133
|
+
:ModelMetadataGetProducerName, callback(%i[pointer pointer pointer], :pointer),
|
134
|
+
:ModelMetadataGetGraphName, callback(%i[pointer pointer pointer], :pointer),
|
135
|
+
:ModelMetadataGetDomain, callback(%i[pointer pointer pointer], :pointer),
|
136
|
+
:ModelMetadataGetDescription, callback(%i[pointer pointer pointer], :pointer),
|
137
|
+
:ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer),
|
138
|
+
:ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer),
|
139
|
+
:ReleaseModelMetadata, callback(%i[pointer], :void),
|
140
|
+
:CreateEnvWithGlobalThreadPools, callback(%i[], :pointer),
|
141
|
+
:DisablePerSessionThreads, callback(%i[], :pointer),
|
142
|
+
:CreateThreadingOptions, callback(%i[], :pointer),
|
143
|
+
:ReleaseThreadingOptions, callback(%i[], :pointer),
|
144
|
+
:ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
|
145
|
+
:AddFreeDimensionOverrideByName, callback(%i[], :pointer)
|
123
146
|
end
|
124
147
|
|
125
148
|
class ApiBase < ::FFI::Struct
|
@@ -127,7 +150,7 @@ module OnnxRuntime
|
|
127
150
|
# to prevent "unable to resolve type" error on Ubuntu
|
128
151
|
layout \
|
129
152
|
:GetApi, callback(%i[uint32], Api.by_ref),
|
130
|
-
:GetVersionString, callback(%i[], :
|
153
|
+
:GetVersionString, callback(%i[], :pointer)
|
131
154
|
end
|
132
155
|
|
133
156
|
attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
|
@@ -2,21 +2,50 @@ module OnnxRuntime
|
|
2
2
|
class InferenceSession
|
3
3
|
attr_reader :inputs, :outputs
|
4
4
|
|
5
|
-
def initialize(path_or_bytes)
|
5
|
+
def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil)
|
6
6
|
# session options
|
7
7
|
session_options = ::FFI::MemoryPointer.new(:pointer)
|
8
8
|
check_status api[:CreateSessionOptions].call(session_options)
|
9
|
+
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
|
+
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
|
12
|
+
if execution_mode
|
13
|
+
execution_modes = {sequential: 0, parallel: 1}
|
14
|
+
mode = execution_modes[execution_mode]
|
15
|
+
raise ArgumentError, "Invalid execution mode" unless mode
|
16
|
+
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
|
17
|
+
end
|
18
|
+
if graph_optimization_level
|
19
|
+
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
+
# TODO raise error in 0.4.0
|
21
|
+
level = optimization_levels[graph_optimization_level] || graph_optimization_level
|
22
|
+
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
|
+
end
|
24
|
+
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
25
|
+
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
|
26
|
+
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
27
|
+
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
28
|
+
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
|
9
30
|
|
10
31
|
# session
|
11
32
|
@session = ::FFI::MemoryPointer.new(:pointer)
|
12
|
-
|
33
|
+
from_memory =
|
34
|
+
if path_or_bytes.respond_to?(:read)
|
35
|
+
path_or_bytes = path_or_bytes.read
|
36
|
+
true
|
37
|
+
else
|
38
|
+
path_or_bytes = path_or_bytes.to_str
|
39
|
+
path_or_bytes.encoding == Encoding::BINARY
|
40
|
+
end
|
13
41
|
|
14
42
|
# fix for Windows "File doesn't exist"
|
15
|
-
if Gem.win_platform? &&
|
43
|
+
if Gem.win_platform? && !from_memory
|
16
44
|
path_or_bytes = File.binread(path_or_bytes)
|
45
|
+
from_memory = true
|
17
46
|
end
|
18
47
|
|
19
|
-
if
|
48
|
+
if from_memory
|
20
49
|
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
21
50
|
else
|
22
51
|
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
@@ -53,7 +82,8 @@ module OnnxRuntime
|
|
53
82
|
end
|
54
83
|
end
|
55
84
|
|
56
|
-
|
85
|
+
# TODO support logid
|
86
|
+
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
|
57
87
|
input_tensor = create_input_tensor(input_feed)
|
58
88
|
|
59
89
|
output_names ||= @outputs.map { |v| v[:name] }
|
@@ -61,14 +91,66 @@ module OnnxRuntime
|
|
61
91
|
output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
|
62
92
|
input_node_names = create_node_names(input_feed.keys.map(&:to_s))
|
63
93
|
output_node_names = create_node_names(output_names.map(&:to_s))
|
64
|
-
|
65
|
-
|
94
|
+
|
95
|
+
# run options
|
96
|
+
run_options = ::FFI::MemoryPointer.new(:pointer)
|
97
|
+
check_status api[:CreateRunOptions].call(run_options)
|
98
|
+
check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level
|
99
|
+
check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
100
|
+
check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid
|
101
|
+
check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate
|
102
|
+
|
103
|
+
check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
|
66
104
|
|
67
105
|
output_names.size.times.map do |i|
|
68
106
|
create_from_onnx_value(output_tensor[i].read_pointer)
|
69
107
|
end
|
70
108
|
end
|
71
109
|
|
110
|
+
def modelmeta
|
111
|
+
keys = ::FFI::MemoryPointer.new(:pointer)
|
112
|
+
num_keys = ::FFI::MemoryPointer.new(:int64_t)
|
113
|
+
description = ::FFI::MemoryPointer.new(:string)
|
114
|
+
domain = ::FFI::MemoryPointer.new(:string)
|
115
|
+
graph_name = ::FFI::MemoryPointer.new(:string)
|
116
|
+
producer_name = ::FFI::MemoryPointer.new(:string)
|
117
|
+
version = ::FFI::MemoryPointer.new(:int64_t)
|
118
|
+
|
119
|
+
metadata = ::FFI::MemoryPointer.new(:pointer)
|
120
|
+
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
121
|
+
|
122
|
+
custom_metadata_map = {}
|
123
|
+
check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
124
|
+
num_keys.read(:int64_t).times do |i|
|
125
|
+
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
126
|
+
value = ::FFI::MemoryPointer.new(:string)
|
127
|
+
check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
|
128
|
+
custom_metadata_map[key] = value.read_pointer.read_string
|
129
|
+
end
|
130
|
+
|
131
|
+
check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
|
132
|
+
check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
|
133
|
+
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
|
134
|
+
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
|
135
|
+
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
|
136
|
+
api[:ReleaseModelMetadata].call(metadata.read_pointer)
|
137
|
+
|
138
|
+
{
|
139
|
+
custom_metadata_map: custom_metadata_map,
|
140
|
+
description: description.read_pointer.read_string,
|
141
|
+
domain: domain.read_pointer.read_string,
|
142
|
+
graph_name: graph_name.read_pointer.read_string,
|
143
|
+
producer_name: producer_name.read_pointer.read_string,
|
144
|
+
version: version.read(:int64_t)
|
145
|
+
}
|
146
|
+
end
|
147
|
+
|
148
|
+
def end_profiling
|
149
|
+
out = ::FFI::MemoryPointer.new(:string)
|
150
|
+
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
151
|
+
out.read_pointer.read_string
|
152
|
+
end
|
153
|
+
|
72
154
|
private
|
73
155
|
|
74
156
|
def create_input_tensor(input_feed)
|
@@ -208,7 +290,7 @@ module OnnxRuntime
|
|
208
290
|
|
209
291
|
def check_status(status)
|
210
292
|
unless status.null?
|
211
|
-
message = api[:GetErrorMessage].call(status)
|
293
|
+
message = api[:GetErrorMessage].call(status).read_string
|
212
294
|
api[:ReleaseStatus].call(status)
|
213
295
|
raise OnnxRuntime::Error, message
|
214
296
|
end
|
@@ -230,15 +312,32 @@ module OnnxRuntime
|
|
230
312
|
shape: shape
|
231
313
|
}
|
232
314
|
when :sequence
|
233
|
-
|
315
|
+
sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
|
316
|
+
check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
|
317
|
+
nested_type_info = ::FFI::MemoryPointer.new(:pointer)
|
318
|
+
check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
|
319
|
+
v = node_info(nested_type_info)[:type]
|
320
|
+
|
234
321
|
{
|
235
|
-
type: "seq",
|
322
|
+
type: "seq(#{v})",
|
236
323
|
shape: []
|
237
324
|
}
|
238
325
|
when :map
|
239
|
-
|
326
|
+
map_type_info = ::FFI::MemoryPointer.new(:pointer)
|
327
|
+
check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
|
328
|
+
|
329
|
+
# key
|
330
|
+
key_type = ::FFI::MemoryPointer.new(:int)
|
331
|
+
check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
|
332
|
+
k = FFI::TensorElementDataType[key_type.read_int]
|
333
|
+
|
334
|
+
# value
|
335
|
+
value_type_info = ::FFI::MemoryPointer.new(:pointer)
|
336
|
+
check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
|
337
|
+
v = node_info(value_type_info)[:type]
|
338
|
+
|
240
339
|
{
|
241
|
-
type: "map",
|
340
|
+
type: "map(#{k},#{v})",
|
242
341
|
shape: []
|
243
342
|
}
|
244
343
|
else
|
data/lib/onnxruntime/model.rb
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
module OnnxRuntime
|
2
2
|
class Model
|
3
|
-
def initialize(path_or_bytes)
|
4
|
-
@session = InferenceSession.new(path_or_bytes)
|
3
|
+
def initialize(path_or_bytes, **session_options)
|
4
|
+
@session = InferenceSession.new(path_or_bytes, **session_options)
|
5
5
|
end
|
6
6
|
|
7
|
-
def predict(input_feed, output_names: nil)
|
8
|
-
predictions = @session.run(output_names, input_feed)
|
7
|
+
def predict(input_feed, output_names: nil, **run_options)
|
8
|
+
predictions = @session.run(output_names, input_feed, **run_options)
|
9
9
|
output_names ||= outputs.map { |o| o[:name] }
|
10
10
|
|
11
11
|
result = {}
|
@@ -22,5 +22,9 @@ module OnnxRuntime
|
|
22
22
|
def outputs
|
23
23
|
@session.outputs
|
24
24
|
end
|
25
|
+
|
26
|
+
def metadata
|
27
|
+
@session.modelmeta
|
28
|
+
end
|
25
29
|
end
|
26
30
|
end
|
data/lib/onnxruntime/version.rb
CHANGED
data/vendor/libonnxruntime.dylib
CHANGED
Binary file
|
data/vendor/libonnxruntime.so
CHANGED
Binary file
|
data/vendor/onnxruntime.dll
CHANGED
Binary file
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: onnxruntime
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2
|
4
|
+
version: 0.3.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2020-06-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: ffi
|
@@ -76,6 +76,7 @@ files:
|
|
76
76
|
- LICENSE.txt
|
77
77
|
- README.md
|
78
78
|
- lib/onnxruntime.rb
|
79
|
+
- lib/onnxruntime/datasets.rb
|
79
80
|
- lib/onnxruntime/ffi.rb
|
80
81
|
- lib/onnxruntime/inference_session.rb
|
81
82
|
- lib/onnxruntime/model.rb
|
@@ -105,7 +106,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
105
106
|
- !ruby/object:Gem::Version
|
106
107
|
version: '0'
|
107
108
|
requirements: []
|
108
|
-
rubygems_version: 3.
|
109
|
+
rubygems_version: 3.1.2
|
109
110
|
signing_key:
|
110
111
|
specification_version: 4
|
111
112
|
summary: High performance scoring engine for ML models
|