onnxruntime 0.3.2 → 0.5.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 0b017fa8896c64bbeeda0ba6582ca9bfa626de3f13023ec9f3f91000031e88b8
4
- data.tar.gz: 16344db9151ca3f388539e05772a7172129613f949dbff77d89bfb9c307d0de2
3
+ metadata.gz: c4df8503d4840fee0e897a8fc7b093e74c8af0bb7bf438b0ca4e1dd20b9d26e1
4
+ data.tar.gz: cd58fd60bfa2b57ed04797bebd2f6062f5bfb7a5ec5e8707c6666e2eb771493d
5
5
  SHA512:
6
- metadata.gz: 21011ae8d875230858ff148a27ff29caf005210aae6f6e6a347ed9a02dfb85bceac6714957718d3f17872f07f0ace21c3ba7d7ba8884c1dff4b79fd08b43e40b
7
- data.tar.gz: 000f4575890f92c2b1de308e052977f159ead243aef3984cd4190a5a873ecbe2868fe586fdbe14f2963f0518a75d8c4600d47ba2fbdd6a77c84cf3da1fc84292
6
+ metadata.gz: 15538d6be0754132d2fb5f9b2ac75fb4ee263fdc75a39a169a276b8c578586b170b07bda2194ca671a5773172e89749353adcefee6531559851690b4e04a78b9
7
+ data.tar.gz: e9aaaf9e5e64c54579fe279e5d0554ddefa3fc8ee01bb51523d1ecd0a22ded6a3d4b5024ffe3b677dabcf18d5d1bcb7712e12b82ad6844a1ef80425986fe3aee
@@ -1,3 +1,32 @@
1
+ ## 0.5.2 (2020-12-27)
2
+
3
+ - Updated ONNX Runtime to 1.6.0
4
+ - Fixed error with `execution_mode` option
5
+ - Fixed error with `bool` input
6
+
7
+ ## 0.5.1 (2020-11-01)
8
+
9
+ - Updated ONNX Runtime to 1.5.2
10
+ - Added support for string output
11
+ - Added `output_type` option
12
+ - Improved performance for Numo array inputs
13
+
14
+ ## 0.5.0 (2020-10-01)
15
+
16
+ - Updated ONNX Runtime to 1.5.1
17
+ - OpenMP is now required on Mac
18
+ - Fixed `mul_1.onnx` example
19
+
20
+ ## 0.4.0 (2020-07-20)
21
+
22
+ - Updated ONNX Runtime to 1.4.0
23
+ - Added `providers` method
24
+ - Fixed errors on Windows
25
+
26
+ ## 0.3.3 (2020-06-17)
27
+
28
+ - Fixed segmentation fault on exit on Linux
29
+
1
30
  ## 0.3.2 (2020-06-16)
2
31
 
3
32
  - Fixed error with FFI 1.13.0+
@@ -1,23 +1,22 @@
1
- Copyright (c) 2019-2020 Andrew Kane
2
- Datasets Copyright (c) Microsoft Corporation
3
-
4
1
  MIT License
5
2
 
6
- Permission is hereby granted, free of charge, to any person obtaining
7
- a copy of this software and associated documentation files (the
8
- "Software"), to deal in the Software without restriction, including
9
- without limitation the rights to use, copy, modify, merge, publish,
10
- distribute, sublicense, and/or sell copies of the Software, and to
11
- permit persons to whom the Software is furnished to do so, subject to
12
- the following conditions:
3
+ Copyright (c) 2018 Microsoft Corporation
4
+ Copyright (c) 2019-2020 Andrew Kane
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
13
12
 
14
- The above copyright notice and this permission notice shall be
15
- included in all copies or substantial portions of the Software.
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
16
15
 
17
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
data/README.md CHANGED
@@ -4,7 +4,7 @@
4
4
 
