onnxruntime 0.3.3 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b7d22851572b35128d1e2bbcc041b4989851e02354cace5389afc25855674b11
4
- data.tar.gz: f1eee285e5dbff1fbf6de4e350560e3e386e1f4d26d53810f154cd16117e38e6
3
+ metadata.gz: 54260e1a83f205da2a0a016cc5b6a8508aa8969b484c22f18db966429e6007b0
4
+ data.tar.gz: 44b4310ff48bb154057adf1ddc0622980148cfda47d7b973f01ae58f4e5e7416
5
5
  SHA512:
6
- metadata.gz: f89ba13181bfcc8cdf35356efae175d0e2c7a0787a13af8d46788645046f057bce466e28566c0f2982f52893dbb7cc1553a741b1ca4923066e90c0a5fb230edf
7
- data.tar.gz: 3843b9c1e5a9432d3b72ebb33498aaef0fb6edf8942a9f0b6794bed36fb95df67290e6c4385e19927e5404510d89ad8965b15de0a96246b7f1770c4fb01216f6
6
+ metadata.gz: d61177039a5314b80342b627d5f7f1f8352be51d95dd71568d9b3ef2a5d970f9b9763023e7ca653f86689aa0ade0188c2d3a28a6f2cd3c890c43f7358952f77a
7
+ data.tar.gz: 7c2a6c4c52e93fc3f3dc4a297b921eb69606b89e1c05d40ac736dd731d8e5ab2a51c5173c2cde7b081ec46f19d88a333adf96c01cef485ba6d73943e9160c640
@@ -1,3 +1,9 @@
1
+ ## 0.4.0 (2020-07-20)
2
+
3
+ - Updated ONNX Runtime to 1.4.0
4
+ - Added `providers` method
5
+ - Fixed errors on Windows
6
+
1
7
  ## 0.3.3 (2020-06-17)
2
8
 
3
9
  - Fixed segmentation fault on exit on Linux
@@ -2,12 +2,7 @@ module OnnxRuntime
2
2
  module FFI
3
3
  extend ::FFI::Library
4
4
 
5
- begin
6
- ffi_lib Array(OnnxRuntime.ffi_lib)
7
- rescue LoadError => e
8
- raise e if ENV["ONNXRUNTIME_DEBUG"]
9
- raise LoadError, "Could not find ONNX Runtime"
10
- end
5
+ ffi_lib Array(OnnxRuntime.ffi_lib)
11
6
 
12
7
  # https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
13
8
  # keep same order
@@ -25,14 +20,14 @@ module OnnxRuntime
25
20
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
21
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
27
22
  :DisableTelemetryEvents, callback(%i[pointer], :pointer),
28
- :CreateSession, callback(%i[pointer string pointer pointer], :pointer),
23
+ :CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
29
24
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
25
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
26
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
27
+ :SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
33
28
  :CloneSessionOptions, callback(%i[], :pointer),
34
29
  :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[pointer string], :pointer),
30
+ :EnableProfiling, callback(%i[pointer pointer], :pointer),
36
31
  :DisableProfiling, callback(%i[pointer], :pointer),
37
32
  :EnableMemPattern, callback(%i[pointer], :pointer),
38
33
  :DisableMemPattern, callback(%i[pointer], :pointer),
@@ -142,7 +137,9 @@ module OnnxRuntime
142
137
  :CreateThreadingOptions, callback(%i[], :pointer),
143
138
  :ReleaseThreadingOptions, callback(%i[], :pointer),
144
139
  :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
145
- :AddFreeDimensionOverrideByName, callback(%i[], :pointer)
140
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer),
141
+ :GetAvailableProviders, callback(%i[pointer pointer], :pointer),
142
+ :ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
146
143
  end
147
144
 
148
145
  class ApiBase < ::FFI::Struct
@@ -154,5 +151,13 @@ module OnnxRuntime
154
151
  end
155
152
 
156
153
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
154
+
155
+ if Gem.win_platform?
156
+ class Libc
157
+ extend ::FFI::Library
158
+ ffi_lib ::FFI::Library::LIBC
159
+ attach_function :mbstowcs, %i[pointer string size_t], :size_t
160
+ end
161
+ end
157
162
  end
158
163
  end
@@ -8,7 +8,7 @@ module OnnxRuntime
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
9
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
10
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
- check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
12
12
  if execution_mode
13
13
  execution_modes = {sequential: 0, parallel: 1}
14
14
  mode = execution_modes[execution_mode]
@@ -17,8 +17,8 @@ module OnnxRuntime
17
17
  end
18
18
  if graph_optimization_level
19
19
  optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
- # TODO raise error in 0.4.0
21
- level = optimization_levels[graph_optimization_level] || graph_optimization_level
20
+ level = optimization_levels[graph_optimization_level]
21
+ raise ArgumentError, "Invalid graph optimization level" unless level
22
22
  check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
23
  end
24
24
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
@@ -26,7 +26,7 @@ module OnnxRuntime
26
26
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
27
27
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
28
28
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
29
- check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
30
30
 
31
31
  # session
32
32
  @session = ::FFI::MemoryPointer.new(:pointer)
@@ -39,16 +39,10 @@ module OnnxRuntime
39
39
  path_or_bytes.encoding == Encoding::BINARY
40
40
  end
41
41
 
42
- # fix for Windows "File doesn't exist"
43
- if Gem.win_platform? && !from_memory
44
- path_or_bytes = File.binread(path_or_bytes)
45
- from_memory = true
46
- end
47
-
48
42
  if from_memory
