onnxruntime 0.3.1 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2d4b9ddc710644eab07227fe005c92ad039643e989236e5f112a83df2266fc55
4
- data.tar.gz: d15c9ac4ae98130499d1411d3be2e7b822df8248550c0ad44c2e7a02d88f95d8
3
+ metadata.gz: 316527780be2781a474d0813aff47d840654423823ad83b8b52b13752caf6814
4
+ data.tar.gz: 5954ba2dc4223b8330fb52e8474be78542360c2f4c1cb5946529d062c0b1b864
5
5
  SHA512:
6
- metadata.gz: 83a2e3699506460b8b280376ec8935a680e7e2cffe6bf7dd6766a85eea84bf06273e13889b6a2961c9df3c277be6ea7d8afbb67a4d4be62d50b266d7ccd8dc55
7
- data.tar.gz: 5086c2c04f76ec67224e7dbc9ccc36366b394840296536232b29d92009c7a4ef637fb4f39daec8abb7a5e3fdcca591675cfaf3e4a8e31ec939f9aa1e2ae6f112
6
+ metadata.gz: c127874dd75a10b8cb9d9d033e607f9d73069bb5709d7b9ecee04c1d59f969f61db6ec88fd569188b0a35c8276f7611619f81662f3735fbd53e61938f15c6822
7
+ data.tar.gz: 77a7c9f8c98b25fd82ee8818c63176d2005f846db489b5973330220344be8fd47d8e5279517a716f8c1222f729a248fc8ac80cb610633dbdc35c49418a42ea1c
@@ -1,3 +1,31 @@
1
+ ## 0.5.1 (2020-11-01)
2
+
3
+ - Updated ONNX Runtime to 1.5.2
4
+ - Added support for string output
5
+ - Added `output_type` option
6
+ - Improved performance for Numo array inputs
7
+
8
+ ## 0.5.0 (2020-10-01)
9
+
10
+ - Updated ONNX Runtime to 1.5.1
11
+ - OpenMP is now required on Mac
12
+ - Fixed `mul_1.onnx` example
13
+
14
+ ## 0.4.0 (2020-07-20)
15
+
16
+ - Updated ONNX Runtime to 1.4.0
17
+ - Added `providers` method
18
+ - Fixed errors on Windows
19
+
20
+ ## 0.3.3 (2020-06-17)
21
+
22
+ - Fixed segmentation fault on exit on Linux
23
+
24
+ ## 0.3.2 (2020-06-16)
25
+
26
+ - Fixed error with FFI 1.13.0+
27
+ - Added friendly graph optimization levels
28
+
1
29
  ## 0.3.1 (2020-05-18)
2
30
 
3
31
  - Updated ONNX Runtime to 1.3.0
@@ -1,23 +1,22 @@
1
- Copyright (c) 2019-2020 Andrew Kane
2
- Datasets Copyright (c) Microsoft Corporation
3
-
4
1
  MIT License
5
2
 
6
- Permission is hereby granted, free of charge, to any person obtaining
7
- a copy of this software and associated documentation files (the
8
- "Software"), to deal in the Software without restriction, including
9
- without limitation the rights to use, copy, modify, merge, publish,
10
- distribute, sublicense, and/or sell copies of the Software, and to
11
- permit persons to whom the Software is furnished to do so, subject to
12
- the following conditions:
3
+ Copyright (c) 2018 Microsoft Corporation
4
+ Copyright (c) 2019-2020 Andrew Kane
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
13
12
 
14
- The above copyright notice and this permission notice shall be
15
- included in all copies or substantial portions of the Software.
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
16
15
 
17
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
data/README.md CHANGED
@@ -14,6 +14,12 @@ Add this line to your application’s Gemfile:
14
14
  gem 'onnxruntime'