5
5
  Check out [an example](https://ankane.org/tensorflow-ruby)
6
6
 
7
- [![Build Status](https://travis-ci.org/ankane/onnxruntime.svg?branch=master)](https://travis-ci.org/ankane/onnxruntime) [![Build status](https://ci.appveyor.com/api/projects/status/f2bq6ruqjf4jx671/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/onnxruntime/branch/master)
7
+ [![Build Status](https://github.com/ankane/onnxruntime/workflows/build/badge.svg?branch=master)](https://github.com/ankane/onnxruntime/actions)
8
8
 
9
9
  ## Installation
10
10
 
@@ -14,6 +14,12 @@ Add this line to your application’s Gemfile:
14
14
  gem 'onnxruntime'
15
15
  ```
16
16
 
17
+ On Mac, also install OpenMP:
18
+
19
+ ```sh
20
+ brew install libomp
21
+ ```
22
+
17
23
  ## Getting Started
18
24
 
19
25
  Load a model and make predictions
@@ -81,7 +87,8 @@ model.predict(input_feed, {
81
87
  log_severity_level: 2,
82
88
  log_verbosity_level: 0,
83
89
  logid: nil,
84
- terminate: false
90
+ terminate: false,
91
+ output_type: :ruby # :ruby or :numo
85
92
  })
86
93
  ```
87
94
 
@@ -3,10 +3,13 @@ module OnnxRuntime
3
3
  extend ::FFI::Library
4
4
 
5
5
  begin
6
- ffi_lib Array(OnnxRuntime.ffi_lib)
6
+ ffi_lib OnnxRuntime.ffi_lib
7
7
  rescue LoadError => e
8
- raise e if ENV["ONNXRUNTIME_DEBUG"]
9
- raise LoadError, "Could not find ONNX Runtime"
8
+ if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
9
+ raise LoadError, "OpenMP not found. Run `brew install libomp`"
10
+ else
11
+ raise e
12
+ end
10
13
  end
11
14
 
12
15
  # https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -25,14 +28,14 @@ module OnnxRuntime
25
28
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
29
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
27
30
  :DisableTelemetryEvents, callback(%i[pointer], :pointer),
28
- :CreateSession, callback(%i[pointer string pointer pointer], :pointer),
31
+ :CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
29
32
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
33
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
34
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
35
+ :SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
33
36
  :CloneSessionOptions, callback(%i[], :pointer),
34
- :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[pointer string], :pointer),
37
+ :SetSessionExecutionMode, callback(%i[pointer int], :pointer),
38
+ :EnableProfiling, callback(%i[pointer pointer], :pointer),
36
39
  :DisableProfiling, callback(%i[pointer], :pointer),
37
40
  :EnableMemPattern, callback(%i[pointer], :pointer),
38
41
  :DisableMemPattern, callback(%i[pointer], :pointer),
@@ -71,8 +74,8 @@ module OnnxRuntime
71
74
  :IsTensor, callback(%i[], :pointer),
72
75
  :GetTensorMutableData, callback(%i[pointer pointer], :pointer),
73
76
  :FillStringTensor, callback(%i[pointer pointer size_t], :pointer),
74
- :GetStringTensorDataLength, callback(%i[], :pointer),
75
- :GetStringTensorContent, callback(%i[], :pointer),
77
+ :GetStringTensorDataLength, callback(%i[pointer pointer], :pointer),
78
+ :GetStringTensorContent, callback(%i[pointer pointer size_t pointer size_t], :pointer),
76
79
  :CastTypeInfoToTensorInfo, callback(%i[pointer pointer], :pointer),
77
80
  :GetOnnxTypeFromTypeInfo, callback(%i[pointer pointer], :pointer),
78
81
  :CreateTensorTypeAndShapeInfo, callback(%i[], :pointer),
@@ -142,7 +145,9 @@ module OnnxRuntime
142
145
  :CreateThreadingOptions, callback(%i[], :pointer),
143
146
  :ReleaseThreadingOptions, callback(%i[], :pointer),
144
147
  :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
145
- :AddFreeDimensionOverrideByName, callback(%i[], :pointer)
148
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer),
149
+ :GetAvailableProviders, callback(%i[pointer pointer], :pointer),
150
+ :ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
146
151
  end
147
152
 
148
153
  class ApiBase < ::FFI::Struct
@@ -154,5 +159,13 @@ module OnnxRuntime
154
159
  end
155
160
 
156
161
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
162
+
163
+ if Gem.win_platform?
164
+ class Libc
165
+ extend ::FFI::Library
166
+ ffi_lib ::FFI::Library::LIBC
167
+ attach_function :mbstowcs, %i[pointer string size_t], :size_t
168
+ end
169
+ end
157
170
  end
158
171
  end
@@ -8,7 +8,7 @@ module OnnxRuntime
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
9
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
10
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
- check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
12
12
  if execution_mode
13
13
  execution_modes = {sequential: 0, parallel: 1}
14
14
  mode = execution_modes[execution_mode]
@@ -17,8 +17,8 @@ module OnnxRuntime
17
17
  end
18
18
  if graph_optimization_level
19
19
  optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
- # TODO raise error in 0.4.0
21
- level = optimization_levels[graph_optimization_level] || graph_optimization_level
20
+ level = optimization_levels[graph_optimization_level]
21
+ raise ArgumentError, "Invalid graph optimization level" unless level
22
22
  check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
23
  end
24
24
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
@@ -26,7 +26,7 @@ module OnnxRuntime
26
26
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
27
27
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
28
28
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
29
- check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
30
30
 
31
31
  # session
32
32
  @session = ::FFI::MemoryPointer.new(:pointer)
@@ -39,17 +39,12 @@ module OnnxRuntime
39
39
  path_or_bytes.encoding == Encoding::BINARY
40
40
  end
41
41
 
42
- # fix for Windows "File doesn't exist"
43
- if Gem.win_platform? && !from_memory
44
- path_or_bytes = File.binread(path_or_bytes)
45
- from_memory = true
46
- end
47
-
48
42
  if from_memory
49
43
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
50
44
  else
51
- check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
45
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
52
46
  end
47
+ ObjectSpace.define_finalizer(self, self.class.finalize(@session))
53
48
 
54
49
  # input info
55
50
  allocator = ::FFI::MemoryPointer.new(:pointer)
@@ -62,7 +57,7 @@ module OnnxRuntime
62
57
  # input
63
58
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
64
59
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
65
- read_size_t(num_input_nodes).times do |i|
60
+ num_input_nodes.read(:size_t).times do |i|
66
61
  name_ptr = ::FFI::MemoryPointer.new(:string)
67
62
  check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
68
63
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -73,17 +68,19 @@ module OnnxRuntime
73
68
  # output
74
69
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
75
70
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
76
- read_size_t(num_output_nodes).times do |i|
71
+ num_output_nodes.read(:size_t).times do |i|
77
72
  name_ptr = ::FFI::MemoryPointer.new(:string)
78
73
  check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
79
74
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
80
75
  check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
81
76
  @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
82
77
  end
78
+ ensure
79
+ # release :SessionOptions, session_options
83
80
  end
84
81
 
85
82
  # TODO support logid
86
- def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
83
+ def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
87
84
  input_tensor = create_input_tensor(input_feed)
88
85
 
89
86
  output_names ||= @outputs.map { |v| v[:name] }
@@ -103,7 +100,14 @@ module OnnxRuntime
103
100
  check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
104
101
 
105
102
  output_names.size.times.map do |i|
106
- create_from_onnx_value(output_tensor[i].read_pointer)
103
+ create_from_onnx_value(output_tensor[i].read_pointer, output_type)
104
+ end
105
+ ensure
106
+ release :RunOptions, run_options
107
+ if input_tensor
108
+ input_feed.size.times do |i|
109
+ release :Value, input_tensor[i]
110
+ end
107
111
  end
108
112
  end
109
113
 
@@ -120,7 +124,7 @@ module OnnxRuntime
120
124
  check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
121
125
 
122
126
  custom_metadata_map = {}
123
- check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
127
+ check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
124
128
  num_keys.read(:int64_t).times do |i|
125
129
  key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
126
130
  value = ::FFI::MemoryPointer.new(:string)
@@ -133,7 +137,6 @@ module OnnxRuntime
133
137
  check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
134
138
  check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
135
139
  check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
136
- api[:ReleaseModelMetadata].call(metadata.read_pointer)
137
140
 
138
141
  {
139
142
  custom_metadata_map: custom_metadata_map,
@@ -143,58 +146,90 @@ module OnnxRuntime
143
146
  producer_name: producer_name.read_pointer.read_string,
144
147
  version: version.read(:int64_t)
145
148
  }
149
+ ensure
150
+ release :ModelMetadata, metadata
146
151
  end
147
152
 
153
+ # return value has double underscore like Python
148
154
  def end_profiling
149
155
  out = ::FFI::MemoryPointer.new(:string)
150
156
  check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
151
157
  out.read_pointer.read_string
152
158
  end
153
159
 
160
+ # no way to set providers with C API yet
161
+ # so we can return all available providers
162
+ def providers
163
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
164
+ length_ptr = ::FFI::MemoryPointer.new(:int)
165
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
166
+ length = length_ptr.read_int
167
+ providers = []
168
+ length.times do |i|
169
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
170
+ end
171
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
172
+ providers
173
+ end
174
+
154
175
  private
155
176
 
156
177
  def create_input_tensor(input_feed)
157
178
  allocator_info = ::FFI::MemoryPointer.new(:pointer)
158
- check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
179
+ check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
159
180
  input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
160
181
 
161
182
  input_feed.each_with_index do |(input_name, input), idx|
162
- input = input.to_a unless input.is_a?(Array)
183
+ if numo_array?(input)
184
+ shape = input.shape
185
+ else
186
+ input = input.to_a unless input.is_a?(Array)
163
187
 
164
- shape = []
165
- s = input
166
- while s.is_a?(Array)
167
- shape << s.size
168
- s = s.first
188
+ shape = []
189
+ s = input
190
+ while s.is_a?(Array)
191
+ shape << s.size
192
+ s = s.first
193
+ end
169
194
  end
170
195
 
171
- flat_input = input.flatten
172
- input_tensor_size = flat_input.size
173
-
174
196
  # TODO support more types
175
197
  inp = @inputs.find { |i| i[:name] == input_name.to_s }
176
- raise "Unknown input: #{input_name}" unless inp
198
+ raise Error, "Unknown input: #{input_name}" unless inp
177
199
 
178
200
  input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
179
201
  input_node_dims.write_array_of_int64(shape)
180
202
 
181
203
  if inp[:type] == "tensor(string)"
182
- input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
183
- input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
204
+ if numo_array?(input)
205
+ input_tensor_size = input.size
206
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size)
207
+ input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) })
208
+ else
209
+ flat_input = input.flatten.to_a
210
+ input_tensor_size = flat_input.size
211
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
212
+ input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
213
+ end
184
214
  type_enum = FFI::TensorElementDataType[:string]
