onnxruntime 0.2.1 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2405ad6cda55897be52cbb6afd3854d2cf2ed5ed01abc09bf78568557af446ff
4
- data.tar.gz: 1e04e516b4726f510bf489108c1f684e398f8d63718c9717106ac9659773d3f0
3
+ metadata.gz: 0b017fa8896c64bbeeda0ba6582ca9bfa626de3f13023ec9f3f91000031e88b8
4
+ data.tar.gz: 16344db9151ca3f388539e05772a7172129613f949dbff77d89bfb9c307d0de2
5
5
  SHA512:
6
- metadata.gz: a2a38af03a82eecc60e28b6fb7b012c28188af69af9a7521cb4b1e4c2519c99c6d328045f0f20fe509787942d304a15e92e43ad94483c55f895cf4bbaf21b53e
7
- data.tar.gz: c782475317186f24c7e34285584ee7761522cc1f792e092c896abe32a041cf593ef4e7f0736a12688ad3e7539101234123d8948ab8ebe0c7986bfb40f33ff236
6
+ metadata.gz: 21011ae8d875230858ff148a27ff29caf005210aae6f6e6a347ed9a02dfb85bceac6714957718d3f17872f07f0ace21c3ba7d7ba8884c1dff4b79fd08b43e40b
7
+ data.tar.gz: 000f4575890f92c2b1de308e052977f159ead243aef3984cd4190a5a873ecbe2868fe586fdbe14f2963f0518a75d8c4600d47ba2fbdd6a77c84cf3da1fc84292
@@ -1,3 +1,31 @@
1
+ ## 0.3.2 (2020-06-16)
2
+
3
+ - Fixed error with FFI 1.13.0+
4
+ - Added friendly graph optimization levels
5
+
6
+ ## 0.3.1 (2020-05-18)
7
+
8
+ - Updated ONNX Runtime to 1.3.0
9
+ - Added `custom_metadata_map` to model metadata
10
+
11
+ ## 0.3.0 (2020-03-11)
12
+
13
+ - Updated ONNX Runtime to 1.2.0
14
+ - Added model metadata
15
+ - Added `end_profiling` method
16
+ - Added support for loading from IO objects
17
+ - Improved `input` and `output` for `seq` and `map` types
18
+
19
+ ## 0.2.3 (2020-01-23)
20
+
21
+ - Updated ONNX Runtime to 1.1.1
22
+
23
+ ## 0.2.2 (2019-12-24)
24
+
25
+ - Added support for session options
26
+ - Added support for run options
27
+ - Added `Datasets` module
28
+
1
29
  ## 0.2.1 (2019-12-19)
2
30
 
3
31
  - Updated ONNX Runtime to 1.1.0
@@ -1,4 +1,5 @@
1
- Copyright (c) 2019 Andrew Kane
1
+ Copyright (c) 2019-2020 Andrew Kane
2
+ Datasets Copyright (c) Microsoft Corporation
2
3
 
3
4
  MIT License
4
5
 
data/README.md CHANGED
@@ -20,7 +20,7 @@ Load a model and make predictions
20
20
 
21
21
  ```ruby
22
22
  model = OnnxRuntime::Model.new("model.onnx")
23
- model.predict(x: [1, 2, 3])
23
+ model.predict({x: [1, 2, 3]})
24
24
  ```
25
25
 
