onnxruntime 0.3.3 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b7d22851572b35128d1e2bbcc041b4989851e02354cace5389afc25855674b11
4
- data.tar.gz: f1eee285e5dbff1fbf6de4e350560e3e386e1f4d26d53810f154cd16117e38e6
3
+ metadata.gz: 54260e1a83f205da2a0a016cc5b6a8508aa8969b484c22f18db966429e6007b0
4
+ data.tar.gz: 44b4310ff48bb154057adf1ddc0622980148cfda47d7b973f01ae58f4e5e7416
5
5
  SHA512:
6
- metadata.gz: f89ba13181bfcc8cdf35356efae175d0e2c7a0787a13af8d46788645046f057bce466e28566c0f2982f52893dbb7cc1553a741b1ca4923066e90c0a5fb230edf
7
- data.tar.gz: 3843b9c1e5a9432d3b72ebb33498aaef0fb6edf8942a9f0b6794bed36fb95df67290e6c4385e19927e5404510d89ad8965b15de0a96246b7f1770c4fb01216f6
6
+ metadata.gz: d61177039a5314b80342b627d5f7f1f8352be51d95dd71568d9b3ef2a5d970f9b9763023e7ca653f86689aa0ade0188c2d3a28a6f2cd3c890c43f7358952f77a
7
+ data.tar.gz: 7c2a6c4c52e93fc3f3dc4a297b921eb69606b89e1c05d40ac736dd731d8e5ab2a51c5173c2cde7b081ec46f19d88a333adf96c01cef485ba6d73943e9160c640
@@ -1,3 +1,9 @@
1
+ ## 0.4.0 (2020-07-20)
2
+
3
+ - Updated ONNX Runtime to 1.4.0
4
+ - Added `providers` method
5
+ - Fixed errors on Windows
6
+
1
7
  ## 0.3.3 (2020-06-17)
2
8
 
3
9
  - Fixed segmentation fault on exit on Linux
@@ -2,12 +2,7 @@ module OnnxRuntime
2
2
  module FFI
3
3
  extend ::FFI::Library
4
4
 
5
- begin
6
- ffi_lib Array(OnnxRuntime.ffi_lib)
7
- rescue LoadError => e
8
- raise e if ENV["ONNXRUNTIME_DEBUG"]
9
- raise LoadError, "Could not find ONNX Runtime"
10
- end
5
+ ffi_lib Array(OnnxRuntime.ffi_lib)
11
6
 
12
7
  # https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
13
8
  # keep same order
@@ -25,14 +20,14 @@ module OnnxRuntime
25
20
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
21
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
27
22
  :DisableTelemetryEvents, callback(%i[pointer], :pointer),
28
- :CreateSession, callback(%i[pointer string pointer pointer], :pointer),
23
+ :CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
29
24
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
25
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
26
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
27
+ :SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
33
28
  :CloneSessionOptions, callback(%i[], :pointer),
34
29
  :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[pointer string], :pointer),
30
+ :EnableProfiling, callback(%i[pointer pointer], :pointer),
36
31
  :DisableProfiling, callback(%i[pointer], :pointer),
37
32
  :EnableMemPattern, callback(%i[pointer], :pointer),
38
33
  :DisableMemPattern, callback(%i[pointer], :pointer),
@@ -142,7 +137,9 @@ module OnnxRuntime
142
137
  :CreateThreadingOptions, callback(%i[], :pointer),
143
138
  :ReleaseThreadingOptions, callback(%i[], :pointer),
144
139
  :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
145
- :AddFreeDimensionOverrideByName, callback(%i[], :pointer)
140
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer),
141
+ :GetAvailableProviders, callback(%i[pointer pointer], :pointer),
142
+ :ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
146
143
  end
147
144
 
148
145
  class ApiBase < ::FFI::Struct
@@ -154,5 +151,13 @@ module OnnxRuntime
154
151
  end
155
152
 
156
153
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
154
+
155
+ if Gem.win_platform?
156
+ class Libc
157
+ extend ::FFI::Library
158
+ ffi_lib ::FFI::Library::LIBC
159
+ attach_function :mbstowcs, %i[pointer string size_t], :size_t
160
+ end
161
+ end
157
162
  end
158
163
  end
@@ -8,7 +8,7 @@ module OnnxRuntime
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
9
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
10
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
- check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
12
12
  if execution_mode
13
13
  execution_modes = {sequential: 0, parallel: 1}
14
14
  mode = execution_modes[execution_mode]
@@ -17,8 +17,8 @@ module OnnxRuntime
17
17
  end
18
18
  if graph_optimization_level
19
19
  optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
- # TODO raise error in 0.4.0
21
- level = optimization_levels[graph_optimization_level] || graph_optimization_level
20
+ level = optimization_levels[graph_optimization_level]
21
+ raise ArgumentError, "Invalid graph optimization level" unless level
22
22
  check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
23
  end
24
24
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
@@ -26,7 +26,7 @@ module OnnxRuntime
26
26
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
27
27
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
28
28
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
29
- check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
30
30
 
31
31
  # session
32
32
  @session = ::FFI::MemoryPointer.new(:pointer)
@@ -39,16 +39,10 @@ module OnnxRuntime
39
39
  path_or_bytes.encoding == Encoding::BINARY
40
40
  end
41
41
 
42
- # fix for Windows "File doesn't exist"
43
- if Gem.win_platform? && !from_memory
44
- path_or_bytes = File.binread(path_or_bytes)
45
- from_memory = true
46
- end
47
-
48
42
  if from_memory