185
215
  check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
186
- check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size)
216
+ check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size)
187
217
  else
188
- tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
189
218
  tensor_type = tensor_types[inp[:type]]
190
219
 
191
220
  if tensor_type
192
- input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, input_tensor_size)
193
- if tensor_type == :bool
194
- tensor_type = :uchar
195
- flat_input = flat_input.map { |v| v ? 1 : 0 }
221
+ if numo_array?(input)
222
+ input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
223
+ else
224
+ flat_input = input.flatten.to_a
225
+ input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
226
+ if tensor_type == :bool
227
+ input_tensor_values.write_array_of_uint8(flat_input.map { |v| v ? 1 : 0 })
228
+ else
229
+ input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
230
+ end
196
231
  end
197
- input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
232
+
198
233
  type_enum = FFI::TensorElementDataType[tensor_type]
199
234
  else
200
235
  unsupported_type("input", inp[:type])
@@ -213,9 +248,9 @@ module OnnxRuntime
213
248
  ptr
214
249
  end
215
250
 
216
- def create_from_onnx_value(out_ptr)
251
+ def create_from_onnx_value(out_ptr, output_type)
217
252
  out_type = ::FFI::MemoryPointer.new(:int)
218
- check_status = api[:GetValueType].call(out_ptr, out_type)
253
+ check_status api[:GetValueType].call(out_ptr, out_type)
219
254
  type = FFI::OnnxType[out_type.read_int]
