onnxruntime 0.1.1 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fe15097b48e0c08cc25878384b74239ff9a462a23ac99ba26331aff7d9315dd8
4
- data.tar.gz: 94bbc725a419d5e4712d0c0fefb46e30acd3d9c675d0a0e156b2356b490a2617
3
+ metadata.gz: d37fddde85ef4715dc9c74d37f3465b9ba773ecfe491789ea7f07476c9078842
4
+ data.tar.gz: 7898982a8b0070a8b85467fe0c496c1713ac79f7505887425e1b72120ff471a1
5
5
  SHA512:
6
- metadata.gz: 9e1970fbf2430692e875b0ce450b16ac6930ba4fea3a7c686604f37ec3eca7e526209fb9df3be18cce26a24c50e59fafb98c3ca5b054d6a5b47c430b3316a450
7
- data.tar.gz: c2c57c89b52aa70acde18dca516adcdbb51a53d31d979aa498a16636b841b1f83aa0156e948150628f2663a52c2c0562e0837bf8edc70d4380220f2b548d1fee
6
+ metadata.gz: 4f78f1cf39a22280c1d3cfb385cbc27d527a909e8281eb9681328b963f391b3343b8b093687665479d673a4a865fa449ea7555bd2d4b45e49764716a253699cf
7
+ data.tar.gz: 52f7481494f15a22110c52dbe83947282db4da193804d79b3ebd22a194f0adda8e3a4ae69b96211bd0ea72f937670a5876d70271f7716bafd18d6c94ff23120d
@@ -1,3 +1,9 @@
1
+ ## 0.1.2
2
+
3
+ - Added support for Numo::NArray
4
+ - Made thread-safe
5
+ - Fixed error with JRuby
6
+
1
7
  ## 0.1.1
2
8
 
3
9
  - Packaged ONNX Runtime with gem
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: [ONNX Runtime](https://github.com/Microsoft/onnxruntime) - the high performance scoring engine for ML models - for Ruby
4
4
 
5
+ Check out [an example](https://ankane.org/tensorflow-ruby)
6
+
5
7
  [![Build Status](https://travis-ci.org/ankane/onnxruntime.svg?branch=master)](https://travis-ci.org/ankane/onnxruntime) [![Build status](https://ci.appveyor.com/api/projects/status/f2bq6ruqjf4jx671/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/onnxruntime/branch/master)
6
8
 
7
9
  ## Installation
@@ -21,6 +23,8 @@ model = OnnxRuntime::Model.new("model.onnx")
21
23
  model.predict(x: [1, 2, 3])
22
24
  ```
23
25
 
26
+ > Download pre-trained models from the [ONNX Model Zoo](https://github.com/onnx/models)
27
+
24
28
  Get inputs
25
29
 
26
30
  ```ruby
@@ -33,7 +33,7 @@ module OnnxRuntime
33
33
  # input
34
34
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
35
35
  check_status FFI.OrtSessionGetInputCount(read_pointer, num_input_nodes)
36
- num_input_nodes.read(:size_t).times do |i|
36
+ read_size_t(num_input_nodes).times do |i|
37
37
  name_ptr = ::FFI::MemoryPointer.new(:string)
38
38
  check_status FFI.OrtSessionGetInputName(read_pointer, i, @allocator.read_pointer, name_ptr)
39
39
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -44,7 +44,7 @@ module OnnxRuntime
44
44
  # output
45
45
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
46
46
  check_status FFI.OrtSessionGetOutputCount(read_pointer, num_output_nodes)
47
- num_output_nodes.read(:size_t).times do |i|
47
+ read_size_t(num_output_nodes).times do |i|
48
48
  name_ptr = ::FFI::MemoryPointer.new(:string)
49
49
  check_status FFI.OrtSessionGetOutputName(read_pointer, i, allocator.read_pointer, name_ptr)
50
50
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -77,6 +77,8 @@ module OnnxRuntime
77
77
  input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
78
78
 
79
79
  input_feed.each_with_index do |(input_name, input), idx|
80
+ input = input.to_a unless input.is_a?(Array)
81
+
80
82
  shape = []
81
83
  s = input
82
84
  while s.is_a?(Array)
@@ -146,7 +148,7 @@ module OnnxRuntime
146
148
 
147
149
  out_size = ::FFI::MemoryPointer.new(:size_t)
148
150
  output_tensor_size = FFI.OrtGetTensorShapeElementCount(typeinfo.read_pointer, out_size)
149
- output_tensor_size = out_size.read(:size_t)
151
+ output_tensor_size = read_size_t(out_size)
150
152
 
151
153
  # TODO support more types
152
154
  type = FFI::TensorElementDataType[type]
@@ -165,7 +167,7 @@ module OnnxRuntime
165
167
  out = ::FFI::MemoryPointer.new(:size_t)
166
168
  check_status FFI.OrtGetValueCount(out_ptr, out)
167
169
 
168
- out.read(:size_t).times.map do |i|
170
+ read_size_t(out).times.map do |i|
169
171
  seq = ::FFI::MemoryPointer.new(:pointer)
170
172
  check_status FFI.OrtGetValue(out_ptr, i, @allocator.read_pointer, seq)
171
173
  create_from_onnx_value(seq.read_pointer)
@@ -252,7 +254,7 @@ module OnnxRuntime
252
254
 
253
255
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
254
256
  check_status FFI.OrtGetDimensionsCount(tensor_info.read_pointer, num_dims_ptr)
255
- num_dims = num_dims_ptr.read(:size_t)
257
+ num_dims = read_size_t(num_dims_ptr)
256
258
 
257
259
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
258
260
  check_status FFI.OrtGetDimensions(tensor_info.read_pointer, node_dims, num_dims)
@@ -264,14 +266,24 @@ module OnnxRuntime
264
266
  raise "Unsupported #{name} type: #{type}"
265
267
  end
266
268
 
267
- # share env
268
- # TODO mutex around creation?
269
+ # read(:size_t) not supported in FFI JRuby
270
+ def read_size_t(ptr)
271
+ if RUBY_PLATFORM == "java"
272
+ ptr.read_long
273
+ else
274
+ ptr.read(:size_t)
275
+ end
276
+ end
277
+
269
278
  def env
270
- @@env ||= begin
271
- env = ::FFI::MemoryPointer.new(:pointer)
272
- check_status FFI.OrtCreateEnv(3, "Default", env)
273
- at_exit { FFI.OrtReleaseEnv(env.read_pointer) }
274
- env
279
+ # use mutex for thread-safety
280
+ Utils.mutex.synchronize do
281
+ @@env ||= begin
282
+ env = ::FFI::MemoryPointer.new(:pointer)
283
+ check_status FFI.OrtCreateEnv(3, "Default", env)
284
+ at_exit { FFI.OrtReleaseEnv(env.read_pointer) }
285
+ env
286
+ end
275
287
  end
276
288
  end
277
289
  end
@@ -1,6 +1,12 @@
1
1
  module OnnxRuntime
2
2
  module Utils
3
+ class << self
4
+ attr_accessor :mutex
5
+ end
6
+ self.mutex = Mutex.new
7
+
3
8
  def self.reshape(arr, dims)
9
+ arr = arr.flatten
4
10
  dims[1..-1].reverse.each do |dim|
5
11
  arr = arr.each_slice(dim)
6
12
  end
@@ -1,3 +1,3 @@
1
1
  module OnnxRuntime
2
- VERSION = "0.1.1"
2
+ VERSION = "0.1.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: onnxruntime
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.1.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-09-03 00:00:00.000000000 Z
11
+ date: 2019-10-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: ffi