onnxruntime 0.7.0-x64-mingw

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,531 @@
1
+ module OnnxRuntime
2
+ class InferenceSession
3
+ attr_reader :inputs, :outputs
4
+
5
+ def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil)
6
+ # session options
7
+ session_options = ::FFI::MemoryPointer.new(:pointer)
8
+ check_status api[:CreateSessionOptions].call(session_options)
9
+ check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
+ check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
12
+ if execution_mode
13
+ execution_modes = {sequential: 0, parallel: 1}
14
+ mode = execution_modes[execution_mode]
15
+ raise ArgumentError, "Invalid execution mode" unless mode
16
+ check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
17
+ end
18
+ if graph_optimization_level
19
+ optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
+ level = optimization_levels[graph_optimization_level]
21
+ raise ArgumentError, "Invalid graph optimization level" unless level
22
+ check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
+ end
24
+ check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
25
+ check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
26
+ check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
27
+ check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
28
+ check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
30
+
31
+ # session
32
+ @session = ::FFI::MemoryPointer.new(:pointer)
33
+ from_memory =
34
+ if path_or_bytes.respond_to?(:read)
35
+ path_or_bytes = path_or_bytes.read
36
+ true
37
+ else
38
+ path_or_bytes = path_or_bytes.to_str
39
+ path_or_bytes.encoding == Encoding::BINARY
40
+ end
41
+
42
+ if from_memory
43
+ check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
44
+ else
45
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
46
+ end
47
+ ObjectSpace.define_finalizer(self, self.class.finalize(@session))
48
+
49
+ # input info
50
+ allocator = ::FFI::MemoryPointer.new(:pointer)
51
+ check_status api[:GetAllocatorWithDefaultOptions].call(allocator)
52
+ @allocator = allocator
53
+
54
+ @inputs = []
55
+ @outputs = []
56
+
57
+ # input
58
+ num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
59
+ check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
60
+ num_input_nodes.read(:size_t).times do |i|
61
+ name_ptr = ::FFI::MemoryPointer.new(:string)
62
+ check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
63
+ typeinfo = ::FFI::MemoryPointer.new(:pointer)
64
+ check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo)
65
+ @inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
66
+ end
67
+
68
+ # output
69
+ num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
70
+ check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
71
+ num_output_nodes.read(:size_t).times do |i|
72
+ name_ptr = ::FFI::MemoryPointer.new(:string)
73
+ check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
74
+ typeinfo = ::FFI::MemoryPointer.new(:pointer)
75
+ check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
76
+ @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
77
+ end
78
+ ensure
79
+ # release :SessionOptions, session_options
80
+ end
81
+
82
+ # TODO support logid
83
+ def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
84
+ # pointer references
85
+ refs = []
86
+
87
+ input_tensor = create_input_tensor(input_feed, refs)
88
+
89
+ output_names ||= @outputs.map { |v| v[:name] }
90
+
91
+ output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
92
+ input_node_names = create_node_names(input_feed.keys.map(&:to_s), refs)
93
+ output_node_names = create_node_names(output_names.map(&:to_s), refs)
94
+
95
+ # run options
96
+ run_options = ::FFI::MemoryPointer.new(:pointer)
97
+ check_status api[:CreateRunOptions].call(run_options)
98
+ check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level
99
+ check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level
100
+ check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid
101
+ check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate
102
+
103
+ check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
104
+
105
+ output_names.size.times.map do |i|
106
+ create_from_onnx_value(output_tensor[i].read_pointer, output_type)
107
+ end
108
+ ensure
109
+ release :RunOptions, run_options
110
+ if input_tensor
111
+ input_feed.size.times do |i|
112
+ release :Value, input_tensor[i]
113
+ end
114
+ end
115
+ end
116
+
117
+ def modelmeta
118
+ keys = ::FFI::MemoryPointer.new(:pointer)
119
+ num_keys = ::FFI::MemoryPointer.new(:int64_t)
120
+ description = ::FFI::MemoryPointer.new(:string)
121
+ domain = ::FFI::MemoryPointer.new(:string)
122
+ graph_name = ::FFI::MemoryPointer.new(:string)
123
+ producer_name = ::FFI::MemoryPointer.new(:string)
124
+ version = ::FFI::MemoryPointer.new(:int64_t)
125
+
126
+ metadata = ::FFI::MemoryPointer.new(:pointer)
127
+ check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
128
+
129
+ custom_metadata_map = {}
130
+ check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
131
+ num_keys.read(:int64_t).times do |i|
132
+ key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
133
+ value = ::FFI::MemoryPointer.new(:string)
134
+ check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
135
+ custom_metadata_map[key] = value.read_pointer.read_string
136
+ end
137
+
138
+ check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
139
+ check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
140
+ check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
141
+ check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
142
+ check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
143
+
144
+ {
145
+ custom_metadata_map: custom_metadata_map,
146
+ description: description.read_pointer.read_string,
147
+ domain: domain.read_pointer.read_string,
148
+ graph_name: graph_name.read_pointer.read_string,
149
+ producer_name: producer_name.read_pointer.read_string,
150
+ version: version.read(:int64_t)
151
+ }
152
+ ensure
153
+ release :ModelMetadata, metadata
154
+ end
155
+
156
+ # return value has double underscore like Python
157
+ def end_profiling
158
+ out = ::FFI::MemoryPointer.new(:string)
159
+ check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
160
+ out.read_pointer.read_string
161
+ end
162
+
163
+ # no way to set providers with C API yet
164
+ # so we can return all available providers
165
+ def providers
166
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
167
+ length_ptr = ::FFI::MemoryPointer.new(:int)
168
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
169
+ length = length_ptr.read_int
170
+ providers = []
171
+ length.times do |i|
172
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
173
+ end
174
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
175
+ providers
176
+ end
177
+
178
+ private
179
+
180
+ def create_input_tensor(input_feed, refs)
181
+ allocator_info = ::FFI::MemoryPointer.new(:pointer)
182
+ check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
183
+ input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
184
+
185
+ input_feed.each_with_index do |(input_name, input), idx|
186
+ if numo_array?(input)
187
+ shape = input.shape
188
+ else
189
+ input = input.to_a unless input.is_a?(Array)
190
+
191
+ shape = []
192
+ s = input
193
+ while s.is_a?(Array)
194
+ shape << s.size
195
+ s = s.first
196
+ end
197
+ end
198
+
199
+ # TODO support more types
200
+ inp = @inputs.find { |i| i[:name] == input_name.to_s }
201
+ raise Error, "Unknown input: #{input_name}" unless inp
202
+
203
+ input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
204
+ input_node_dims.write_array_of_int64(shape)
205
+
206
+ if inp[:type] == "tensor(string)"
207
+ str_ptrs =
208
+ if numo_array?(input)
209
+ input.size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) }
210
+ else
211
+ input.flatten.map { |v| ::FFI::MemoryPointer.from_string(v) }
212
+ end
213
+
214
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, str_ptrs.size)
215
+ input_tensor_values.write_array_of_pointer(str_ptrs)
216
+
217
+ type_enum = FFI::TensorElementDataType[:string]
218
+ check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
219
+ check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, str_ptrs.size)
220
+
221
+ refs << str_ptrs
222
+ else
223
+ tensor_type = tensor_types[inp[:type]]
224
+
225
+ if tensor_type
226
+ if numo_array?(input)
227
+ input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
228
+ else
229
+ flat_input = input.flatten.to_a
230
+ input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
231
+ if tensor_type == :bool
232
+ input_tensor_values.write_array_of_uint8(flat_input.map { |v| v ? 1 : 0 })
233
+ else
234
+ input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
235
+ end
236
+ end
237
+
238
+ type_enum = FFI::TensorElementDataType[tensor_type]
239
+ else
240
+ unsupported_type("input", inp[:type])
241
+ end
242
+
243
+ check_status api[:CreateTensorWithDataAsOrtValue].call(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx])
244
+
245
+ refs << input_node_dims
246
+ refs << input_tensor_values
247
+ end
248
+ end
249
+
250
+ refs << allocator_info
251
+
252
+ input_tensor
253
+ end
254
+
255
+ def create_node_names(names, refs)
256
+ str_ptrs = names.map { |v| ::FFI::MemoryPointer.from_string(v) }
257
+ refs << str_ptrs
258
+
259
+ ptr = ::FFI::MemoryPointer.new(:pointer, names.size)
260
+ ptr.write_array_of_pointer(str_ptrs)
261
+ ptr
262
+ end
263
+
264
+ def create_from_onnx_value(out_ptr, output_type)
265
+ out_type = ::FFI::MemoryPointer.new(:int)
266
+ check_status api[:GetValueType].call(out_ptr, out_type)
267
+ type = FFI::OnnxType[out_type.read_int]
268
+
269
+ case type
270
+ when :tensor
271
+ typeinfo = ::FFI::MemoryPointer.new(:pointer)
272
+ check_status api[:GetTensorTypeAndShape].call(out_ptr, typeinfo)
273
+
274
+ type, shape = tensor_type_and_shape(typeinfo)
275
+
276
+ tensor_data = ::FFI::MemoryPointer.new(:pointer)
277
+ check_status api[:GetTensorMutableData].call(out_ptr, tensor_data)
278
+
279
+ out_size = ::FFI::MemoryPointer.new(:size_t)
280
+ output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
281
+ output_tensor_size = out_size.read(:size_t)
282
+
283
+ release :TensorTypeAndShapeInfo, typeinfo
284
+
285
+ # TODO support more types
286
+ type = FFI::TensorElementDataType[type]
287
+
288
+ case output_type
289
+ when :numo
290
+ case type
291
+ when :string
292
+ result = Numo::RObject.new(shape)
293
+ result.allocate
294
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
295
+ else
296
+ numo_type = numo_types[type]
297
+ unsupported_type("element", type) unless numo_type
298
+ numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
299
+ end
300
+ when :ruby
301
+ arr =
302
+ case type
303
+ when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
304
+ tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
305
+ when :bool
306
+ tensor_data.read_pointer.read_array_of_uint8(output_tensor_size).map { |v| v == 1 }
307
+ when :string
308
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
309
+ else
310
+ unsupported_type("element", type)
311
+ end
312
+
313
+ Utils.reshape(arr, shape)
314
+ else
315
+ raise ArgumentError, "Invalid output type: #{output_type}"
316
+ end
317
+ when :sequence
318
+ out = ::FFI::MemoryPointer.new(:size_t)
319
+ check_status api[:GetValueCount].call(out_ptr, out)
320
+
321
+ out.read(:size_t).times.map do |i|
322
+ seq = ::FFI::MemoryPointer.new(:pointer)
323
+ check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
324
+ create_from_onnx_value(seq.read_pointer, output_type)
325
+ end
326
+ when :map
327
+ type_shape = ::FFI::MemoryPointer.new(:pointer)
328
+ map_keys = ::FFI::MemoryPointer.new(:pointer)
329
+ map_values = ::FFI::MemoryPointer.new(:pointer)
330
+ elem_type = ::FFI::MemoryPointer.new(:int)
331
+
332
+ check_status api[:GetValue].call(out_ptr, 0, @allocator.read_pointer, map_keys)
333
+ check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
334
+ check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
335
+ check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
336
+ release :TensorTypeAndShapeInfo, type_shape
337
+
338
+ # TODO support more types
339
+ elem_type = FFI::TensorElementDataType[elem_type.read_int]
340
+ case elem_type
341
+ when :int64
342
+ ret = {}
343
+ keys = create_from_onnx_value(map_keys.read_pointer, output_type)
344
+ values = create_from_onnx_value(map_values.read_pointer, output_type)
345
+ keys.zip(values).each do |k, v|
346
+ ret[k] = v
347
+ end
348
+ ret
349
+ else
350
+ unsupported_type("element", elem_type)
351
+ end
352
+ else
353
+ unsupported_type("ONNX", type)
354
+ end
355
+ end
356
+
357
+ def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
358
+ len = ::FFI::MemoryPointer.new(:size_t)
359
+ check_status api[:GetStringTensorDataLength].call(out_ptr, len)
360
+
361
+ s_len = len.read(:size_t)
362
+ s = ::FFI::MemoryPointer.new(:uchar, s_len)
363
+ offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
364
+ check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
365
+
366
+ offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
367
+ offsets << s_len
368
+ output_tensor_size.times do |i|
369
+ result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
370
+ end
371
+ result
372
+ end
373
+
374
+ def read_pointer
375
+ @session.read_pointer
376
+ end
377
+
378
+ def check_status(status)
379
+ unless status.null?
380
+ message = api[:GetErrorMessage].call(status).read_string
381
+ api[:ReleaseStatus].call(status)
382
+ raise Error, message
383
+ end
384
+ end
385
+
386
+ def node_info(typeinfo)
387
+ onnx_type = ::FFI::MemoryPointer.new(:int)
388
+ check_status api[:GetOnnxTypeFromTypeInfo].call(typeinfo.read_pointer, onnx_type)
389
+
390
+ type = FFI::OnnxType[onnx_type.read_int]
391
+ case type
392
+ when :tensor
393
+ tensor_info = ::FFI::MemoryPointer.new(:pointer)
394
+ # don't free tensor_info
395
+ check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
396
+
397
+ type, shape = tensor_type_and_shape(tensor_info)
398
+ {
399
+ type: "tensor(#{FFI::TensorElementDataType[type]})",
400
+ shape: shape
401
+ }
402
+ when :sequence
403
+ sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
404
+ check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
405
+ nested_type_info = ::FFI::MemoryPointer.new(:pointer)
406
+ check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
407
+ v = node_info(nested_type_info)[:type]
408
+
409
+ {
410
+ type: "seq(#{v})",
411
+ shape: []
412
+ }
413
+ when :map
414
+ map_type_info = ::FFI::MemoryPointer.new(:pointer)
415
+ check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
416
+
417
+ # key
418
+ key_type = ::FFI::MemoryPointer.new(:int)
419
+ check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
420
+ k = FFI::TensorElementDataType[key_type.read_int]
421
+
422
+ # value
423
+ value_type_info = ::FFI::MemoryPointer.new(:pointer)
424
+ check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
425
+ v = node_info(value_type_info)[:type]
426
+
427
+ {
428
+ type: "map(#{k},#{v})",
429
+ shape: []
430
+ }
431
+ else
432
+ unsupported_type("ONNX", type)
433
+ end
434
+ ensure
435
+ release :TypeInfo, typeinfo
436
+ end
437
+
438
+ def tensor_type_and_shape(tensor_info)
439
+ type = ::FFI::MemoryPointer.new(:int)
440
+ check_status api[:GetTensorElementType].call(tensor_info.read_pointer, type)
441
+
442
+ num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
443
+ check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
444
+ num_dims = num_dims_ptr.read(:size_t)
445
+
446
+ node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
447
+ check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
448
+
449
+ [type.read_int, node_dims.read_array_of_int64(num_dims)]
450
+ end
451
+
452
+ def unsupported_type(name, type)
453
+ raise Error, "Unsupported #{name} type: #{type}"
454
+ end
455
+
456
+ def tensor_types
457
+ @tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
458
+ end
459
+
460
+ def numo_array?(obj)
461
+ defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
462
+ end
463
+
464
+ def numo_types
465
+ @numo_types ||= {
466
+ float: Numo::SFloat,
467
+ uint8: Numo::UInt8,
468
+ int8: Numo::Int8,
469
+ uint16: Numo::UInt16,
470
+ int16: Numo::Int16,
471
+ int32: Numo::Int32,
472
+ int64: Numo::Int64,
473
+ bool: Numo::UInt8,
474
+ double: Numo::DFloat,
475
+ uint32: Numo::UInt32,
476
+ uint64: Numo::UInt64
477
+ }
478
+ end
479
+
480
+ def api
481
+ self.class.api
482
+ end
483
+
484
+ def release(*args)
485
+ self.class.release(*args)
486
+ end
487
+
488
+ def self.api
489
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
490
+ end
491
+
492
+ def self.release(type, pointer)
493
+ api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
494
+ end
495
+
496
+ def self.finalize(session)
497
+ # must use proc instead of stabby lambda
498
+ proc { release :Session, session }
499
+ end
500
+
501
+ # wide string on Windows
502
+ # char string on Linux
503
+ # see ORTCHAR_T in onnxruntime_c_api.h
504
+ def ort_string(str)
505
+ if Gem.win_platform?
506
+ max = str.size + 1 # for null byte
507
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
508
+ ret = FFI::Libc.mbstowcs(dest, str, max)
509
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
510
+ dest
511
+ else
512
+ str
513
+ end
514
+ end
515
+
516
+ def env
517
+ # use mutex for thread-safety
518
+ Utils.mutex.synchronize do
519
+ @@env ||= begin
520
+ env = ::FFI::MemoryPointer.new(:pointer)
521
+ check_status api[:CreateEnv].call(3, "Default", env)
522
+ at_exit { release :Env, env }
523
+ # disable telemetry
524
+ # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
525
+ check_status api[:DisableTelemetryEvents].call(env)
526
+ env
527
+ end
528
+ end
529
+ end
530
+ end
531
+ end
@@ -0,0 +1,30 @@
1
+ module OnnxRuntime
2
+ class Model
3
+ def initialize(path_or_bytes, **session_options)
4
+ @session = InferenceSession.new(path_or_bytes, **session_options)
5
+ end
6
+
7
+ def predict(input_feed, output_names: nil, **run_options)
8
+ predictions = @session.run(output_names, input_feed, **run_options)
9
+ output_names ||= outputs.map { |o| o[:name] }
10
+
11
+ result = {}
12
+ output_names.zip(predictions).each do |k, v|
13
+ result[k.to_s] = v
14
+ end
15
+ result
16
+ end
17
+
18
+ def inputs
19
+ @session.inputs
20
+ end
21
+
22
+ def outputs
23
+ @session.outputs
24
+ end
25
+
26
+ def metadata
27
+ @session.modelmeta
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,16 @@
1
+ module OnnxRuntime
2
+ module Utils
3
+ class << self
4
+ attr_accessor :mutex
5
+ end
6
+ self.mutex = Mutex.new
7
+
8
+ def self.reshape(arr, dims)
9
+ arr = arr.flatten
10
+ dims[1..-1].reverse.each do |dim|
11
+ arr = arr.each_slice(dim)
12
+ end
13
+ arr.to_a
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,3 @@
1
+ module OnnxRuntime
2
+ VERSION = "0.7.0"
3
+ end
@@ -0,0 +1,42 @@
1
+ # dependencies
2
+ require "ffi"
3
+
4
+ # modules
5
+ require "onnxruntime/datasets"
6
+ require "onnxruntime/inference_session"
7
+ require "onnxruntime/model"
8
+ require "onnxruntime/utils"
9
+ require "onnxruntime/version"
10
+
11
+ module OnnxRuntime
12
+ class Error < StandardError; end
13
+
14
+ class << self
15
+ attr_accessor :ffi_lib
16
+ end
17
+ lib_name =
18
+ if Gem.win_platform?
19
+ "onnxruntime.dll"
20
+ elsif RbConfig::CONFIG["host_os"] =~ /darwin/i
21
+ if RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i
22
+ "libonnxruntime.arm64.dylib"
23
+ else
24
+ "libonnxruntime.dylib"
25
+ end
26
+ else
27
+ if RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i
28
+ "libonnxruntime.arm64.so"
29
+ else
30
+ "libonnxruntime.so"
31
+ end
32
+ end
33
+ vendor_lib = File.expand_path("../vendor/#{lib_name}", __dir__)
34
+ self.ffi_lib = [vendor_lib]
35
+
36
+ def self.lib_version
37
+ FFI.OrtGetApiBase[:GetVersionString].call.read_string
38
+ end
39
+
40
+ # friendlier error message
41
+ autoload :FFI, "onnxruntime/ffi"
42
+ end
data/vendor/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.