220
255
 
221
256
  case type
@@ -230,29 +265,50 @@ module OnnxRuntime
230
265
 
231
266
  out_size = ::FFI::MemoryPointer.new(:size_t)
232
267
  output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
233
- output_tensor_size = read_size_t(out_size)
268
+ output_tensor_size = out_size.read(:size_t)
269
+
270
+ release :TensorTypeAndShapeInfo, typeinfo
234
271
 
235
272
  # TODO support more types
236
273
  type = FFI::TensorElementDataType[type]
237
- arr =
274
+
275
+ case output_type
276
+ when :numo
238
277
  case type
239
- when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
240
- tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
241
- when :bool
242
- tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
278
+ when :string
279
+ result = Numo::RObject.new(shape)
280
+ result.allocate
281
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
243
282
  else
244
- unsupported_type("element", type)
283
+ numo_type = numo_types[type]
284
+ unsupported_type("element", type) unless numo_type
285
+ numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
245
286
  end
287
+ when :ruby
288
+ arr =
289
+ case type
290
+ when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
291
+ tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
292
+ when :bool
293
+ tensor_data.read_pointer.read_array_of_uint8(output_tensor_size).map { |v| v == 1 }
294
+ when :string
295
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
296
+ else
297
+ unsupported_type("element", type)
298
+ end
246
299
 
