onnxruntime 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +95 -0
- data/lib/onnxruntime/ffi.rb +71 -0
- data/lib/onnxruntime/inference_session.rb +259 -0
- data/lib/onnxruntime/model.rb +26 -0
- data/lib/onnxruntime/utils.rb +10 -0
- data/lib/onnxruntime/version.rb +3 -0
- data/lib/onnxruntime.rb +24 -0
- metadata +107 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: c7a7fd8412c39749689e27196bbb472e006a7686225dacf5f6946dbaa9af0496
|
4
|
+
data.tar.gz: '019cec927488c9ab07a76d19913c99f70cab1df81a0bd62bbb3233d383b1c427'
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 730bf9550ddbd2e8a6cc1d0b58d02464ae59f18c5f2eaf2709f34d701838ff44ab42b478c3c014e97fc2f578a0ec937d204b91ab6ec596f0a1324c119d64863e
|
7
|
+
data.tar.gz: 16972280e527e11588f478be8e2ddb850b3a090432c39868f969d9cd9627890b4de721fa66a58a9b226430708aa20316da62456176bb24b1e4fe984b245658d3
|
data/CHANGELOG.md
ADDED
data/LICENSE.txt
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
Copyright (c) 2019 Andrew Kane
|
2
|
+
|
3
|
+
MIT License
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining
|
6
|
+
a copy of this software and associated documentation files (the
|
7
|
+
"Software"), to deal in the Software without restriction, including
|
8
|
+
without limitation the rights to use, copy, modify, merge, publish,
|
9
|
+
distribute, sublicense, and/or sell copies of the Software, and to
|
10
|
+
permit persons to whom the Software is furnished to do so, subject to
|
11
|
+
the following conditions:
|
12
|
+
|
13
|
+
The above copyright notice and this permission notice shall be
|
14
|
+
included in all copies or substantial portions of the Software.
|
15
|
+
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
17
|
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
18
|
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
19
|
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
20
|
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
21
|
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
22
|
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
data/README.md
ADDED
@@ -0,0 +1,95 @@
|
|
1
|
+
# ONNX Runtime
|
2
|
+
|
3
|
+
:fire: [ONNX Runtime](https://github.com/Microsoft/onnxruntime) - the high performance scoring engine for ML models - for Ruby
|
4
|
+
|
5
|
+
## Installation
|
6
|
+
|
7
|
+
First, [install ONNX Runtime](#onnx-runtime-installation).
|
8
|
+
|
9
|
+
Add this line to your application’s Gemfile:
|
10
|
+
|
11
|
+
```ruby
|
12
|
+
gem 'onnxruntime'
|
13
|
+
```
|
14
|
+
|
15
|
+
## Getting Started
|
16
|
+
|
17
|
+
Load a model and make predictions
|
18
|
+
|
19
|
+
```ruby
|
20
|
+
model = OnnxRuntime::Model.new("model.onnx")
|
21
|
+
model.predict(x: [1, 2, 3])
|
22
|
+
```
|
23
|
+
|
24
|
+
Get inputs
|
25
|
+
|
26
|
+
```ruby
|
27
|
+
model.inputs
|
28
|
+
```
|
29
|
+
|
30
|
+
Get outputs
|
31
|
+
|
32
|
+
```ruby
|
33
|
+
model.outputs
|
34
|
+
```
|
35
|
+
|
36
|
+
Load a model from a string
|
37
|
+
|
38
|
+
```ruby
|
39
|
+
byte_str = File.binread("model.onnx")
|
40
|
+
model = OnnxRuntime::Model.new(byte_str)
|
41
|
+
```
|
42
|
+
|
43
|
+
Get specific outputs
|
44
|
+
|
45
|
+
```ruby
|
46
|
+
model.predict({x: [1, 2, 3]}, output_names: ["label"])
|
47
|
+
```
|
48
|
+
|
49
|
+
## Inference Session API
|
50
|
+
|
51
|
+
You can also use the Inference Session API, which follows the [Python API](https://microsoft.github.io/onnxruntime/api_summary.html).
|
52
|
+
|
53
|
+
```ruby
|
54
|
+
session = OnnxRuntime::InferenceSession.new("model.onnx")
|
55
|
+
session.run(nil, {x: [1, 2, 3]})
|
56
|
+
```
|
57
|
+
|
58
|
+
## ONNX Runtime Installation
|
59
|
+
|
60
|
+
ONNX Runtime provides [prebuilt libraries](https://github.com/microsoft/onnxruntime/releases).
|
61
|
+
|
62
|
+
### Mac
|
63
|
+
|
64
|
+
```sh
|
65
|
+
wget https://github.com/microsoft/onnxruntime/releases/download/v0.5.0/onnxruntime-osx-x64-0.5.0.tgz
|
66
|
+
tar xf onnxruntime-osx-x64-0.5.0.tgz
|
67
|
+
cd onnxruntime-osx-x64-0.5.0
|
68
|
+
cp lib/libonnxruntime.0.5.0.dylib /usr/local/lib/libonnxruntime.dylib
|
69
|
+
```
|
70
|
+
|
71
|
+
### Linux
|
72
|
+
|
73
|
+
```sh
|
74
|
+
wget https://github.com/microsoft/onnxruntime/releases/download/v0.5.0/onnxruntime-linux-x64-0.5.0.tgz
|
75
|
+
tar xf onnxruntime-linux-x64-0.5.0.tgz
|
76
|
+
cd onnxruntime-linux-x64-0.5.0.tgz
|
77
|
+
cp lib/libonnxruntime.0.5.0.so /usr/local/lib/libonnxruntime.so
|
78
|
+
```
|
79
|
+
|
80
|
+
### Windows
|
81
|
+
|
82
|
+
Download [ONNX Runtime](https://github.com/microsoft/onnxruntime/releases/download/v0.5.0/onnxruntime-win-x64-0.5.0.zip). Unzip and move `lib/onnxruntime.dll` to `C:\Windows\System32\onnxruntime.dll`.
|
83
|
+
|
84
|
+
## History
|
85
|
+
|
86
|
+
View the [changelog](https://github.com/ankane/onnxruntime/blob/master/CHANGELOG.md)
|
87
|
+
|
88
|
+
## Contributing
|
89
|
+
|
90
|
+
Everyone is encouraged to help improve this project. Here are a few ways you can help:
|
91
|
+
|
92
|
+
- [Report bugs](https://github.com/ankane/onnxruntime/issues)
|
93
|
+
- Fix bugs and [submit pull requests](https://github.com/ankane/onnxruntime/pulls)
|
94
|
+
- Write, clarify, or fix documentation
|
95
|
+
- Suggest or add new features
|
@@ -0,0 +1,71 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
module FFI
|
3
|
+
extend ::FFI::Library
|
4
|
+
|
5
|
+
begin
|
6
|
+
ffi_lib OnnxRuntime.ffi_lib
|
7
|
+
rescue LoadError => e
|
8
|
+
raise e if ENV["ONNXRUNTIME_DEBUG"]
|
9
|
+
raise LoadError, "Could not find ONNX Runtime"
|
10
|
+
end
|
11
|
+
|
12
|
+
# https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
|
13
|
+
# keep same order
|
14
|
+
|
15
|
+
# enums
|
16
|
+
TensorElementDataType = enum(:undefined, :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :string, :bool, :float16, :double, :uint32, :uint64, :complex64, :complex128, :bfloat16)
|
17
|
+
OnnxType = enum(:unknown, :tensor, :sequence, :map, :opaque, :sparsetensor)
|
18
|
+
|
19
|
+
# session
|
20
|
+
attach_function :OrtCreateEnv, %i[int string pointer], :pointer
|
21
|
+
attach_function :OrtCreateSession, %i[pointer string pointer pointer], :pointer
|
22
|
+
attach_function :OrtCreateSessionFromArray, %i[pointer pointer size_t pointer pointer], :pointer
|
23
|
+
attach_function :OrtRun, %i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer
|
24
|
+
attach_function :OrtCreateSessionOptions, %i[pointer], :pointer
|
25
|
+
attach_function :OrtSetSessionGraphOptimizationLevel, %i[pointer int], :pointer
|
26
|
+
attach_function :OrtSetSessionThreadPoolSize, %i[pointer int], :pointer
|
27
|
+
|
28
|
+
# input and output
|
29
|
+
attach_function :OrtSessionGetInputCount, %i[pointer pointer], :pointer
|
30
|
+
attach_function :OrtSessionGetOutputCount, %i[pointer pointer], :pointer
|
31
|
+
attach_function :OrtSessionGetInputTypeInfo, %i[pointer size_t pointer], :pointer
|
32
|
+
attach_function :OrtSessionGetOutputTypeInfo, %i[pointer size_t pointer], :pointer
|
33
|
+
attach_function :OrtSessionGetInputName, %i[pointer size_t pointer pointer], :pointer
|
34
|
+
attach_function :OrtSessionGetOutputName, %i[pointer size_t pointer pointer], :pointer
|
35
|
+
|
36
|
+
# tensor
|
37
|
+
attach_function :OrtCreateTensorWithDataAsOrtValue, %i[pointer pointer size_t pointer size_t int pointer], :pointer
|
38
|
+
attach_function :OrtGetTensorMutableData, %i[pointer pointer], :pointer
|
39
|
+
attach_function :OrtIsTensor, %i[pointer pointer], :pointer
|
40
|
+
attach_function :OrtCastTypeInfoToTensorInfo, %i[pointer pointer], :pointer
|
41
|
+
attach_function :OrtOnnxTypeFromTypeInfo, %i[pointer pointer], :pointer
|
42
|
+
attach_function :OrtGetTensorElementType, %i[pointer pointer], :pointer
|
43
|
+
attach_function :OrtGetDimensionsCount, %i[pointer pointer], :pointer
|
44
|
+
attach_function :OrtGetDimensions, %i[pointer pointer size_t], :pointer
|
45
|
+
attach_function :OrtGetTensorShapeElementCount, %i[pointer pointer], :pointer
|
46
|
+
attach_function :OrtGetTensorTypeAndShape, %i[pointer pointer], :pointer
|
47
|
+
|
48
|
+
# value
|
49
|
+
attach_function :OrtGetTypeInfo, %i[pointer pointer], :pointer
|
50
|
+
attach_function :OrtGetValueType, %i[pointer pointer], :pointer
|
51
|
+
|
52
|
+
# maps and sequences
|
53
|
+
attach_function :OrtGetValue, %i[pointer int pointer pointer], :pointer
|
54
|
+
attach_function :OrtGetValueCount, %i[pointer pointer], :pointer
|
55
|
+
|
56
|
+
# version
|
57
|
+
attach_function :OrtGetVersionString, %i[], :string
|
58
|
+
|
59
|
+
# error
|
60
|
+
attach_function :OrtGetErrorMessage, %i[pointer], :string
|
61
|
+
|
62
|
+
# allocator
|
63
|
+
attach_function :OrtCreateCpuAllocatorInfo, %i[int int pointer], :pointer
|
64
|
+
attach_function :OrtCreateDefaultAllocator, %i[pointer], :pointer
|
65
|
+
|
66
|
+
# release
|
67
|
+
attach_function :OrtReleaseEnv, %i[pointer], :pointer
|
68
|
+
attach_function :OrtReleaseTypeInfo, %i[pointer], :pointer
|
69
|
+
attach_function :OrtReleaseStatus, %i[pointer], :pointer
|
70
|
+
end
|
71
|
+
end
|
@@ -0,0 +1,259 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
class InferenceSession
|
3
|
+
attr_reader :inputs, :outputs
|
4
|
+
|
5
|
+
def initialize(path_or_bytes)
|
6
|
+
# session options
|
7
|
+
session_options = ::FFI::MemoryPointer.new(:pointer)
|
8
|
+
check_status FFI.OrtCreateSessionOptions(session_options)
|
9
|
+
|
10
|
+
# session
|
11
|
+
@session = ::FFI::MemoryPointer.new(:pointer)
|
12
|
+
path_or_bytes = path_or_bytes.to_str
|
13
|
+
if path_or_bytes.encoding == Encoding::BINARY
|
14
|
+
check_status FFI.OrtCreateSessionFromArray(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
15
|
+
else
|
16
|
+
check_status FFI.OrtCreateSession(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
|
17
|
+
end
|
18
|
+
|
19
|
+
# input info
|
20
|
+
allocator = ::FFI::MemoryPointer.new(:pointer)
|
21
|
+
check_status FFI.OrtCreateDefaultAllocator(allocator)
|
22
|
+
@allocator = allocator
|
23
|
+
|
24
|
+
@inputs = []
|
25
|
+
@outputs = []
|
26
|
+
|
27
|
+
# input
|
28
|
+
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
29
|
+
check_status FFI.OrtSessionGetInputCount(read_pointer, num_input_nodes)
|
30
|
+
num_input_nodes.read(:size_t).times do |i|
|
31
|
+
name_ptr = ::FFI::MemoryPointer.new(:string)
|
32
|
+
check_status FFI.OrtSessionGetInputName(read_pointer, i, @allocator.read_pointer, name_ptr)
|
33
|
+
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
34
|
+
check_status FFI.OrtSessionGetInputTypeInfo(read_pointer, i, typeinfo)
|
35
|
+
@inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
36
|
+
end
|
37
|
+
|
38
|
+
# output
|
39
|
+
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
40
|
+
check_status FFI.OrtSessionGetOutputCount(read_pointer, num_output_nodes)
|
41
|
+
num_output_nodes.read(:size_t).times do |i|
|
42
|
+
name_ptr = ::FFI::MemoryPointer.new(:string)
|
43
|
+
check_status FFI.OrtSessionGetOutputName(read_pointer, i, allocator.read_pointer, name_ptr)
|
44
|
+
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
45
|
+
check_status FFI.OrtSessionGetOutputTypeInfo(read_pointer, i, typeinfo)
|
46
|
+
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
def run(output_names, input_feed)
|
51
|
+
input_tensor = create_input_tensor(input_feed)
|
52
|
+
|
53
|
+
outputs = @outputs
|
54
|
+
if output_names
|
55
|
+
output_names = output_names.map(&:to_s)
|
56
|
+
outputs = outputs.select { |o| output_names.include?(o[:name]) }
|
57
|
+
end
|
58
|
+
|
59
|
+
output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
|
60
|
+
input_node_names = create_node_names(input_feed.keys.map(&:to_s))
|
61
|
+
output_node_names = create_node_names(outputs.map { |v| v[:name] })
|
62
|
+
# TODO support run options
|
63
|
+
check_status FFI.OrtRun(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, outputs.size, output_tensor)
|
64
|
+
|
65
|
+
outputs.size.times.map do |i|
|
66
|
+
create_from_onnx_value(output_tensor[i].read_pointer)
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
private
|
71
|
+
|
72
|
+
def create_input_tensor(input_feed)
|
73
|
+
allocator_info = ::FFI::MemoryPointer.new(:pointer)
|
74
|
+
check_status = FFI.OrtCreateCpuAllocatorInfo(1, 0, allocator_info)
|
75
|
+
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
|
76
|
+
|
77
|
+
input_feed.each_with_index do |(input_name, input), idx|
|
78
|
+
shape = []
|
79
|
+
s = input
|
80
|
+
while s.is_a?(Array)
|
81
|
+
shape << s.size
|
82
|
+
s = s.first
|
83
|
+
end
|
84
|
+
|
85
|
+
flat_input = input.flatten
|
86
|
+
input_tensor_size = flat_input.size
|
87
|
+
|
88
|
+
# TODO support more types
|
89
|
+
inp = @inputs.find { |i| i[:name] == input_name.to_s } || {}
|
90
|
+
case inp[:type]
|
91
|
+
when "tensor(bool)"
|
92
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:uchar, input_tensor_size)
|
93
|
+
input_tensor_values.write_array_of_uchar(flat_input.map { |v| v ? 1 : 0 })
|
94
|
+
type_enum = FFI::TensorElementDataType[:bool]
|
95
|
+
else
|
96
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:float, input_tensor_size)
|
97
|
+
input_tensor_values.write_array_of_float(flat_input)
|
98
|
+
type_enum = FFI::TensorElementDataType[:float]
|
99
|
+
end
|
100
|
+
|
101
|
+
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
102
|
+
input_node_dims.write_array_of_int64(shape)
|
103
|
+
check_status FFI.OrtCreateTensorWithDataAsOrtValue(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx])
|
104
|
+
end
|
105
|
+
|
106
|
+
input_tensor
|
107
|
+
end
|
108
|
+
|
109
|
+
def create_node_names(names)
|
110
|
+
ptr = ::FFI::MemoryPointer.new(:pointer, names.size)
|
111
|
+
ptr.write_array_of_pointer(names.map { |v| ::FFI::MemoryPointer.from_string(v) })
|
112
|
+
ptr
|
113
|
+
end
|
114
|
+
|
115
|
+
def create_from_onnx_value(out_ptr)
|
116
|
+
out_type = ::FFI::MemoryPointer.new(:int)
|
117
|
+
check_status = FFI.OrtGetValueType(out_ptr, out_type)
|
118
|
+
type = FFI::OnnxType[out_type.read_int]
|
119
|
+
|
120
|
+
case type
|
121
|
+
when :tensor
|
122
|
+
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
123
|
+
check_status FFI.OrtGetTensorTypeAndShape(out_ptr, typeinfo)
|
124
|
+
|
125
|
+
type, shape = tensor_type_and_shape(typeinfo)
|
126
|
+
|
127
|
+
tensor_data = ::FFI::MemoryPointer.new(:pointer)
|
128
|
+
check_status FFI.OrtGetTensorMutableData(out_ptr, tensor_data)
|
129
|
+
|
130
|
+
out_size = ::FFI::MemoryPointer.new(:size_t)
|
131
|
+
output_tensor_size = FFI.OrtGetTensorShapeElementCount(typeinfo.read_pointer, out_size)
|
132
|
+
output_tensor_size = out_size.read(:size_t)
|
133
|
+
|
134
|
+
# TODO support more types
|
135
|
+
type = FFI::TensorElementDataType[type]
|
136
|
+
arr =
|
137
|
+
case type
|
138
|
+
when :float
|
139
|
+
tensor_data.read_pointer.read_array_of_float(output_tensor_size)
|
140
|
+
when :int64
|
141
|
+
tensor_data.read_pointer.read_array_of_int64(output_tensor_size)
|
142
|
+
when :bool
|
143
|
+
tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
|
144
|
+
else
|
145
|
+
raise "Unsupported element type: #{type}"
|
146
|
+
end
|
147
|
+
|
148
|
+
Utils.reshape(arr, shape)
|
149
|
+
when :sequence
|
150
|
+
out = ::FFI::MemoryPointer.new(:size_t)
|
151
|
+
check_status FFI.OrtGetValueCount(out_ptr, out)
|
152
|
+
|
153
|
+
out.read(:size_t).times.map do |i|
|
154
|
+
seq = ::FFI::MemoryPointer.new(:pointer)
|
155
|
+
check_status FFI.OrtGetValue(out_ptr, i, @allocator.read_pointer, seq)
|
156
|
+
create_from_onnx_value(seq.read_pointer)
|
157
|
+
end
|
158
|
+
when :map
|
159
|
+
type_shape = ::FFI::MemoryPointer.new(:pointer)
|
160
|
+
map_keys = ::FFI::MemoryPointer.new(:pointer)
|
161
|
+
map_values = ::FFI::MemoryPointer.new(:pointer)
|
162
|
+
elem_type = ::FFI::MemoryPointer.new(:int)
|
163
|
+
|
164
|
+
check_status FFI.OrtGetValue(out_ptr, 0, @allocator.read_pointer, map_keys)
|
165
|
+
check_status FFI.OrtGetValue(out_ptr, 1, @allocator.read_pointer, map_values)
|
166
|
+
check_status FFI.OrtGetTensorTypeAndShape(map_keys.read_pointer, type_shape)
|
167
|
+
check_status FFI.OrtGetTensorElementType(type_shape.read_pointer, elem_type)
|
168
|
+
|
169
|
+
# TODO support more types
|
170
|
+
elem_type = FFI::TensorElementDataType[elem_type.read_int]
|
171
|
+
case elem_type
|
172
|
+
when :int64
|
173
|
+
ret = {}
|
174
|
+
keys = create_from_onnx_value(map_keys.read_pointer)
|
175
|
+
values = create_from_onnx_value(map_values.read_pointer)
|
176
|
+
keys.zip(values).each do |k, v|
|
177
|
+
ret[k] = v
|
178
|
+
end
|
179
|
+
ret
|
180
|
+
else
|
181
|
+
raise "Unsupported element type: #{elem_type}"
|
182
|
+
end
|
183
|
+
else
|
184
|
+
raise "Unsupported ONNX type: #{type}"
|
185
|
+
end
|
186
|
+
end
|
187
|
+
|
188
|
+
def read_pointer
|
189
|
+
@session.read_pointer
|
190
|
+
end
|
191
|
+
|
192
|
+
def check_status(status)
|
193
|
+
unless status.null?
|
194
|
+
message = FFI.OrtGetErrorMessage(status)
|
195
|
+
FFI.OrtReleaseStatus(status)
|
196
|
+
raise OnnxRuntime::Error, message
|
197
|
+
end
|
198
|
+
end
|
199
|
+
|
200
|
+
def node_info(typeinfo)
|
201
|
+
onnx_type = ::FFI::MemoryPointer.new(:int)
|
202
|
+
check_status FFI.OrtOnnxTypeFromTypeInfo(typeinfo.read_pointer, onnx_type)
|
203
|
+
|
204
|
+
type = FFI::OnnxType[onnx_type.read_int]
|
205
|
+
case type
|
206
|
+
when :tensor
|
207
|
+
tensor_info = ::FFI::MemoryPointer.new(:pointer)
|
208
|
+
check_status FFI.OrtCastTypeInfoToTensorInfo(typeinfo.read_pointer, tensor_info)
|
209
|
+
|
210
|
+
type, shape = tensor_type_and_shape(tensor_info)
|
211
|
+
{
|
212
|
+
type: "tensor(#{FFI::TensorElementDataType[type]})",
|
213
|
+
shape: shape
|
214
|
+
}
|
215
|
+
when :sequence
|
216
|
+
# TODO show nested
|
217
|
+
{
|
218
|
+
type: "seq",
|
219
|
+
shape: []
|
220
|
+
}
|
221
|
+
when :map
|
222
|
+
# TODO show nested
|
223
|
+
{
|
224
|
+
type: "map",
|
225
|
+
shape: []
|
226
|
+
}
|
227
|
+
else
|
228
|
+
raise "Unsupported ONNX type: #{type}"
|
229
|
+
end
|
230
|
+
ensure
|
231
|
+
FFI.OrtReleaseTypeInfo(typeinfo.read_pointer)
|
232
|
+
end
|
233
|
+
|
234
|
+
def tensor_type_and_shape(tensor_info)
|
235
|
+
type = ::FFI::MemoryPointer.new(:int)
|
236
|
+
check_status FFI.OrtGetTensorElementType(tensor_info.read_pointer, type)
|
237
|
+
|
238
|
+
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
239
|
+
check_status FFI.OrtGetDimensionsCount(tensor_info.read_pointer, num_dims_ptr)
|
240
|
+
num_dims = num_dims_ptr.read(:size_t)
|
241
|
+
|
242
|
+
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
243
|
+
check_status FFI.OrtGetDimensions(tensor_info.read_pointer, node_dims, num_dims)
|
244
|
+
|
245
|
+
[type.read_int, node_dims.read_array_of_int64(num_dims)]
|
246
|
+
end
|
247
|
+
|
248
|
+
# share env
|
249
|
+
# TODO mutex around creation?
|
250
|
+
def env
|
251
|
+
@@env ||= begin
|
252
|
+
env = ::FFI::MemoryPointer.new(:pointer)
|
253
|
+
check_status FFI.OrtCreateEnv(3, "Default", env)
|
254
|
+
at_exit { FFI.OrtReleaseEnv(env.read_pointer) }
|
255
|
+
env
|
256
|
+
end
|
257
|
+
end
|
258
|
+
end
|
259
|
+
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
class Model
|
3
|
+
def initialize(path_or_bytes)
|
4
|
+
@session = InferenceSession.new(path_or_bytes)
|
5
|
+
end
|
6
|
+
|
7
|
+
def predict(input_feed, output_names: nil)
|
8
|
+
predictions = @session.run(output_names, input_feed)
|
9
|
+
output_names ||= outputs.map { |o| o[:name] }
|
10
|
+
|
11
|
+
result = {}
|
12
|
+
output_names.zip(predictions).each do |k, v|
|
13
|
+
result[k.to_s] = v
|
14
|
+
end
|
15
|
+
result
|
16
|
+
end
|
17
|
+
|
18
|
+
def inputs
|
19
|
+
@session.inputs
|
20
|
+
end
|
21
|
+
|
22
|
+
def outputs
|
23
|
+
@session.outputs
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
data/lib/onnxruntime.rb
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
# dependencies
|
2
|
+
require "ffi"
|
3
|
+
|
4
|
+
# modules
|
5
|
+
require "onnxruntime/inference_session"
|
6
|
+
require "onnxruntime/model"
|
7
|
+
require "onnxruntime/utils"
|
8
|
+
require "onnxruntime/version"
|
9
|
+
|
10
|
+
module OnnxRuntime
|
11
|
+
class Error < StandardError; end
|
12
|
+
|
13
|
+
class << self
|
14
|
+
attr_accessor :ffi_lib
|
15
|
+
end
|
16
|
+
self.ffi_lib = ["onnxruntime"]
|
17
|
+
|
18
|
+
def self.lib_version
|
19
|
+
FFI.OrtGetVersionString
|
20
|
+
end
|
21
|
+
|
22
|
+
# friendlier error message
|
23
|
+
autoload :FFI, "onnxruntime/ffi"
|
24
|
+
end
|
metadata
ADDED
@@ -0,0 +1,107 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: onnxruntime
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Andrew Kane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2019-08-26 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: ffi
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: bundler
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rake
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: minitest
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - ">="
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: '5'
|
62
|
+
type: :development
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - ">="
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: '5'
|
69
|
+
description:
|
70
|
+
email: andrew@chartkick.com
|
71
|
+
executables: []
|
72
|
+
extensions: []
|
73
|
+
extra_rdoc_files: []
|
74
|
+
files:
|
75
|
+
- CHANGELOG.md
|
76
|
+
- LICENSE.txt
|
77
|
+
- README.md
|
78
|
+
- lib/onnxruntime.rb
|
79
|
+
- lib/onnxruntime/ffi.rb
|
80
|
+
- lib/onnxruntime/inference_session.rb
|
81
|
+
- lib/onnxruntime/model.rb
|
82
|
+
- lib/onnxruntime/utils.rb
|
83
|
+
- lib/onnxruntime/version.rb
|
84
|
+
homepage: https://github.com/ankane/onnxruntime
|
85
|
+
licenses:
|
86
|
+
- MIT
|
87
|
+
metadata: {}
|
88
|
+
post_install_message:
|
89
|
+
rdoc_options: []
|
90
|
+
require_paths:
|
91
|
+
- lib
|
92
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
93
|
+
requirements:
|
94
|
+
- - ">="
|
95
|
+
- !ruby/object:Gem::Version
|
96
|
+
version: '2.4'
|
97
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
98
|
+
requirements:
|
99
|
+
- - ">="
|
100
|
+
- !ruby/object:Gem::Version
|
101
|
+
version: '0'
|
102
|
+
requirements: []
|
103
|
+
rubygems_version: 3.0.3
|
104
|
+
signing_key:
|
105
|
+
specification_version: 4
|
106
|
+
summary: High performance scoring engine for ML models
|
107
|
+
test_files: []
|