onnxruntime 0.3.1 → 0.5.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2d4b9ddc710644eab07227fe005c92ad039643e989236e5f112a83df2266fc55
4
- data.tar.gz: d15c9ac4ae98130499d1411d3be2e7b822df8248550c0ad44c2e7a02d88f95d8
3
+ metadata.gz: 316527780be2781a474d0813aff47d840654423823ad83b8b52b13752caf6814
4
+ data.tar.gz: 5954ba2dc4223b8330fb52e8474be78542360c2f4c1cb5946529d062c0b1b864
5
5
  SHA512:
6
- metadata.gz: 83a2e3699506460b8b280376ec8935a680e7e2cffe6bf7dd6766a85eea84bf06273e13889b6a2961c9df3c277be6ea7d8afbb67a4d4be62d50b266d7ccd8dc55
7
- data.tar.gz: 5086c2c04f76ec67224e7dbc9ccc36366b394840296536232b29d92009c7a4ef637fb4f39daec8abb7a5e3fdcca591675cfaf3e4a8e31ec939f9aa1e2ae6f112
6
+ metadata.gz: c127874dd75a10b8cb9d9d033e607f9d73069bb5709d7b9ecee04c1d59f969f61db6ec88fd569188b0a35c8276f7611619f81662f3735fbd53e61938f15c6822
7
+ data.tar.gz: 77a7c9f8c98b25fd82ee8818c63176d2005f846db489b5973330220344be8fd47d8e5279517a716f8c1222f729a248fc8ac80cb610633dbdc35c49418a42ea1c
@@ -1,3 +1,31 @@
1
+ ## 0.5.1 (2020-11-01)
2
+
3
+ - Updated ONNX Runtime to 1.5.2
4
+ - Added support for string output
5
+ - Added `output_type` option
6
+ - Improved performance for Numo array inputs
7
+
8
+ ## 0.5.0 (2020-10-01)
9
+
10
+ - Updated ONNX Runtime to 1.5.1
11
+ - OpenMP is now required on Mac
12
+ - Fixed `mul_1.onnx` example
13
+
14
+ ## 0.4.0 (2020-07-20)
15
+
16
+ - Updated ONNX Runtime to 1.4.0
17
+ - Added `providers` method
18
+ - Fixed errors on Windows
19
+
20
+ ## 0.3.3 (2020-06-17)
21
+
22
+ - Fixed segmentation fault on exit on Linux
23
+
24
+ ## 0.3.2 (2020-06-16)
25
+
26
+ - Fixed error with FFI 1.13.0+
27
+ - Added friendly graph optimization levels
28
+
1
29
  ## 0.3.1 (2020-05-18)
2
30
 
3
31
  - Updated ONNX Runtime to 1.3.0
@@ -1,23 +1,22 @@
1
- Copyright (c) 2019-2020 Andrew Kane
2
- Datasets Copyright (c) Microsoft Corporation
3
-
4
1
  MIT License
5
2
 
6
- Permission is hereby granted, free of charge, to any person obtaining
7
- a copy of this software and associated documentation files (the
8
- "Software"), to deal in the Software without restriction, including
9
- without limitation the rights to use, copy, modify, merge, publish,
10
- distribute, sublicense, and/or sell copies of the Software, and to
11
- permit persons to whom the Software is furnished to do so, subject to
12
- the following conditions:
3
+ Copyright (c) 2018 Microsoft Corporation
4
+ Copyright (c) 2019-2020 Andrew Kane
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
13
12
 
14
- The above copyright notice and this permission notice shall be
15
- included in all copies or substantial portions of the Software.
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
16
15
 
17
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
data/README.md CHANGED
@@ -14,6 +14,12 @@ Add this line to your application’s Gemfile:
14
14
  gem 'onnxruntime'