247
- Utils.reshape(arr, shape)
300
+ Utils.reshape(arr, shape)
301
+ else
302
+ raise ArgumentError, "Invalid output type: #{output_type}"
303
+ end
248
304
  when :sequence
249
305
  out = ::FFI::MemoryPointer.new(:size_t)
250
306
  check_status api[:GetValueCount].call(out_ptr, out)
251
307
 
252
- read_size_t(out).times.map do |i|
308
+ out.read(:size_t).times.map do |i|
253
309
  seq = ::FFI::MemoryPointer.new(:pointer)
254
310
  check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
255
- create_from_onnx_value(seq.read_pointer)
311
+ create_from_onnx_value(seq.read_pointer, output_type)
256
312
  end
257
313
  when :map
258
314
  type_shape = ::FFI::MemoryPointer.new(:pointer)
@@ -264,14 +320,15 @@ module OnnxRuntime
264
320
  check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
265
321
  check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
266
322
  check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
323
+ release :TensorTypeAndShapeInfo, type_shape
267
324
 
268
325
  # TODO support more types
269
326
  elem_type = FFI::TensorElementDataType[elem_type.read_int]
270
327
  case elem_type
271
328
  when :int64
272
329
  ret = {}
273
- keys = create_from_onnx_value(map_keys.read_pointer)
274
- values = create_from_onnx_value(map_values.read_pointer)
330
+ keys = create_from_onnx_value(map_keys.read_pointer, output_type)
331
+ values = create_from_onnx_value(map_values.read_pointer, output_type)
275
332
  keys.zip(values).each do |k, v|
276
333
  ret[k] = v
277
334
  end
@@ -284,6 +341,23 @@ module OnnxRuntime
284
341
  end
285
342
  end
286
343
 
344
+ def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
345
+ len = ::FFI::MemoryPointer.new(:size_t)
346
+ check_status api[:GetStringTensorDataLength].call(out_ptr, len)
347
+
348
+ s_len = len.read(:size_t)
349
+ s = ::FFI::MemoryPointer.new(:uchar, s_len)
350
+ offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
351
+ check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
352
+
353
+ offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
354
+ offsets << s_len
355
+ output_tensor_size.times do |i|
356
+ result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
357
+ end
358
+ result
359
+ end
360
+
287
361
  def read_pointer
288
362
  @session.read_pointer
289
363
  end
@@ -292,7 +366,7 @@ module OnnxRuntime
292
366
  unless status.null?
