onnxruntime 0.2.3 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ec7b632678142610395254db0a898c300e23169f822d352e718076959cac20d1
4
- data.tar.gz: bdb36f003d7285fb95b8e962c168dd9d0a639811e118c67fd5d5c29179755d02
3
+ metadata.gz: 54260e1a83f205da2a0a016cc5b6a8508aa8969b484c22f18db966429e6007b0
4
+ data.tar.gz: 44b4310ff48bb154057adf1ddc0622980148cfda47d7b973f01ae58f4e5e7416
5
5
  SHA512:
6
- metadata.gz: 98bc00958bdf3f38a1c33a118850a57edfffc06d7da82ef0d2aa305a625ac1d0aae57b2c5314f10868d04a41c7de8adc981de03f3c8367d1b05c6ad308d12d79
7
- data.tar.gz: 948dec286c981fcd623e3252ef51ccf6b33a275c7019896a68869895643741d5c6f4bb534b505687acc3f717f5da5962f484c24f9ec2938821f003aa2d0e4174
6
+ metadata.gz: d61177039a5314b80342b627d5f7f1f8352be51d95dd71568d9b3ef2a5d970f9b9763023e7ca653f86689aa0ade0188c2d3a28a6f2cd3c890c43f7358952f77a
7
+ data.tar.gz: 7c2a6c4c52e93fc3f3dc4a297b921eb69606b89e1c05d40ac736dd731d8e5ab2a51c5173c2cde7b081ec46f19d88a333adf96c01cef485ba6d73943e9160c640
@@ -1,3 +1,31 @@
1
+ ## 0.4.0 (2020-07-20)
2
+
3
+ - Updated ONNX Runtime to 1.4.0
4
+ - Added `providers` method
5
+ - Fixed errors on Windows
6
+
7
+ ## 0.3.3 (2020-06-17)
8
+
9
+ - Fixed segmentation fault on exit on Linux
10
+
11
+ ## 0.3.2 (2020-06-16)
12
+
13
+ - Fixed error with FFI 1.13.0+
14
+ - Added friendly graph optimization levels
15
+
16
+ ## 0.3.1 (2020-05-18)
17
+
18
+ - Updated ONNX Runtime to 1.3.0
19
+ - Added `custom_metadata_map` to model metadata
20
+
21
+ ## 0.3.0 (2020-03-11)
22
+
23
+ - Updated ONNX Runtime to 1.2.0
24
+ - Added model metadata
25
+ - Added `end_profiling` method
26
+ - Added support for loading from IO objects
27
+ - Improved `input` and `output` for `seq` and `map` types
28
+
1
29
  ## 0.2.3 (2020-01-23)
2
30
 
3
31
  - Updated ONNX Runtime to 1.1.1
data/README.md CHANGED
@@ -37,10 +37,16 @@ Get outputs
37
37
  model.outputs
38
38
  ```
39
39
 
40
+ Get metadata
41
+
42
+ ```ruby
43
+ model.metadata
44
+ ```
45
+
40
46
  Load a model from a string
41
47
 
42
48
  ```ruby
43
- byte_str = File.binread("model.onnx")
49
+ byte_str = StringIO.new("...")
44
50
  model = OnnxRuntime::Model.new(byte_str)