15
15
  ```
16
16
 
17
+ On Mac, also install OpenMP:
18
+
19
+ ```sh
20
+ brew install libomp
21
+ ```
22
+
17
23
  ## Getting Started
18
24
 
19
25
  Load a model and make predictions
@@ -63,8 +69,8 @@ OnnxRuntime::Model.new(path_or_bytes, {
63
69
  enable_cpu_mem_arena: true,
64
70
  enable_mem_pattern: true,
65
71
  enable_profiling: false,
66
- execution_mode: :sequential,
67
- graph_optimization_level: nil,
72
+ execution_mode: :sequential, # :sequential or :parallel
73
+ graph_optimization_level: nil, # :none, :basic, :extended, or :all
68
74
  inter_op_num_threads: nil,
69
75
  intra_op_num_threads: nil,
70
76
  log_severity_level: 2,
@@ -81,7 +87,8 @@ model.predict(input_feed, {
81
87
  log_severity_level: 2,
82
88
  log_verbosity_level: 0,
83
89
  logid: nil,
84
- terminate: false
90
+ terminate: false,
91
+ output_type: :ruby # :ruby or :numo
85
92
  })
86
93
  ```
87
94
 
@@ -19,7 +19,7 @@ module OnnxRuntime
19
19
  self.ffi_lib = [vendor_lib]
20
20
 
21
21
  def self.lib_version
22
- FFI.OrtGetApiBase[:GetVersionString].call
22
+ FFI.OrtGetApiBase[:GetVersionString].call.read_string
23
23
  end
24
24
 
25
25
  # friendlier error message
@@ -3,10 +3,13 @@ module OnnxRuntime
3
3
  extend ::FFI::Library
4
4
 
5
5
  begin
6
- ffi_lib Array(OnnxRuntime.ffi_lib)
6
+ ffi_lib OnnxRuntime.ffi_lib
7
7
  rescue LoadError => e
8
- raise e if ENV["ONNXRUNTIME_DEBUG"]
9
- raise LoadError, "Could not find ONNX Runtime"
8
+ if e.message.include?("Library not loaded: /usr/local/opt/libomp/lib/libomp.dylib") && e.message.include?("Reason: image not found")
9
+ raise LoadError, "OpenMP not found. Run `brew install libomp`"
10
+ else
11
+ raise e
12
+ end
10
13
  end
11
14
 
12
15
  # https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -20,19 +23,19 @@ module OnnxRuntime
20
23
  layout \
21
24
  :CreateStatus, callback(%i[int string], :pointer),
22
25
  :GetErrorCode, callback(%i[pointer], :pointer),
23
- :GetErrorMessage, callback(%i[pointer], :string),
26
+ :GetErrorMessage, callback(%i[pointer], :pointer),
24
27
  :CreateEnv, callback(%i[int string pointer], :pointer),
25
28
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
29
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
27
30
  :DisableTelemetryEvents, callback(%i[pointer], :pointer),
28
- :CreateSession, callback(%i[pointer string pointer pointer], :pointer),
31
+ :CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
29
32
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
33
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
34
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
35
+ :SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
33
36
  :CloneSessionOptions, callback(%i[], :pointer),
34
37
  :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[pointer string], :pointer),
38
+ :EnableProfiling, callback(%i[pointer pointer], :pointer),
36
39
  :DisableProfiling, callback(%i[pointer], :pointer),
37
40
  :EnableMemPattern, callback(%i[pointer], :pointer),
38
41
  :DisableMemPattern, callback(%i[pointer], :pointer),
@@ -71,8 +74,8 @@ module OnnxRuntime
71
74
  :IsTensor, callback(%i[], :pointer),
72
75
  :GetTensorMutableData, callback(%i[pointer pointer], :pointer),
73
76
  :FillStringTensor, callback(%i[pointer pointer size_t], :pointer),
74
- :GetStringTensorDataLength, callback(%i[], :pointer),
75
- :GetStringTensorContent, callback(%i[], :pointer),
77
+ :GetStringTensorDataLength, callback(%i[pointer pointer], :pointer),
78
+ :GetStringTensorContent, callback(%i[pointer pointer size_t pointer size_t], :pointer),
76
79
  :CastTypeInfoToTensorInfo, callback(%i[pointer pointer], :pointer),
77
80
  :GetOnnxTypeFromTypeInfo, callback(%i[pointer pointer], :pointer),
78
81
  :CreateTensorTypeAndShapeInfo, callback(%i[], :pointer),
@@ -142,7 +145,9 @@ module OnnxRuntime
142
145
  :CreateThreadingOptions, callback(%i[], :pointer),
