onnxruntime 0.2.1 → 0.3.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2405ad6cda55897be52cbb6afd3854d2cf2ed5ed01abc09bf78568557af446ff
4
- data.tar.gz: 1e04e516b4726f510bf489108c1f684e398f8d63718c9717106ac9659773d3f0
3
+ metadata.gz: 0b017fa8896c64bbeeda0ba6582ca9bfa626de3f13023ec9f3f91000031e88b8
4
+ data.tar.gz: 16344db9151ca3f388539e05772a7172129613f949dbff77d89bfb9c307d0de2
5
5
  SHA512:
6
- metadata.gz: a2a38af03a82eecc60e28b6fb7b012c28188af69af9a7521cb4b1e4c2519c99c6d328045f0f20fe509787942d304a15e92e43ad94483c55f895cf4bbaf21b53e
7
- data.tar.gz: c782475317186f24c7e34285584ee7761522cc1f792e092c896abe32a041cf593ef4e7f0736a12688ad3e7539101234123d8948ab8ebe0c7986bfb40f33ff236
6
+ metadata.gz: 21011ae8d875230858ff148a27ff29caf005210aae6f6e6a347ed9a02dfb85bceac6714957718d3f17872f07f0ace21c3ba7d7ba8884c1dff4b79fd08b43e40b
7
+ data.tar.gz: 000f4575890f92c2b1de308e052977f159ead243aef3984cd4190a5a873ecbe2868fe586fdbe14f2963f0518a75d8c4600d47ba2fbdd6a77c84cf3da1fc84292
@@ -1,3 +1,31 @@
1
+ ## 0.3.2 (2020-06-16)
2
+
3
+ - Fixed error with FFI 1.13.0+
4
+ - Added friendly graph optimization levels
5
+
6
+ ## 0.3.1 (2020-05-18)
7
+
8
+ - Updated ONNX Runtime to 1.3.0
9
+ - Added `custom_metadata_map` to model metadata
10
+
11
+ ## 0.3.0 (2020-03-11)
12
+
13
+ - Updated ONNX Runtime to 1.2.0
14
+ - Added model metadata
15
+ - Added `end_profiling` method
16
+ - Added support for loading from IO objects
17
+ - Improved `input` and `output` for `seq` and `map` types
18
+
19
+ ## 0.2.3 (2020-01-23)
20
+
21
+ - Updated ONNX Runtime to 1.1.1
22
+
23
+ ## 0.2.2 (2019-12-24)
24
+
25
+ - Added support for session options
26
+ - Added support for run options
27
+ - Added `Datasets` module
28
+
1
29
  ## 0.2.1 (2019-12-19)
2
30
 
3
31
  - Updated ONNX Runtime to 1.1.0
@@ -1,4 +1,5 @@
1
- Copyright (c) 2019 Andrew Kane
1
+ Copyright (c) 2019-2020 Andrew Kane
2
+ Datasets Copyright (c) Microsoft Corporation
2
3
 
3
4
  MIT License
4
5
 
data/README.md CHANGED
@@ -20,7 +20,7 @@ Load a model and make predictions
20
20
 
21
21
  ```ruby
22
22
  model = OnnxRuntime::Model.new("model.onnx")
23
- model.predict(x: [1, 2, 3])
23
+ model.predict({x: [1, 2, 3]})
24
24
  ```
25
25
 