15
15
  ```
16
16
 
17
+ On Mac, also install OpenMP:
18
+
19
+ ```sh
20
+ brew install libomp
21
+ ```
22
+
17
23
  ## Getting Started
18
24
 
19
25
  Load a model and make predictions
@@ -63,8 +69,8 @@ OnnxRuntime::Model.new(path_or_bytes, {
63
69
  enable_cpu_mem_arena: true,
64
70
  enable_mem_pattern: true,
65
71
  enable_profiling: false,
66
- execution_mode: :sequential,
67
- graph_optimization_level: nil,
72
+ execution_mode: :sequential, # :sequential or :parallel
73
+ graph_optimization_level: nil, # :none, :basic, :extended, or :all
68
74
  inter_op_num_threads: nil,
69
75
  intra_op_num_threads: nil,
70
76
  log_severity_level: 2,
@@ -81,7 +87,8 @@ model.predict(input_feed, {
81
87
  log_severity_level: 2,
82
88
  log_verbosity_level: 0,
83
89
  logid: nil,
84
- terminate: false
90
+ terminate: false,
91
+ output_type: :ruby # :ruby or :numo
85
92
  })
86
93
  ```
87
94
 
@@ -19,7 +19,7 @@ module OnnxRuntime
19
19
  self.ffi_lib = [vendor_lib]
20
20
 
21
21
  def self.lib_version
22
- FFI.OrtGetApiBase[:GetVersionString].call
22
+ FFI.OrtGetApiBase[:GetVersionString].call.read_string
23
23
  end
24
24
 
25
25
  # friendlier error message
@@ -3,10 +3,13 @@ module OnnxRuntime
3
3
  extend ::FFI::Library
4
4
 
5
5
  begin
6
- ffi_lib Array(OnnxRuntime.ffi_lib)
6
+ ffi_lib OnnxRuntime.ffi_lib
7
7
  rescue LoadError => e
8
- raise e if ENV["ONNXRUNTIME_DEBUG"]
9
- raise LoadError, "Could not find ONNX Runtime"
8
+ if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
9
+ raise LoadError, "OpenMP not found. Run `brew install libomp`"
10
+ else
11
+ raise e
12
+ end
10
13
  end
11
14
 
12
15
  # https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -20,19 +23,19 @@ module OnnxRuntime
20
23
  layout \
21
24
  :CreateStatus, callback(%i[int string], :pointer),
22
25
  :GetErrorCode, callback(%i[pointer], :pointer),
23
- :GetErrorMessage, callback(%i[pointer], :string),
26
+ :GetErrorMessage, callback(%i[pointer], :pointer),
24
27
  :CreateEnv, callback(%i[int string pointer], :pointer),
25
28
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
29
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
27
30
  :DisableTelemetryEvents, callback(%i[pointer], :pointer),
28
- :CreateSession, callback(%i[pointer string pointer pointer], :pointer),
31
+ :CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
29
32
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
33
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
34
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
35
+ :SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
33
36
  :CloneSessionOptions, callback(%i[], :pointer),
34
37
  :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[pointer string], :pointer),
38
+ :EnableProfiling, callback(%i[pointer pointer], :pointer),
36
39
  :DisableProfiling, callback(%i[pointer], :pointer),
37
40
  :EnableMemPattern, callback(%i[pointer], :pointer),
38
41
  :DisableMemPattern, callback(%i[pointer], :pointer),
@@ -71,8 +74,8 @@ module OnnxRuntime
71
74
  :IsTensor, callback(%i[], :pointer),
72
75
  :GetTensorMutableData, callback(%i[pointer pointer], :pointer),
73
76
  :FillStringTensor, callback(%i[pointer pointer size_t], :pointer),
74
- :GetStringTensorDataLength, callback(%i[], :pointer),
75
- :GetStringTensorContent, callback(%i[], :pointer),
77
+ :GetStringTensorDataLength, callback(%i[pointer pointer], :pointer),
78
+ :GetStringTensorContent, callback(%i[pointer pointer size_t pointer size_t], :pointer),
76
79
  :CastTypeInfoToTensorInfo, callback(%i[pointer pointer], :pointer),
77
80
  :GetOnnxTypeFromTypeInfo, callback(%i[pointer pointer], :pointer),
78
81
  :CreateTensorTypeAndShapeInfo, callback(%i[], :pointer),
@@ -142,7 +145,9 @@ module OnnxRuntime
142
145
  :CreateThreadingOptions, callback(%i[], :pointer),
143
146
  :ReleaseThreadingOptions, callback(%i[], :pointer),
144
147
  :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