45
51
  ```
46
52
 
@@ -57,8 +63,8 @@ OnnxRuntime::Model.new(path_or_bytes, {
57
63
  enable_cpu_mem_arena: true,
58
64
  enable_mem_pattern: true,
59
65
  enable_profiling: false,
60
- execution_mode: :sequential,
61
- graph_optimization_level: nil,
66
+ execution_mode: :sequential, # :sequential or :parallel
67
+ graph_optimization_level: nil, # :none, :basic, :extended, or :all
62
68
  inter_op_num_threads: nil,
63
69
  intra_op_num_threads: nil,
64
70
  log_severity_level: 2,
@@ -19,7 +19,7 @@ module OnnxRuntime
19
19
  self.ffi_lib = [vendor_lib]
20
20
 
21
21
  def self.lib_version
22
- FFI.OrtGetApiBase[:GetVersionString].call
22
+ FFI.OrtGetApiBase[:GetVersionString].call.read_string
23
23
  end
24
24
 
25
25
  # friendlier error message
@@ -2,12 +2,7 @@ module OnnxRuntime
2
2
  module FFI
3
3
  extend ::FFI::Library
4
4
 
5
- begin
6
- ffi_lib Array(OnnxRuntime.ffi_lib)
7
- rescue LoadError => e
8
- raise e if ENV["ONNXRUNTIME_DEBUG"]
9
- raise LoadError, "Could not find ONNX Runtime"
10
- end
5
+ ffi_lib Array(OnnxRuntime.ffi_lib)
11
6
 
12
7
  # https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
13
8
  # keep same order
@@ -20,19 +15,19 @@ module OnnxRuntime
20
15
  layout \
21
16
  :CreateStatus, callback(%i[int string], :pointer),
22
17
  :GetErrorCode, callback(%i[pointer], :pointer),
23
- :GetErrorMessage, callback(%i[pointer], :string),
18
+ :GetErrorMessage, callback(%i[pointer], :pointer),
24
19
  :CreateEnv, callback(%i[int string pointer], :pointer),
25
20
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
21
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
27
22
  :DisableTelemetryEvents, callback(%i[pointer], :pointer),
28
- :CreateSession, callback(%i[pointer string pointer pointer], :pointer),
23
+ :CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
29
24
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
25
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
26
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
27
+ :SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
33
28
  :CloneSessionOptions, callback(%i[], :pointer),
34
29
  :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[pointer string], :pointer),
30
+ :EnableProfiling, callback(%i[pointer pointer], :pointer),
36
31
  :DisableProfiling, callback(%i[pointer], :pointer),
37
32
  :EnableMemPattern, callback(%i[pointer], :pointer),
38
33
  :DisableMemPattern, callback(%i[pointer], :pointer),
@@ -119,7 +114,32 @@ module OnnxRuntime
119
114
  :ReleaseTypeInfo, callback(%i[pointer], :void),
120
115
  :ReleaseTensorTypeAndShapeInfo, callback(%i[pointer], :void),
121
116
  :ReleaseSessionOptions, callback(%i[pointer], :void),
122
- :ReleaseCustomOpDomain, callback(%i[pointer], :void)
117
+ :ReleaseCustomOpDomain, callback(%i[pointer], :void),
118
+ :GetDenotationFromTypeInfo, callback(%i[], :pointer),
119
+ :CastTypeInfoToMapTypeInfo, callback(%i[pointer pointer], :pointer),
120
+ :CastTypeInfoToSequenceTypeInfo, callback(%i[pointer pointer], :pointer),
121
+ :GetMapKeyType, callback(%i[pointer pointer], :pointer),
122
+ :GetMapValueType, callback(%i[pointer pointer], :pointer),
123
+ :GetSequenceElementType, callback(%i[pointer pointer], :pointer),
124
+ :ReleaseMapTypeInfo, callback(%i[pointer], :void),
125
+ :ReleaseSequenceTypeInfo, callback(%i[pointer], :void),
126
+ :SessionEndProfiling, callback(%i[pointer pointer pointer], :pointer),
127
+ :SessionGetModelMetadata, callback(%i[pointer pointer], :pointer),
128
+ :ModelMetadataGetProducerName, callback(%i[pointer pointer pointer], :pointer),
129
+ :ModelMetadataGetGraphName, callback(%i[pointer pointer pointer], :pointer),
130
+ :ModelMetadataGetDomain, callback(%i[pointer pointer pointer], :pointer),
131
+ :ModelMetadataGetDescription, callback(%i[pointer pointer pointer], :pointer),
132
+ :ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer),
133
+ :ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer),
134
+ :ReleaseModelMetadata, callback(%i[pointer], :void),
135
+ :CreateEnvWithGlobalThreadPools, callback(%i[], :pointer),
136
+ :DisablePerSessionThreads, callback(%i[], :pointer),
137
+ :CreateThreadingOptions, callback(%i[], :pointer),
138
+ :ReleaseThreadingOptions, callback(%i[], :pointer),
139
+ :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
140
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer),
141
+ :GetAvailableProviders, callback(%i[pointer pointer], :pointer),
142
+ :ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
123
143
  end
124
144
 
125
145
  class ApiBase < ::FFI::Struct
@@ -127,9 +147,17 @@ module OnnxRuntime
127
147
  # to prevent "unable to resolve type" error on Ubuntu
128
148
  layout \
129
149
  :GetApi, callback(%i[uint32], Api.by_ref),
130
- :GetVersionString, callback(%i[], :string)
150
+ :GetVersionString, callback(%i[], :pointer)
131
151
  end
132
152
 
133
153
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
154
+
155
+ if Gem.win_platform?
156
+ class Libc
157
+ extend ::FFI::Library
158
+ ffi_lib ::FFI::Library::LIBC
159
+ attach_function :mbstowcs, %i[pointer string size_t], :size_t
160
+ end
161
+ end
134
162
  end
135
163
  end
@@ -8,41 +8,43 @@ module OnnxRuntime
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
9
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
10
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
- check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
12
12
  if execution_mode
13
- mode =
14
- case execution_mode
15
- when :sequential
16
- 0
17
- when :parallel
18
- 1
19
- else
20
- raise ArgumentError, "Invalid execution mode"
21
- end
13
+ execution_modes = {sequential: 0, parallel: 1}
14
+ mode = execution_modes[execution_mode]
15
+ raise ArgumentError, "Invalid execution mode" unless mode
22
16
  check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
23
17
  end
24
- check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level
18
+ if graph_optimization_level
19
+ optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
+ level = optimization_levels[graph_optimization_level]
21
+ raise ArgumentError, "Invalid graph optimization level" unless level
22
+ check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
+ end
25
24
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
26
25
  check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
27
26
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
28
27
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
29
28
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
30
- check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
31
30
 
32
31
  # session
33
32
  @session = ::FFI::MemoryPointer.new(:pointer)
34
- path_or_bytes = path_or_bytes.to_str
35
-
36
- # fix for Windows "File doesn't exist"
37
- if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY
38
- path_or_bytes = File.binread(path_or_bytes)
39
- end
33
+ from_memory =
34
+ if path_or_bytes.respond_to?(:read)
35
+ path_or_bytes = path_or_bytes.read
36
+ true
37
+ else
38
+ path_or_bytes = path_or_bytes.to_str
39
+ path_or_bytes.encoding == Encoding::BINARY
40
+ end
40
41
 
41
- if path_or_bytes.encoding == Encoding::BINARY
42
+ if from_memory
42
43
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
43
44
  else
44
- check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
45
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
45
46
  end
47
+ ObjectSpace.define_finalizer(self, self.class.finalize(@session))
46
48
 
47
49
  # input info
48
50
  allocator = ::FFI::MemoryPointer.new(:pointer)
@@ -55,7 +57,7 @@ module OnnxRuntime
55
57
  # input
56
58
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
57
59
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
58
- read_size_t(num_input_nodes).times do |i|
60
+ num_input_nodes.read(:size_t).times do |i|
59
61
  name_ptr = ::FFI::MemoryPointer.new(:string)
60
62
  check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
61
63
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -66,13 +68,15 @@ module OnnxRuntime
66
68
  # output
67
69
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
68
70
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
69
- read_size_t(num_output_nodes).times do |i|
71
+ num_output_nodes.read(:size_t).times do |i|
70
72
  name_ptr = ::FFI::MemoryPointer.new(:string)
71
73
  check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
72
74
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
73
75
  check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
74
76
  @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
75
77
  end
78
+ ensure
79
+ # release :SessionOptions, session_options
76
80
  end
77
81
 
78
82
  # TODO support logid
@@ -98,6 +102,74 @@ module OnnxRuntime
98
102
  output_names.size.times.map do |i|
99
103
  create_from_onnx_value(output_tensor[i].read_pointer)
100
104
  end
105
+ ensure
106
+ release :RunOptions, run_options
107
+ if input_tensor
108
+ input_feed.size.times do |i|
109
+ release :Value, input_tensor[i]
110
+ end
111
+ end
112
+ end
113
+
114
+ def modelmeta
115
+ keys = ::FFI::MemoryPointer.new(:pointer)
116
+ num_keys = ::FFI::MemoryPointer.new(:int64_t)
117
+ description = ::FFI::MemoryPointer.new(:string)
118
+ domain = ::FFI::MemoryPointer.new(:string)
119
+ graph_name = ::FFI::MemoryPointer.new(:string)
120
+ producer_name = ::FFI::MemoryPointer.new(:string)
121
+ version = ::FFI::MemoryPointer.new(:int64_t)
122
+
123
+ metadata = ::FFI::MemoryPointer.new(:pointer)
124
+ check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
125
+
126
+ custom_metadata_map = {}
127
+ check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
128
+ num_keys.read(:int64_t).times do |i|
129
+ key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
130
+ value = ::FFI::MemoryPointer.new(:string)
131
+ check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
132
+ custom_metadata_map[key] = value.read_pointer.read_string
133
+ end
134
+
135
+ check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
136
+ check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
137
+ check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
138
+ check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
139
+ check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
140
+
141
+ {
142
+ custom_metadata_map: custom_metadata_map,
143
+ description: description.read_pointer.read_string,
144
+ domain: domain.read_pointer.read_string,
145
+ graph_name: graph_name.read_pointer.read_string,
146
+ producer_name: producer_name.read_pointer.read_string,
147
+ version: version.read(:int64_t)
148
+ }
149
+ ensure
150
+ release :ModelMetadata, metadata
151
+ end
152
+
153
+ # return value has double underscore like Python
154
+ def end_profiling
155
+ out = ::FFI::MemoryPointer.new(:string)
156
+ check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
157
+ out.read_pointer.read_string
158
+ end
159
+
160
+ # no way to set providers with C API yet
161
+ # so we can return all available providers
162
+ def providers
163
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
164
+ length_ptr = ::FFI::MemoryPointer.new(:int)
165
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
166
+ length = length_ptr.read_int
167
+ providers = []
168
+ length.times do |i|
169
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
170
+ end
171
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
172
+ providers
101
173
  end
102
174
 
103
175
  private
@@ -122,7 +194,7 @@ module OnnxRuntime
122
194
 
123
195
  # TODO support more types
124
196
  inp = @inputs.find { |i| i[:name] == input_name.to_s }
125
- raise "Unknown input: #{input_name}" unless inp
197
+ raise Error, "Unknown input: #{input_name}" unless inp
126
198
 
127
199
  input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
128
200
  input_node_dims.write_array_of_int64(shape)
@@ -179,7 +251,9 @@ module OnnxRuntime
179
251
 
180
252
  out_size = ::FFI::MemoryPointer.new(:size_t)
181
253
  output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
182
- output_tensor_size = read_size_t(out_size)
254
+ output_tensor_size = out_size.read(:size_t)
255
+
256
+ release :TensorTypeAndShapeInfo, typeinfo
183
257
 
184
258
  # TODO support more types
185
259
  type = FFI::TensorElementDataType[type]
@@ -198,7 +272,7 @@ module OnnxRuntime
198
272
  out = ::FFI::MemoryPointer.new(:size_t)
199
273
  check_status api[:GetValueCount].call(out_ptr, out)
200
274
 
201
- read_size_t(out).times.map do |i|
275
+ out.read(:size_t).times.map do |i|
202
276
  seq = ::FFI::MemoryPointer.new(:pointer)
203
277
  check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
204
278
  create_from_onnx_value(seq.read_pointer)
@@ -213,6 +287,7 @@ module OnnxRuntime
213
287
  check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
214
288
  check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
215
289
  check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
290
+ release :TensorTypeAndShapeInfo, type_shape
216
291
 
217
292
  # TODO support more types
218
293
  elem_type = FFI::TensorElementDataType[elem_type.read_int]
@@ -239,9 +314,9 @@ module OnnxRuntime
239
314
 
240
315
  def check_status(status)
241
316
  unless status.null?
242
- message = api[:GetErrorMessage].call(status)
317
+ message = api[:GetErrorMessage].call(status).read_string
243
318
  api[:ReleaseStatus].call(status)
244
- raise OnnxRuntime::Error, message
319
+ raise Error, message
245
320
  end
246
321
  end
247
322
 
@@ -253,6 +328,7 @@ module OnnxRuntime
253
328
  case type
254
329
  when :tensor
255
330
  tensor_info = ::FFI::MemoryPointer.new(:pointer)
331
+ # don't free tensor_info
256
332
  check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
257
333
 
258
334
  type, shape = tensor_type_and_shape(tensor_info)
@@ -261,22 +337,39 @@ module OnnxRuntime
261
337
  shape: shape
262
338
  }
263
339
  when :sequence
264
- # TODO show nested
340
+ sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
341
+ check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
342
+ nested_type_info = ::FFI::MemoryPointer.new(:pointer)
343
+ check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
344
+ v = node_info(nested_type_info)[:type]
345
+
265
346
  {
266
- type: "seq",
347
+ type: "seq(#{v})",
267
348
  shape: []
268
349
  }
269
350
  when :map
270
- # TODO show nested
351
+ map_type_info = ::FFI::MemoryPointer.new(:pointer)
352
+ check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
353
+
354
+ # key
355
+ key_type = ::FFI::MemoryPointer.new(:int)
356
+ check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
357
+ k = FFI::TensorElementDataType[key_type.read_int]
358
+
359
+ # value
360
+ value_type_info = ::FFI::MemoryPointer.new(:pointer)
361
+ check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
362
+ v = node_info(value_type_info)[:type]
363
+
271
364
  {
272
- type: "map",
365
+ type: "map(#{k},#{v})",
273
366
  shape: []
274
367
  }
275
368
  else
276
369
  unsupported_type("ONNX", type)
277
370
  end
278
371
  ensure
279
- api[:ReleaseTypeInfo].call(typeinfo.read_pointer)
372
+ release :TypeInfo, typeinfo
280
373
  end
281
374
 
282
375
  def tensor_type_and_shape(tensor_info)
@@ -285,7 +378,7 @@ module OnnxRuntime
285
378
 
286
379
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
287
380
  check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
288
- num_dims = read_size_t(num_dims_ptr)
381
+ num_dims = num_dims_ptr.read(:size_t)
289
382
 
290
383
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
291
384
  check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
@@ -294,20 +387,43 @@ module OnnxRuntime
294
387
  end
295
388
 
296
389
  def unsupported_type(name, type)
297
- raise "Unsupported #{name} type: #{type}"
390
+ raise Error, "Unsupported #{name} type: #{type}"
298
391
  end
299
392
 
300
- # read(:size_t) not supported in FFI JRuby
301
- def read_size_t(ptr)
302
- if RUBY_PLATFORM == "java"
303
- ptr.read_long
304
- else
305
- ptr.read(:size_t)
306
- end
393
+ def api
394
+ self.class.api
307
395
  end
308
396
 
309
- def api
310
- @api ||= FFI.OrtGetApiBase[:GetApi].call(1)
397
+ def release(*args)
398
+ self.class.release(*args)
399
+ end
400
+
401
+ def self.api
402
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
403
+ end
404
+
405
+ def self.release(type, pointer)
406
+ api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
407
+ end
408
+
409
+ def self.finalize(session)
410
+ # must use proc instead of stabby lambda
411
+ proc { release :Session, session }
412
+ end
413
+
414
+ # wide string on Windows
415
+ # char string on Linux
416
+ # see ORTCHAR_T in onnxruntime_c_api.h
417
+ def ort_string(str)
418
+ if Gem.win_platform?
419
+ max = str.size + 1 # for null byte
420
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
421
+ ret = FFI::Libc.mbstowcs(dest, str, max)
422
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
423
+ dest
424
+ else
425
+ str
426
+ end
311
427
  end
312
428
 
313
429
  def env
@@ -316,7 +432,7 @@ module OnnxRuntime
316
432
  @@env ||= begin
317
433
  env = ::FFI::MemoryPointer.new(:pointer)
318
434
  check_status api[:CreateEnv].call(3, "Default", env)
319
- at_exit { api[:ReleaseEnv].call(env.read_pointer) }
435
+ at_exit { release :Env, env }
320
436
  # disable telemetry
321
437
  # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
322
438
  check_status api[:DisableTelemetryEvents].call(env)
@@ -22,5 +22,9 @@ module OnnxRuntime
22
22
  def outputs
23
23
  @session.outputs
24
24
  end
25
+
26
+ def metadata
27
+ @session.modelmeta
28
+ end
25
29
  end
26
30
  end
@@ -1,3 +1,3 @@
1
1
  module OnnxRuntime
2
- VERSION = "0.2.3"
2
+ VERSION = "0.4.0"
3
3
  end
Binary file
Binary file
Binary file
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: onnxruntime
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-01-24 00:00:00.000000000 Z
11
+ date: 2020-07-21 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: ffi