49
43
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
50
44
  else
51
- check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
45
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
52
46
  end
53
47
  ObjectSpace.define_finalizer(self, self.class.finalize(@session))
54
48
 
@@ -63,7 +57,7 @@ module OnnxRuntime
63
57
  # input
64
58
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
65
59
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
66
- read_size_t(num_input_nodes).times do |i|
60
+ num_input_nodes.read(:size_t).times do |i|
67
61
  name_ptr = ::FFI::MemoryPointer.new(:string)
68
62
  check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
69
63
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -74,7 +68,7 @@ module OnnxRuntime
74
68
  # output
75
69
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
76
70
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
77
- read_size_t(num_output_nodes).times do |i|
71
+ num_output_nodes.read(:size_t).times do |i|
78
72
  name_ptr = ::FFI::MemoryPointer.new(:string)
79
73
  check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
80
74
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -156,12 +150,28 @@ module OnnxRuntime
156
150
  release :ModelMetadata, metadata
157
151
  end
158
152
 
153
+ # return value has double underscore like Python
159
154
  def end_profiling
160
155
  out = ::FFI::MemoryPointer.new(:string)
161
156
  check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
162
157
  out.read_pointer.read_string
163
158
  end
164
159
 
160
+ # no way to set providers with C API yet
161
+ # so we can return all available providers
162
+ def providers
163
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
164
+ length_ptr = ::FFI::MemoryPointer.new(:int)
165
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
166
+ length = length_ptr.read_int
167
+ providers = []
168
+ length.times do |i|
169
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
170
+ end
171
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
172
+ providers
173
+ end
174
+
165
175
  private
166
176
 
167
177
  def create_input_tensor(input_feed)
@@ -184,7 +194,7 @@ module OnnxRuntime
184
194
 
185
195
  # TODO support more types
186
196
  inp = @inputs.find { |i| i[:name] == input_name.to_s }
187
- raise "Unknown input: #{input_name}" unless inp
197
+ raise Error, "Unknown input: #{input_name}" unless inp
188
198
 
189
199
  input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
190
200
  input_node_dims.write_array_of_int64(shape)
@@ -241,7 +251,7 @@ module OnnxRuntime
241
251
 
242
252
  out_size = ::FFI::MemoryPointer.new(:size_t)
243
253
  output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
244
- output_tensor_size = read_size_t(out_size)
254
+ output_tensor_size = out_size.read(:size_t)
245
255
 
246
256
  release :TensorTypeAndShapeInfo, typeinfo
247
257
 
@@ -262,7 +272,7 @@ module OnnxRuntime
262
272
  out = ::FFI::MemoryPointer.new(:size_t)
263
273
  check_status api[:GetValueCount].call(out_ptr, out)
264
274
 
265
- read_size_t(out).times.map do |i|
275
+ out.read(:size_t).times.map do |i|
266
276
  seq = ::FFI::MemoryPointer.new(:pointer)
267
277
  check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
268
278
  create_from_onnx_value(seq.read_pointer)
@@ -306,7 +316,7 @@ module OnnxRuntime
306
316
  unless status.null?
307
317
  message = api[:GetErrorMessage].call(status).read_string
308
318
  api[:ReleaseStatus].call(status)
309
- raise OnnxRuntime::Error, message
319
+ raise Error, message
310
320
  end
311
321
  end
312
322
 
@@ -368,7 +378,7 @@ module OnnxRuntime
368
378
 
369
379
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
370
380
  check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
371
- num_dims = read_size_t(num_dims_ptr)
381
+ num_dims = num_dims_ptr.read(:size_t)
372
382
 
373
383
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
374
384
  check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
@@ -377,16 +387,7 @@ module OnnxRuntime
377
387
  end
378
388
 
379
389
  def unsupported_type(name, type)
380
- raise "Unsupported #{name} type: #{type}"
381
- end
382
-
383
- # read(:size_t) not supported in FFI JRuby
384
- def read_size_t(ptr)
385
- if RUBY_PLATFORM == "java"
386
- ptr.read_long
387
- else
388
- ptr.read(:size_t)
389
- end
390
+ raise Error, "Unsupported #{name} type: #{type}"
390
391
  end
391
392
 
392
393
  def api
@@ -398,7 +399,7 @@ module OnnxRuntime
398
399
  end
399
400
 
400
401
  def self.api
401
- @api ||= FFI.OrtGetApiBase[:GetApi].call(3)
402
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
402
403
  end
403
404
 
404
405
  def self.release(type, pointer)
@@ -410,6 +411,21 @@ module OnnxRuntime
410
411
  proc { release :Session, session }
411
412
  end
412
413
 
414
+ # wide string on Windows
415
+ # char string on Linux
416
+ # see ORTCHAR_T in onnxruntime_c_api.h
417
+ def ort_string(str)
418
+ if Gem.win_platform?
419
+ max = str.size + 1 # for null byte
420
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
421
+ ret = FFI::Libc.mbstowcs(dest, str, max)
422
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
423
+ dest
424
+ else
425
+ str
426
+ end
427
+ end
428
+
413
429
  def env
414
430
  # use mutex for thread-safety
415
431
  Utils.mutex.synchronize do
@@ -1,3 +1,3 @@
1
1
  module OnnxRuntime
2
- VERSION = "0.3.3"
2
+ VERSION = "0.4.0"
3
3
  end
Binary file
Binary file
Binary file
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: onnxruntime
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.3
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-17 00:00:00.000000000 Z
11
+ date: 2020-07-21 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: ffi