145
- :AddFreeDimensionOverrideByName, callback(%i[], :pointer)
148
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer),
149
+ :GetAvailableProviders, callback(%i[pointer pointer], :pointer),
150
+ :ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
146
151
  end
147
152
 
148
153
  class ApiBase < ::FFI::Struct
@@ -150,9 +155,17 @@ module OnnxRuntime
150
155
  # to prevent "unable to resolve type" error on Ubuntu
151
156
  layout \
152
157
  :GetApi, callback(%i[uint32], Api.by_ref),
153
- :GetVersionString, callback(%i[], :string)
158
+ :GetVersionString, callback(%i[], :pointer)
154
159
  end
155
160
 
156
161
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
162
+
163
+ if Gem.win_platform?
164
+ class Libc
165
+ extend ::FFI::Library
166
+ ffi_lib ::FFI::Library::LIBC
167
+ attach_function :mbstowcs, %i[pointer string size_t], :size_t
168
+ end
169
+ end
157
170
  end
158
171
  end
@@ -8,26 +8,25 @@ module OnnxRuntime
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
9
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
10
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
- check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
12
12
  if execution_mode
13
- mode =
14
- case execution_mode
15
- when :sequential
16
- 0
17
- when :parallel
18
- 1
19
- else
20
- raise ArgumentError, "Invalid execution mode"
21
- end
13
+ execution_modes = {sequential: 0, parallel: 1}
14
+ mode = execution_modes[execution_mode]
15
+ raise ArgumentError, "Invalid execution mode" unless mode
22
16
  check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
23
17
  end
24
- check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level
18
+ if graph_optimization_level
19
+ optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
+ level = optimization_levels[graph_optimization_level]
21
+ raise ArgumentError, "Invalid graph optimization level" unless level
22
+ check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
+ end
25
24
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
26
25
  check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
27
26
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
28
27
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
29
28
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
30
- check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
31
30
 
32
31
  # session
33
32
  @session = ::FFI::MemoryPointer.new(:pointer)
@@ -40,17 +39,12 @@ module OnnxRuntime
40
39
  path_or_bytes.encoding == Encoding::BINARY
41
40
  end
42
41
 
43
- # fix for Windows "File doesn't exist"
44
- if Gem.win_platform? && !from_memory
45
- path_or_bytes = File.binread(path_or_bytes)
46
- from_memory = true
47
- end
48
-
49
42
  if from_memory
50
43
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
51
44
  else
52
- check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
45
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
53
46
  end
47
+ ObjectSpace.define_finalizer(self, self.class.finalize(@session))
54
48
 
55
49
  # input info
56
50
  allocator = ::FFI::MemoryPointer.new(:pointer)
@@ -63,7 +57,7 @@ module OnnxRuntime
63
57
  # input
64
58
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
65
59
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
66
- read_size_t(num_input_nodes).times do |i|
60
+ num_input_nodes.read(:size_t).times do |i|
67
61
  name_ptr = ::FFI::MemoryPointer.new(:string)
68
62
  check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
69
63
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -74,17 +68,19 @@ module OnnxRuntime
74
68
  # output
75
69
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
76
70
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
77
- read_size_t(num_output_nodes).times do |i|
71
+ num_output_nodes.read(:size_t).times do |i|
78
72
  name_ptr = ::FFI::MemoryPointer.new(:string)
79
73
  check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
80
74
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
81
75
  check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
82
76
  @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
83
77
  end
78
+ ensure
79
+ # release :SessionOptions, session_options
84
80
  end
85
81
 
86
82
  # TODO support logid
87
- def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
83
+ def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
88
84
  input_tensor = create_input_tensor(input_feed)
89
85
 
90
86
  output_names ||= @outputs.map { |v| v[:name] }
@@ -104,7 +100,14 @@ module OnnxRuntime
104
100
  check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
105
101
 
106
102
  output_names.size.times.map do |i|
107
- create_from_onnx_value(output_tensor[i].read_pointer)
103
+ create_from_onnx_value(output_tensor[i].read_pointer, output_type)
104
+ end
105
+ ensure
106
+ release :RunOptions, run_options
107
+ if input_tensor
108
+ input_feed.size.times do |i|
109
+ release :Value, input_tensor[i]
110
+ end
108
111
  end
