onnxruntime 0.2.3 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ec7b632678142610395254db0a898c300e23169f822d352e718076959cac20d1
4
- data.tar.gz: bdb36f003d7285fb95b8e962c168dd9d0a639811e118c67fd5d5c29179755d02
3
+ metadata.gz: 54260e1a83f205da2a0a016cc5b6a8508aa8969b484c22f18db966429e6007b0
4
+ data.tar.gz: 44b4310ff48bb154057adf1ddc0622980148cfda47d7b973f01ae58f4e5e7416
5
5
  SHA512:
6
- metadata.gz: 98bc00958bdf3f38a1c33a118850a57edfffc06d7da82ef0d2aa305a625ac1d0aae57b2c5314f10868d04a41c7de8adc981de03f3c8367d1b05c6ad308d12d79
7
- data.tar.gz: 948dec286c981fcd623e3252ef51ccf6b33a275c7019896a68869895643741d5c6f4bb534b505687acc3f717f5da5962f484c24f9ec2938821f003aa2d0e4174
6
+ metadata.gz: d61177039a5314b80342b627d5f7f1f8352be51d95dd71568d9b3ef2a5d970f9b9763023e7ca653f86689aa0ade0188c2d3a28a6f2cd3c890c43f7358952f77a
7
+ data.tar.gz: 7c2a6c4c52e93fc3f3dc4a297b921eb69606b89e1c05d40ac736dd731d8e5ab2a51c5173c2cde7b081ec46f19d88a333adf96c01cef485ba6d73943e9160c640
@@ -1,3 +1,31 @@
1
+ ## 0.4.0 (2020-07-20)
2
+
3
+ - Updated ONNX Runtime to 1.4.0
4
+ - Added `providers` method
5
+ - Fixed errors on Windows
6
+
7
+ ## 0.3.3 (2020-06-17)
8
+
9
+ - Fixed segmentation fault on exit on Linux
10
+
11
+ ## 0.3.2 (2020-06-16)
12
+
13
+ - Fixed error with FFI 1.13.0+
14
+ - Added friendly graph optimization levels
15
+
16
+ ## 0.3.1 (2020-05-18)
17
+
18
+ - Updated ONNX Runtime to 1.3.0
19
+ - Added `custom_metadata_map` to model metadata
20
+
21
+ ## 0.3.0 (2020-03-11)
22
+
23
+ - Updated ONNX Runtime to 1.2.0
24
+ - Added model metadata
25
+ - Added `end_profiling` method
26
+ - Added support for loading from IO objects
27
+ - Improved `input` and `output` for `seq` and `map` types
28
+
1
29
  ## 0.2.3 (2020-01-23)
2
30
 
3
31
  - Updated ONNX Runtime to 1.1.1
data/README.md CHANGED
@@ -37,10 +37,16 @@ Get outputs
37
37
  model.outputs
38
38
  ```
39
39
 
40
+ Get metadata
41
+
42
+ ```ruby
43
+ model.metadata
44
+ ```
45
+
40
46
  Load a model from a string
41
47
 
42
48
  ```ruby
43
- byte_str = File.binread("model.onnx")
49
+ byte_str = StringIO.new("...")
44
50
  model = OnnxRuntime::Model.new(byte_str)
