onnxruntime 0.3.3 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b7d22851572b35128d1e2bbcc041b4989851e02354cace5389afc25855674b11
4
- data.tar.gz: f1eee285e5dbff1fbf6de4e350560e3e386e1f4d26d53810f154cd16117e38e6
3
+ metadata.gz: 4b0bef4682d8d44fb3cfd3fa619c1e7e620943014a22f259e14bfa1692dbe2ef
4
+ data.tar.gz: bfe59759c177613806ab0ddd97c4d8a68a7e3915299d81aab3e9d369ce672918
5
5
  SHA512:
6
- metadata.gz: f89ba13181bfcc8cdf35356efae175d0e2c7a0787a13af8d46788645046f057bce466e28566c0f2982f52893dbb7cc1553a741b1ca4923066e90c0a5fb230edf
7
- data.tar.gz: 3843b9c1e5a9432d3b72ebb33498aaef0fb6edf8942a9f0b6794bed36fb95df67290e6c4385e19927e5404510d89ad8965b15de0a96246b7f1770c4fb01216f6
6
+ metadata.gz: f0dfc6c0a91621b5637e56061b40ba26648a7e538d84cb2c135a06c1832c9f71aba2e7fb7ab849750756540d02632726b3f45732a8c50649ca4ace03befbe086
7
+ data.tar.gz: 59a01f8d955f9d31f35864a9217fcf79eb35876229e749c5c50b4572903d9dc3e312a4a7a55612e9c36e3095a797da58637442ac0bb191adcb743b62fefb0905
data/CHANGELOG.md CHANGED
@@ -1,3 +1,33 @@
1
+ ## 0.6.0 (2021-03-14)
2
+
3
+ - Updated ONNX Runtime to 1.7.0
4
+ - OpenMP is no longer required
5
+
6
+ ## 0.5.2 (2020-12-27)
7
+
8
+ - Updated ONNX Runtime to 1.6.0
9
+ - Fixed error with `execution_mode` option
10
+ - Fixed error with `bool` input
11
+
12
+ ## 0.5.1 (2020-11-01)
13
+
14
+ - Updated ONNX Runtime to 1.5.2
15
+ - Added support for string output
16
+ - Added `output_type` option
17
+ - Improved performance for Numo array inputs
18
+
19
+ ## 0.5.0 (2020-10-01)
20
+
21
+ - Updated ONNX Runtime to 1.5.1
22
+ - OpenMP is now required on Mac
23
+ - Fixed `mul_1.onnx` example
24
+
25
+ ## 0.4.0 (2020-07-20)
26
+
27
+ - Updated ONNX Runtime to 1.4.0
28
+ - Added `providers` method
29
+ - Fixed errors on Windows
30
+
1
31
  ## 0.3.3 (2020-06-17)
2
32
 
3
33
  - Fixed segmentation fault on exit on Linux
data/LICENSE.txt CHANGED
@@ -1,23 +1,22 @@
1
- Copyright (c) 2019-2020 Andrew Kane
2
- Datasets Copyright (c) Microsoft Corporation
3
-
4
1
  MIT License
5
2
 
6
- Permission is hereby granted, free of charge, to any person obtaining
7
- a copy of this software and associated documentation files (the
8
- "Software"), to deal in the Software without restriction, including
9
- without limitation the rights to use, copy, modify, merge, publish,
10
- distribute, sublicense, and/or sell copies of the Software, and to
11
- permit persons to whom the Software is furnished to do so, subject to
12
- the following conditions:
3
+ Copyright (c) Microsoft Corporation
4
+ Copyright (c) 2019-2021 Andrew Kane
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
13
12
 
14
- The above copyright notice and this permission notice shall be
15
- included in all copies or substantial portions of the Software.
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
16
15
 
17
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
data/README.md CHANGED
@@ -4,7 +4,7 @@
4
4
 