109
112
  end
110
113
 
@@ -121,7 +124,7 @@ module OnnxRuntime
121
124
  check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
122
125
 
123
126
  custom_metadata_map = {}
124
- check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
127
+ check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
125
128
  num_keys.read(:int64_t).times do |i|
126
129
  key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
127
130
  value = ::FFI::MemoryPointer.new(:string)
@@ -134,7 +137,6 @@ module OnnxRuntime
134
137
  check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
135
138
  check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
136
139
  check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
137
- api[:ReleaseModelMetadata].call(metadata.read_pointer)
138
140
 
139
141
  {
140
142
  custom_metadata_map: custom_metadata_map,
@@ -144,58 +146,90 @@ module OnnxRuntime
144
146
  producer_name: producer_name.read_pointer.read_string,
145
147
  version: version.read(:int64_t)
146
148
  }
149
+ ensure
150
+ release :ModelMetadata, metadata
147
151
  end
148
152
 
153
+ # return value has double underscore like Python
149
154
  def end_profiling
150
155
  out = ::FFI::MemoryPointer.new(:string)
151
156
  check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
152
157
  out.read_pointer.read_string
153
158
  end
154
159
 
160
+ # no way to set providers with C API yet
161
+ # so we can return all available providers
162
+ def providers
163
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
164
+ length_ptr = ::FFI::MemoryPointer.new(:int)
165
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
166
+ length = length_ptr.read_int
167
+ providers = []
168
+ length.times do |i|
169
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
170
+ end
171
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
172
+ providers
173
+ end
174
+
155
175
  private
156
176
 
157
177
  def create_input_tensor(input_feed)
158
178
  allocator_info = ::FFI::MemoryPointer.new(:pointer)
159
- check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
179
+ check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
160
180
  input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
161
181
 
162
182
  input_feed.each_with_index do |(input_name, input), idx|
163
- input = input.to_a unless input.is_a?(Array)
183
+ if numo_array?(input)
184
+ shape = input.shape
185
+ else
186
+ input = input.to_a unless input.is_a?(Array)
164
187
 
165
- shape = []
166
- s = input
167
- while s.is_a?(Array)
168
- shape << s.size
169
- s = s.first
188
+ shape = []
189
+ s = input
190
+ while s.is_a?(Array)
191
+ shape << s.size
192
+ s = s.first
193
+ end
170
194
  end
171
195
 
172
- flat_input = input.flatten
173
- input_tensor_size = flat_input.size
174
-
175
196
  # TODO support more types
176
197
  inp = @inputs.find { |i| i[:name] == input_name.to_s }
177
- raise "Unknown input: #{input_name}" unless inp
198
+ raise Error, "Unknown input: #{input_name}" unless inp
178
199
 
179
200
  input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
180
201
  input_node_dims.write_array_of_int64(shape)
181
202
 
182
203
  if inp[:type] == "tensor(string)"
183
- input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
184
- input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
204
+ if numo_array?(input)
205
+ input_tensor_size = input.size
206
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size)
207
+ input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) })
208
+ else
209
+ flat_input = input.flatten.to_a
210
+ input_tensor_size = flat_input.size
211
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
212
+ input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
213
+ end
185
214
  type_enum = FFI::TensorElementDataType[:string]
186
215
  check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
187
- check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size)
216
+ check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size)
188
217
  else
189
- tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
190
218
  tensor_type = tensor_types[inp[:type]]
191
219
 
192
220
  if tensor_type
193
- input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, input_tensor_size)
194
- if tensor_type == :bool
195
- tensor_type = :uchar
196
- flat_input = flat_input.map { |v| v ? 1 : 0 }
221
+ if numo_array?(input)
222
+ input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
223
+ else
224
+ flat_input = input.flatten.to_a
225
+ input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
226
+ if tensor_type == :bool
227
+ tensor_type = :uchar
228
+ flat_input = flat_input.map { |v| v ? 1 : 0 }
229
+ end
230
+ input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
197
231
  end
198
- input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
232
+
199
233
  type_enum = FFI::TensorElementDataType[tensor_type]
