onnxruntime 0.1.1 → 0.1.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +4 -0
- data/lib/onnxruntime/inference_session.rb +24 -12
- data/lib/onnxruntime/utils.rb +6 -0
- data/lib/onnxruntime/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d37fddde85ef4715dc9c74d37f3465b9ba773ecfe491789ea7f07476c9078842
|
4
|
+
data.tar.gz: 7898982a8b0070a8b85467fe0c496c1713ac79f7505887425e1b72120ff471a1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4f78f1cf39a22280c1d3cfb385cbc27d527a909e8281eb9681328b963f391b3343b8b093687665479d673a4a865fa449ea7555bd2d4b45e49764716a253699cf
|
7
|
+
data.tar.gz: 52f7481494f15a22110c52dbe83947282db4da193804d79b3ebd22a194f0adda8e3a4ae69b96211bd0ea72f937670a5876d70271f7716bafd18d6c94ff23120d
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: [ONNX Runtime](https://github.com/Microsoft/onnxruntime) - the high performance scoring engine for ML models - for Ruby
|
4
4
|
|
5
|
+
Check out [an example](https://ankane.org/tensorflow-ruby)
|
6
|
+
|
5
7
|
[![Build Status](https://travis-ci.org/ankane/onnxruntime.svg?branch=master)](https://travis-ci.org/ankane/onnxruntime) [![Build status](https://ci.appveyor.com/api/projects/status/f2bq6ruqjf4jx671/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/onnxruntime/branch/master)
|
6
8
|
|
7
9
|
## Installation
|
@@ -21,6 +23,8 @@ model = OnnxRuntime::Model.new("model.onnx")
|
|
21
23
|
model.predict(x: [1, 2, 3])
|
22
24
|
```
|
23
25
|
|
26
|
+
> Download pre-trained models from the [ONNX Model Zoo](https://github.com/onnx/models)
|
27
|
+
|
24
28
|
Get inputs
|
25
29
|
|
26
30
|
```ruby
|
@@ -33,7 +33,7 @@ module OnnxRuntime
|
|
33
33
|
# input
|
34
34
|
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
35
35
|
check_status FFI.OrtSessionGetInputCount(read_pointer, num_input_nodes)
|
36
|
-
num_input_nodes
|
36
|
+
read_size_t(num_input_nodes).times do |i|
|
37
37
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
38
38
|
check_status FFI.OrtSessionGetInputName(read_pointer, i, @allocator.read_pointer, name_ptr)
|
39
39
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -44,7 +44,7 @@ module OnnxRuntime
|
|
44
44
|
# output
|
45
45
|
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
46
46
|
check_status FFI.OrtSessionGetOutputCount(read_pointer, num_output_nodes)
|
47
|
-
num_output_nodes
|
47
|
+
read_size_t(num_output_nodes).times do |i|
|
48
48
|
name_ptr = ::FFI::MemoryPointer.new(:string)
|
49
49
|
check_status FFI.OrtSessionGetOutputName(read_pointer, i, allocator.read_pointer, name_ptr)
|
50
50
|
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
@@ -77,6 +77,8 @@ module OnnxRuntime
|
|
77
77
|
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
|
78
78
|
|
79
79
|
input_feed.each_with_index do |(input_name, input), idx|
|
80
|
+
input = input.to_a unless input.is_a?(Array)
|
81
|
+
|
80
82
|
shape = []
|
81
83
|
s = input
|
82
84
|
while s.is_a?(Array)
|
@@ -146,7 +148,7 @@ module OnnxRuntime
|
|
146
148
|
|
147
149
|
out_size = ::FFI::MemoryPointer.new(:size_t)
|
148
150
|
output_tensor_size = FFI.OrtGetTensorShapeElementCount(typeinfo.read_pointer, out_size)
|
149
|
-
output_tensor_size = out_size
|
151
|
+
output_tensor_size = read_size_t(out_size)
|
150
152
|
|
151
153
|
# TODO support more types
|
152
154
|
type = FFI::TensorElementDataType[type]
|
@@ -165,7 +167,7 @@ module OnnxRuntime
|
|
165
167
|
out = ::FFI::MemoryPointer.new(:size_t)
|
166
168
|
check_status FFI.OrtGetValueCount(out_ptr, out)
|
167
169
|
|
168
|
-
out
|
170
|
+
read_size_t(out).times.map do |i|
|
169
171
|
seq = ::FFI::MemoryPointer.new(:pointer)
|
170
172
|
check_status FFI.OrtGetValue(out_ptr, i, @allocator.read_pointer, seq)
|
171
173
|
create_from_onnx_value(seq.read_pointer)
|
@@ -252,7 +254,7 @@ module OnnxRuntime
|
|
252
254
|
|
253
255
|
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
254
256
|
check_status FFI.OrtGetDimensionsCount(tensor_info.read_pointer, num_dims_ptr)
|
255
|
-
num_dims = num_dims_ptr
|
257
|
+
num_dims = read_size_t(num_dims_ptr)
|
256
258
|
|
257
259
|
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
258
260
|
check_status FFI.OrtGetDimensions(tensor_info.read_pointer, node_dims, num_dims)
|
@@ -264,14 +266,24 @@ module OnnxRuntime
|
|
264
266
|
raise "Unsupported #{name} type: #{type}"
|
265
267
|
end
|
266
268
|
|
267
|
-
#
|
268
|
-
|
269
|
+
# read(:size_t) not supported in FFI JRuby
|
270
|
+
def read_size_t(ptr)
|
271
|
+
if RUBY_PLATFORM == "java"
|
272
|
+
ptr.read_long
|
273
|
+
else
|
274
|
+
ptr.read(:size_t)
|
275
|
+
end
|
276
|
+
end
|
277
|
+
|
269
278
|
def env
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
279
|
+
# use mutex for thread-safety
|
280
|
+
Utils.mutex.synchronize do
|
281
|
+
@@env ||= begin
|
282
|
+
env = ::FFI::MemoryPointer.new(:pointer)
|
283
|
+
check_status FFI.OrtCreateEnv(3, "Default", env)
|
284
|
+
at_exit { FFI.OrtReleaseEnv(env.read_pointer) }
|
285
|
+
env
|
286
|
+
end
|
275
287
|
end
|
276
288
|
end
|
277
289
|
end
|
data/lib/onnxruntime/utils.rb
CHANGED
data/lib/onnxruntime/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: onnxruntime
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-10-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: ffi
|