onnxruntime 0.1.1 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fe15097b48e0c08cc25878384b74239ff9a462a23ac99ba26331aff7d9315dd8
4
- data.tar.gz: 94bbc725a419d5e4712d0c0fefb46e30acd3d9c675d0a0e156b2356b490a2617
3
+ metadata.gz: d37fddde85ef4715dc9c74d37f3465b9ba773ecfe491789ea7f07476c9078842
4
+ data.tar.gz: 7898982a8b0070a8b85467fe0c496c1713ac79f7505887425e1b72120ff471a1
5
5
  SHA512:
6
- metadata.gz: 9e1970fbf2430692e875b0ce450b16ac6930ba4fea3a7c686604f37ec3eca7e526209fb9df3be18cce26a24c50e59fafb98c3ca5b054d6a5b47c430b3316a450
7
- data.tar.gz: c2c57c89b52aa70acde18dca516adcdbb51a53d31d979aa498a16636b841b1f83aa0156e948150628f2663a52c2c0562e0837bf8edc70d4380220f2b548d1fee
6
+ metadata.gz: 4f78f1cf39a22280c1d3cfb385cbc27d527a909e8281eb9681328b963f391b3343b8b093687665479d673a4a865fa449ea7555bd2d4b45e49764716a253699cf
7
+ data.tar.gz: 52f7481494f15a22110c52dbe83947282db4da193804d79b3ebd22a194f0adda8e3a4ae69b96211bd0ea72f937670a5876d70271f7716bafd18d6c94ff23120d
@@ -1,3 +1,9 @@
1
+ ## 0.1.2
2
+
3
+ - Added support for Numo::NArray
4
+ - Made thread-safe
5
+ - Fixed error with JRuby
6
+
1
7
  ## 0.1.1
2
8
 
3
9
  - Packaged ONNX Runtime with gem
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: [ONNX Runtime](https://github.com/Microsoft/onnxruntime) - the high performance scoring engine for ML models - for Ruby
4
4
 
5
+ Check out [an example](https://ankane.org/tensorflow-ruby)
6
+
5
7
  [![Build Status](https://travis-ci.org/ankane/onnxruntime.svg?branch=master)](https://travis-ci.org/ankane/onnxruntime) [![Build status](https://ci.appveyor.com/api/projects/status/f2bq6ruqjf4jx671/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/onnxruntime/branch/master)
6
8
 
7
9
  ## Installation
@@ -21,6 +23,8 @@ model = OnnxRuntime::Model.new("model.onnx")
21
23
  model.predict(x: [1, 2, 3])
22
24
  ```
23
25
 
26
+ > Download pre-trained models from the [ONNX Model Zoo](https://github.com/onnx/models)
27
+
24
28
  Get inputs
25
29
 
26
30
  ```ruby
@@ -33,7 +33,7 @@ module OnnxRuntime
33
33
  # input
34
34
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
35
35
  check_status FFI.OrtSessionGetInputCount(read_pointer, num_input_nodes)
36
- num_input_nodes.read(:size_t).times do |i|
36
+ read_size_t(num_input_nodes).times do |i|
37
37
  name_ptr = ::FFI::MemoryPointer.new(:string)
38
38
  check_status FFI.OrtSessionGetInputName(read_pointer, i, @allocator.read_pointer, name_ptr)
39
39
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -44,7 +44,7 @@ module OnnxRuntime
44
44
  # output
45
45
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
46
46
  check_status FFI.OrtSessionGetOutputCount(read_pointer, num_output_nodes)
47
- num_output_nodes.read(:size_t).times do |i|
47
+ read_size_t(num_output_nodes).times do |i|
48
48
  name_ptr = ::FFI::MemoryPointer.new(:string)
49
49
  check_status FFI.OrtSessionGetOutputName(read_pointer, i, allocator.read_pointer, name_ptr)
50
50
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -77,6 +77,8 @@ module OnnxRuntime
77
77
  input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
78
78
 
79
79
  input_feed.each_with_index do |(input_name, input), idx|
80
+ input = input.to_a unless input.is_a?(Array)
81
+
80
82
  shape = []
81
83
  s = input
82
84
  while s.is_a?(Array)
@@ -146,7 +148,7 @@ module OnnxRuntime
146
148
 
147
149
  out_size = ::FFI::MemoryPointer.new(:size_t)
148
150
  output_tensor_size = FFI.OrtGetTensorShapeElementCount(typeinfo.read_pointer, out_size)
149
- output_tensor_size = out_size.read(:size_t)
151
+ output_tensor_size = read_size_t(out_size)
150
152
 
151
153
  # TODO support more types
152
154
  type = FFI::TensorElementDataType[type]
@@ -165,7 +167,7 @@ module OnnxRuntime
165
167
  out = ::FFI::MemoryPointer.new(:size_t)
166
168
  check_status FFI.OrtGetValueCount(out_ptr, out)
167
169
 
168
- out.read(:size_t).times.map do |i|
170
+ read_size_t(out).times.map do |i|
169
171
  seq = ::FFI::MemoryPointer.new(:pointer)
170
172
  check_status FFI.OrtGetValue(out_ptr, i, @allocator.read_pointer, seq)
171
173
  create_from_onnx_value(seq.read_pointer)
@@ -252,7 +254,7 @@ module OnnxRuntime
252
254
 
253
255
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
254
256
  check_status FFI.OrtGetDimensionsCount(tensor_info.read_pointer, num_dims_ptr)
255
- num_dims = num_dims_ptr.read(:size_t)
257
+ num_dims = read_size_t(num_dims_ptr)
256
258
 
257
259
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
258
260
  check_status FFI.OrtGetDimensions(tensor_info.read_pointer, node_dims, num_dims)
@@ -264,14 +266,24 @@ module OnnxRuntime
264
266
  raise "Unsupported #{name} type: #{type}"
265
267
  end
266
268
 
267
- # share env
268
- # TODO mutex around creation?
269
+ # read(:size_t) not supported in FFI JRuby
270
+ def read_size_t(ptr)
271
+ if RUBY_PLATFORM == "java"
272
+ ptr.read_long
273
+ else
274
+ ptr.read(:size_t)
275
+ end
276
+ end
277
+
269
278
  def env
270
- @@env ||= begin
271
- env = ::FFI::MemoryPointer.new(:pointer)
272
- check_status FFI.OrtCreateEnv(3, "Default", env)
273
- at_exit { FFI.OrtReleaseEnv(env.read_pointer) }
274
- env
279
+ # use mutex for thread-safety
280
+ Utils.mutex.synchronize do
281
+ @@env ||= begin
282
+ env = ::FFI::MemoryPointer.new(:pointer)
283
+ check_status FFI.OrtCreateEnv(3, "Default", env)
284
+ at_exit { FFI.OrtReleaseEnv(env.read_pointer) }
285
+ env
286
+ end
275
287
  end
276
288
  end
277
289
  end
@@ -1,6 +1,12 @@
1
1
  module OnnxRuntime
2
2
  module Utils
3
+ class << self
4
+ attr_accessor :mutex
5
+ end
6
+ self.mutex = Mutex.new
7
+
3
8
  def self.reshape(arr, dims)
9
+ arr = arr.flatten
4
10
  dims[1..-1].reverse.each do |dim|
5
11
  arr = arr.each_slice(dim)
6
12
  end
@@ -1,3 +1,3 @@
1
1
  module OnnxRuntime
2
- VERSION = "0.1.1"
2
+ VERSION = "0.1.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: onnxruntime
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.1.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-09-03 00:00:00.000000000 Z
11
+ date: 2019-10-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: ffi