143
146
  :ReleaseThreadingOptions, callback(%i[], :pointer),
144
147
  :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
145
- :AddFreeDimensionOverrideByName, callback(%i[], :pointer)
148
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer),
149
+ :GetAvailableProviders, callback(%i[pointer pointer], :pointer),
150
+ :ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
146
151
  end
147
152
 
148
153
  class ApiBase < ::FFI::Struct
@@ -150,9 +155,17 @@ module OnnxRuntime
150
155
  # to prevent "unable to resolve type" error on Ubuntu
151
156
  layout \
152
157
  :GetApi, callback(%i[uint32], Api.by_ref),
153
- :GetVersionString, callback(%i[], :string)
158
+ :GetVersionString, callback(%i[], :pointer)
154
159
  end
155
160
 
156
161
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
162
+
163
+ if Gem.win_platform?
164
+ class Libc
165
+ extend ::FFI::Library
166
+ ffi_lib ::FFI::Library::LIBC
167
+ attach_function :mbstowcs, %i[pointer string size_t], :size_t
168
+ end
169
+ end
157
170
  end
158
171
  end
@@ -8,26 +8,25 @@ module OnnxRuntime
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
9
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
10
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
- check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
12
12
  if execution_mode
13
- mode =
14
- case execution_mode
15
- when :sequential
16
- 0
17
- when :parallel
18
- 1
19
- else
20
- raise ArgumentError, "Invalid execution mode"
21
- end
13
+ execution_modes = {sequential: 0, parallel: 1}
14
+ mode = execution_modes[execution_mode]
15
+ raise ArgumentError, "Invalid execution mode" unless mode
22
16
  check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
23
17
  end
24
- check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level
18
+ if graph_optimization_level
19
+ optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
+ level = optimization_levels[graph_optimization_level]
21
+ raise ArgumentError, "Invalid graph optimization level" unless level
22
+ check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
+ end
25
24
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
26
25
  check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
27
26
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
28
27
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
29
28
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
30
- check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
31
30
 
32
31
  # session
33
32
  @session = ::FFI::MemoryPointer.new(:pointer)
@@ -40,17 +39,12 @@ module OnnxRuntime
40
39
  path_or_bytes.encoding == Encoding::BINARY
41
40
  end
42
41
 
43
- # fix for Windows "File doesn't exist"
44
- if Gem.win_platform? && !from_memory
45
- path_or_bytes = File.binread(path_or_bytes)
46
- from_memory = true
47
- end
48
-
49
42
  if from_memory
50
43
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
51
44
  else
52
- check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
45
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
53
46
  end
47
+ ObjectSpace.define_finalizer(self, self.class.finalize(@session))
54
48
 
55
49
  # input info
56
50
  allocator = ::FFI::MemoryPointer.new(:pointer)
@@ -63,7 +57,7 @@ module OnnxRuntime
63
57
  # input
64
58
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
65
59
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
66
- read_size_t(num_input_nodes).times do |i|
60
+ num_input_nodes.read(:size_t).times do |i|
67
61
  name_ptr = ::FFI::MemoryPointer.new(:string)
68
62
  check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
69
63
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -74,17 +68,19 @@ module OnnxRuntime
74
68
  # output
75
69
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
76
70
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
77
- read_size_t(num_output_nodes).times do |i|
71
+ num_output_nodes.read(:size_t).times do |i|
78
72
  name_ptr = ::FFI::MemoryPointer.new(:string)
79
73
  check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
80
74
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
81
75
  check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
82
76
  @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
83
77
  end
78
+ ensure
79
+ # release :SessionOptions, session_options
84
80
  end
85
81
 
86
82
  # TODO support logid
87
- def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
83
+ def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
88
84
  input_tensor = create_input_tensor(input_feed)
89
85
 
90
86
  output_names ||= @outputs.map { |v| v[:name] }
@@ -104,7 +100,14 @@ module OnnxRuntime
104
100
  check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
105
101
 
106
102
  output_names.size.times.map do |i|
107
- create_from_onnx_value(output_tensor[i].read_pointer)
103
+ create_from_onnx_value(output_tensor[i].read_pointer, output_type)
104
+ end
105
+ ensure
106
+ release :RunOptions, run_options
107
+ if input_tensor
108
+ input_feed.size.times do |i|
109
+ release :Value, input_tensor[i]
110
+ end
108
111
  end
