onnxruntime 0.7.0-x86_64-darwin
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +117 -0
- data/LICENSE.txt +22 -0
- data/README.md +134 -0
- data/lib/onnxruntime/datasets.rb +10 -0
- data/lib/onnxruntime/ffi.rb +171 -0
- data/lib/onnxruntime/inference_session.rb +531 -0
- data/lib/onnxruntime/model.rb +30 -0
- data/lib/onnxruntime/utils.rb +16 -0
- data/lib/onnxruntime/version.rb +3 -0
- data/lib/onnxruntime.rb +42 -0
- data/vendor/LICENSE +21 -0
- data/vendor/ThirdPartyNotices.txt +4779 -0
- data/vendor/libonnxruntime.dylib +0 -0
- metadata +69 -0
@@ -0,0 +1,531 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
class InferenceSession
|
3
|
+
attr_reader :inputs, :outputs
|
4
|
+
|
5
|
+
def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil)
|
6
|
+
# session options
|
7
|
+
session_options = ::FFI::MemoryPointer.new(:pointer)
|
8
|
+
check_status api[:CreateSessionOptions].call(session_options)
|
9
|
+
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
|
10
|
+
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
|
11
|
+
check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
|
12
|
+
if execution_mode
|
13
|
+
execution_modes = {sequential: 0, parallel: 1}
|
14
|
+
mode = execution_modes[execution_mode]
|
15
|
+
raise ArgumentError, "Invalid execution mode" unless mode
|
16
|
+
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
|
17
|
+
end
|
18
|
+
if graph_optimization_level
|
19
|
+
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
|
20
|
+
level = optimization_levels[graph_optimization_level]
|
21
|
+
raise ArgumentError, "Invalid graph optimization level" unless level
|
22
|
+
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
|
23
|
+
end
|
24
|
+
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
|
25
|
+
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
|
26
|
+
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
|
27
|
+
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
28
|
+
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
|
29
|
+
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
|
30
|
+
|
31
|
+
# session
|
32
|
+
@session = ::FFI::MemoryPointer.new(:pointer)
|
33
|
+
from_memory =
|
34
|
+
if path_or_bytes.respond_to?(:read)
|
35
|
+
path_or_bytes = path_or_bytes.read
|
36
|
+
true
|
37
|
+
else
|
38
|
+
path_or_bytes = path_or_bytes.to_str
|
39
|
+
path_or_bytes.encoding == Encoding::BINARY
|
40
|
+
end
|
41
|
+
|
42
|
+
if from_memory
|
43
|
+
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
|
44
|
+
else
|
45
|
+
check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
|
46
|
+
end
|
47
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
|
48
|
+
|
49
|
+
# input info
|
50
|
+
allocator = ::FFI::MemoryPointer.new(:pointer)
|
51
|
+
check_status api[:GetAllocatorWithDefaultOptions].call(allocator)
|
52
|
+
@allocator = allocator
|
53
|
+
|
54
|
+
@inputs = []
|
55
|
+
@outputs = []
|
56
|
+
|
57
|
+
# input
|
58
|
+
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
|
59
|
+
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
|
60
|
+
num_input_nodes.read(:size_t).times do |i|
|
61
|
+
name_ptr = ::FFI::MemoryPointer.new(:string)
|
62
|
+
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
|
63
|
+
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
64
|
+
check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo)
|
65
|
+
@inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
66
|
+
end
|
67
|
+
|
68
|
+
# output
|
69
|
+
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
|
70
|
+
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
|
71
|
+
num_output_nodes.read(:size_t).times do |i|
|
72
|
+
name_ptr = ::FFI::MemoryPointer.new(:string)
|
73
|
+
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
|
74
|
+
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
75
|
+
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
|
76
|
+
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
|
77
|
+
end
|
78
|
+
ensure
|
79
|
+
# release :SessionOptions, session_options
|
80
|
+
end
|
81
|
+
|
82
|
+
# TODO support logid
|
83
|
+
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
|
84
|
+
# pointer references
|
85
|
+
refs = []
|
86
|
+
|
87
|
+
input_tensor = create_input_tensor(input_feed, refs)
|
88
|
+
|
89
|
+
output_names ||= @outputs.map { |v| v[:name] }
|
90
|
+
|
91
|
+
output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
|
92
|
+
input_node_names = create_node_names(input_feed.keys.map(&:to_s), refs)
|
93
|
+
output_node_names = create_node_names(output_names.map(&:to_s), refs)
|
94
|
+
|
95
|
+
# run options
|
96
|
+
run_options = ::FFI::MemoryPointer.new(:pointer)
|
97
|
+
check_status api[:CreateRunOptions].call(run_options)
|
98
|
+
check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level
|
99
|
+
check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level
|
100
|
+
check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid
|
101
|
+
check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate
|
102
|
+
|
103
|
+
check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
|
104
|
+
|
105
|
+
output_names.size.times.map do |i|
|
106
|
+
create_from_onnx_value(output_tensor[i].read_pointer, output_type)
|
107
|
+
end
|
108
|
+
ensure
|
109
|
+
release :RunOptions, run_options
|
110
|
+
if input_tensor
|
111
|
+
input_feed.size.times do |i|
|
112
|
+
release :Value, input_tensor[i]
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
|
117
|
+
def modelmeta
|
118
|
+
keys = ::FFI::MemoryPointer.new(:pointer)
|
119
|
+
num_keys = ::FFI::MemoryPointer.new(:int64_t)
|
120
|
+
description = ::FFI::MemoryPointer.new(:string)
|
121
|
+
domain = ::FFI::MemoryPointer.new(:string)
|
122
|
+
graph_name = ::FFI::MemoryPointer.new(:string)
|
123
|
+
producer_name = ::FFI::MemoryPointer.new(:string)
|
124
|
+
version = ::FFI::MemoryPointer.new(:int64_t)
|
125
|
+
|
126
|
+
metadata = ::FFI::MemoryPointer.new(:pointer)
|
127
|
+
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
|
128
|
+
|
129
|
+
custom_metadata_map = {}
|
130
|
+
check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
|
131
|
+
num_keys.read(:int64_t).times do |i|
|
132
|
+
key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
133
|
+
value = ::FFI::MemoryPointer.new(:string)
|
134
|
+
check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
|
135
|
+
custom_metadata_map[key] = value.read_pointer.read_string
|
136
|
+
end
|
137
|
+
|
138
|
+
check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
|
139
|
+
check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
|
140
|
+
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
|
141
|
+
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
|
142
|
+
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
|
143
|
+
|
144
|
+
{
|
145
|
+
custom_metadata_map: custom_metadata_map,
|
146
|
+
description: description.read_pointer.read_string,
|
147
|
+
domain: domain.read_pointer.read_string,
|
148
|
+
graph_name: graph_name.read_pointer.read_string,
|
149
|
+
producer_name: producer_name.read_pointer.read_string,
|
150
|
+
version: version.read(:int64_t)
|
151
|
+
}
|
152
|
+
ensure
|
153
|
+
release :ModelMetadata, metadata
|
154
|
+
end
|
155
|
+
|
156
|
+
# return value has double underscore like Python
|
157
|
+
def end_profiling
|
158
|
+
out = ::FFI::MemoryPointer.new(:string)
|
159
|
+
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
|
160
|
+
out.read_pointer.read_string
|
161
|
+
end
|
162
|
+
|
163
|
+
# no way to set providers with C API yet
|
164
|
+
# so we can return all available providers
|
165
|
+
def providers
|
166
|
+
out_ptr = ::FFI::MemoryPointer.new(:pointer)
|
167
|
+
length_ptr = ::FFI::MemoryPointer.new(:int)
|
168
|
+
check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
|
169
|
+
length = length_ptr.read_int
|
170
|
+
providers = []
|
171
|
+
length.times do |i|
|
172
|
+
providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
|
173
|
+
end
|
174
|
+
api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
|
175
|
+
providers
|
176
|
+
end
|
177
|
+
|
178
|
+
private
|
179
|
+
|
180
|
+
def create_input_tensor(input_feed, refs)
|
181
|
+
allocator_info = ::FFI::MemoryPointer.new(:pointer)
|
182
|
+
check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
|
183
|
+
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
|
184
|
+
|
185
|
+
input_feed.each_with_index do |(input_name, input), idx|
|
186
|
+
if numo_array?(input)
|
187
|
+
shape = input.shape
|
188
|
+
else
|
189
|
+
input = input.to_a unless input.is_a?(Array)
|
190
|
+
|
191
|
+
shape = []
|
192
|
+
s = input
|
193
|
+
while s.is_a?(Array)
|
194
|
+
shape << s.size
|
195
|
+
s = s.first
|
196
|
+
end
|
197
|
+
end
|
198
|
+
|
199
|
+
# TODO support more types
|
200
|
+
inp = @inputs.find { |i| i[:name] == input_name.to_s }
|
201
|
+
raise Error, "Unknown input: #{input_name}" unless inp
|
202
|
+
|
203
|
+
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
|
204
|
+
input_node_dims.write_array_of_int64(shape)
|
205
|
+
|
206
|
+
if inp[:type] == "tensor(string)"
|
207
|
+
str_ptrs =
|
208
|
+
if numo_array?(input)
|
209
|
+
input.size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) }
|
210
|
+
else
|
211
|
+
input.flatten.map { |v| ::FFI::MemoryPointer.from_string(v) }
|
212
|
+
end
|
213
|
+
|
214
|
+
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, str_ptrs.size)
|
215
|
+
input_tensor_values.write_array_of_pointer(str_ptrs)
|
216
|
+
|
217
|
+
type_enum = FFI::TensorElementDataType[:string]
|
218
|
+
check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
|
219
|
+
check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, str_ptrs.size)
|
220
|
+
|
221
|
+
refs << str_ptrs
|
222
|
+
else
|
223
|
+
tensor_type = tensor_types[inp[:type]]
|
224
|
+
|
225
|
+
if tensor_type
|
226
|
+
if numo_array?(input)
|
227
|
+
input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
|
228
|
+
else
|
229
|
+
flat_input = input.flatten.to_a
|
230
|
+
input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
|
231
|
+
if tensor_type == :bool
|
232
|
+
input_tensor_values.write_array_of_uint8(flat_input.map { |v| v ? 1 : 0 })
|
233
|
+
else
|
234
|
+
input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
|
235
|
+
end
|
236
|
+
end
|
237
|
+
|
238
|
+
type_enum = FFI::TensorElementDataType[tensor_type]
|
239
|
+
else
|
240
|
+
unsupported_type("input", inp[:type])
|
241
|
+
end
|
242
|
+
|
243
|
+
check_status api[:CreateTensorWithDataAsOrtValue].call(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx])
|
244
|
+
|
245
|
+
refs << input_node_dims
|
246
|
+
refs << input_tensor_values
|
247
|
+
end
|
248
|
+
end
|
249
|
+
|
250
|
+
refs << allocator_info
|
251
|
+
|
252
|
+
input_tensor
|
253
|
+
end
|
254
|
+
|
255
|
+
def create_node_names(names, refs)
|
256
|
+
str_ptrs = names.map { |v| ::FFI::MemoryPointer.from_string(v) }
|
257
|
+
refs << str_ptrs
|
258
|
+
|
259
|
+
ptr = ::FFI::MemoryPointer.new(:pointer, names.size)
|
260
|
+
ptr.write_array_of_pointer(str_ptrs)
|
261
|
+
ptr
|
262
|
+
end
|
263
|
+
|
264
|
+
def create_from_onnx_value(out_ptr, output_type)
|
265
|
+
out_type = ::FFI::MemoryPointer.new(:int)
|
266
|
+
check_status api[:GetValueType].call(out_ptr, out_type)
|
267
|
+
type = FFI::OnnxType[out_type.read_int]
|
268
|
+
|
269
|
+
case type
|
270
|
+
when :tensor
|
271
|
+
typeinfo = ::FFI::MemoryPointer.new(:pointer)
|
272
|
+
check_status api[:GetTensorTypeAndShape].call(out_ptr, typeinfo)
|
273
|
+
|
274
|
+
type, shape = tensor_type_and_shape(typeinfo)
|
275
|
+
|
276
|
+
tensor_data = ::FFI::MemoryPointer.new(:pointer)
|
277
|
+
check_status api[:GetTensorMutableData].call(out_ptr, tensor_data)
|
278
|
+
|
279
|
+
out_size = ::FFI::MemoryPointer.new(:size_t)
|
280
|
+
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
|
281
|
+
output_tensor_size = out_size.read(:size_t)
|
282
|
+
|
283
|
+
release :TensorTypeAndShapeInfo, typeinfo
|
284
|
+
|
285
|
+
# TODO support more types
|
286
|
+
type = FFI::TensorElementDataType[type]
|
287
|
+
|
288
|
+
case output_type
|
289
|
+
when :numo
|
290
|
+
case type
|
291
|
+
when :string
|
292
|
+
result = Numo::RObject.new(shape)
|
293
|
+
result.allocate
|
294
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
295
|
+
else
|
296
|
+
numo_type = numo_types[type]
|
297
|
+
unsupported_type("element", type) unless numo_type
|
298
|
+
numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
|
299
|
+
end
|
300
|
+
when :ruby
|
301
|
+
arr =
|
302
|
+
case type
|
303
|
+
when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
|
304
|
+
tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
|
305
|
+
when :bool
|
306
|
+
tensor_data.read_pointer.read_array_of_uint8(output_tensor_size).map { |v| v == 1 }
|
307
|
+
when :string
|
308
|
+
create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
|
309
|
+
else
|
310
|
+
unsupported_type("element", type)
|
311
|
+
end
|
312
|
+
|
313
|
+
Utils.reshape(arr, shape)
|
314
|
+
else
|
315
|
+
raise ArgumentError, "Invalid output type: #{output_type}"
|
316
|
+
end
|
317
|
+
when :sequence
|
318
|
+
out = ::FFI::MemoryPointer.new(:size_t)
|
319
|
+
check_status api[:GetValueCount].call(out_ptr, out)
|
320
|
+
|
321
|
+
out.read(:size_t).times.map do |i|
|
322
|
+
seq = ::FFI::MemoryPointer.new(:pointer)
|
323
|
+
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
|
324
|
+
create_from_onnx_value(seq.read_pointer, output_type)
|
325
|
+
end
|
326
|
+
when :map
|
327
|
+
type_shape = ::FFI::MemoryPointer.new(:pointer)
|
328
|
+
map_keys = ::FFI::MemoryPointer.new(:pointer)
|
329
|
+
map_values = ::FFI::MemoryPointer.new(:pointer)
|
330
|
+
elem_type = ::FFI::MemoryPointer.new(:int)
|
331
|
+
|
332
|
+
check_status api[:GetValue].call(out_ptr, 0, @allocator.read_pointer, map_keys)
|
333
|
+
check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
|
334
|
+
check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
|
335
|
+
check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
|
336
|
+
release :TensorTypeAndShapeInfo, type_shape
|
337
|
+
|
338
|
+
# TODO support more types
|
339
|
+
elem_type = FFI::TensorElementDataType[elem_type.read_int]
|
340
|
+
case elem_type
|
341
|
+
when :int64
|
342
|
+
ret = {}
|
343
|
+
keys = create_from_onnx_value(map_keys.read_pointer, output_type)
|
344
|
+
values = create_from_onnx_value(map_values.read_pointer, output_type)
|
345
|
+
keys.zip(values).each do |k, v|
|
346
|
+
ret[k] = v
|
347
|
+
end
|
348
|
+
ret
|
349
|
+
else
|
350
|
+
unsupported_type("element", elem_type)
|
351
|
+
end
|
352
|
+
else
|
353
|
+
unsupported_type("ONNX", type)
|
354
|
+
end
|
355
|
+
end
|
356
|
+
|
357
|
+
def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
|
358
|
+
len = ::FFI::MemoryPointer.new(:size_t)
|
359
|
+
check_status api[:GetStringTensorDataLength].call(out_ptr, len)
|
360
|
+
|
361
|
+
s_len = len.read(:size_t)
|
362
|
+
s = ::FFI::MemoryPointer.new(:uchar, s_len)
|
363
|
+
offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
|
364
|
+
check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
|
365
|
+
|
366
|
+
offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
|
367
|
+
offsets << s_len
|
368
|
+
output_tensor_size.times do |i|
|
369
|
+
result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
|
370
|
+
end
|
371
|
+
result
|
372
|
+
end
|
373
|
+
|
374
|
+
def read_pointer
|
375
|
+
@session.read_pointer
|
376
|
+
end
|
377
|
+
|
378
|
+
def check_status(status)
|
379
|
+
unless status.null?
|
380
|
+
message = api[:GetErrorMessage].call(status).read_string
|
381
|
+
api[:ReleaseStatus].call(status)
|
382
|
+
raise Error, message
|
383
|
+
end
|
384
|
+
end
|
385
|
+
|
386
|
+
def node_info(typeinfo)
|
387
|
+
onnx_type = ::FFI::MemoryPointer.new(:int)
|
388
|
+
check_status api[:GetOnnxTypeFromTypeInfo].call(typeinfo.read_pointer, onnx_type)
|
389
|
+
|
390
|
+
type = FFI::OnnxType[onnx_type.read_int]
|
391
|
+
case type
|
392
|
+
when :tensor
|
393
|
+
tensor_info = ::FFI::MemoryPointer.new(:pointer)
|
394
|
+
# don't free tensor_info
|
395
|
+
check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
|
396
|
+
|
397
|
+
type, shape = tensor_type_and_shape(tensor_info)
|
398
|
+
{
|
399
|
+
type: "tensor(#{FFI::TensorElementDataType[type]})",
|
400
|
+
shape: shape
|
401
|
+
}
|
402
|
+
when :sequence
|
403
|
+
sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
|
404
|
+
check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
|
405
|
+
nested_type_info = ::FFI::MemoryPointer.new(:pointer)
|
406
|
+
check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
|
407
|
+
v = node_info(nested_type_info)[:type]
|
408
|
+
|
409
|
+
{
|
410
|
+
type: "seq(#{v})",
|
411
|
+
shape: []
|
412
|
+
}
|
413
|
+
when :map
|
414
|
+
map_type_info = ::FFI::MemoryPointer.new(:pointer)
|
415
|
+
check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
|
416
|
+
|
417
|
+
# key
|
418
|
+
key_type = ::FFI::MemoryPointer.new(:int)
|
419
|
+
check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
|
420
|
+
k = FFI::TensorElementDataType[key_type.read_int]
|
421
|
+
|
422
|
+
# value
|
423
|
+
value_type_info = ::FFI::MemoryPointer.new(:pointer)
|
424
|
+
check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
|
425
|
+
v = node_info(value_type_info)[:type]
|
426
|
+
|
427
|
+
{
|
428
|
+
type: "map(#{k},#{v})",
|
429
|
+
shape: []
|
430
|
+
}
|
431
|
+
else
|
432
|
+
unsupported_type("ONNX", type)
|
433
|
+
end
|
434
|
+
ensure
|
435
|
+
release :TypeInfo, typeinfo
|
436
|
+
end
|
437
|
+
|
438
|
+
def tensor_type_and_shape(tensor_info)
|
439
|
+
type = ::FFI::MemoryPointer.new(:int)
|
440
|
+
check_status api[:GetTensorElementType].call(tensor_info.read_pointer, type)
|
441
|
+
|
442
|
+
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
|
443
|
+
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
|
444
|
+
num_dims = num_dims_ptr.read(:size_t)
|
445
|
+
|
446
|
+
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
|
447
|
+
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
|
448
|
+
|
449
|
+
[type.read_int, node_dims.read_array_of_int64(num_dims)]
|
450
|
+
end
|
451
|
+
|
452
|
+
def unsupported_type(name, type)
|
453
|
+
raise Error, "Unsupported #{name} type: #{type}"
|
454
|
+
end
|
455
|
+
|
456
|
+
def tensor_types
|
457
|
+
@tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
|
458
|
+
end
|
459
|
+
|
460
|
+
def numo_array?(obj)
|
461
|
+
defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
|
462
|
+
end
|
463
|
+
|
464
|
+
def numo_types
|
465
|
+
@numo_types ||= {
|
466
|
+
float: Numo::SFloat,
|
467
|
+
uint8: Numo::UInt8,
|
468
|
+
int8: Numo::Int8,
|
469
|
+
uint16: Numo::UInt16,
|
470
|
+
int16: Numo::Int16,
|
471
|
+
int32: Numo::Int32,
|
472
|
+
int64: Numo::Int64,
|
473
|
+
bool: Numo::UInt8,
|
474
|
+
double: Numo::DFloat,
|
475
|
+
uint32: Numo::UInt32,
|
476
|
+
uint64: Numo::UInt64
|
477
|
+
}
|
478
|
+
end
|
479
|
+
|
480
|
+
def api
|
481
|
+
self.class.api
|
482
|
+
end
|
483
|
+
|
484
|
+
def release(*args)
|
485
|
+
self.class.release(*args)
|
486
|
+
end
|
487
|
+
|
488
|
+
def self.api
|
489
|
+
@api ||= FFI.OrtGetApiBase[:GetApi].call(4)
|
490
|
+
end
|
491
|
+
|
492
|
+
def self.release(type, pointer)
|
493
|
+
api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
|
494
|
+
end
|
495
|
+
|
496
|
+
def self.finalize(session)
|
497
|
+
# must use proc instead of stabby lambda
|
498
|
+
proc { release :Session, session }
|
499
|
+
end
|
500
|
+
|
501
|
+
# wide string on Windows
|
502
|
+
# char string on Linux
|
503
|
+
# see ORTCHAR_T in onnxruntime_c_api.h
|
504
|
+
def ort_string(str)
|
505
|
+
if Gem.win_platform?
|
506
|
+
max = str.size + 1 # for null byte
|
507
|
+
dest = ::FFI::MemoryPointer.new(:wchar_t, max)
|
508
|
+
ret = FFI::Libc.mbstowcs(dest, str, max)
|
509
|
+
raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
|
510
|
+
dest
|
511
|
+
else
|
512
|
+
str
|
513
|
+
end
|
514
|
+
end
|
515
|
+
|
516
|
+
def env
|
517
|
+
# use mutex for thread-safety
|
518
|
+
Utils.mutex.synchronize do
|
519
|
+
@@env ||= begin
|
520
|
+
env = ::FFI::MemoryPointer.new(:pointer)
|
521
|
+
check_status api[:CreateEnv].call(3, "Default", env)
|
522
|
+
at_exit { release :Env, env }
|
523
|
+
# disable telemetry
|
524
|
+
# https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
|
525
|
+
check_status api[:DisableTelemetryEvents].call(env)
|
526
|
+
env
|
527
|
+
end
|
528
|
+
end
|
529
|
+
end
|
530
|
+
end
|
531
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
class Model
|
3
|
+
def initialize(path_or_bytes, **session_options)
|
4
|
+
@session = InferenceSession.new(path_or_bytes, **session_options)
|
5
|
+
end
|
6
|
+
|
7
|
+
def predict(input_feed, output_names: nil, **run_options)
|
8
|
+
predictions = @session.run(output_names, input_feed, **run_options)
|
9
|
+
output_names ||= outputs.map { |o| o[:name] }
|
10
|
+
|
11
|
+
result = {}
|
12
|
+
output_names.zip(predictions).each do |k, v|
|
13
|
+
result[k.to_s] = v
|
14
|
+
end
|
15
|
+
result
|
16
|
+
end
|
17
|
+
|
18
|
+
def inputs
|
19
|
+
@session.inputs
|
20
|
+
end
|
21
|
+
|
22
|
+
def outputs
|
23
|
+
@session.outputs
|
24
|
+
end
|
25
|
+
|
26
|
+
def metadata
|
27
|
+
@session.modelmeta
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module OnnxRuntime
|
2
|
+
module Utils
|
3
|
+
class << self
|
4
|
+
attr_accessor :mutex
|
5
|
+
end
|
6
|
+
self.mutex = Mutex.new
|
7
|
+
|
8
|
+
def self.reshape(arr, dims)
|
9
|
+
arr = arr.flatten
|
10
|
+
dims[1..-1].reverse.each do |dim|
|
11
|
+
arr = arr.each_slice(dim)
|
12
|
+
end
|
13
|
+
arr.to_a
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
data/lib/onnxruntime.rb
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
# dependencies
|
2
|
+
require "ffi"
|
3
|
+
|
4
|
+
# modules
|
5
|
+
require "onnxruntime/datasets"
|
6
|
+
require "onnxruntime/inference_session"
|
7
|
+
require "onnxruntime/model"
|
8
|
+
require "onnxruntime/utils"
|
9
|
+
require "onnxruntime/version"
|
10
|
+
|
11
|
+
module OnnxRuntime
|
12
|
+
class Error < StandardError; end
|
13
|
+
|
14
|
+
class << self
|
15
|
+
attr_accessor :ffi_lib
|
16
|
+
end
|
17
|
+
lib_name =
|
18
|
+
if Gem.win_platform?
|
19
|
+
"onnxruntime.dll"
|
20
|
+
elsif RbConfig::CONFIG["host_os"] =~ /darwin/i
|
21
|
+
if RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i
|
22
|
+
"libonnxruntime.arm64.dylib"
|
23
|
+
else
|
24
|
+
"libonnxruntime.dylib"
|
25
|
+
end
|
26
|
+
else
|
27
|
+
if RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i
|
28
|
+
"libonnxruntime.arm64.so"
|
29
|
+
else
|
30
|
+
"libonnxruntime.so"
|
31
|
+
end
|
32
|
+
end
|
33
|
+
vendor_lib = File.expand_path("../vendor/#{lib_name}", __dir__)
|
34
|
+
self.ffi_lib = [vendor_lib]
|
35
|
+
|
36
|
+
def self.lib_version
|
37
|
+
FFI.OrtGetApiBase[:GetVersionString].call.read_string
|
38
|
+
end
|
39
|
+
|
40
|
+
# friendlier error message
|
41
|
+
autoload :FFI, "onnxruntime/ffi"
|
42
|
+
end
|
data/vendor/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) Microsoft Corporation
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|