200
234
  else
201
235
  unsupported_type("input", inp[:type])
@@ -214,9 +248,9 @@ module OnnxRuntime
214
248
  ptr
215
249
  end
216
250
 
217
- def create_from_onnx_value(out_ptr)
251
+ def create_from_onnx_value(out_ptr, output_type)
218
252
  out_type = ::FFI::MemoryPointer.new(:int)
219
- check_status = api[:GetValueType].call(out_ptr, out_type)
253
+ check_status api[:GetValueType].call(out_ptr, out_type)
220
254
  type = FFI::OnnxType[out_type.read_int]
221
255
 
222
256
  case type
@@ -231,29 +265,50 @@ module OnnxRuntime
231
265
 
232
266
  out_size = ::FFI::MemoryPointer.new(:size_t)
233
267
  output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
234
- output_tensor_size = read_size_t(out_size)
268
+ output_tensor_size = out_size.read(:size_t)
269
+
270
+ release :TensorTypeAndShapeInfo, typeinfo
235
271
 
236
272
  # TODO support more types
237
273
  type = FFI::TensorElementDataType[type]
238
- arr =
274
+
275
+ case output_type
276
+ when :numo
239
277
  case type
240
- when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
241
- tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
242
- when :bool
243
- tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
278
+ when :string
279
+ result = Numo::RObject.new(shape)
280
+ result.allocate
281
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
244
282
  else
245
- unsupported_type("element", type)
283
+ numo_type = numo_types[type]
284
+ unsupported_type("element", type) unless numo_type
285
+ numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
246
286
  end
287
+ when :ruby
288
+ arr =
289
+ case type
290
+ when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
291
+ tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
292
+ when :bool
293
+ tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
294
+ when :string
295
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
296
+ else
297
+ unsupported_type("element", type)
298
+ end
247
299
 
248
- Utils.reshape(arr, shape)
300
+ Utils.reshape(arr, shape)
301
+ else
302
+ raise ArgumentError, "Invalid output type: #{output_type}"
303
+ end
249
304
  when :sequence
250
305
  out = ::FFI::MemoryPointer.new(:size_t)
251
306
  check_status api[:GetValueCount].call(out_ptr, out)
252
307
 
253
- read_size_t(out).times.map do |i|
308
+ out.read(:size_t).times.map do |i|
254
309
  seq = ::FFI::MemoryPointer.new(:pointer)
255
310
  check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
256
- create_from_onnx_value(seq.read_pointer)
311
+ create_from_onnx_value(seq.read_pointer, output_type)
257
312
  end
258
313
  when :map
259
314
  type_shape = ::FFI::MemoryPointer.new(:pointer)
@@ -265,14 +320,15 @@ module OnnxRuntime
265
320
  check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
266
321
  check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
267
322
  check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
323
+ release :TensorTypeAndShapeInfo, type_shape
268
324
 
269
325
  # TODO support more types
270
326
  elem_type = FFI::TensorElementDataType[elem_type.read_int]
271
327
  case elem_type
272
328
  when :int64
273
329
  ret = {}
274
- keys = create_from_onnx_value(map_keys.read_pointer)
275
- values = create_from_onnx_value(map_values.read_pointer)
330
+ keys = create_from_onnx_value(map_keys.read_pointer, output_type)
331
+ values = create_from_onnx_value(map_values.read_pointer, output_type)
276
332
  keys.zip(values).each do |k, v|
277
333
  ret[k] = v
278
334
  end
@@ -285,15 +341,32 @@ module OnnxRuntime
285
341
  end
286
342
  end
287
343
 
344
+ def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
345
+ len = ::FFI::MemoryPointer.new(:size_t)
346
+ check_status api[:GetStringTensorDataLength].call(out_ptr, len)
347
+
348
+ s_len = len.read(:size_t)
349
+ s = ::FFI::MemoryPointer.new(:uchar, s_len)
350
+ offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
351
+ check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
352
+
353
+ offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
354
+ offsets << s_len
355
+ output_tensor_size.times do |i|
356
+ result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
357
+ end
358
+ result
359
+ end
360
+
288
361
  def read_pointer
289
362
  @session.read_pointer