49
43
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
50
44
  else
51
- check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
45
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
52
46
  end
53
47
  ObjectSpace.define_finalizer(self, self.class.finalize(@session))
54
48
 
@@ -63,7 +57,7 @@ module OnnxRuntime
63
57
  # input
64
58
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
65
59
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
66
- read_size_t(num_input_nodes).times do |i|
60
+ num_input_nodes.read(:size_t).times do |i|
67
61
  name_ptr = ::FFI::MemoryPointer.new(:string)
68
62
  check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
69
63
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -74,7 +68,7 @@ module OnnxRuntime
74
68
  # output
75
69
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
76
70
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
77
- read_size_t(num_output_nodes).times do |i|
71
+ num_output_nodes.read(:size_t).times do |i|
78
72
  name_ptr = ::FFI::MemoryPointer.new(:string)
79
73
  check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
80
74
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -156,12 +150,28 @@ module OnnxRuntime
156
150
  release :ModelMetadata, metadata
157
151
  end
158
152
 
153
+ # return value has double underscore like Python
159
154
  def end_profiling
160
155
  out = ::FFI::MemoryPointer.new(:string)
161
156
  check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
162
157
  out.read_pointer.read_string
163
158
  end
164
159
 
160
+ # no way to set providers with C API yet
161
+ # so we can return all available providers
162
+ def providers
163
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
164
+ length_ptr = ::FFI::MemoryPointer.new(:int)
165
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
166
+ length = length_ptr.read_int
167
+ providers = []
168
+ length.times do |i|
169
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
170
+ end
171
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
172
+ providers
173
+ end
174
+
165
175
  private
166
176
 
167
177
  def create_input_tensor(input_feed)
@@ -184,7 +194,7 @@ module OnnxRuntime
184
194
 
185
195
  # TODO support more types
186
196
  inp = @inputs.find { |i| i[:name] == input_name.to_s }
187
- raise "Unknown input: #{input_name}" unless inp
197
+ raise Error, "Unknown input: #{input_name}" unless inp
188
198
 
189
199
  input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
190
200
  input_node_dims.write_array_of_int64(shape)
@@ -241,7 +251,7 @@ module OnnxRuntime
241
251
 
242
252
  out_size = ::FFI::MemoryPointer.new(:size_t)
243
253
  output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
244
- output_tensor_size = read_size_t(out_size)
254
+ output_tensor_size = out_size.read(:size_t)
245
255
 
246
256
  release :TensorTypeAndShapeInfo, typeinfo
247
257
 
@@ -262,7 +272,7 @@ module OnnxRuntime
262
272
  out = ::FFI::MemoryPointer.new(:size_t)
263
273
  check_status api[:GetValueCount].call(out_ptr, out)
264
274
 
265
- read_size_t(out).times.map do |i|
275
+ out.read(:size_t).times.map do |i|
266
276
  seq = ::FFI::MemoryPointer.new(:pointer)
267
277
  check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
268
278
  create_from_onnx_value(seq.read_pointer)
@@ -306,7 +316,7 @@ module OnnxRuntime
306
316
  unless status.null?
307
317
  message = api[:GetErrorMessage].call(status).read_string
308
318
  api[:ReleaseStatus].call(status)
309
- raise OnnxRuntime::Error, message
319
+ raise Error, message
310
320
  end
311
321
  end
312
322
 
@@ -368,7 +378,7 @@ module OnnxRuntime
368
378
 
369
379
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
370
380
  check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
371
- num_dims = read_size_t(num_dims_ptr)
381
+ num_dims = num_dims_ptr.read(:size_t)
372
382
 
373
383
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
374
384
  check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
@@ -377,16 +387,7 @@ module OnnxRuntime
377
387
  end
378
388
 
379
389
  def unsupported_type(name, type)
380
- raise "Unsupported #{name} type: #{type}"
381
- end
382
-
383
- # read(:size_t) not supported in FFI JRuby
384
- def read_size_t(ptr)
385
- if RUBY_PLATFORM == "java"
386
- ptr.read_long
387
- else
388
- ptr.read(:size_t)
389
- end
390
+ raise Error, "Unsupported #{name} type: #{type}"
390
391
  end
391
392
 
392
393
  def api
@@ -398,7 +399,7 @@ module OnnxRuntime
398
399
  end
399
400
 
400
401
  def self.api
401
- @api ||= FFI.OrtGetApiBase[:GetApi].call(3)
402
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
402
403
  end
403
404
 
404
405
  def self.release(type, pointer)
@@ -410,6 +411,21 @@ module OnnxRuntime
410
411
  proc { release :Session, session }
411
412
  end
412
413
 
414
+ # wide string on Windows
415
+ # char string on Linux
416
+ # see ORTCHAR_T in onnxruntime_c_api.h
417
+ def ort_string(str)
418
+ if Gem.win_platform?
419
+ max = str.size + 1 # for null byte
420
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
421
+ ret = FFI::Libc.mbstowcs(dest, str, max)
422
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
423
+ dest
424
+ else
425
+ str
426
+ end
427
+ end
428
+
413
429
  def env
414
430
  # use mutex for thread-safety
415
431
  Utils.mutex.synchronize do
@@ -1,3 +1,3 @@
1
1
  module OnnxRuntime
2
- VERSION = "0.3.3"
2
+ VERSION = "0.4.0"
3
3
  end
Binary file
Binary file
Binary file
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: onnxruntime
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.3
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-17 00:00:00.000000000 Z
11
+ date: 2020-07-21 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: ffi