109
112
  end
110
113
 
@@ -121,7 +124,7 @@ module OnnxRuntime
121
124
  check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
122
125
 
123
126
  custom_metadata_map = {}
124
- check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
127
+ check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
125
128
  num_keys.read(:int64_t).times do |i|
126
129
  key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
127
130
  value = ::FFI::MemoryPointer.new(:string)
@@ -134,7 +137,6 @@ module OnnxRuntime
134
137
  check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
135
138
  check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
136
139
  check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
137
- api[:ReleaseModelMetadata].call(metadata.read_pointer)
138
140
 
139
141
  {
140
142
  custom_metadata_map: custom_metadata_map,
@@ -144,58 +146,90 @@ module OnnxRuntime
144
146
  producer_name: producer_name.read_pointer.read_string,
145
147
  version: version.read(:int64_t)
146
148
  }
149
+ ensure
150
+ release :ModelMetadata, metadata
147
151
  end
148
152
 
153
+ # return value has double underscore like Python
149
154
  def end_profiling
150
155
  out = ::FFI::MemoryPointer.new(:string)
151
156
  check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
152
157
  out.read_pointer.read_string
153
158
  end
154
159
 
160
+ # no way to set providers with C API yet
161
+ # so we can return all available providers
162
+ def providers
163
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
164
+ length_ptr = ::FFI::MemoryPointer.new(:int)
165
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
166
+ length = length_ptr.read_int
167
+ providers = []
168
+ length.times do |i|
169
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
170
+ end
171
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
172
+ providers
173
+ end
174
+
155
175
  private
156
176
 
157
177
  def create_input_tensor(input_feed)
158
178
  allocator_info = ::FFI::MemoryPointer.new(:pointer)
159
- check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
179
+ check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
160
180
  input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
161
181
 
162
182
  input_feed.each_with_index do |(input_name, input), idx|
163
- input = input.to_a unless input.is_a?(Array)
183
+ if numo_array?(input)
184
+ shape = input.shape
185
+ else
186
+ input = input.to_a unless input.is_a?(Array)
164
187
 
165
- shape = []
166
- s = input
167
- while s.is_a?(Array)
168
- shape << s.size
169
- s = s.first
188
+ shape = []
189
+ s = input
190
+ while s.is_a?(Array)
191
+ shape << s.size
192
+ s = s.first
193
+ end
170
194
  end
171
195
 
172
- flat_input = input.flatten
173
- input_tensor_size = flat_input.size
174
-
175
196
  # TODO support more types
176
197
  inp = @inputs.find { |i| i[:name] == input_name.to_s }
177
- raise "Unknown input: #{input_name}" unless inp
198
+ raise Error, "Unknown input: #{input_name}" unless inp
178
199
 
179
200
  input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
180
201
  input_node_dims.write_array_of_int64(shape)
181
202
 
182
203
  if inp[:type] == "tensor(string)"
183
- input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
184
- input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
204
+ if numo_array?(input)
205
+ input_tensor_size = input.size
206
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size)
207
+ input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) })
208
+ else
209
+ flat_input = input.flatten.to_a
210
+ input_tensor_size = flat_input.size
211
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
212
+ input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
213
+ end
185
214
  type_enum = FFI::TensorElementDataType[:string]
186
215
  check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
187
- check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size)
216
+ check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size)
188
217
  else
189
- tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
190
218
  tensor_type = tensor_types[inp[:type]]
191
219
 
192
220
  if tensor_type
193
- input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, input_tensor_size)
194
- if tensor_type == :bool
195
- tensor_type = :uchar
196
- flat_input = flat_input.map { |v| v ? 1 : 0 }
221
+ if numo_array?(input)
222
+ input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
223
+ else
224
+ flat_input = input.flatten.to_a
225
+ input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
226
+ if tensor_type == :bool
227
+ tensor_type = :uchar
228
+ flat_input = flat_input.map { |v| v ? 1 : 0 }
229
+ end
230
+ input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
197
231
  end