293
367
  message = api[:GetErrorMessage].call(status).read_string
294
368
  api[:ReleaseStatus].call(status)
295
- raise OnnxRuntime::Error, message
369
+ raise Error, message
296
370
  end
297
371
  end
298
372
 
@@ -304,6 +378,7 @@ module OnnxRuntime
304
378
  case type
305
379
  when :tensor
306
380
  tensor_info = ::FFI::MemoryPointer.new(:pointer)
381
+ # don't free tensor_info
307
382
  check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
308
383
 
309
384
  type, shape = tensor_type_and_shape(tensor_info)
@@ -344,7 +419,7 @@ module OnnxRuntime
344
419
  unsupported_type("ONNX", type)
345
420
  end
346
421
  ensure
347
- api[:ReleaseTypeInfo].call(typeinfo.read_pointer)
422
+ release :TypeInfo, typeinfo
348
423
  end
349
424
 
350
425
  def tensor_type_and_shape(tensor_info)
@@ -353,7 +428,7 @@ module OnnxRuntime
353
428
 
354
429
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
355
430
  check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
356
- num_dims = read_size_t(num_dims_ptr)
431
+ num_dims = num_dims_ptr.read(:size_t)
357
432
 
358
433
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
359
434
  check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
@@ -362,20 +437,67 @@ module OnnxRuntime
362
437
  end
363
438
 
364
439
  def unsupported_type(name, type)
365
- raise "Unsupported #{name} type: #{type}"
440
+ raise Error, "Unsupported #{name} type: #{type}"
366
441
  end
367
442
 
368
- # read(:size_t) not supported in FFI JRuby
369
- def read_size_t(ptr)
370
- if RUBY_PLATFORM == "java"
371
- ptr.read_long
372
- else
373
- ptr.read(:size_t)
374
- end
443
+ def tensor_types
444
+ @tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
445
+ end
446
+
447
+ def numo_array?(obj)
448
+ defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
449
+ end
450
+
451
+ def numo_types
452
+ @numo_types ||= {
453
+ float: Numo::SFloat,
454
+ uint8: Numo::UInt8,
455
+ int8: Numo::Int8,
456
+ uint16: Numo::UInt16,
457
+ int16: Numo::Int16,
458
+ int32: Numo::Int32,
459
+ int64: Numo::Int64,
460
+ bool: Numo::UInt8,
461
+ double: Numo::DFloat,
462
+ uint32: Numo::UInt32,
463
+ uint64: Numo::UInt64
464
+ }
375
465
  end
376
466
 
377
467
  def api
378
- @api ||= FFI.OrtGetApiBase[:GetApi].call(1)
468
+ self.class.api
469
+ end
470
+
471
+ def release(*args)
472
+ self.class.release(*args)
473
+ end
474
+
475
+ def self.api
476
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
477
+ end
478
+
479
+ def self.release(type, pointer)
480
+ api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
481
+ end
482
+
483
+ def self.finalize(session)
484
+ # must use proc instead of stabby lambda
485
+ proc { release :Session, session }
486
+ end
487
+
488
+ # wide string on Windows
489
+ # char string on Linux
490
+ # see ORTCHAR_T in onnxruntime_c_api.h
491
+ def ort_string(str)
492
+ if Gem.win_platform?
493
+ max = str.size + 1 # for null byte
494
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
495
+ ret = FFI::Libc.mbstowcs(dest, str, max)
496
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
497
+ dest
498
+ else
499
+ str
500
+ end
379
501
  end
380
502
 
381
503
  def env
@@ -384,7 +506,7 @@ module OnnxRuntime
384
506
  @@env ||= begin
385
507
  env = ::FFI::MemoryPointer.new(:pointer)
386
508
  check_status api[:CreateEnv].call(3, "Default", env)
387
- at_exit { api[:ReleaseEnv].call(env.read_pointer) }
509
+ at_exit { release :Env, env }
388
510
  # disable telemetry
389
511
  # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
390
512
  check_status api[:DisableTelemetryEvents].call(env)