26
26
  > Download pre-trained models from the [ONNX Model Zoo](https://github.com/onnx/models)
@@ -37,10 +37,16 @@ Get outputs
37
37
  model.outputs
38
38
  ```
39
39
 
40
+ Get metadata
41
+
42
+ ```ruby
43
+ model.metadata
44
+ ```
45
+
40
46
  Load a model from a string
41
47
 
42
48
  ```ruby
43
- byte_str = File.binread("model.onnx")
49
+ byte_str = StringIO.new("...")
44
50
  model = OnnxRuntime::Model.new(byte_str)
45
51
  ```
46
52
 
@@ -50,6 +56,35 @@ Get specific outputs
50
56
  model.predict({x: [1, 2, 3]}, output_names: ["label"])
51
57
  ```
52
58
 
59
+ ## Session Options
60
+
61
+ ```ruby
62
+ OnnxRuntime::Model.new(path_or_bytes, {
63
+ enable_cpu_mem_arena: true,
64
+ enable_mem_pattern: true,
65
+ enable_profiling: false,
66
+ execution_mode: :sequential, # :sequential or :parallel
67
+ graph_optimization_level: nil, # :none, :basic, :extended, or :all
68
+ inter_op_num_threads: nil,
69
+ intra_op_num_threads: nil,
70
+ log_severity_level: 2,
71
+ log_verbosity_level: 0,
72
+ logid: nil,
73
+ optimized_model_filepath: nil
74
+ })
75
+ ```
76
+
77
+ ## Run Options
78
+
79
+ ```ruby
80
+ model.predict(input_feed, {
81
+ log_severity_level: 2,
82
+ log_verbosity_level: 0,
83
+ logid: nil,
84
+ terminate: false
85
+ })
86
+ ```
87
+
53
88
  ## Inference Session API
54
89
 
55
90
  You can also use the Inference Session API, which follows the [Python API](https://microsoft.github.io/onnxruntime/python/api_summary.html).
@@ -59,6 +94,20 @@ session = OnnxRuntime::InferenceSession.new("model.onnx")
59
94
  session.run(nil, {x: [1, 2, 3]})
60
95
  ```
61
96
 
97
+ The Python example models are included as well.
98
+
99
+ ```ruby
100
+ OnnxRuntime::Datasets.example("sigmoid.onnx")
101
+ ```
102
+
103
+ ## GPU Support
104
+
105
+ To enable GPU support on Linux and Windows, download the appropriate [GPU release](https://github.com/microsoft/onnxruntime/releases) and set:
106
+
107
+ ```ruby
108
+ OnnxRuntime.ffi_lib = "path/to/lib/libonnxruntime.so" # onnxruntime.dll for Windows
109
+ ```
110
+
62
111
  ## History
63
112
 
64
113
  View the [changelog](https://github.com/ankane/onnxruntime/blob/master/CHANGELOG.md)
@@ -2,6 +2,7 @@
2
2
  require "ffi"
3
3
 
4
4
  # modules
5
+ require "onnxruntime/datasets"
5
6
  require "onnxruntime/inference_session"
6
7
  require "onnxruntime/model"
7
8
  require "onnxruntime/utils"
@@ -18,7 +19,7 @@ module OnnxRuntime
18
19
  self.ffi_lib = [vendor_lib]
19
20
 
20
21
  def self.lib_version
21
- FFI.OrtGetApiBase[:GetVersionString].call
22
+ FFI.OrtGetApiBase[:GetVersionString].call.read_string
22
23
  end
23
24
 
24
25
  # friendlier error message
@@ -0,0 +1,10 @@
1
+ module OnnxRuntime
2
+ module Datasets
3
+ def self.example(name)
4
+ unless %w(logreg_iris.onnx mul_1.onnx sigmoid.onnx).include?(name)
5
+ raise ArgumentError, "Unable to find example '#{name}'"
6
+ end
7
+ File.expand_path("../../datasets/#{name}", __dir__)
8
+ end
9
+ end
10
+ end
@@ -3,7 +3,7 @@ module OnnxRuntime
3
3
  extend ::FFI::Library
4
4
 
5
5
  begin
6
- ffi_lib OnnxRuntime.ffi_lib
6
+ ffi_lib Array(OnnxRuntime.ffi_lib)
7
7
  rescue LoadError => e
8
8
  raise e if ENV["ONNXRUNTIME_DEBUG"]
9
9
  raise LoadError, "Could not find ONNX Runtime"
@@ -20,7 +20,7 @@ module OnnxRuntime
20
20
  layout \
21
21
  :CreateStatus, callback(%i[int string], :pointer),
22
22
  :GetErrorCode, callback(%i[pointer], :pointer),
23
- :GetErrorMessage, callback(%i[pointer], :string),
23
+ :GetErrorMessage, callback(%i[pointer], :pointer),
24
24
  :CreateEnv, callback(%i[int string pointer], :pointer),
25
25
  :CreateEnvWithCustomLogger, callback(%i[], :pointer),
26
26
  :EnableTelemetryEvents, callback(%i[pointer], :pointer),
@@ -29,21 +29,21 @@ module OnnxRuntime
29
29
  :CreateSessionFromArray, callback(%i[pointer pointer size_t pointer pointer], :pointer),
30
30
  :Run, callback(%i[pointer pointer pointer pointer size_t pointer size_t pointer], :pointer),
31
31
  :CreateSessionOptions, callback(%i[pointer], :pointer),
32
- :SetOptimizedModelFilePath, callback(%i[], :pointer),
32
+ :SetOptimizedModelFilePath, callback(%i[pointer string], :pointer),
33
33
  :CloneSessionOptions, callback(%i[], :pointer),
34
34
  :SetSessionExecutionMode, callback(%i[], :pointer),
35
- :EnableProfiling, callback(%i[], :pointer),
36
- :DisableProfiling, callback(%i[], :pointer),
37
- :EnableMemPattern, callback(%i[], :pointer),
38
- :DisableMemPattern, callback(%i[], :pointer),
39
- :EnableCpuMemArena, callback(%i[], :pointer),
40
- :DisableCpuMemArena, callback(%i[], :pointer),
41
- :SetSessionLogId, callback(%i[], :pointer),
42
- :SetSessionLogVerbosityLevel, callback(%i[], :pointer),
43
- :SetSessionLogSeverityLevel, callback(%i[], :pointer),
44
- :SetSessionGraphOptimizationLevel, callback(%i[], :pointer),
45
- :SetIntraOpNumThreads, callback(%i[], :pointer),
46
- :SetInterOpNumThreads, callback(%i[], :pointer),
35
+ :EnableProfiling, callback(%i[pointer string], :pointer),
36
+ :DisableProfiling, callback(%i[pointer], :pointer),
37
+ :EnableMemPattern, callback(%i[pointer], :pointer),
38
+ :DisableMemPattern, callback(%i[pointer], :pointer),
39
+ :EnableCpuMemArena, callback(%i[pointer], :pointer),
40
+ :DisableCpuMemArena, callback(%i[pointer], :pointer),
41
+ :SetSessionLogId, callback(%i[pointer string], :pointer),
42
+ :SetSessionLogVerbosityLevel, callback(%i[pointer int], :pointer),
43
+ :SetSessionLogSeverityLevel, callback(%i[pointer int], :pointer),
44
+ :SetSessionGraphOptimizationLevel, callback(%i[pointer int], :pointer),
45
+ :SetIntraOpNumThreads, callback(%i[pointer int], :pointer),
46
+ :SetInterOpNumThreads, callback(%i[pointer int], :pointer),
47
47
  :CreateCustomOpDomain, callback(%i[], :pointer),
48
48
  :CustomOpDomain_Add, callback(%i[], :pointer),
49
49
  :AddCustomOpDomain, callback(%i[], :pointer),
@@ -57,15 +57,15 @@ module OnnxRuntime
57
57
  :SessionGetInputName, callback(%i[pointer size_t pointer pointer], :pointer),
58
58
  :SessionGetOutputName, callback(%i[pointer size_t pointer pointer], :pointer),
59
59
  :SessionGetOverridableInitializerName, callback(%i[], :pointer),
60
- :CreateRunOptions, callback(%i[], :pointer),
61
- :RunOptionsSetRunLogVerbosityLevel, callback(%i[], :pointer),
62
- :RunOptionsSetRunLogSeverityLevel, callback(%i[], :pointer),
63
- :RunOptionsSetRunTag, callback(%i[], :pointer),
60
+ :CreateRunOptions, callback(%i[pointer], :pointer),
61
+ :RunOptionsSetRunLogVerbosityLevel, callback(%i[pointer int], :pointer),
62
+ :RunOptionsSetRunLogSeverityLevel, callback(%i[pointer int], :pointer),
63
+ :RunOptionsSetRunTag, callback(%i[pointer string], :pointer),
64
64
  :RunOptionsGetRunLogVerbosityLevel, callback(%i[], :pointer),
65
65
  :RunOptionsGetRunLogSeverityLevel, callback(%i[], :pointer),
66
66
  :RunOptionsGetRunTag, callback(%i[], :pointer),
67
- :RunOptionsSetTerminate, callback(%i[], :pointer),
68
- :RunOptionsUnsetTerminate, callback(%i[], :pointer),
67
+ :RunOptionsSetTerminate, callback(%i[pointer], :pointer),
68
+ :RunOptionsUnsetTerminate, callback(%i[pointer], :pointer),
69
69
  :CreateTensorAsOrtValue, callback(%i[pointer pointer size_t int pointer], :pointer),
70
70
  :CreateTensorWithDataAsOrtValue, callback(%i[pointer pointer size_t pointer size_t int pointer], :pointer),
71
71
  :IsTensor, callback(%i[], :pointer),
@@ -119,7 +119,30 @@ module OnnxRuntime
119
119
  :ReleaseTypeInfo, callback(%i[pointer], :void),
120
120
  :ReleaseTensorTypeAndShapeInfo, callback(%i[pointer], :void),
121
121
  :ReleaseSessionOptions, callback(%i[pointer], :void),
122
- :ReleaseCustomOpDomain, callback(%i[pointer], :void)
122
+ :ReleaseCustomOpDomain, callback(%i[pointer], :void),
123
+ :GetDenotationFromTypeInfo, callback(%i[], :pointer),
124
+ :CastTypeInfoToMapTypeInfo, callback(%i[pointer pointer], :pointer),
125
+ :CastTypeInfoToSequenceTypeInfo, callback(%i[pointer pointer], :pointer),
126
+ :GetMapKeyType, callback(%i[pointer pointer], :pointer),
127
+ :GetMapValueType, callback(%i[pointer pointer], :pointer),
128
+ :GetSequenceElementType, callback(%i[pointer pointer], :pointer),
129
+ :ReleaseMapTypeInfo, callback(%i[pointer], :void),
130
+ :ReleaseSequenceTypeInfo, callback(%i[pointer], :void),
131
+ :SessionEndProfiling, callback(%i[pointer pointer pointer], :pointer),
132
+ :SessionGetModelMetadata, callback(%i[pointer pointer], :pointer),
133
+ :ModelMetadataGetProducerName, callback(%i[pointer pointer pointer], :pointer),
134
+ :ModelMetadataGetGraphName, callback(%i[pointer pointer pointer], :pointer),
135
+ :ModelMetadataGetDomain, callback(%i[pointer pointer pointer], :pointer),
136
+ :ModelMetadataGetDescription, callback(%i[pointer pointer pointer], :pointer),
137
+ :ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer),
138
+ :ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer),
139
+ :ReleaseModelMetadata, callback(%i[pointer], :void),
140
+ :CreateEnvWithGlobalThreadPools, callback(%i[], :pointer),
141
+ :DisablePerSessionThreads, callback(%i[], :pointer),
142
+ :CreateThreadingOptions, callback(%i[], :pointer),
143
+ :ReleaseThreadingOptions, callback(%i[], :pointer),
144
+ :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer),
145
+ :AddFreeDimensionOverrideByName, callback(%i[], :pointer)
123
146
  end
124
147
 
125
148
  class ApiBase < ::FFI::Struct
@@ -127,7 +150,7 @@ module OnnxRuntime
127
150
  # to prevent "unable to resolve type" error on Ubuntu
128
151
  layout \
129
152
  :GetApi, callback(%i[uint32], Api.by_ref),
130
- :GetVersionString, callback(%i[], :string)
153
+ :GetVersionString, callback(%i[], :pointer)
131
154
  end
132
155
 
133
156
  attach_function :OrtGetApiBase, %i[], ApiBase.by_ref
@@ -2,21 +2,50 @@ module OnnxRuntime
2
2
  class InferenceSession
3
3
  attr_reader :inputs, :outputs
4
4
 
5
- def initialize(path_or_bytes)
5
+ def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil)
6
6
  # session options
7
7
  session_options = ::FFI::MemoryPointer.new(:pointer)
8
8
  check_status api[:CreateSessionOptions].call(session_options)
9
+ check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
10
+ check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
11
+ check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
12
+ if execution_mode
13
+ execution_modes = {sequential: 0, parallel: 1}
14
+ mode = execution_modes[execution_mode]
15
+ raise ArgumentError, "Invalid execution mode" unless mode
16
+ check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
17
+ end
18
+ if graph_optimization_level
19
+ optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
20
+ # TODO raise error in 0.4.0
21
+ level = optimization_levels[graph_optimization_level] || graph_optimization_level
22
+ check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
23
+ end
24
+ check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
25
+ check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
26
+ check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
27
+ check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
28
+ check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
29
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
9
30
 
10
31
  # session
11
32
  @session = ::FFI::MemoryPointer.new(:pointer)
12
- path_or_bytes = path_or_bytes.to_str
33
+ from_memory =
34
+ if path_or_bytes.respond_to?(:read)
35
+ path_or_bytes = path_or_bytes.read
36
+ true
37
+ else
38
+ path_or_bytes = path_or_bytes.to_str
39
+ path_or_bytes.encoding == Encoding::BINARY
40
+ end
13
41
 
14
42
  # fix for Windows "File doesn't exist"
15
- if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY
43
+ if Gem.win_platform? && !from_memory
16
44
  path_or_bytes = File.binread(path_or_bytes)
45
+ from_memory = true
17
46
  end
18
47
 
19
- if path_or_bytes.encoding == Encoding::BINARY
48
+ if from_memory
20
49
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
21
50
  else
22
51
  check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
@@ -53,7 +82,8 @@ module OnnxRuntime
53
82
  end
54
83
  end
55
84
 
56
- def run(output_names, input_feed)
85
+ # TODO support logid
86
+ def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
57
87
  input_tensor = create_input_tensor(input_feed)
58
88
 
59
89
  output_names ||= @outputs.map { |v| v[:name] }
@@ -61,14 +91,66 @@ module OnnxRuntime
61
91
  output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
62
92
  input_node_names = create_node_names(input_feed.keys.map(&:to_s))
63
93
  output_node_names = create_node_names(output_names.map(&:to_s))
64
- # TODO support run options
65
- check_status api[:Run].call(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
94
+
95
+ # run options
96
+ run_options = ::FFI::MemoryPointer.new(:pointer)
97
+ check_status api[:CreateRunOptions].call(run_options)
98
+ check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level
99
+ check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level
100
+ check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid
101
+ check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate
102
+
103
+ check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
66
104
 
67
105
  output_names.size.times.map do |i|
68
106
  create_from_onnx_value(output_tensor[i].read_pointer)
69
107
  end
70
108
  end
71
109
 
110
+ def modelmeta
111
+ keys = ::FFI::MemoryPointer.new(:pointer)
112
+ num_keys = ::FFI::MemoryPointer.new(:int64_t)
113
+ description = ::FFI::MemoryPointer.new(:string)
114
+ domain = ::FFI::MemoryPointer.new(:string)
115
+ graph_name = ::FFI::MemoryPointer.new(:string)
116
+ producer_name = ::FFI::MemoryPointer.new(:string)
117
+ version = ::FFI::MemoryPointer.new(:int64_t)
118
+
119
+ metadata = ::FFI::MemoryPointer.new(:pointer)
120
+ check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
121
+
122
+ custom_metadata_map = {}
123
+ check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
124
+ num_keys.read(:int64_t).times do |i|
125
+ key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
126
+ value = ::FFI::MemoryPointer.new(:string)
127
+ check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
128
+ custom_metadata_map[key] = value.read_pointer.read_string
129
+ end
130
+
131
+ check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
132
+ check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
133
+ check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
134
+ check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
135
+ check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
136
+ api[:ReleaseModelMetadata].call(metadata.read_pointer)
137
+
138
+ {
139
+ custom_metadata_map: custom_metadata_map,
140
+ description: description.read_pointer.read_string,
141
+ domain: domain.read_pointer.read_string,
142
+ graph_name: graph_name.read_pointer.read_string,
143
+ producer_name: producer_name.read_pointer.read_string,
144
+ version: version.read(:int64_t)
145
+ }
146
+ end
147
+
148
+ def end_profiling
149
+ out = ::FFI::MemoryPointer.new(:string)
150
+ check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
151
+ out.read_pointer.read_string
152
+ end
153
+
72
154
  private
73
155
 
74
156
  def create_input_tensor(input_feed)
@@ -208,7 +290,7 @@ module OnnxRuntime
208
290
 
209
291
  def check_status(status)
210
292
  unless status.null?
211
- message = api[:GetErrorMessage].call(status)
293
+ message = api[:GetErrorMessage].call(status).read_string
212
294
  api[:ReleaseStatus].call(status)
213
295
  raise OnnxRuntime::Error, message
214
296
  end
@@ -230,15 +312,32 @@ module OnnxRuntime
230
312
  shape: shape
231
313
  }
232
314
  when :sequence
233
- # TODO show nested
315
+ sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
316
+ check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
317
+ nested_type_info = ::FFI::MemoryPointer.new(:pointer)
318
+ check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
319
+ v = node_info(nested_type_info)[:type]
320
+
234
321
  {
235
- type: "seq",
322
+ type: "seq(#{v})",
236
323
  shape: []
237
324
  }
238
325
  when :map
239
- # TODO show nested
326
+ map_type_info = ::FFI::MemoryPointer.new(:pointer)
327
+ check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
328
+
329
+ # key
330
+ key_type = ::FFI::MemoryPointer.new(:int)
331
+ check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
332
+ k = FFI::TensorElementDataType[key_type.read_int]
333
+
334
+ # value
335
+ value_type_info = ::FFI::MemoryPointer.new(:pointer)
336
+ check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
337
+ v = node_info(value_type_info)[:type]
338
+
240
339
  {
241
- type: "map",
340
+ type: "map(#{k},#{v})",
242
341
  shape: []
243
342
  }
244
343
  else
@@ -1,11 +1,11 @@
1
1
  module OnnxRuntime
2
2
  class Model
3
- def initialize(path_or_bytes)
4
- @session = InferenceSession.new(path_or_bytes)
3
+ def initialize(path_or_bytes, **session_options)
4
+ @session = InferenceSession.new(path_or_bytes, **session_options)
5
5
  end
6
6
 
7
- def predict(input_feed, output_names: nil)
8
- predictions = @session.run(output_names, input_feed)
7
+ def predict(input_feed, output_names: nil, **run_options)
8
+ predictions = @session.run(output_names, input_feed, **run_options)
9
9
  output_names ||= outputs.map { |o| o[:name] }
10
10
 
11
11
  result = {}
@@ -22,5 +22,9 @@ module OnnxRuntime
22
22
  def outputs
23
23
  @session.outputs
24
24
  end
25
+
26
+ def metadata
27
+ @session.modelmeta
28
+ end
25
29
  end
26
30
  end
@@ -1,3 +1,3 @@
1
1
  module OnnxRuntime
2
- VERSION = "0.2.1"
2
+ VERSION = "0.3.2"
3
3
  end
Binary file
Binary file
Binary file
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: onnxruntime
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.1
4
+ version: 0.3.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-12-20 00:00:00.000000000 Z
11
+ date: 2020-06-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: ffi
@@ -76,6 +76,7 @@ files:
76
76
  - LICENSE.txt
77
77
  - README.md
78
78
  - lib/onnxruntime.rb
79
+ - lib/onnxruntime/datasets.rb
79
80
  - lib/onnxruntime/ffi.rb
80
81
  - lib/onnxruntime/inference_session.rb
81
82
  - lib/onnxruntime/model.rb
@@ -105,7 +106,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
105
106
  - !ruby/object:Gem::Version
106
107
  version: '0'
107
108
  requirements: []
108
- rubygems_version: 3.0.3
109
+ rubygems_version: 3.1.2
109
110
  signing_key:
110
111
  specification_version: 4
111
112
  summary: High performance scoring engine for ML models