198
- input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
232
+
199
233
  type_enum = FFI::TensorElementDataType[tensor_type]
200
234
  else
201
235
  unsupported_type("input", inp[:type])
@@ -214,9 +248,9 @@ module OnnxRuntime
214
248
  ptr
215
249
  end
216
250
 
217
- def create_from_onnx_value(out_ptr)
251
+ def create_from_onnx_value(out_ptr, output_type)
218
252
  out_type = ::FFI::MemoryPointer.new(:int)
219
- check_status = api[:GetValueType].call(out_ptr, out_type)
253
+ check_status api[:GetValueType].call(out_ptr, out_type)
220
254
  type = FFI::OnnxType[out_type.read_int]
221
255
 
222
256
  case type
@@ -231,29 +265,50 @@ module OnnxRuntime
231
265
 
232
266
  out_size = ::FFI::MemoryPointer.new(:size_t)
233
267
  output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
234
- output_tensor_size = read_size_t(out_size)
268
+ output_tensor_size = out_size.read(:size_t)
269
+
270
+ release :TensorTypeAndShapeInfo, typeinfo
235
271
 
236
272
  # TODO support more types
237
273
  type = FFI::TensorElementDataType[type]
238
- arr =
274
+
275
+ case output_type
276
+ when :numo
239
277
  case type
240
- when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
241
- tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
242
- when :bool
243
- tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
278
+ when :string
279
+ result = Numo::RObject.new(shape)
280
+ result.allocate
281
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
244
282
  else
245
- unsupported_type("element", type)
283
+ numo_type = numo_types[type]
284
+ unsupported_type("element", type) unless numo_type
285
+ numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
246
286
  end
287
+ when :ruby
288
+ arr =
289
+ case type
290
+ when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
291
+ tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
292
+ when :bool
293
+ tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
294
+ when :string
295
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
296
+ else
297
+ unsupported_type("element", type)
298
+ end
247
299
 
248
- Utils.reshape(arr, shape)
300
+ Utils.reshape(arr, shape)
301
+ else
302
+ raise ArgumentError, "Invalid output type: #{output_type}"
303
+ end
249
304
  when :sequence
250
305
  out = ::FFI::MemoryPointer.new(:size_t)
251
306
  check_status api[:GetValueCount].call(out_ptr, out)
252
307
 
253
- read_size_t(out).times.map do |i|
308
+ out.read(:size_t).times.map do |i|
254
309
  seq = ::FFI::MemoryPointer.new(:pointer)
255
310
  check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
256
- create_from_onnx_value(seq.read_pointer)
311
+ create_from_onnx_value(seq.read_pointer, output_type)
257
312
  end
258
313
  when :map
259
314
  type_shape = ::FFI::MemoryPointer.new(:pointer)
@@ -265,14 +320,15 @@ module OnnxRuntime
265
320
  check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
266
321
  check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
267
322
  check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
323
+ release :TensorTypeAndShapeInfo, type_shape
268
324
 
269
325
  # TODO support more types
270
326
  elem_type = FFI::TensorElementDataType[elem_type.read_int]
271
327
  case elem_type
272
328
  when :int64
273
329
  ret = {}
274
- keys = create_from_onnx_value(map_keys.read_pointer)
275
- values = create_from_onnx_value(map_values.read_pointer)
330
+ keys = create_from_onnx_value(map_keys.read_pointer, output_type)
331
+ values = create_from_onnx_value(map_values.read_pointer, output_type)
276
332
  keys.zip(values).each do |k, v|
277
333
  ret[k] = v
278
334
  end
@@ -285,15 +341,32 @@ module OnnxRuntime
285
341
  end
286
342
  end
287
343
 
344
+ def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
345
+ len = ::FFI::MemoryPointer.new(:size_t)
346
+ check_status api[:GetStringTensorDataLength].call(out_ptr, len)
347
+
348
+ s_len = len.read(:size_t)
349
+ s = ::FFI::MemoryPointer.new(:uchar, s_len)
350
+ offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
351
+ check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
352
+
353
+ offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
354
+ offsets << s_len
355
+ output_tensor_size.times do |i|
356
+ result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
357
+ end
358
+ result
359
+ end
360
+
288
361
  def read_pointer