5
5
  Check out [an example](https://ankane.org/tensorflow-ruby)
6
6
 
7
- [![Build Status](https://travis-ci.org/ankane/onnxruntime.svg?branch=master)](https://travis-ci.org/ankane/onnxruntime) [![Build status](https://ci.appveyor.com/api/projects/status/f2bq6ruqjf4jx671/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/onnxruntime/branch/master)
7
+ [![Build Status](https://github.com/ankane/onnxruntime/workflows/build/badge.svg?branch=master)](https://github.com/ankane/onnxruntime/actions)
8
8
 
9
9
  ## Installation
10
10
 
@@ -81,7 +81,8 @@ model.predict(input_feed, {
81
81
  log_severity_level: 2,
82
82
  log_verbosity_level: 0,
83
83
  logid: nil,
84
- terminate: false
84
+ terminate: false,
85
+ output_type: :ruby # :ruby or :numo
85
86
  })
86
87
  ```
87
88
 
@@ -3,10 +3,13 @@ module OnnxRuntime
3
3
  extend ::FFI::Library
4
4
 
5
5
  begin
6
- ffi_lib Array(OnnxRuntime.ffi_lib)
6
+ ffi_lib OnnxRuntime.ffi_lib
7
7
  rescue LoadError => e
8
- raise e if ENV["ONNXRUNTIME_DEBUG"]
9
- raise LoadError, "Could not find ONNX Runtime"
8
+ if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
9
+ raise LoadError, "OpenMP not found. Run `brew install libomp`"
10
+ else
11
+ raise e
12
+ end
10
13
  end
11
14
 
12
15
  # https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -25,14 +28,14 @@ module OnnxRuntime
25
28
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
29
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
27
30
  :DisableTelemetryEvents, callback(%i[pointer], :pointer),
28
- :CreateSession, callback(%i[pointer string pointer pointer], :pointer),
31
+ :CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
29
32
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
33
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
34
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
35
+ :SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
33
36
  :CloneSessionOptions, callback(%i[], :pointer),
34
- :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[pointer string], :pointer),
37
+ :SetSessionExecutionMode, callback(%i[pointer int], :pointer),
38
+ :EnableProfiling, callback(%i[pointer pointer], :pointer),
36
39
  :DisableProfiling, callback(%i[pointer], :pointer),
37
40
  :EnableMemPattern, callback(%i[pointer], :pointer),
38
41
  :DisableMemPattern, callback(%i[pointer], :pointer),
@@ -71,8 +74,8 @@ module OnnxRuntime
71
74
  :IsTensor, callback(%i[], :pointer),
72
75
  :GetTensorMutableData, callback(%i[pointer pointer], :pointer),
73
76
  :FillStringTensor, callback(%i[pointer pointer size_t], :pointer),
74
- :GetStringTensorDataLength, callback(%i[], :pointer),
75
- :GetStringTensorContent, callback(%i[], :pointer),
77
+ :GetStringTensorDataLength, callback(%i[pointer pointer], :pointer),
78
+ :GetStringTensorContent, callback(%i[pointer pointer size_t pointer size_t], :pointer),
76
79
  :CastTypeInfoToTensorInfo, callback(%i[pointer pointer], :pointer),
77
80
  :GetOnnxTypeFromTypeInfo, callback(%i[pointer pointer], :pointer),
78
81
  :CreateTensorTypeAndShapeInfo, callback(%i[], :pointer),
@@ -142,7 +145,9 @@ module OnnxRuntime
142
145
  :CreateThreadingOptions, callback(%i[], :pointer),
143
146
  :ReleaseThreadingOptions, callback(%i[], :pointer),
144
147
  :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
145
- :AddFreeDimensionOverrideByName, callback(%i[], :pointer)
148
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer),
149
+ :GetAvailableProviders, callback(%i[pointer pointer], :pointer),
150
+ :ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
146
151
  end
147
152
 
148
153
  class ApiBase < ::FFI::Struct
@@ -154,5 +159,13 @@ module OnnxRuntime
154
159
  end
155
160
 
156
161
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
162
+
163
+ if Gem.win_platform?
164
+ class Libc
165
+ extend ::FFI::Library
166
+ ffi_lib ::FFI::Library::LIBC
167
+ attach_function :mbstowcs, %i[pointer string size_t], :size_t
168
+ end
169
+ end
157
170
  end
158
171
  end
@@ -8,7 +8,7 @@ module OnnxRuntime
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
9
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
10
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
- check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
12
12
  if execution_mode
13
13
  execution_modes = {sequential: 0, parallel: 1}
14
14
  mode = execution_modes[execution_mode]
@@ -17,8 +17,8 @@ module OnnxRuntime
17
17
  end
18
18
  if graph_optimization_level
19
19
  optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
- # TODO raise error in 0.4.0
21
- level = optimization_levels[graph_optimization_level] || graph_optimization_level
20
+ level = optimization_levels[graph_optimization_level]
21
+ raise ArgumentError, "Invalid graph optimization level" unless level
22
22
  check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
23
  end
24
24
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
@@ -26,7 +26,7 @@ module OnnxRuntime
26
26
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
27
27
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
28
28
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
29
- check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
30
30
 
31
31
  # session
32
32
  @session = ::FFI::MemoryPointer.new(:pointer)
@@ -39,16 +39,10 @@ module OnnxRuntime
39
39
  path_or_bytes.encoding == Encoding::BINARY
40
40
  end
41
41
 
42
- # fix for Windows "File doesn't exist"
43
- if Gem.win_platform? && !from_memory
44
- path_or_bytes = File.binread(path_or_bytes)
45
- from_memory = true
46
- end
47
-
48
42
  if from_memory
49
43
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
50
44
  else
51
- check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
45
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
52
46
  end
53
47
  ObjectSpace.define_finalizer(self, self.class.finalize(@session))
54
48
 
@@ -63,7 +57,7 @@ module OnnxRuntime
63
57
  # input
64
58
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
65
59
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
66
- read_size_t(num_input_nodes).times do |i|
60
+ num_input_nodes.read(:size_t).times do |i|
67
61
  name_ptr = ::FFI::MemoryPointer.new(:string)
68
62
  check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
69
63
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -74,7 +68,7 @@ module OnnxRuntime
74
68
  # output
75
69
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
76
70
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
77
- read_size_t(num_output_nodes).times do |i|
71
+ num_output_nodes.read(:size_t).times do |i|
78
72
  name_ptr = ::FFI::MemoryPointer.new(:string)
79
73
  check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
80
74
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -86,7 +80,7 @@ module OnnxRuntime
86
80
  end
87
81
 
88
82
  # TODO support logid
89
- def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
83
+ def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
90
84
  input_tensor = create_input_tensor(input_feed)
91
85
 
92
86
  output_names ||= @outputs.map { |v| v[:name] }
@@ -106,7 +100,7 @@ module OnnxRuntime
106
100
  check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
107
101
 
108
102
  output_names.size.times.map do |i|
109
- create_from_onnx_value(output_tensor[i].read_pointer)
103
+ create_from_onnx_value(output_tensor[i].read_pointer, output_type)
110
104
  end
111
105
  ensure
112
106
  release :RunOptions, run_options
@@ -130,7 +124,7 @@ module OnnxRuntime
130
124
  check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
131
125
 
132
126
  custom_metadata_map = {}
133
- check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
127
+ check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
134
128
  num_keys.read(:int64_t).times do |i|
135
129
  key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
136
130
  value = ::FFI::MemoryPointer.new(:string)
@@ -156,56 +150,86 @@ module OnnxRuntime
156
150
  release :ModelMetadata, metadata
157
151
  end
158
152
 
153
+ # return value has double underscore like Python
159
154
  def end_profiling
160
155
  out = ::FFI::MemoryPointer.new(:string)
161
156
  check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
162
157
  out.read_pointer.read_string
163
158
  end
164
159
 
160
+ # no way to set providers with C API yet
161
+ # so we can return all available providers
162
+ def providers
163
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
164
+ length_ptr = ::FFI::MemoryPointer.new(:int)
165
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
166
+ length = length_ptr.read_int
167
+ providers = []
168
+ length.times do |i|
169
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
170
+ end
171
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
172
+ providers
173
+ end
174
+
165
175
  private
166
176
 
167
177
  def create_input_tensor(input_feed)
168
178
  allocator_info = ::FFI::MemoryPointer.new(:pointer)
169
- check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
179
+ check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
170
180
  input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
171
181
 
172
182
  input_feed.each_with_index do |(input_name, input), idx|
173
- input = input.to_a unless input.is_a?(Array)
183
+ if numo_array?(input)
184
+ shape = input.shape
185
+ else
186
+ input = input.to_a unless input.is_a?(Array)
174
187
 
175
- shape = []
176
- s = input
177
- while s.is_a?(Array)
178
- shape << s.size
179
- s = s.first
188
+ shape = []
189
+ s = input
190
+ while s.is_a?(Array)
191
+ shape << s.size
192
+ s = s.first
193
+ end
180
194
  end
181
195
 
182
- flat_input = input.flatten
183
- input_tensor_size = flat_input.size
184
-
185
196
  # TODO support more types
186
197
  inp = @inputs.find { |i| i[:name] == input_name.to_s }
187
- raise "Unknown input: #{input_name}" unless inp
198
+ raise Error, "Unknown input: #{input_name}" unless inp
188
199
 
189
200
  input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
190
201
  input_node_dims.write_array_of_int64(shape)
191
202
 
192
203
  if inp[:type] == "tensor(string)"
193
- input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
194
- input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
204
+ if numo_array?(input)
205
+ input_tensor_size = input.size
206
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size)
207
+ input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) })
208
+ else
209
+ flat_input = input.flatten.to_a
210
+ input_tensor_size = flat_input.size
211
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
212
+ input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
213
+ end
195
214
  type_enum = FFI::TensorElementDataType[:string]
196
215
  check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
197
- check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size)
216
+ check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size)
198
217
  else
199
- tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
200
218
  tensor_type = tensor_types[inp[:type]]
201
219
 
202
220
  if tensor_type
203
- input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, input_tensor_size)
204
- if tensor_type == :bool
205
- tensor_type = :uchar
206
- flat_input = flat_input.map { |v| v ? 1 : 0 }
221
+ if numo_array?(input)
222
+ input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
223
+ else
224
+ flat_input = input.flatten.to_a
225
+ input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
226
+ if tensor_type == :bool
227
+ input_tensor_values.write_array_of_uint8(flat_input.map { |v| v ? 1 : 0 })
228
+ else
229
+ input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
230
+ end
207
231
  end
208
- input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
232
+
209
233
  type_enum = FFI::TensorElementDataType[tensor_type]
210
234
  else
211
235
  unsupported_type("input", inp[:type])
@@ -224,9 +248,9 @@ module OnnxRuntime
224
248
  ptr
225
249
  end
226
250
 
227
- def create_from_onnx_value(out_ptr)
251
+ def create_from_onnx_value(out_ptr, output_type)
228
252
  out_type = ::FFI::MemoryPointer.new(:int)
229
- check_status = api[:GetValueType].call(out_ptr, out_type)
253
+ check_status api[:GetValueType].call(out_ptr, out_type)
230
254
  type = FFI::OnnxType[out_type.read_int]
231
255
 
232
256
  case type
@@ -241,31 +265,50 @@ module OnnxRuntime
241
265
 
242
266
  out_size = ::FFI::MemoryPointer.new(:size_t)
243
267
  output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
244
- output_tensor_size = read_size_t(out_size)
268
+ output_tensor_size = out_size.read(:size_t)
245
269
 
246
270
  release :TensorTypeAndShapeInfo, typeinfo
247
271
 
248
272
  # TODO support more types
249
273
  type = FFI::TensorElementDataType[type]
250
- arr =
274
+
275
+ case output_type
276
+ when :numo
251
277
  case type
252
- when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
253
- tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
254
- when :bool
255
- tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
278
+ when :string
279
+ result = Numo::RObject.new(shape)
280
+ result.allocate
281
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
256
282
  else
257
- unsupported_type("element", type)
283
+ numo_type = numo_types[type]
284
+ unsupported_type("element", type) unless numo_type
285
+ numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
258
286
  end
287
+ when :ruby
288
+ arr =
289
+ case type
290
+ when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
291
+ tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
292
+ when :bool
293
+ tensor_data.read_pointer.read_array_of_uint8(output_tensor_size).map { |v| v == 1 }
294
+ when :string
295
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
296
+ else
297
+ unsupported_type("element", type)
298
+ end
259
299
 
260
- Utils.reshape(arr, shape)
300
+ Utils.reshape(arr, shape)
301
+ else
302
+ raise ArgumentError, "Invalid output type: #{output_type}"
303
+ end
261
304
  when :sequence
262
305
  out = ::FFI::MemoryPointer.new(:size_t)
263
306
  check_status api[:GetValueCount].call(out_ptr, out)
264
307
 
265
- read_size_t(out).times.map do |i|
308
+ out.read(:size_t).times.map do |i|
266
309
  seq = ::FFI::MemoryPointer.new(:pointer)
267
310
  check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
268
- create_from_onnx_value(seq.read_pointer)
311
+ create_from_onnx_value(seq.read_pointer, output_type)
269
312
  end
270
313
  when :map
271
314
  type_shape = ::FFI::MemoryPointer.new(:pointer)
@@ -284,8 +327,8 @@ module OnnxRuntime
284
327
  case elem_type
285
328
  when :int64
286
329
  ret = {}
287
- keys = create_from_onnx_value(map_keys.read_pointer)
288
- values = create_from_onnx_value(map_values.read_pointer)
330
+ keys = create_from_onnx_value(map_keys.read_pointer, output_type)
331
+ values = create_from_onnx_value(map_values.read_pointer, output_type)
289
332
  keys.zip(values).each do |k, v|
290
333
  ret[k] = v
291
334
  end
@@ -298,6 +341,23 @@ module OnnxRuntime
298
341
  end
299
342
  end
300
343
 
344
+ def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
345
+ len = ::FFI::MemoryPointer.new(:size_t)
346
+ check_status api[:GetStringTensorDataLength].call(out_ptr, len)
347
+
348
+ s_len = len.read(:size_t)
349
+ s = ::FFI::MemoryPointer.new(:uchar, s_len)
350
+ offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
351
+ check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
352
+
353
+ offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
354
+ offsets << s_len
355
+ output_tensor_size.times do |i|
356
+ result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
357
+ end
358
+ result
359
+ end
360
+
301
361
  def read_pointer
302
362
  @session.read_pointer
303
363
  end
@@ -306,7 +366,7 @@ module OnnxRuntime
306
366
  unless status.null?
307
367
  message = api[:GetErrorMessage].call(status).read_string
308
368
  api[:ReleaseStatus].call(status)
309
- raise OnnxRuntime::Error, message
369
+ raise Error, message
310
370
  end
311
371
  end
312
372
 
@@ -368,7 +428,7 @@ module OnnxRuntime
368
428
 
369
429
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
370
430
  check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
371
- num_dims = read_size_t(num_dims_ptr)
431
+ num_dims = num_dims_ptr.read(:size_t)
372
432
 
373
433
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
374
434
  check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
@@ -377,16 +437,31 @@ module OnnxRuntime
377
437
  end
378
438
 
379
439
  def unsupported_type(name, type)
380
- raise "Unsupported #{name} type: #{type}"
440
+ raise Error, "Unsupported #{name} type: #{type}"
381
441
  end
382
442
 
383
- # read(:size_t) not supported in FFI JRuby
384
- def read_size_t(ptr)
385
- if RUBY_PLATFORM == "java"
386
- ptr.read_long
387
- else
388
- ptr.read(:size_t)
389
- end
443
+ def tensor_types
444
+ @tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
445
+ end
446
+
447
+ def numo_array?(obj)
448
+ defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
449
+ end
450
+
451
+ def numo_types
452
+ @numo_types ||= {
453
+ float: Numo::SFloat,
454
+ uint8: Numo::UInt8,
455
+ int8: Numo::Int8,
456
+ uint16: Numo::UInt16,
457
+ int16: Numo::Int16,
458
+ int32: Numo::Int32,
459
+ int64: Numo::Int64,
460
+ bool: Numo::UInt8,
461
+ double: Numo::DFloat,
462
+ uint32: Numo::UInt32,
463
+ uint64: Numo::UInt64
464
+ }
390
465
  end
391
466
 
392
467
  def api
@@ -398,7 +473,7 @@ module OnnxRuntime
398
473
  end
399
474
 
400
475
  def self.api
401
- @api ||= FFI.OrtGetApiBase[:GetApi].call(3)
476
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
402
477
  end
403
478
 
404
479
  def self.release(type, pointer)
@@ -410,6 +485,21 @@ module OnnxRuntime
410
485
  proc { release :Session, session }
411
486
  end
412
487
 
488
+ # wide string on Windows
489
+ # char string on Linux
490
+ # see ORTCHAR_T in onnxruntime_c_api.h
491
+ def ort_string(str)
492
+ if Gem.win_platform?
493
+ max = str.size + 1 # for null byte
494
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
495
+ ret = FFI::Libc.mbstowcs(dest, str, max)
496
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
497
+ dest
498
+ else
499
+ str
500
+ end
501
+ end
502
+
413
503
  def env
414
504
  # use mutex for thread-safety
415
505
  Utils.mutex.synchronize do