45
51
  ```
46
52
 
@@ -57,8 +63,8 @@ OnnxRuntime::Model.new(path_or_bytes, {
57
63
  enable_cpu_mem_arena: true,
58
64
  enable_mem_pattern: true,
59
65
  enable_profiling: false,
60
- execution_mode: :sequential,
61
- graph_optimization_level: nil,
66
+ execution_mode: :sequential, # :sequential or :parallel
67
+ graph_optimization_level: nil, # :none, :basic, :extended, or :all
62
68
  inter_op_num_threads: nil,
63
69
  intra_op_num_threads: nil,
64
70
  log_severity_level: 2,
@@ -19,7 +19,7 @@ module OnnxRuntime
19
19
  self.ffi_lib = [vendor_lib]
20
20
 
21
21
  def self.lib_version
22
- FFI.OrtGetApiBase[:GetVersionString].call
22
+ FFI.OrtGetApiBase[:GetVersionString].call.read_string
23
23
  end
24
24
 
25
25
  # friendlier error message
@@ -2,12 +2,7 @@ module OnnxRuntime
2
2
  module FFI
3
3
  extend ::FFI::Library
4
4
 
5
- begin
6
- ffi_lib Array(OnnxRuntime.ffi_lib)
7
- rescue LoadError => e
8
- raise e if ENV["ONNXRUNTIME_DEBUG"]
9
- raise LoadError, "Could not find ONNX Runtime"
10
- end
5
+ ffi_lib Array(OnnxRuntime.ffi_lib)
11
6
 
12
7
  # https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h
13
8
  # keep same order
@@ -20,19 +15,19 @@ module OnnxRuntime
20
15
  layout \
21
16
  :CreateStatus, callback(%i[int string], :pointer),
22
17
  :GetErrorCode, callback(%i[pointer], :pointer),
23
- :GetErrorMessage, callback(%i[pointer], :string),
18
+ :GetErrorMessage, callback(%i[pointer], :pointer),
24
19
  :CreateEnv, callback(%i[int string pointer], :pointer),
25
20
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
21
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
27
22
  :DisableTelemetryEvents, callback(%i[pointer], :pointer),
28
- :CreateSession, callback(%i[pointer string pointer pointer], :pointer),
23
+ :CreateSession, callback(%i[pointer pointer pointer pointer], :pointer),
29
24
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
25
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
26
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
27
+ :SetOptimizedModelFilePath, callback(%i[pointer pointer], :pointer),
33
28
  :CloneSessionOptions, callback(%i[], :pointer),
34
29
  :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[pointer string], :pointer),
30
+ :EnableProfiling, callback(%i[pointer pointer], :pointer),
36
31
  :DisableProfiling, callback(%i[pointer], :pointer),
37
32
  :EnableMemPattern, callback(%i[pointer], :pointer),
38
33
  :DisableMemPattern, callback(%i[pointer], :pointer),
@@ -119,7 +114,32 @@ module OnnxRuntime
119
114
  :ReleaseTypeInfo, callback(%i[pointer], :void),
120
115
  :ReleaseTensorTypeAndShapeInfo, callback(%i[pointer], :void),
121
116
  :ReleaseSessionOptions, callback(%i[pointer], :void),
122
- :ReleaseCustomOpDomain, callback(%i[pointer], :void)
117
+ :ReleaseCustomOpDomain, callback(%i[pointer], :void),
118
+ :GetDenotationFromTypeInfo, callback(%i[], :pointer),
119
+ :CastTypeInfoToMapTypeInfo, callback(%i[pointer pointer], :pointer),
120
+ :CastTypeInfoToSequenceTypeInfo, callback(%i[pointer pointer], :pointer),
121
+ :GetMapKeyType, callback(%i[pointer pointer], :pointer),
122
+ :GetMapValueType, callback(%i[pointer pointer], :pointer),
123
+ :GetSequenceElementType, callback(%i[pointer pointer], :pointer),
124
+ :ReleaseMapTypeInfo, callback(%i[pointer], :void),
125
+ :ReleaseSequenceTypeInfo, callback(%i[pointer], :void),
126
+ :SessionEndProfiling, callback(%i[pointer pointer pointer], :pointer),
127
+ :SessionGetModelMetadata, callback(%i[pointer pointer], :pointer),
128
+ :ModelMetadataGetProducerName, callback(%i[pointer pointer pointer], :pointer),
129
+ :ModelMetadataGetGraphName, callback(%i[pointer pointer pointer], :pointer),
130
+ :ModelMetadataGetDomain, callback(%i[pointer pointer pointer], :pointer),
131
+ :ModelMetadataGetDescription, callback(%i[pointer pointer pointer], :pointer),
132
+ :ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer),
133
+ :ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer),
134
+ :ReleaseModelMetadata, callback(%i[pointer], :void),
135
+ :CreateEnvWithGlobalThreadPools, callback(%i[], :pointer),
136
+ :DisablePerSessionThreads, callback(%i[], :pointer),
137
+ :CreateThreadingOptions, callback(%i[], :pointer),
138
+ :ReleaseThreadingOptions, callback(%i[], :pointer),
139
+ :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
140
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer),
141
+ :GetAvailableProviders, callback(%i[pointer pointer], :pointer),
142
+ :ReleaseAvailableProviders, callback(%i[pointer int], :pointer)
123
143
  end
124
144
 
125
145
  class ApiBase < ::FFI::Struct
@@ -127,9 +147,17 @@ module OnnxRuntime
127
147
  # to prevent "unable to resolve type" error on Ubuntu
128
148
  layout \
129
149
  :GetApi, callback(%i[uint32], Api.by_ref),
130
- :GetVersionString, callback(%i[], :string)
150
+ :GetVersionString, callback(%i[], :pointer)
131
151
  end
132
152
 
133
153
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
154
+
155
+ if Gem.win_platform?
156
+ class Libc
157
+ extend ::FFI::Library
158
+ ffi_lib ::FFI::Library::LIBC
159
+ attach_function :mbstowcs, %i[pointer string size_t], :size_t
160
+ end
161
+ end
134
162
  end
135
163
  end
@@ -8,41 +8,43 @@ module OnnxRuntime
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
9
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
10
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
- check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
12
12
  if execution_mode
13
- mode =
14
- case execution_mode
15
- when :sequential
16
- 0
17
- when :parallel
18
- 1
19
- else
20
- raise ArgumentError, "Invalid execution mode"
21
- end
13
+ execution_modes = {sequential: 0, parallel: 1}
14
+ mode = execution_modes[execution_mode]
15
+ raise ArgumentError, "Invalid execution mode" unless mode
22
16
  check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
23
17
  end
24
- check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level
18
+ if graph_optimization_level
19
+ optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
+ level = optimization_levels[graph_optimization_level]
21
+ raise ArgumentError, "Invalid graph optimization level" unless level
22
+ check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
+ end
25
24
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
26
25
  check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
27
26
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
28
27
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
29
28
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
30
- check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
31
30
 
32
31
  # session
33
32
  @session = ::FFI::MemoryPointer.new(:pointer)
34
- path_or_bytes = path_or_bytes.to_str
35
-
36
- # fix for Windows "File doesn't exist"
37
- if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY
38
- path_or_bytes = File.binread(path_or_bytes)
39
- end
33
+ from_memory =
34
+ if path_or_bytes.respond_to?(:read)
35
+ path_or_bytes = path_or_bytes.read
36
+ true
37
+ else
38
+ path_or_bytes = path_or_bytes.to_str
39
+ path_or_bytes.encoding == Encoding::BINARY
40
+ end
40
41
 
41
- if path_or_bytes.encoding == Encoding::BINARY
42
+ if from_memory
42
43
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
43
44
  else
44
- check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
45
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
45
46
  end
47
+ ObjectSpace.define_finalizer(self, self.class.finalize(@session))
46
48
 
47
49
  # input info
48
50
  allocator = ::FFI::MemoryPointer.new(:pointer)
@@ -55,7 +57,7 @@ module OnnxRuntime
55
57
  # input
56
58
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
57
59
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
58
- read_size_t(num_input_nodes).times do |i|
60
+ num_input_nodes.read(:size_t).times do |i|
59
61
  name_ptr = ::FFI::MemoryPointer.new(:string)
60
62
  check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
61
63
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
@@ -66,13 +68,15 @@ module OnnxRuntime
66
68
  # output
67
69
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
68
70
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
69
- read_size_t(num_output_nodes).times do |i|
71
+ num_output_nodes.read(:size_t).times do |i|
70
72
  name_ptr = ::FFI::MemoryPointer.new(:string)
71
73
  check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
72
74
  typeinfo = ::FFI::MemoryPointer.new(:pointer)
73
75
  check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
74
76
  @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
75
77
  end
78
+ ensure
79
+ # release :SessionOptions, session_options
76
80
  end
77
81
 
78
82
  # TODO support logid
@@ -98,6 +102,74 @@ module OnnxRuntime
98
102
  output_names.size.times.map do |i|
99
103
  create_from_onnx_value(output_tensor[i].read_pointer)
100
104
  end
105
+ ensure
106
+ release :RunOptions, run_options
107
+ if input_tensor
108
+ input_feed.size.times do |i|
109
+ release :Value, input_tensor[i]
110
+ end
111
+ end
112
+ end
113
+
114
+ def modelmeta
115
+ keys = ::FFI::MemoryPointer.new(:pointer)
116
+ num_keys = ::FFI::MemoryPointer.new(:int64_t)
117
+ description = ::FFI::MemoryPointer.new(:string)
118
+ domain = ::FFI::MemoryPointer.new(:string)
119
+ graph_name = ::FFI::MemoryPointer.new(:string)
120
+ producer_name = ::FFI::MemoryPointer.new(:string)
121
+ version = ::FFI::MemoryPointer.new(:int64_t)
122
+
123
+ metadata = ::FFI::MemoryPointer.new(:pointer)
124
+ check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
125
+
126
+ custom_metadata_map = {}
127
+ check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
128
+ num_keys.read(:int64_t).times do |i|
129
+ key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
130
+ value = ::FFI::MemoryPointer.new(:string)
131
+ check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
132
+ custom_metadata_map[key] = value.read_pointer.read_string
133
+ end
134
+
135
+ check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
136
+ check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
137
+ check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
138
+ check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
139
+ check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
140
+
141
+ {
142
+ custom_metadata_map: custom_metadata_map,
143
+ description: description.read_pointer.read_string,
144
+ domain: domain.read_pointer.read_string,
145
+ graph_name: graph_name.read_pointer.read_string,
146
+ producer_name: producer_name.read_pointer.read_string,
147
+ version: version.read(:int64_t)
148
+ }
149
+ ensure
150
+ release :ModelMetadata, metadata
151
+ end
152
+
153
+ # return value has double underscore like Python
154
+ def end_profiling
155
+ out = ::FFI::MemoryPointer.new(:string)
156
+ check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
157
+ out.read_pointer.read_string
158
+ end
159
+
160
+ # no way to set providers with C API yet
161
+ # so we can return all available providers
162
+ def providers
163
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
164
+ length_ptr = ::FFI::MemoryPointer.new(:int)
165
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
166
+ length = length_ptr.read_int
167
+ providers = []
168
+ length.times do |i|
169
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
170
+ end
171
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
172
+ providers
101
173
  end
102
174
 
103
175
  private
@@ -122,7 +194,7 @@ module OnnxRuntime
122
194
 
123
195
  # TODO support more types
124
196
  inp = @inputs.find { |i| i[:name] == input_name.to_s }
125
- raise "Unknown input: #{input_name}" unless inp
197
+ raise Error, "Unknown input: #{input_name}" unless inp
126
198
 
127
199
  input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
128
200
  input_node_dims.write_array_of_int64(shape)
@@ -179,7 +251,9 @@ module OnnxRuntime
179
251
 
180
252
  out_size = ::FFI::MemoryPointer.new(:size_t)
181
253
  output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
182
- output_tensor_size = read_size_t(out_size)
254
+ output_tensor_size = out_size.read(:size_t)
255
+
256
+ release :TensorTypeAndShapeInfo, typeinfo
183
257
 
184
258
  # TODO support more types
185
259
  type = FFI::TensorElementDataType[type]
@@ -198,7 +272,7 @@ module OnnxRuntime
198
272
  out = ::FFI::MemoryPointer.new(:size_t)
199
273
  check_status api[:GetValueCount].call(out_ptr, out)
200
274
 
201
- read_size_t(out).times.map do |i|
275
+ out.read(:size_t).times.map do |i|
202
276
  seq = ::FFI::MemoryPointer.new(:pointer)
203
277
  check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
204
278
  create_from_onnx_value(seq.read_pointer)
@@ -213,6 +287,7 @@ module OnnxRuntime
213
287
  check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
214
288
  check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
215
289
  check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
290
+ release :TensorTypeAndShapeInfo, type_shape
216
291
 
217
292
  # TODO support more types
218
293
  elem_type = FFI::TensorElementDataType[elem_type.read_int]
@@ -239,9 +314,9 @@ module OnnxRuntime
239
314
 
240
315
  def check_status(status)
241
316
  unless status.null?
242
- message = api[:GetErrorMessage].call(status)
317
+ message = api[:GetErrorMessage].call(status).read_string
243
318
  api[:ReleaseStatus].call(status)
244
- raise OnnxRuntime::Error, message
319
+ raise Error, message
245
320
  end
246
321
  end
247
322
 
@@ -253,6 +328,7 @@ module OnnxRuntime
253
328
  case type
254
329
  when :tensor
255
330
  tensor_info = ::FFI::MemoryPointer.new(:pointer)
331
+ # don't free tensor_info
256
332
  check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
257
333
 
258
334
  type, shape = tensor_type_and_shape(tensor_info)
@@ -261,22 +337,39 @@ module OnnxRuntime
261
337
  shape: shape
262
338
  }
263
339
  when :sequence
264
- # TODO show nested
340
+ sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
341
+ check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
342
+ nested_type_info = ::FFI::MemoryPointer.new(:pointer)
343
+ check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
344
+ v = node_info(nested_type_info)[:type]
345
+
265
346
  {
266
- type: "seq",
347
+ type: "seq(#{v})",
267
348
  shape: []
268
349
  }
269
350
  when :map
270
- # TODO show nested
351
+ map_type_info = ::FFI::MemoryPointer.new(:pointer)
352
+ check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
353
+
354
+ # key
355
+ key_type = ::FFI::MemoryPointer.new(:int)
356
+ check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
357
+ k = FFI::TensorElementDataType[key_type.read_int]
358
+
359
+ # value
360
+ value_type_info = ::FFI::MemoryPointer.new(:pointer)
361
+ check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
362
+ v = node_info(value_type_info)[:type]
363
+
271
364
  {
272
- type: "map",
365
+ type: "map(#{k},#{v})",
273
366
  shape: []
274
367
  }
275
368
  else
276
369
  unsupported_type("ONNX", type)
277
370
  end
278
371
  ensure
279
- api[:ReleaseTypeInfo].call(typeinfo.read_pointer)
372
+ release :TypeInfo, typeinfo
280
373
  end
281
374
 
282
375
  def tensor_type_and_shape(tensor_info)
@@ -285,7 +378,7 @@ module OnnxRuntime
285
378
 
286
379
  num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
287
380
  check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
288
- num_dims = read_size_t(num_dims_ptr)
381
+ num_dims = num_dims_ptr.read(:size_t)
289
382
 
290
383
  node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
291
384
  check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
@@ -294,20 +387,43 @@ module OnnxRuntime
294
387
  end
295
388
 
296
389
  def unsupported_type(name, type)
297
- raise "Unsupported #{name} type: #{type}"
390
+ raise Error, "Unsupported #{name} type: #{type}"
298
391
  end
299
392
 
300
- # read(:size_t) not supported in FFI JRuby
301
- def read_size_t(ptr)
302
- if RUBY_PLATFORM == "java"
303
- ptr.read_long
304
- else
305
- ptr.read(:size_t)
306
- end
393
+ def api
394
+ self.class.api
307
395
  end
308
396
 
309
- def api
310
- @api ||= FFI.OrtGetApiBase[:GetApi].call(1)
397
+ def release(*args)
398
+ self.class.release(*args)
399
+ end
400
+
401
+ def self.api
402
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
403
+ end
404
+
405
+ def self.release(type, pointer)
406
+ api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
407
+ end
408
+
409
+ def self.finalize(session)
410
+ # must use proc instead of stabby lambda
411
+ proc { release :Session, session }
412
+ end
413
+
414
+ # wide string on Windows
415
+ # char string on Linux
416
+ # see ORTCHAR_T in onnxruntime_c_api.h
417
+ def ort_string(str)
418
+ if Gem.win_platform?
419
+ max = str.size + 1 # for null byte
420
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
421
+ ret = FFI::Libc.mbstowcs(dest, str, max)
422
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
423
+ dest
424
+ else
425
+ str
426
+ end
311
427
  end
312
428
 
313
429
  def env
@@ -316,7 +432,7 @@ module OnnxRuntime
316
432
  @@env ||= begin
317
433
  env = ::FFI::MemoryPointer.new(:pointer)
318
434
  check_status api[:CreateEnv].call(3, "Default", env)
319
- at_exit { api[:ReleaseEnv].call(env.read_pointer) }
435
+ at_exit { release :Env, env }
320
436
  # disable telemetry
321
437
  # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
322
438
  check_status api[:DisableTelemetryEvents].call(env)
@@ -22,5 +22,9 @@ module OnnxRuntime
22
22
  def outputs
23
23
  @session.outputs
24
24
  end
25
+
26
+ def metadata
27
+ @session.modelmeta
28
+ end
25
29
  end
26
30
  end
@@ -1,3 +1,3 @@
1
1
  module OnnxRuntime
2
- VERSION = "0.2.3"
2
+ VERSION = "0.4.0"
3
3
  end
Binary file
Binary file
Binary file
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: onnxruntime
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-01-24 00:00:00.000000000 Z
11
+ date: 2020-07-21 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: ffi