289
362
  @session.read_pointer
290
363
  end
291
364
 
292
365
  def check_status(status)
293
366
  unless status.null?
294
- message = api[:GetErrorMessage].call(status)
367
+ message = api[:GetErrorMessage].call(status).read_string
295
368
  api[:ReleaseStatus].call(status)
296
- raise OnnxRuntime::Error, message
369
+ raise Error, message
297
370
  end
298
371
  end
299
372
 
@@ -305,6 +378,7 @@ module OnnxRuntime
305
378
  case type
306
379
  when :tensor
307
380
  tensor_info = ::FFI::MemoryPointer.new(:pointer)
381
+ # don't free tensor_info
308
382
  check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
309
383
 
310
384
  type, shape = tensor_type_and_shape(tensor_info)
@@ -345,7 +419,7 @@ module OnnxRuntime
345
419
  unsupported_type("ONNX", type)
346
420
  end
347
421
  ensure
348
- api[:ReleaseTypeInfo].call(typeinfo.read_pointer)
422
+ release :TypeInfo, typeinfo
349
423
  end
350
424
 
351
425
  def tensor_type_and_shape(tensor_info)
@@ -354,7 +428,7 @@ module OnnxRuntime
354
428
 
355
429
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
356
430
  check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
357
- num_dims = read_size_t(num_dims_ptr)
431
+ num_dims = num_dims_ptr.read(:size_t)
358
432
 
359
433
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
360
434
  check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
@@ -363,20 +437,67 @@ module OnnxRuntime
363
437
  end
364
438
 
365
439
  def unsupported_type(name, type)
366
- raise "Unsupported #{name} type: #{type}"
440
+ raise Error, "Unsupported #{name} type: #{type}"
367
441
  end
368
442
 
369
- # read(:size_t) not supported in FFI JRuby
370
- def read_size_t(ptr)
371
- if RUBY_PLATFORM == "java"
372
- ptr.read_long
373
- else
374
- ptr.read(:size_t)
375
- end
443
+ def tensor_types
444
+ @tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
445
+ end
446
+
447
+ def numo_array?(obj)
448
+ defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
449
+ end
450
+
451
+ def numo_types
452
+ @numo_types ||= {
453
+ float: Numo::SFloat,
454
+ uint8: Numo::UInt8,
455
+ int8: Numo::Int8,
456
+ uint16: Numo::UInt16,
457
+ int16: Numo::Int16,
458
+ int32: Numo::Int32,
459
+ int64: Numo::Int64,
460
+ bool: Numo::UInt8,
461
+ double: Numo::DFloat,
462
+ uint32: Numo::UInt32,
463
+ uint64: Numo::UInt64
464
+ }
376
465
  end
377
466
 
378
467
  def api
379
- @api ||= FFI.OrtGetApiBase[:GetApi].call(1)
468
+ self.class.api
469
+ end
470
+
471
+ def release(*args)
472
+ self.class.release(*args)
473
+ end
474
+
475
+ def self.api
476
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
477
+ end
478
+
479
+ def self.release(type, pointer)
480
+ api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
481
+ end
482
+
483
+ def self.finalize(session)
484
+ # must use proc instead of stabby lambda
485
+ proc { release :Session, session }
486
+ end
487
+
488
+ # wide string on Windows
489
+ # char string on Linux
490
+ # see ORTCHAR_T in onnxruntime_c_api.h
491
+ def ort_string(str)
492
+ if Gem.win_platform?
493
+ max = str.size + 1 # for null byte
494
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
495
+ ret = FFI::Libc.mbstowcs(dest, str, max)
496
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
497
+ dest
498
+ else
499
+ str
500
+ end
380
501
  end
381
502
 
382
503
  def env
@@ -385,7 +506,7 @@ module OnnxRuntime
385
506
  @@env ||= begin
386
507
  env = ::FFI::MemoryPointer.new(:pointer)
387
508
  check_status api[:CreateEnv].call(3, "Default", env)
388
- at_exit { api[:ReleaseEnv].call(env.read_pointer) }
509
+ at_exit { release :Env, env }
389
510
  # disable telemetry
390
511
  # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
391
512
  check_status api[:DisableTelemetryEvents].call(env)