26
26
  > Download pre-trained models from the [ONNX Model Zoo](https://github.com/onnx/models)
@@ -37,10 +37,16 @@ Get outputs
37
37
  model.outputs
38
38
  ```
39
39
 
40
+ Get metadata
41
+
42
+ ```ruby
43
+ model.metadata
44
+ ```
45
+
40
46
  Load a model from a string
41
47
 
42
48
  ```ruby
43
- byte_str = File.binread("model.onnx")
49
+ byte_str = StringIO.new("...")
44
50
  model = OnnxRuntime::Model.new(byte_str)
45
51
  ```
46
52
 
@@ -50,6 +56,35 @@ Get specific outputs
50
56
  model.predict({x: [1, 2, 3]}, output_names: ["label"])
51
57
  ```
52
58
 
59
+ ## Session Options
60
+
61
+ ```ruby
62
+ OnnxRuntime::Model.new(path_or_bytes, {
63
+ enable_cpu_mem_arena: true,
64
+ enable_mem_pattern: true,
65
+ enable_profiling: false,
66
+ execution_mode: :sequential, # :sequential or :parallel
67
+ graph_optimization_level: nil, # :none, :basic, :extended, or :all
68
+ inter_op_num_threads: nil,
69
+ intra_op_num_threads: nil,
70
+ log_severity_level: 2,
71
+ log_verbosity_level: 0,
72
+ logid: nil,
73
+ optimized_model_filepath: nil
74
+ })
75
+ ```
76
+
77
+ ## Run Options
78
+
79
+ ```ruby
80
+ model.predict(input_feed, {
81
+ log_severity_level: 2,
82
+ log_verbosity_level: 0,
83
+ logid: nil,
84
+ terminate: false
85
+ })
86
+ ```
87
+
53
88
  ## Inference Session API
54
89
 
55
90
  You can also use the Inference Session API, which follows the [Python API](https://microsoft.github.io/onnxruntime/python/api_summary.html).
@@ -59,6 +94,20 @@ session = OnnxRuntime::InferenceSession.new("model.onnx")
59
94
  session.run(nil, {x: [1, 2, 3]})
60
95
  ```
61
96
 
97
+ The Python example models are included as well.
98
+
99
+ ```ruby
100
+ OnnxRuntime::Datasets.example("sigmoid.onnx")
101
+ ```
102
+
103
+ ## GPU Support
104
+
105
+ To enable GPU support on Linux and Windows, download the appropriate [GPU release](https://github.com/microsoft/onnxruntime/releases) and set:
106
+
107
+ ```ruby
108
+ OnnxRuntime.ffi_lib = "path/to/lib/libonnxruntime.so" # onnxruntime.dll for Windows
109
+ ```
110
+
62
111
  ## History
63
112
 
64
113
  View the [changelog](https://github.com/ankane/onnxruntime/blob/master/CHANGELOG.md)
@@ -2,6 +2,7 @@
2
2
  require "ffi"
3
3
 
4
4
  # modules
5
+ require "onnxruntime/datasets"
5
6
  require "onnxruntime/inference_session"
6
7
  require "onnxruntime/model"
7
8
  require "onnxruntime/utils"
@@ -18,7 +19,7 @@ module OnnxRuntime
18
19
  self.ffi_lib = [vendor_lib]
19
20
 
20
21
  def self.lib_version
21
- FFI.OrtGetApiBase[:GetVersionString].call
22
+ FFI.OrtGetApiBase[:GetVersionString].call.read_string
22
23
  end
23
24
 
24
25
  # friendlier error message
@@ -0,0 +1,10 @@
1
+ module OnnxRuntime
2
+ module Datasets
3
+ def self.example(name)
4
+ unless %w(logreg_iris.onnx mul_1.onnx sigmoid.onnx).include?(name)
5
+ raise ArgumentError, "Unable to find example '#{name}'"
6
+ end
7
+ File.expand_path("../../datasets/#{name}", __dir__)
8
+ end
9
+ end
10
+ end
@@ -3,7 +3,7 @@ module OnnxRuntime
3
3
  extend ::FFI::Library
4
4
 
5
5
  begin
6
- ffi_lib OnnxRuntime.ffi_lib
6
+ ffi_lib Array(OnnxRuntime.ffi_lib)
7
7
  rescue LoadError => e
8
8
  raise e if ENV["ONNXRUNTIME_DEBUG"]
9
9
  raise LoadError, "Could not find ONNX Runtime"
@@ -20,7 +20,7 @@ module OnnxRuntime
20
20
  layout \
21
21
  :CreateStatus, callback(%i[int string], :pointer),
22
22
  :GetErrorCode, callback(%i[pointer], :pointer),
23
- :GetErrorMessage, callback(%i[pointer], :string),
23
+ :GetErrorMessage, callback(%i[pointer], :pointer),
24
24
  :CreateEnv, callback(%i[int string pointer], :pointer),
25
25
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
26
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
@@ -29,21 +29,21 @@ module OnnxRuntime
29
29
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
30
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
31
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[], :pointer),
32
+ :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
33
33
  :CloneSessionOptions, callback(%i[], :pointer),
34
34
  :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[], :pointer),
36
- :DisableProfiling, callback(%i[], :pointer),
37
- :EnableMemPattern, callback(%i[], :pointer),
38
- :DisableMemPattern, callback(%i[], :pointer),
39
- :EnableCpuMemArena, callback(%i[], :pointer),
40
- :DisableCpuMemArena, callback(%i[], :pointer),
41
- :SetSessionLogId, callback(%i[], :pointer),
42
- :SetSessionLogVerbosityLevel, callback(%i[], :pointer),
43
- :SetSessionLogSeverityLevel, callback(%i[], :pointer),
44
- :SetSessionGraphOptimizationLevel, callback(%i[], :pointer),
45
- :SetIntraOpNumThreads, callback(%i[], :pointer),
46
- :SetInterOpNumThreads, callback(%i[], :pointer),
35
+ :EnableProfiling, callback(%i[pointer string], :pointer),
36
+ :DisableProfiling, callback(%i[pointer], :pointer),
37
+ :EnableMemPattern, callback(%i[pointer], :pointer),
38
+ :DisableMemPattern, callback(%i[pointer], :pointer),
39
+ :EnableCpuMemArena, callback(%i[pointer], :pointer),
40
+ :DisableCpuMemArena, callback(%i[pointer], :pointer),
41
+ :SetSessionLogId, callback(%i[pointer string], :pointer),
42
+ :SetSessionLogVerbosityLevel, callback(%i[pointer int], :pointer),
43
+ :SetSessionLogSeverityLevel, callback(%i[pointer int], :pointer),
44
+ :SetSessionGraphOptimizationLevel, callback(%i[pointer int], :pointer),
45
+ :SetIntraOpNumThreads, callback(%i[pointer int], :pointer),
46
+ :SetInterOpNumThreads, callback(%i[pointer int], :pointer),
47
47
  :CreateCustomOpDomain, callback(%i[], :pointer),
48
48
  :CustomOpDomain_Add, callback(%i[], :pointer),
49
49
  :AddCustomOpDomain, callback(%i[], :pointer),
@@ -57,15 +57,15 @@ module OnnxRuntime
57
57
  :SessionGetInputName, callback(%i[pointer size_t pointer pointer], :pointer),
58
58
  :SessionGetOutputName, callback(%i[pointer size_t pointer pointer], :pointer),
59
59
  :SessionGetOverridableInitializerName, callback(%i[], :pointer),
60
- :CreateRunOptions, callback(%i[], :pointer),
61
- :RunOptionsSetRunLogVerbosityLevel, callback(%i[], :pointer),
62
- :RunOptionsSetRunLogSeverityLevel, callback(%i[], :pointer),
63
- :RunOptionsSetRunTag, callback(%i[], :pointer),
60
+ :CreateRunOptions, callback(%i[pointer], :pointer),
61
+ :RunOptionsSetRunLogVerbosityLevel, callback(%i[pointer int], :pointer),
62
+ :RunOptionsSetRunLogSeverityLevel, callback(%i[pointer int], :pointer),
63
+ :RunOptionsSetRunTag, callback(%i[pointer string], :pointer),
64
64
  :RunOptionsGetRunLogVerbosityLevel, callback(%i[], :pointer),
65
65
  :RunOptionsGetRunLogSeverityLevel, callback(%i[], :pointer),
66
66
  :RunOptionsGetRunTag, callback(%i[], :pointer),
67
- :RunOptionsSetTerminate, callback(%i[], :pointer),
68
- :RunOptionsUnsetTerminate, callback(%i[], :pointer),
67
+ :RunOptionsSetTerminate, callback(%i[pointer], :pointer),
68
+ :RunOptionsUnsetTerminate, callback(%i[pointer], :pointer),
69
69
  :CreateTensorAsOrtValue, callback(%i[pointer pointer size_t int pointer], :pointer),
70
70
  :CreateTensorWithDataAsOrtValue, callback(%i[pointer pointer size_t pointer size_t int pointer], :pointer),
71
71
  :IsTensor, callback(%i[], :pointer),
@@ -119,7 +119,30 @@ module OnnxRuntime
119
119
  :ReleaseTypeInfo, callback(%i[pointer], :void),
120
120
  :ReleaseTensorTypeAndShapeInfo, callback(%i[pointer], :void),
121
121
  :ReleaseSessionOptions, callback(%i[pointer], :void),
122
- :ReleaseCustomOpDomain, callback(%i[pointer], :void)
122
+ :ReleaseCustomOpDomain, callback(%i[pointer], :void),
123
+ :GetDenotationFromTypeInfo, callback(%i[], :pointer),
124
+ :CastTypeInfoToMapTypeInfo, callback(%i[pointer pointer], :pointer),
125
+ :CastTypeInfoToSequenceTypeInfo, callback(%i[pointer pointer], :pointer),
126
+ :GetMapKeyType, callback(%i[pointer pointer], :pointer),
127
+ :GetMapValueType, callback(%i[pointer pointer], :pointer),
128
+ :GetSequenceElementType, callback(%i[pointer pointer], :pointer),
129
+ :ReleaseMapTypeInfo, callback(%i[pointer], :void),
130
+ :ReleaseSequenceTypeInfo, callback(%i[pointer], :void),
131
+ :SessionEndProfiling, callback(%i[pointer pointer pointer], :pointer),
132
+ :SessionGetModelMetadata, callback(%i[pointer pointer], :pointer),
133
+ :ModelMetadataGetProducerName, callback(%i[pointer pointer pointer], :pointer),
134
+ :ModelMetadataGetGraphName, callback(%i[pointer pointer pointer], :pointer),
135
+ :ModelMetadataGetDomain, callback(%i[pointer pointer pointer], :pointer),
136
+ :ModelMetadataGetDescription, callback(%i[pointer pointer pointer], :pointer),
137
+ :ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer),
138
+ :ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer),
139
+ :ReleaseModelMetadata, callback(%i[pointer], :void),
140
+ :CreateEnvWithGlobalThreadPools, callback(%i[], :pointer),
141
+ :DisablePerSessionThreads, callback(%i[], :pointer),
142
+ :CreateThreadingOptions, callback(%i[], :pointer),
143
+ :ReleaseThreadingOptions, callback(%i[], :pointer),
144
+ :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
145
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer)
123
146
  end
124
147
 
125
148
  class ApiBase < ::FFI::Struct
@@ -127,7 +150,7 @@ module OnnxRuntime
127
150
  # to prevent "unable to resolve type" error on Ubuntu
128
151
  layout \
129
152
  :GetApi, callback(%i[uint32], Api.by_ref),
130
- :GetVersionString, callback(%i[], :string)
153
+ :GetVersionString, callback(%i[], :pointer)
131
154
  end
132
155
 
133
156
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
@@ -2,21 +2,50 @@ module OnnxRuntime
2
2
  class InferenceSession
3
3
  attr_reader :inputs, :outputs
4
4
 
5
- def initialize(path_or_bytes)
5
+ def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil)
6
6
  # session options
7
7
  session_options = ::FFI::MemoryPointer.new(:pointer)
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
+ check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
+ check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
12
+ if execution_mode
13
+ execution_modes = {sequential: 0, parallel: 1}
14
+ mode = execution_modes[execution_mode]
15
+ raise ArgumentError, "Invalid execution mode" unless mode
16
+ check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
17
+ end
18
+ if graph_optimization_level
19
+ optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
+ # TODO raise error in 0.4.0
21
+ level = optimization_levels[graph_optimization_level] || graph_optimization_level
22
+ check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
+ end
24
+ check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
25
+ check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
26
+ check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
27
+ check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
28
+ check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
9
30
 
10
31
  # session
11
32
  @session = ::FFI::MemoryPointer.new(:pointer)
12
- path_or_bytes = path_or_bytes.to_str
33
+ from_memory =
34
+ if path_or_bytes.respond_to?(:read)
35
+ path_or_bytes = path_or_bytes.read
36
+ true
37
+ else
38
+ path_or_bytes = path_or_bytes.to_str
39
+ path_or_bytes.encoding == Encoding::BINARY
40
+ end
13
41
 
14
42
  # fix for Windows "File doesn't exist"
15
- if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY
43
+ if Gem.win_platform? && !from_memory
16
44
  path_or_bytes = File.binread(path_or_bytes)
45
+ from_memory = true
17
46
  end
18
47
 
19
- if path_or_bytes.encoding == Encoding::BINARY
48
+ if from_memory
20
49
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
21
50
  else
22
51
  check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
@@ -53,7 +82,8 @@ module OnnxRuntime
53
82
  end
54
83
  end
55
84
 
56
- def run(output_names, input_feed)
85
+ # TODO support logid
86
+ def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
57
87
  input_tensor = create_input_tensor(input_feed)
58
88
 
59
89
  output_names ||= @outputs.map { |v| v[:name] }
@@ -61,14 +91,66 @@ module OnnxRuntime
61
91
  output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
62
92
  input_node_names = create_node_names(input_feed.keys.map(&:to_s))
63
93
  output_node_names = create_node_names(output_names.map(&:to_s))
64
- # TODO support run options
65
- check_status api[:Run].call(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
94
+
95
+ # run options
96
+ run_options = ::FFI::MemoryPointer.new(:pointer)
97
+ check_status api[:CreateRunOptions].call(run_options)
98
+ check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level
99
+ check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level
100
+ check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid
101
+ check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate
102
+
103
+ check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
66
104
 
67
105
  output_names.size.times.map do |i|
68
106
  create_from_onnx_value(output_tensor[i].read_pointer)
69
107
  end
70
108
  end
71
109
 
110
+ def modelmeta
111
+ keys = ::FFI::MemoryPointer.new(:pointer)
112
+ num_keys = ::FFI::MemoryPointer.new(:int64_t)
113
+ description = ::FFI::MemoryPointer.new(:string)
114
+ domain = ::FFI::MemoryPointer.new(:string)
115
+ graph_name = ::FFI::MemoryPointer.new(:string)
116
+ producer_name = ::FFI::MemoryPointer.new(:string)
117
+ version = ::FFI::MemoryPointer.new(:int64_t)
118
+
119
+ metadata = ::FFI::MemoryPointer.new(:pointer)
120
+ check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
121
+
122
+ custom_metadata_map = {}
123
+ check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
124
+ num_keys.read(:int64_t).times do |i|
125
+ key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
126
+ value = ::FFI::MemoryPointer.new(:string)
127
+ check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
128
+ custom_metadata_map[key] = value.read_pointer.read_string
129
+ end
130
+
131
+ check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
132
+ check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
133
+ check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
134
+ check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
135
+ check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
136
+ api[:ReleaseModelMetadata].call(metadata.read_pointer)
137
+
138
+ {
139
+ custom_metadata_map: custom_metadata_map,
140
+ description: description.read_pointer.read_string,
141
+ domain: domain.read_pointer.read_string,
142
+ graph_name: graph_name.read_pointer.read_string,
143
+ producer_name: producer_name.read_pointer.read_string,
144
+ version: version.read(:int64_t)
145
+ }
146
+ end
147
+
148
+ def end_profiling
149
+ out = ::FFI::MemoryPointer.new(:string)
150
+ check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
151
+ out.read_pointer.read_string
152
+ end
153
+
72
154
  private
73
155
 
74
156
  def create_input_tensor(input_feed)
@@ -208,7 +290,7 @@ module OnnxRuntime
208
290
 
209
291
  def check_status(status)
210
292
  unless status.null?
211
- message = api[:GetErrorMessage].call(status)
293
+ message = api[:GetErrorMessage].call(status).read_string
212
294
  api[:ReleaseStatus].call(status)
213
295
  raise OnnxRuntime::Error, message
214
296
  end
@@ -230,15 +312,32 @@ module OnnxRuntime
230
312
  shape: shape
231
313
  }
232
314
  when :sequence
233
- # TODO show nested
315
+ sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
316
+ check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
317
+ nested_type_info = ::FFI::MemoryPointer.new(:pointer)
318
+ check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
319
+ v = node_info(nested_type_info)[:type]
320
+
234
321
  {
235
- type: "seq",
322
+ type: "seq(#{v})",
236
323
  shape: []
237
324
  }
238
325
  when :map
239
- # TODO show nested
326
+ map_type_info = ::FFI::MemoryPointer.new(:pointer)
327
+ check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
328
+
329
+ # key
330
+ key_type = ::FFI::MemoryPointer.new(:int)
331
+ check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
332
+ k = FFI::TensorElementDataType[key_type.read_int]
333
+
334
+ # value
335
+ value_type_info = ::FFI::MemoryPointer.new(:pointer)
336
+ check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
337
+ v = node_info(value_type_info)[:type]
338
+
240
339
  {
241
- type: "map",
340
+ type: "map(#{k},#{v})",
242
341
  shape: []
243
342
  }
244
343
  else
@@ -1,11 +1,11 @@
1
1
  module OnnxRuntime
2
2
  class Model
3
- def initialize(path_or_bytes)
4
- @session = InferenceSession.new(path_or_bytes)
3
+ def initialize(path_or_bytes, **session_options)
4
+ @session = InferenceSession.new(path_or_bytes, **session_options)
5
5
  end
6
6
 
7
- def predict(input_feed, output_names: nil)
8
- predictions = @session.run(output_names, input_feed)
7
+ def predict(input_feed, output_names: nil, **run_options)
8
+ predictions = @session.run(output_names, input_feed, **run_options)
9
9
  output_names ||= outputs.map { |o| o[:name] }
10
10
 
11
11
  result = {}
@@ -22,5 +22,9 @@ module OnnxRuntime
22
22
  def outputs
23
23
  @session.outputs
24
24
  end
25
+
26
+ def metadata
27
+ @session.modelmeta
28
+ end
25
29
  end
26
30
  end
@@ -1,3 +1,3 @@
1
1
  module OnnxRuntime
2
- VERSION = "0.2.1"
2
+ VERSION = "0.3.2"
3
3
  end
Binary file
Binary file
Binary file
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: onnxruntime
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.1
4
+ version: 0.3.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-12-20 00:00:00.000000000 Z
11
+ date: 2020-06-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: ffi
@@ -76,6 +76,7 @@ files:
76
76
  - LICENSE.txt
77
77
  - README.md
78
78
  - lib/onnxruntime.rb
79
+ - lib/onnxruntime/datasets.rb
79
80
  - lib/onnxruntime/ffi.rb
80
81
  - lib/onnxruntime/inference_session.rb
81
82
  - lib/onnxruntime/model.rb
@@ -105,7 +106,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
105
106
  - !ruby/object:Gem::Version
106
107
  version: '0'
107
108
  requirements: []
108
- rubygems_version: 3.0.3
109
+ rubygems_version: 3.1.2
109
110
  signing_key:
110
111
  specification_version: 4
111
112
  summary: High performance scoring engine for ML models