290
363
  end
291
364
 
292
365
  def check_status(status)
293
366
  unless status.null?
294
- message = api[:GetErrorMessage].call(status)
367
+ message = api[:GetErrorMessage].call(status).read_string
295
368
  api[:ReleaseStatus].call(status)
296
- raise OnnxRuntime::Error, message
369
+ raise Error, message
297
370
  end
298
371
  end
299
372
 
@@ -305,6 +378,7 @@ module OnnxRuntime
305
378
  case type
306
379
  when :tensor
307
380
  tensor_info = ::FFI::MemoryPointer.new(:pointer)
381
+ # don't free tensor_info
308
382
  check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
309
383
 
310
384
  type, shape = tensor_type_and_shape(tensor_info)
@@ -345,7 +419,7 @@ module OnnxRuntime
345
419
  unsupported_type("ONNX", type)
346
420
  end
347
421
  ensure
348
- api[:ReleaseTypeInfo].call(typeinfo.read_pointer)
422
+ release :TypeInfo, typeinfo
349
423
  end
350
424
 
351
425
  def tensor_type_and_shape(tensor_info)
@@ -354,7 +428,7 @@ module OnnxRuntime
354
428
 
355
429
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
356
430
  check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
357
- num_dims = read_size_t(num_dims_ptr)
431
+ num_dims = num_dims_ptr.read(:size_t)
358
432
 
359
433
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
360
434
  check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
@@ -363,20 +437,67 @@ module OnnxRuntime
363
437
  end
364
438
 
365
439
  def unsupported_type(name, type)
366
- raise "Unsupported #{name} type: #{type}"
440
+ raise Error, "Unsupported #{name} type: #{type}"
367
441
  end
368
442
 
369
- # read(:size_t) not supported in FFI JRuby
370
- def read_size_t(ptr)
371
- if RUBY_PLATFORM == "java"
372
- ptr.read_long
373
- else
374
- ptr.read(:size_t)
375
- end
443
+ def tensor_types
444
+ @tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
445
+ end
446
+
447
+ def numo_array?(obj)
448
+ defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
449
+ end
450
+
451
+ def numo_types
452
+ @numo_types ||= {
453
+ float: Numo::SFloat,
454
+ uint8: Numo::UInt8,
455
+ int8: Numo::Int8,
456
+ uint16: Numo::UInt16,
457
+ int16: Numo::Int16,
458
+ int32: Numo::Int32,
459
+ int64: Numo::Int64,
460
+ bool: Numo::UInt8,
461
+ double: Numo::DFloat,
462
+ uint32: Numo::UInt32,
463
+ uint64: Numo::UInt64
464
+ }
376
465
  end
377
466
 
378
467
  def api
379
- @api ||= FFI.OrtGetApiBase[:GetApi].call(1)
468
+ self.class.api
469
+ end
470
+
471
+ def release(*args)
472
+ self.class.release(*args)
473
+ end
474
+
475
+ def self.api
476
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
477
+ end
478
+
479
+ def self.release(type, pointer)
480
+ api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
481
+ end
482
+
483
+ def self.finalize(session)
484
+ # must use proc instead of stabby lambda
485
+ proc { release :Session, session }
486
+ end
487
+
488
+ # wide string on Windows
489
+ # char string on Linux
490
+ # see ORTCHAR_T in onnxruntime_c_api.h
491
+ def ort_string(str)
492
+ if Gem.win_platform?
493
+ max = str.size + 1 # for null byte
494
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
495
+ ret = FFI::Libc.mbstowcs(dest, str, max)
496
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
497
+ dest
498
+ else
499
+ str
500
+ end
380
501
  end
381
502
 
382
503
  def env
@@ -385,7 +506,7 @@ module OnnxRuntime
385
506
  @@env ||= begin
386
507
  env = ::FFI::MemoryPointer.new(:pointer)
387
508
  check_status api[:CreateEnv].call(3, "Default", env)
388
- at_exit { api[:ReleaseEnv].call(env.read_pointer) }
509
+ at_exit { release :Env, env }
389
510
  # disable telemetry
390
511
  # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
391
512
  check_status api[:DisableTelemetryEvents].call(env)