onnxruntime 0.7.0-arm64-darwin
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +117 -0
- data/LICENSE.txt +22 -0
- data/README.md +134 -0
- data/lib/onnxruntime/datasets.rb +10 -0
- data/lib/onnxruntime/ffi.rb +171 -0
- data/lib/onnxruntime/inference_session.rb +531 -0
- data/lib/onnxruntime/model.rb +30 -0
- data/lib/onnxruntime/utils.rb +16 -0
- data/lib/onnxruntime/version.rb +3 -0
- data/lib/onnxruntime.rb +42 -0
- data/vendor/LICENSE +21 -0
- data/vendor/ThirdPartyNotices.txt +4779 -0
- data/vendor/libonnxruntime.arm64.dylib +0 -0
- metadata +69 -0
@@ -0,0 +1,531 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
class InferenceSession
|
3
|
+
attr_reader :inputs, :outputs
|
4
|
+
|
5
|
+
def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil)
|
6
|
+
# session options
|
7
|
+
session_options = ::FFI::MemoryPointer.new(:pointer)
|
8
|
+
check_status api[:CreateSessionOptions].call(session_options)
|
9
|
+
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
|
+
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
12
|
+
if execution_mode
|
13
|
+
execution_modes = {sequential: 0, parallel: 1}
|
14
|
+
mode = execution_modes[execution_mode]
|
15
|
+
raise ArgumentError, "Invalid execution mode" unless mode
|
16
|
+
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
|
17
|
+
end
|
18
|
+
if graph_optimization_level
|
19
|
+
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
22
|
+
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
|
+
end
|
24
|
+
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
25
|
+
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
|
26
|
+
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
27
|
+
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
28
|
+
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
30
|
+
|
31
|
+
# session
|
32
|
+
@session = ::FFI::MemoryPointer.new(:pointer)
|
33
|
+
from_memory =
|
34
|
+
if path_or_bytes.respond_to?(:read)
|
35
|
+
path_or_bytes = path_or_bytes.read
|
36
|
+
true
|
37
|
+
else
|
38
|
+
path_or_bytes = path_or_bytes.to_str
|
39
|
+
path_or_bytes.encoding == Encoding::BINARY
|
40
|
+
end
|
41
|
+
|
42
|
+
if from_memory
|
43
|
+
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
44
|
+
else
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
46
|
+
end
|
47
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
48
|
+
|
49
|
+
# input info
|
50
|
+
allocator = ::FFI::MemoryPointer.new(:pointer)
|
51
|
+
check_status api[:GetAllocatorWithDefaultOptions].call(allocator)
|
52
|
+
@allocator = allocator
|
53
|
+
|
54
|
+
@inputs = []
|
55
|
+
@outputs = []
|
56
|
+
|
57
|
+
# input
|
58
|
+
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
59
|
+
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
61
|
+
name_ptr = ::FFI::MemoryPointer.new(:string)
|
62
|
+
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
63
|
+
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
64
|
+
check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo)
|
65
|
+
@inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
66
|
+
end
|
67
|
+
|
68
|
+
# output
|
69
|
+
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
70
|
+
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
72
|
+
name_ptr = ::FFI::MemoryPointer.new(:string)
|
73
|
+
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
74
|
+
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
75
|
+
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
|
76
|
+
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
77
|
+
end
|
78
|
+
ensure
|
79
|
+
# release :SessionOptions, session_options
|
80
|
+
end
|
81
|
+
|
82
|
+
# TODO support logid
|
83
|
+
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
|
84
|
+
# pointer references
|
85
|
+
refs = []
|
86
|
+
|
87
|
+
input_tensor = create_input_tensor(input_feed, refs)
|
88
|
+
|
89
|
+
output_names ||= @outputs.map { |v| v[:name] }
|
90
|
+
|
91
|
+
output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
|
92
|
+
input_node_names = create_node_names(input_feed.keys.map(&:to_s), refs)
|
93
|
+
output_node_names = create_node_names(output_names.map(&:to_s), refs)
|
94
|
+
|
95
|
+
# run options
|
96
|
+
run_options = ::FFI::MemoryPointer.new(:pointer)
|
97
|
+
check_status api[:CreateRunOptions].call(run_options)
|
98
|
+
check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level
|
99
|
+
check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
100
|
+
check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid
|
101
|
+
check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate
|
102
|
+
|
103
|
+
check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
|
104
|
+
|
105
|
+
output_names.size.times.map do |i|
|
106
|
+
create_from_onnx_value(output_tensor[i].read_pointer, output_type)
|
107
|
+
end
|
108
|
+
ensure
|
109
|
+
release :RunOptions, run_options
|
110
|
+
if input_tensor
|
111
|
+
input_feed.size.times do |i|
|
112
|
+
release :Value, input_tensor[i]
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
|
117
|
+
def modelmeta
|
118
|
+
keys = ::FFI::MemoryPointer.new(:pointer)
|
119
|
+
num_keys = ::FFI::MemoryPointer.new(:int64_t)
|
120
|
+
description = ::FFI::MemoryPointer.new(:string)
|
121
|
+
domain = ::FFI::MemoryPointer.new(:string)
|
122
|
+
graph_name = ::FFI::MemoryPointer.new(:string)
|
123
|
+
producer_name = ::FFI::MemoryPointer.new(:string)
|
124
|
+
version = ::FFI::MemoryPointer.new(:int64_t)
|
125
|
+
|
126
|
+
metadata = ::FFI::MemoryPointer.new(:pointer)
|
127
|
+
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
128
|
+
|
129
|
+
custom_metadata_map = {}
|
130
|
+
check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
131
|
+
num_keys.read(:int64_t).times do |i|
|
132
|
+
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
133
|
+
value = ::FFI::MemoryPointer.new(:string)
|
134
|
+
check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
|
135
|
+
custom_metadata_map[key] = value.read_pointer.read_string
|
136
|
+
end
|
137
|
+
|
138
|
+
check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
|
139
|
+
check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
|
140
|
+
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
|
141
|
+
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
|
142
|
+
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
|
143
|
+
|
144
|
+
{
|
145
|
+
custom_metadata_map: custom_metadata_map,
|
146
|
+
description: description.read_pointer.read_string,
|
147
|
+
domain: domain.read_pointer.read_string,
|
148
|
+
graph_name: graph_name.read_pointer.read_string,
|
149
|
+
producer_name: producer_name.read_pointer.read_string,
|
150
|
+
version: version.read(:int64_t)
|
151
|
+
}
|
152
|
+
ensure
|
153
|
+
release :ModelMetadata, metadata
|
154
|
+
end
|
155
|
+
|
156
|
+
# return value has double underscore like Python
|
157
|
+
def end_profiling
|
158
|
+
out = ::FFI::MemoryPointer.new(:string)
|
159
|
+
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
160
|
+
out.read_pointer.read_string
|
161
|
+
end
|
162
|
+
|
163
|
+
# no way to set providers with C API yet
|
164
|
+
# so we can return all available providers
|
165
|
+
def providers
|
166
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
167
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
168
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
169
|
+
length = length_ptr.read_int
|
170
|
+
providers = []
|
171
|
+
length.times do |i|
|
172
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
173
|
+
end
|
174
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
175
|
+
providers
|
176
|
+
end
|
177
|
+
|
178
|
+
private
|
179
|
+
|
180
|
+
def create_input_tensor(input_feed, refs)
|
181
|
+
allocator_info = ::FFI::MemoryPointer.new(:pointer)
|
182
|
+
check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
|
183
|
+
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
|
184
|
+
|
185
|
+
input_feed.each_with_index do |(input_name, input), idx|
|
186
|
+
if numo_array?(input)
|
187
|
+
shape = input.shape
|
188
|
+
else
|
189
|
+
input = input.to_a unless input.is_a?(Array)
|
190
|
+
|
191
|
+
shape = []
|
192
|
+
s = input
|
193
|
+
while s.is_a?(Array)
|
194
|
+
shape << s.size
|
195
|
+
s = s.first
|
196
|
+
end
|
197
|
+
end
|
198
|
+
|
199
|
+
# TODO support more types
|
200
|
+
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
201
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
202
|
+
|
203
|
+
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
204
|
+
input_node_dims.write_array_of_int64(shape)
|
205
|
+
|
206
|
+
if inp[:type] == "tensor(string)"
|
207
|
+
str_ptrs =
|
208
|
+
if numo_array?(input)
|
209
|
+
input.size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) }
|
210
|
+
else
|
211
|
+
input.flatten.map { |v| ::FFI::MemoryPointer.from_string(v) }
|
212
|
+
end
|
213
|
+
|
214
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, str_ptrs.size)
|
215
|
+
input_tensor_values.write_array_of_pointer(str_ptrs)
|
216
|
+
|
217
|
+
type_enum = FFI::TensorElementDataType[:string]
|
218
|
+
check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
|
219
|
+
check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, str_ptrs.size)
|
220
|
+
|
221
|
+
refs << str_ptrs
|
222
|
+
else
|
223
|
+
tensor_type = tensor_types[inp[:type]]
|
224
|
+
|
225
|
+
if tensor_type
|
226
|
+
if numo_array?(input)
|
227
|
+
input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
|
228
|
+
else
|
229
|
+
flat_input = input.flatten.to_a
|
230
|
+
input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
|
231
|
+
if tensor_type == :bool
|
232
|
+
input_tensor_values.write_array_of_uint8(flat_input.map { |v| v ? 1 : 0 })
|
233
|
+
else
|
234
|
+
input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
|
235
|
+
end
|
236
|
+
end
|
237
|
+
|
238
|
+
type_enum = FFI::TensorElementDataType[tensor_type]
|
239
|
+
else
|
240
|
+
unsupported_type("input", inp[:type])
|
241
|
+
end
|
242
|
+
|
243
|
+
check_status api[:CreateTensorWithDataAsOrtValue].call(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx])
|
244
|
+
|
245
|
+
refs << input_node_dims
|
246
|
+
refs << input_tensor_values
|
247
|
+
end
|
248
|
+
end
|
249
|
+
|
250
|
+
refs << allocator_info
|
251
|
+
|
252
|
+
input_tensor
|
253
|
+
end
|
254
|
+
|
255
|
+
def create_node_names(names, refs)
|
256
|
+
str_ptrs = names.map { |v| ::FFI::MemoryPointer.from_string(v) }
|
257
|
+
refs << str_ptrs
|
258
|
+
|
259
|
+
ptr = ::FFI::MemoryPointer.new(:pointer, names.size)
|
260
|
+
ptr.write_array_of_pointer(str_ptrs)
|
261
|
+
ptr
|
262
|
+
end
|
263
|
+
|
264
|
+
def create_from_onnx_value(out_ptr, output_type)
|
265
|
+
out_type = ::FFI::MemoryPointer.new(:int)
|
266
|
+
check_status api[:GetValueType].call(out_ptr, out_type)
|
267
|
+
type = FFI::OnnxType[out_type.read_int]
|
268
|
+
|
269
|
+
case type
|
270
|
+
when :tensor
|
271
|
+
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
272
|
+
check_status api[:GetTensorTypeAndShape].call(out_ptr, typeinfo)
|
273
|
+
|
274
|
+
type, shape = tensor_type_and_shape(typeinfo)
|
275
|
+
|
276
|
+
tensor_data = ::FFI::MemoryPointer.new(:pointer)
|
277
|
+
check_status api[:GetTensorMutableData].call(out_ptr, tensor_data)
|
278
|
+
|
279
|
+
out_size = ::FFI::MemoryPointer.new(:size_t)
|
280
|
+
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
281
|
+
output_tensor_size = out_size.read(:size_t)
|
282
|
+
|
283
|
+
release :TensorTypeAndShapeInfo, typeinfo
|
284
|
+
|
285
|
+
# TODO support more types
|
286
|
+
type = FFI::TensorElementDataType[type]
|
287
|
+
|
288
|
+
case output_type
|
289
|
+
when :numo
|
290
|
+
case type
|
291
|
+
when :string
|
292
|
+
result = Numo::RObject.new(shape)
|
293
|
+
result.allocate
|
294
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
295
|
+
else
|
296
|
+
numo_type = numo_types[type]
|
297
|
+
unsupported_type("element", type) unless numo_type
|
298
|
+
numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
|
299
|
+
end
|
300
|
+
when :ruby
|
301
|
+
arr =
|
302
|
+
case type
|
303
|
+
when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
|
304
|
+
tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
|
305
|
+
when :bool
|
306
|
+
tensor_data.read_pointer.read_array_of_uint8(output_tensor_size).map { |v| v == 1 }
|
307
|
+
when :string
|
308
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
|
309
|
+
else
|
310
|
+
unsupported_type("element", type)
|
311
|
+
end
|
312
|
+
|
313
|
+
Utils.reshape(arr, shape)
|
314
|
+
else
|
315
|
+
raise ArgumentError, "Invalid output type: #{output_type}"
|
316
|
+
end
|
317
|
+
when :sequence
|
318
|
+
out = ::FFI::MemoryPointer.new(:size_t)
|
319
|
+
check_status api[:GetValueCount].call(out_ptr, out)
|
320
|
+
|
321
|
+
out.read(:size_t).times.map do |i|
|
322
|
+
seq = ::FFI::MemoryPointer.new(:pointer)
|
323
|
+
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
324
|
+
create_from_onnx_value(seq.read_pointer, output_type)
|
325
|
+
end
|
326
|
+
when :map
|
327
|
+
type_shape = ::FFI::MemoryPointer.new(:pointer)
|
328
|
+
map_keys = ::FFI::MemoryPointer.new(:pointer)
|
329
|
+
map_values = ::FFI::MemoryPointer.new(:pointer)
|
330
|
+
elem_type = ::FFI::MemoryPointer.new(:int)
|
331
|
+
|
332
|
+
check_status api[:GetValue].call(out_ptr, 0, @allocator.read_pointer, map_keys)
|
333
|
+
check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
|
334
|
+
check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
|
335
|
+
check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
|
336
|
+
release :TensorTypeAndShapeInfo, type_shape
|
337
|
+
|
338
|
+
# TODO support more types
|
339
|
+
elem_type = FFI::TensorElementDataType[elem_type.read_int]
|
340
|
+
case elem_type
|
341
|
+
when :int64
|
342
|
+
ret = {}
|
343
|
+
keys = create_from_onnx_value(map_keys.read_pointer, output_type)
|
344
|
+
values = create_from_onnx_value(map_values.read_pointer, output_type)
|
345
|
+
keys.zip(values).each do |k, v|
|
346
|
+
ret[k] = v
|
347
|
+
end
|
348
|
+
ret
|
349
|
+
else
|
350
|
+
unsupported_type("element", elem_type)
|
351
|
+
end
|
352
|
+
else
|
353
|
+
unsupported_type("ONNX", type)
|
354
|
+
end
|
355
|
+
end
|
356
|
+
|
357
|
+
def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
358
|
+
len = ::FFI::MemoryPointer.new(:size_t)
|
359
|
+
check_status api[:GetStringTensorDataLength].call(out_ptr, len)
|
360
|
+
|
361
|
+
s_len = len.read(:size_t)
|
362
|
+
s = ::FFI::MemoryPointer.new(:uchar, s_len)
|
363
|
+
offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
|
364
|
+
check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
|
365
|
+
|
366
|
+
offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
|
367
|
+
offsets << s_len
|
368
|
+
output_tensor_size.times do |i|
|
369
|
+
result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
|
370
|
+
end
|
371
|
+
result
|
372
|
+
end
|
373
|
+
|
374
|
+
def read_pointer
|
375
|
+
@session.read_pointer
|
376
|
+
end
|
377
|
+
|
378
|
+
def check_status(status)
|
379
|
+
unless status.null?
|
380
|
+
message = api[:GetErrorMessage].call(status).read_string
|
381
|
+
api[:ReleaseStatus].call(status)
|
382
|
+
raise Error, message
|
383
|
+
end
|
384
|
+
end
|
385
|
+
|
386
|
+
def node_info(typeinfo)
|
387
|
+
onnx_type = ::FFI::MemoryPointer.new(:int)
|
388
|
+
check_status api[:GetOnnxTypeFromTypeInfo].call(typeinfo.read_pointer, onnx_type)
|
389
|
+
|
390
|
+
type = FFI::OnnxType[onnx_type.read_int]
|
391
|
+
case type
|
392
|
+
when :tensor
|
393
|
+
tensor_info = ::FFI::MemoryPointer.new(:pointer)
|
394
|
+
# don't free tensor_info
|
395
|
+
check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
|
396
|
+
|
397
|
+
type, shape = tensor_type_and_shape(tensor_info)
|
398
|
+
{
|
399
|
+
type: "tensor(#{FFI::TensorElementDataType[type]})",
|
400
|
+
shape: shape
|
401
|
+
}
|
402
|
+
when :sequence
|
403
|
+
sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
|
404
|
+
check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
|
405
|
+
nested_type_info = ::FFI::MemoryPointer.new(:pointer)
|
406
|
+
check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
|
407
|
+
v = node_info(nested_type_info)[:type]
|
408
|
+
|
409
|
+
{
|
410
|
+
type: "seq(#{v})",
|
411
|
+
shape: []
|
412
|
+
}
|
413
|
+
when :map
|
414
|
+
map_type_info = ::FFI::MemoryPointer.new(:pointer)
|
415
|
+
check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
|
416
|
+
|
417
|
+
# key
|
418
|
+
key_type = ::FFI::MemoryPointer.new(:int)
|
419
|
+
check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
|
420
|
+
k = FFI::TensorElementDataType[key_type.read_int]
|
421
|
+
|
422
|
+
# value
|
423
|
+
value_type_info = ::FFI::MemoryPointer.new(:pointer)
|
424
|
+
check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
|
425
|
+
v = node_info(value_type_info)[:type]
|
426
|
+
|
427
|
+
{
|
428
|
+
type: "map(#{k},#{v})",
|
429
|
+
shape: []
|
430
|
+
}
|
431
|
+
else
|
432
|
+
unsupported_type("ONNX", type)
|
433
|
+
end
|
434
|
+
ensure
|
435
|
+
release :TypeInfo, typeinfo
|
436
|
+
end
|
437
|
+
|
438
|
+
def tensor_type_and_shape(tensor_info)
|
439
|
+
type = ::FFI::MemoryPointer.new(:int)
|
440
|
+
check_status api[:GetTensorElementType].call(tensor_info.read_pointer, type)
|
441
|
+
|
442
|
+
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
443
|
+
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
444
|
+
num_dims = num_dims_ptr.read(:size_t)
|
445
|
+
|
446
|
+
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
447
|
+
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
448
|
+
|
449
|
+
[type.read_int, node_dims.read_array_of_int64(num_dims)]
|
450
|
+
end
|
451
|
+
|
452
|
+
def unsupported_type(name, type)
|
453
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
454
|
+
end
|
455
|
+
|
456
|
+
def tensor_types
|
457
|
+
@tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
|
458
|
+
end
|
459
|
+
|
460
|
+
def numo_array?(obj)
|
461
|
+
defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
|
462
|
+
end
|
463
|
+
|
464
|
+
def numo_types
|
465
|
+
@numo_types ||= {
|
466
|
+
float: Numo::SFloat,
|
467
|
+
uint8: Numo::UInt8,
|
468
|
+
int8: Numo::Int8,
|
469
|
+
uint16: Numo::UInt16,
|
470
|
+
int16: Numo::Int16,
|
471
|
+
int32: Numo::Int32,
|
472
|
+
int64: Numo::Int64,
|
473
|
+
bool: Numo::UInt8,
|
474
|
+
double: Numo::DFloat,
|
475
|
+
uint32: Numo::UInt32,
|
476
|
+
uint64: Numo::UInt64
|
477
|
+
}
|
478
|
+
end
|
479
|
+
|
480
|
+
def api
|
481
|
+
self.class.api
|
482
|
+
end
|
483
|
+
|
484
|
+
def release(*args)
|
485
|
+
self.class.release(*args)
|
486
|
+
end
|
487
|
+
|
488
|
+
def self.api
|
489
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
490
|
+
end
|
491
|
+
|
492
|
+
def self.release(type, pointer)
|
493
|
+
api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
|
494
|
+
end
|
495
|
+
|
496
|
+
def self.finalize(session)
|
497
|
+
# must use proc instead of stabby lambda
|
498
|
+
proc { release :Session, session }
|
499
|
+
end
|
500
|
+
|
501
|
+
# wide string on Windows
|
502
|
+
# char string on Linux
|
503
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
504
|
+
def ort_string(str)
|
505
|
+
if Gem.win_platform?
|
506
|
+
max = str.size + 1 # for null byte
|
507
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
508
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
509
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
510
|
+
dest
|
511
|
+
else
|
512
|
+
str
|
513
|
+
end
|
514
|
+
end
|
515
|
+
|
516
|
+
def env
|
517
|
+
# use mutex for thread-safety
|
518
|
+
Utils.mutex.synchronize do
|
519
|
+
@@env ||= begin
|
520
|
+
env = ::FFI::MemoryPointer.new(:pointer)
|
521
|
+
check_status api[:CreateEnv].call(3, "Default", env)
|
522
|
+
at_exit { release :Env, env }
|
523
|
+
# disable telemetry
|
524
|
+
# https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
|
525
|
+
check_status api[:DisableTelemetryEvents].call(env)
|
526
|
+
env
|
527
|
+
end
|
528
|
+
end
|
529
|
+
end
|
530
|
+
end
|
531
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
class Model
|
3
|
+
def initialize(path_or_bytes, **session_options)
|
4
|
+
@session = InferenceSession.new(path_or_bytes, **session_options)
|
5
|
+
end
|
6
|
+
|
7
|
+
def predict(input_feed, output_names: nil, **run_options)
|
8
|
+
predictions = @session.run(output_names, input_feed, **run_options)
|
9
|
+
output_names ||= outputs.map { |o| o[:name] }
|
10
|
+
|
11
|
+
result = {}
|
12
|
+
output_names.zip(predictions).each do |k, v|
|
13
|
+
result[k.to_s] = v
|
14
|
+
end
|
15
|
+
result
|
16
|
+
end
|
17
|
+
|
18
|
+
def inputs
|
19
|
+
@session.inputs
|
20
|
+
end
|
21
|
+
|
22
|
+
def outputs
|
23
|
+
@session.outputs
|
24
|
+
end
|
25
|
+
|
26
|
+
def metadata
|
27
|
+
@session.modelmeta
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
module Utils
|
3
|
+
class << self
|
4
|
+
attr_accessor :mutex
|
5
|
+
end
|
6
|
+
self.mutex = Mutex.new
|
7
|
+
|
8
|
+
def self.reshape(arr, dims)
|
9
|
+
arr = arr.flatten
|
10
|
+
dims[1..-1].reverse.each do |dim|
|
11
|
+
arr = arr.each_slice(dim)
|
12
|
+
end
|
13
|
+
arr.to_a
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
data/lib/onnxruntime.rb
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
# dependencies
|
2
|
+
require "ffi"
|
3
|
+
|
4
|
+
# modules
|
5
|
+
require "onnxruntime/datasets"
|
6
|
+
require "onnxruntime/inference_session"
|
7
|
+
require "onnxruntime/model"
|
8
|
+
require "onnxruntime/utils"
|
9
|
+
require "onnxruntime/version"
|
10
|
+
|
11
|
+
module OnnxRuntime
|
12
|
+
class Error < StandardError; end
|
13
|
+
|
14
|
+
class << self
|
15
|
+
attr_accessor :ffi_lib
|
16
|
+
end
|
17
|
+
lib_name =
|
18
|
+
if Gem.win_platform?
|
19
|
+
"onnxruntime.dll"
|
20
|
+
elsif RbConfig::CONFIG["host_os"] =~ /darwin/i
|
21
|
+
if RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i
|
22
|
+
"libonnxruntime.arm64.dylib"
|
23
|
+
else
|
24
|
+
"libonnxruntime.dylib"
|
25
|
+
end
|
26
|
+
else
|
27
|
+
if RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i
|
28
|
+
"libonnxruntime.arm64.so"
|
29
|
+
else
|
30
|
+
"libonnxruntime.so"
|
31
|
+
end
|
32
|
+
end
|
33
|
+
vendor_lib = File.expand_path("../vendor/#{lib_name}", __dir__)
|
34
|
+
self.ffi_lib = [vendor_lib]
|
35
|
+
|
36
|
+
def self.lib_version
|
37
|
+
FFI.OrtGetApiBase[:GetVersionString].call.read_string
|
38
|
+
end
|
39
|
+
|
40
|
+
# friendlier error message
|
41
|
+
autoload :FFI, "onnxruntime/ffi"
|
42
|
+
end
|
data/vendor/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) Microsoft Corporation
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|