onnxruntime 0.3.3 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b7d22851572b35128d1e2bbcc041b4989851e02354cace5389afc25855674b11
4
- data.tar.gz: f1eee285e5dbff1fbf6de4e350560e3e386e1f4d26d53810f154cd16117e38e6
3
+ metadata.gz: 4b0bef4682d8d44fb3cfd3fa619c1e7e620943014a22f259e14bfa1692dbe2ef
4
+ data.tar.gz: bfe59759c177613806ab0ddd97c4d8a68a7e3915299d81aab3e9d369ce672918
5
5
  SHA512:
6
- metadata.gz: f89ba13181bfcc8cdf35356efae175d0e2c7a0787a13af8d46788645046f057bce466e28566c0f2982f52893dbb7cc1553a741b1ca4923066e90c0a5fb230edf
7
- data.tar.gz: 3843b9c1e5a9432d3b72ebb33498aaef0fb6edf8942a9f0b6794bed36fb95df67290e6c4385e19927e5404510d89ad8965b15de0a96246b7f1770c4fb01216f6
6
+ metadata.gz: f0dfc6c0a91621b5637e56061b40ba26648a7e538d84cb2c135a06c1832c9f71aba2e7fb7ab849750756540d02632726b3f45732a8c50649ca4ace03befbe086
7
+ data.tar.gz: 59a01f8d955f9d31f35864a9217fcf79eb35876229e749c5c50b4572903d9dc3e312a4a7a55612e9c36e3095a797da58637442ac0bb191adcb743b62fefb0905
data/CHANGELOG.md CHANGED
@@ -1,3 +1,33 @@
1
+ ## 0.6.0 (2021-03-14)
2
+
3
+ - Updated ONNX Runtime to 1.7.0
4
+ - OpenMP is no longer required
5
+
6
+ ## 0.5.2 (2020-12-27)
7
+
8
+ - Updated ONNX Runtime to 1.6.0
9
+ - Fixed error with `execution_mode` option
10
+ - Fixed error with `bool` input
11
+
12
+ ## 0.5.1 (2020-11-01)
13
+
14
+ - Updated ONNX Runtime to 1.5.2
15
+ - Added support for string output
16
+ - Added `output_type` option
17
+ - Improved performance for Numo array inputs
18
+
19
+ ## 0.5.0 (2020-10-01)
20
+
21
+ - Updated ONNX Runtime to 1.5.1
22
+ - OpenMP is now required on Mac
23
+ - Fixed `mul_1.onnx` example
24
+
25
+ ## 0.4.0 (2020-07-20)
26
+
27
+ - Updated ONNX Runtime to 1.4.0
28
+ - Added `providers` method
29
+ - Fixed errors on Windows
30
+
1
31
  ## 0.3.3 (2020-06-17)
2
32
 
3
33
  - Fixed segmentation fault on exit on Linux
data/LICENSE.txt CHANGED
@@ -1,23 +1,22 @@
1
- Copyright (c) 2019-2020 Andrew Kane
2
- Datasets Copyright (c) Microsoft Corporation
3
-
4
1
  MIT License
5
2
 
6
- Permission is hereby granted, free of charge, to any person obtaining
7
- a copy of this software and associated documentation files (the
8
- "Software"), to deal in the Software without restriction, including
9
- without limitation the rights to use, copy, modify, merge, publish,
10
- distribute, sublicense, and/or sell copies of the Software, and to
11
- permit persons to whom the Software is furnished to do so, subject to
12
- the following conditions:
3
+ Copyright (c) Microsoft Corporation
4
+ Copyright (c) 2019-2021 Andrew Kane
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
13
12
 
14
- The above copyright notice and this permission notice shall be
15
- included in all copies or substantial portions of the Software.
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
16
15
 
17
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
data/README.md CHANGED
@@ -4,7 +4,7 @@
4
4
 
5
5
  Check out [an example](https://ankane.org/tensorflow-ruby)
6
6
 
7
- [![Build Status](https://travis-ci.org/ankane/onnxruntime.svg?branch=master)](https://travis-ci.org/ankane/onnxruntime) [![Build status](https://ci.appveyor.com/api/projects/status/f2bq6ruqjf4jx671/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/onnxruntime/branch/master)
7
+ [![Build Status](https://github.com/ankane/onnxruntime/workflows/build/badge.svg?branch=master)](https://github.com/ankane/onnxruntime/actions)
8
8
 
9
9
  ## Installation
10
10
 
@@ -81,7 +81,8 @@ model.predict(input_feed, {
81
81
  log_severity_level: 2,
82
82
  log_verbosity_level: 0,
83
83
  logid: nil,
84
- terminate: false
84
+ terminate: false,
85
+ output_type: :ruby # :ruby or :numo
85
86
  })
86
87
  ```
87
88
 
@@ -3,10 +3,13 @@ module OnnxRuntime
3
3
  extend ::FFI::Library
4
4
 
5
5
  begin
6
- ffi_lib Array(OnnxRuntime.ffi_lib)
6
+ ffi_lib OnnxRuntime.ffi_lib
7
7
  rescue LoadError => e
8
- raise e if ENV["ONNXRUNTIME_DEBUG"]
9
- raise LoadError, "Could not find ONNX Runtime"
8
+ if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
9
+ raise LoadError, "OpenMP not found. Run `brew install libomp`"
10
+ else
11
+ raise e
12
+ end
10
13
  end
11
14
 
12
15
  # https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -25,14 +28,14 @@ module OnnxRuntime
25
28
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
29
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
27
30
  :DisableTelemetryEvents, callback(%i[pointer], :pointer),
28
- :CreateSession, callback(%i[pointer string pointer pointer], :pointer),
31
+ :CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
29
32
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
33
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
34
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
35
+ :SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
33
36
  :CloneSessionOptions, callback(%i[], :pointer),
34
- :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[pointer string], :pointer),
37
+ :SetSessionExecutionMode, callback(%i[pointer int], :pointer),
38
+ :EnableProfiling, callback(%i[pointer pointer], :pointer),
36
39
  :DisableProfiling, callback(%i[pointer], :pointer),
37
40
  :EnableMemPattern, callback(%i[pointer], :pointer),
38
41
  :DisableMemPattern, callback(%i[pointer], :pointer),
@@ -71,8 +74,8 @@ module OnnxRuntime
71
74
  :IsTensor, callback(%i[], :pointer),
72
75
  :GetTensorMutableData, callback(%i[pointer pointer], :pointer),
73
76
  :FillStringTensor, callback(%i[pointer pointer size_t], :pointer),
74
- :GetStringTensorDataLength, callback(%i[], :pointer),
75
- :GetStringTensorContent, callback(%i[], :pointer),
77
+ :GetStringTensorDataLength, callback(%i[pointer pointer], :pointer),
78
+ :GetStringTensorContent, callback(%i[pointer pointer size_t pointer size_t], :pointer),
76
79
  :CastTypeInfoToTensorInfo, callback(%i[pointer pointer], :pointer),
77
80
  :GetOnnxTypeFromTypeInfo, callback(%i[pointer pointer], :pointer),
78
81
  :CreateTensorTypeAndShapeInfo, callback(%i[], :pointer),
@@ -142,7 +145,9 @@ module OnnxRuntime
142
145
  :CreateThreadingOptions, callback(%i[], :pointer),
143
146
  :ReleaseThreadingOptions, callback(%i[], :pointer),
144
147
  :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
145
- :AddFreeDimensionOverrideByName, callback(%i[], :pointer)
148
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer),
149
+ :GetAvailableProviders, callback(%i[pointer pointer], :pointer),
150
+ :ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
146
151
  end
147
152
 
148
153
  class ApiBase < ::FFI::Struct
@@ -154,5 +159,13 @@ module OnnxRuntime
154
159
  end
155
160
 
156
161
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
162
+
163
+ if Gem.win_platform?
164
+ class Libc
165
+ extend ::FFI::Library
166
+ ffi_lib ::FFI::Library::LIBC
167
+ attach_function :mbstowcs, %i[pointer string size_t], :size_t
168
+ end
169
+ end
157
170
  end
158
171
  end
@@ -8,7 +8,7 @@ module OnnxRuntime
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
9
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
10
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
- check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
12
12
  if execution_mode
13
13
  execution_modes = {sequential: 0, parallel: 1}
14
14
  mode = execution_modes[execution_mode]
@@ -17,8 +17,8 @@ module OnnxRuntime
17
17
  end
18
18
  if graph_optimization_level
19
19
  optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
- # TODO raise error in 0.4.0
21
- level = optimization_levels[graph_optimization_level] || graph_optimization_level
20
+ level = optimization_levels[graph_optimization_level]
21
+ raise ArgumentError, "Invalid graph optimization level" unless level
22
22
  check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
23
  end
24
24
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
@@ -26,7 +26,7 @@ module OnnxRuntime
26
26
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
27
27
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
28
28
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
29
- check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
30
30
 
31
31
  # session
32
32
  @session = ::FFI::MemoryPointer.new(:pointer)
@@ -39,16 +39,10 @@ module OnnxRuntime
39
39
  path_or_bytes.encoding == Encoding::BINARY
40
40
  end
41
41
 
42
- # fix for Windows "File doesn't exist"
43
- if Gem.win_platform? && !from_memory
44
- path_or_bytes = File.binread(path_or_bytes)
45
- from_memory = true
46
- end
47
-
48
42
  if from_memory
49
43
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
50
44
  else
51
- check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
45
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
52
46
  end
53
47
  ObjectSpace.define_finalizer(self, self.class.finalize(@session))
54
48
 
@@ -63,7 +57,7 @@ module OnnxRuntime
63
57
  # input
64
58
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
65
59
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
66
- read_size_t(num_input_nodes).times do |i|
60
+ num_input_nodes.read(:size_t).times do |i|
67
61
  name_ptr = ::FFI::MemoryPointer.new(:string)
68
62
  check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
69
63
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -74,7 +68,7 @@ module OnnxRuntime
74
68
  # output
75
69
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
76
70
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
77
- read_size_t(num_output_nodes).times do |i|
71
+ num_output_nodes.read(:size_t).times do |i|
78
72
  name_ptr = ::FFI::MemoryPointer.new(:string)
79
73
  check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
80
74
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -86,7 +80,7 @@ module OnnxRuntime
86
80
  end
87
81
 
88
82
  # TODO support logid
89
- def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
83
+ def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
90
84
  input_tensor = create_input_tensor(input_feed)
91
85
 
92
86
  output_names ||= @outputs.map { |v| v[:name] }
@@ -106,7 +100,7 @@ module OnnxRuntime
106
100
  check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
107
101
 
108
102
  output_names.size.times.map do |i|
109
- create_from_onnx_value(output_tensor[i].read_pointer)
103
+ create_from_onnx_value(output_tensor[i].read_pointer, output_type)
110
104
  end
111
105
  ensure
112
106
  release :RunOptions, run_options
@@ -130,7 +124,7 @@ module OnnxRuntime
130
124
  check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
131
125
 
132
126
  custom_metadata_map = {}
133
- check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
127
+ check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
134
128
  num_keys.read(:int64_t).times do |i|
135
129
  key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
136
130
  value = ::FFI::MemoryPointer.new(:string)
@@ -156,56 +150,86 @@ module OnnxRuntime
156
150
  release :ModelMetadata, metadata
157
151
  end
158
152
 
153
+ # return value has double underscore like Python
159
154
  def end_profiling
160
155
  out = ::FFI::MemoryPointer.new(:string)
161
156
  check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
162
157
  out.read_pointer.read_string
163
158
  end
164
159
 
160
+ # no way to set providers with C API yet
161
+ # so we can return all available providers
162
+ def providers
163
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
164
+ length_ptr = ::FFI::MemoryPointer.new(:int)
165
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
166
+ length = length_ptr.read_int
167
+ providers = []
168
+ length.times do |i|
169
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
170
+ end
171
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
172
+ providers
173
+ end
174
+
165
175
  private
166
176
 
167
177
  def create_input_tensor(input_feed)
168
178
  allocator_info = ::FFI::MemoryPointer.new(:pointer)
169
- check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
179
+ check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
170
180
  input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
171
181
 
172
182
  input_feed.each_with_index do |(input_name, input), idx|
173
- input = input.to_a unless input.is_a?(Array)
183
+ if numo_array?(input)
184
+ shape = input.shape
185
+ else
186
+ input = input.to_a unless input.is_a?(Array)
174
187
 
175
- shape = []
176
- s = input
177
- while s.is_a?(Array)
178
- shape << s.size
179
- s = s.first
188
+ shape = []
189
+ s = input
190
+ while s.is_a?(Array)
191
+ shape << s.size
192
+ s = s.first
193
+ end
180
194
  end
181
195
 
182
- flat_input = input.flatten
183
- input_tensor_size = flat_input.size
184
-
185
196
  # TODO support more types
186
197
  inp = @inputs.find { |i| i[:name] == input_name.to_s }
187
- raise "Unknown input: #{input_name}" unless inp
198
+ raise Error, "Unknown input: #{input_name}" unless inp
188
199
 
189
200
  input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
190
201
  input_node_dims.write_array_of_int64(shape)
191
202
 
192
203
  if inp[:type] == "tensor(string)"
193
- input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
194
- input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
204
+ if numo_array?(input)
205
+ input_tensor_size = input.size
206
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size)
207
+ input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) })
208
+ else
209
+ flat_input = input.flatten.to_a
210
+ input_tensor_size = flat_input.size
211
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
212
+ input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
213
+ end
195
214
  type_enum = FFI::TensorElementDataType[:string]
196
215
  check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
197
- check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size)
216
+ check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size)
198
217
  else
199
- tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
200
218
  tensor_type = tensor_types[inp[:type]]
201
219
 
202
220
  if tensor_type
203
- input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, input_tensor_size)
204
- if tensor_type == :bool
205
- tensor_type = :uchar
206
- flat_input = flat_input.map { |v| v ? 1 : 0 }
221
+ if numo_array?(input)
222
+ input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
223
+ else
224
+ flat_input = input.flatten.to_a
225
+ input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
226
+ if tensor_type == :bool
227
+ input_tensor_values.write_array_of_uint8(flat_input.map { |v| v ? 1 : 0 })
228
+ else
229
+ input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
230
+ end
207
231
  end
208
- input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
232
+
209
233
  type_enum = FFI::TensorElementDataType[tensor_type]
210
234
  else
211
235
  unsupported_type("input", inp[:type])
@@ -224,9 +248,9 @@ module OnnxRuntime
224
248
  ptr
225
249
  end
226
250
 
227
- def create_from_onnx_value(out_ptr)
251
+ def create_from_onnx_value(out_ptr, output_type)
228
252
  out_type = ::FFI::MemoryPointer.new(:int)
229
- check_status = api[:GetValueType].call(out_ptr, out_type)
253
+ check_status api[:GetValueType].call(out_ptr, out_type)
230
254
  type = FFI::OnnxType[out_type.read_int]
231
255
 
232
256
  case type
@@ -241,31 +265,50 @@ module OnnxRuntime
241
265
 
242
266
  out_size = ::FFI::MemoryPointer.new(:size_t)
243
267
  output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
244
- output_tensor_size = read_size_t(out_size)
268
+ output_tensor_size = out_size.read(:size_t)
245
269
 
246
270
  release :TensorTypeAndShapeInfo, typeinfo
247
271
 
248
272
  # TODO support more types
249
273
  type = FFI::TensorElementDataType[type]
250
- arr =
274
+
275
+ case output_type
276
+ when :numo
251
277
  case type
252
- when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
253
- tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
254
- when :bool
255
- tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
278
+ when :string
279
+ result = Numo::RObject.new(shape)
280
+ result.allocate
281
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
256
282
  else
257
- unsupported_type("element", type)
283
+ numo_type = numo_types[type]
284
+ unsupported_type("element", type) unless numo_type
285
+ numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
258
286
  end
287
+ when :ruby
288
+ arr =
289
+ case type
290
+ when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
291
+ tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
292
+ when :bool
293
+ tensor_data.read_pointer.read_array_of_uint8(output_tensor_size).map { |v| v == 1 }
294
+ when :string
295
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
296
+ else
297
+ unsupported_type("element", type)
298
+ end
259
299
 
260
- Utils.reshape(arr, shape)
300
+ Utils.reshape(arr, shape)
301
+ else
302
+ raise ArgumentError, "Invalid output type: #{output_type}"
303
+ end
261
304
  when :sequence
262
305
  out = ::FFI::MemoryPointer.new(:size_t)
263
306
  check_status api[:GetValueCount].call(out_ptr, out)
264
307
 
265
- read_size_t(out).times.map do |i|
308
+ out.read(:size_t).times.map do |i|
266
309
  seq = ::FFI::MemoryPointer.new(:pointer)
267
310
  check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
268
- create_from_onnx_value(seq.read_pointer)
311
+ create_from_onnx_value(seq.read_pointer, output_type)
269
312
  end
270
313
  when :map
271
314
  type_shape = ::FFI::MemoryPointer.new(:pointer)
@@ -284,8 +327,8 @@ module OnnxRuntime
284
327
  case elem_type
285
328
  when :int64
286
329
  ret = {}
287
- keys = create_from_onnx_value(map_keys.read_pointer)
288
- values = create_from_onnx_value(map_values.read_pointer)
330
+ keys = create_from_onnx_value(map_keys.read_pointer, output_type)
331
+ values = create_from_onnx_value(map_values.read_pointer, output_type)
289
332
  keys.zip(values).each do |k, v|
290
333
  ret[k] = v
291
334
  end
@@ -298,6 +341,23 @@ module OnnxRuntime
298
341
  end
299
342
  end
300
343
 
344
+ def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
345
+ len = ::FFI::MemoryPointer.new(:size_t)
346
+ check_status api[:GetStringTensorDataLength].call(out_ptr, len)
347
+
348
+ s_len = len.read(:size_t)
349
+ s = ::FFI::MemoryPointer.new(:uchar, s_len)
350
+ offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
351
+ check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
352
+
353
+ offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
354
+ offsets << s_len
355
+ output_tensor_size.times do |i|
356
+ result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
357
+ end
358
+ result
359
+ end
360
+
301
361
  def read_pointer
302
362
  @session.read_pointer
303
363
  end
@@ -306,7 +366,7 @@ module OnnxRuntime
306
366
  unless status.null?
307
367
  message = api[:GetErrorMessage].call(status).read_string
308
368
  api[:ReleaseStatus].call(status)
309
- raise OnnxRuntime::Error, message
369
+ raise Error, message
310
370
  end
311
371
  end
312
372
 
@@ -368,7 +428,7 @@ module OnnxRuntime
368
428
 
369
429
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
370
430
  check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
371
- num_dims = read_size_t(num_dims_ptr)
431
+ num_dims = num_dims_ptr.read(:size_t)
372
432
 
373
433
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
374
434
  check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
@@ -377,16 +437,31 @@ module OnnxRuntime
377
437
  end
378
438
 
379
439
  def unsupported_type(name, type)
380
- raise "Unsupported #{name} type: #{type}"
440
+ raise Error, "Unsupported #{name} type: #{type}"
381
441
  end
382
442
 
383
- # read(:size_t) not supported in FFI JRuby
384
- def read_size_t(ptr)
385
- if RUBY_PLATFORM == "java"
386
- ptr.read_long
387
- else
388
- ptr.read(:size_t)
389
- end
443
+ def tensor_types
444
+ @tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
445
+ end
446
+
447
+ def numo_array?(obj)
448
+ defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
449
+ end
450
+
451
+ def numo_types
452
+ @numo_types ||= {
453
+ float: Numo::SFloat,
454
+ uint8: Numo::UInt8,
455
+ int8: Numo::Int8,
456
+ uint16: Numo::UInt16,
457
+ int16: Numo::Int16,
458
+ int32: Numo::Int32,
459
+ int64: Numo::Int64,
460
+ bool: Numo::UInt8,
461
+ double: Numo::DFloat,
462
+ uint32: Numo::UInt32,
463
+ uint64: Numo::UInt64
464
+ }
390
465
  end
391
466
 
392
467
  def api
@@ -398,7 +473,7 @@ module OnnxRuntime
398
473
  end
399
474
 
400
475
  def self.api
401
- @api ||= FFI.OrtGetApiBase[:GetApi].call(3)
476
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
402
477
  end
403
478
 
404
479
  def self.release(type, pointer)
@@ -410,6 +485,21 @@ module OnnxRuntime
410
485
  proc { release :Session, session }
411
486
  end
412
487
 
488
+ # wide string on Windows
489
+ # char string on Linux
490
+ # see ORTCHAR_T in onnxruntime_c_api.h
491
+ def ort_string(str)
492
+ if Gem.win_platform?
493
+ max = str.size + 1 # for null byte
494
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
495
+ ret = FFI::Libc.mbstowcs(dest, str, max)
496
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
497
+ dest
498
+ else
499
+ str
500
+ end
501
+ end
502
+
413
503
  def env
414
504
  # use mutex for thread-safety
415
505
  Utils.mutex.synchronize do