onnxruntime 0.2.3 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ec7b632678142610395254db0a898c300e23169f822d352e718076959cac20d1
4
- data.tar.gz: bdb36f003d7285fb95b8e962c168dd9d0a639811e118c67fd5d5c29179755d02
3
+ metadata.gz: 1a31036492f6ad4f0ab17a2534a69f17ca05058267447089868f2ac5bf2c5e79
4
+ data.tar.gz: 6cd8be229992a0ea4c9fdb8529a229e31a777321e0971f0123bc2680ca828be3
5
5
  SHA512:
6
- metadata.gz: 98bc00958bdf3f38a1c33a118850a57edfffc06d7da82ef0d2aa305a625ac1d0aae57b2c5314f10868d04a41c7de8adc981de03f3c8367d1b05c6ad308d12d79
7
- data.tar.gz: 948dec286c981fcd623e3252ef51ccf6b33a275c7019896a68869895643741d5c6f4bb534b505687acc3f717f5da5962f484c24f9ec2938821f003aa2d0e4174
6
+ metadata.gz: 9bb61d56c9c4ddb17e0e085381226ab0c91e72de592deb6519bf6840ffb1a00eaa0fb9382eb0cba1d563165fb9d0528034da1f2ea9b6b3cb76b780813750ad21
7
+ data.tar.gz: b94b0b2772728ead241a0d16c3ca2db099cbf7ab6f3994384e3107f69d9aa2cdaa31ed4ba08bd4d3d9fae2eef0b13ea84aa70823be0ff2d24b921e8fb390246b
@@ -1,3 +1,11 @@
1
+ ## 0.3.0 (2020-03-11)
2
+
3
+ - Updated ONNX Runtime to 1.2.0
4
+ - Added model metadata
5
+ - Added `end_profiling` method
6
+ - Added support for loading from IO objects
7
+ - Improved `input` and `output` for `seq` and `map` types
8
+
1
9
  ## 0.2.3 (2020-01-23)
2
10
 
3
11
  - Updated ONNX Runtime to 1.1.1
data/README.md CHANGED
@@ -37,10 +37,16 @@ Get outputs
37
37
  model.outputs
38
38
  ```
39
39
 
40
+ Get metadata
41
+
42
+ ```ruby
43
+ model.metadata
44
+ ```
45
+
40
46
  Load a model from a string
41
47
 
42
48
  ```ruby
43
- byte_str = File.binread("model.onnx")
49
+ byte_str = StringIO.new("...")
44
50
  model = OnnxRuntime::Model.new(byte_str)
45
51
  ```
46
52
 
@@ -119,7 +119,24 @@ module OnnxRuntime
119
119
  :ReleaseTypeInfo, callback(%i[pointer], :void),
120
120
  :ReleaseTensorTypeAndShapeInfo, callback(%i[pointer], :void),
121
121
  :ReleaseSessionOptions, callback(%i[pointer], :void),
122
- :ReleaseCustomOpDomain, callback(%i[pointer], :void)
122
+ :ReleaseCustomOpDomain, callback(%i[pointer], :void),
123
+ :GetDenotationFromTypeInfo, callback(%i[], :pointer),
124
+ :CastTypeInfoToMapTypeInfo, callback(%i[pointer pointer], :pointer),
125
+ :CastTypeInfoToSequenceTypeInfo, callback(%i[pointer pointer], :pointer),
126
+ :GetMapKeyType, callback(%i[pointer pointer], :pointer),
127
+ :GetMapValueType, callback(%i[pointer pointer], :pointer),
128
+ :GetSequenceElementType, callback(%i[pointer pointer], :pointer),
129
+ :ReleaseMapTypeInfo, callback(%i[pointer], :void),
130
+ :ReleaseSequenceTypeInfo, callback(%i[pointer], :void),
131
+ :SessionEndProfiling, callback(%i[pointer pointer pointer], :pointer),
132
+ :SessionGetModelMetadata, callback(%i[pointer pointer], :pointer),
133
+ :ModelMetadataGetProducerName, callback(%i[pointer pointer pointer], :pointer),
134
+ :ModelMetadataGetGraphName, callback(%i[pointer pointer pointer], :pointer),
135
+ :ModelMetadataGetDomain, callback(%i[pointer pointer pointer], :pointer),
136
+ :ModelMetadataGetDescription, callback(%i[pointer pointer pointer], :pointer),
137
+ :ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer),
138
+ :ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer),
139
+ :ReleaseModelMetadata, callback(%i[pointer], :void)
123
140
  end
124
141
 
125
142
  class ApiBase < ::FFI::Struct
@@ -31,14 +31,22 @@ module OnnxRuntime
31
31
 
32
32
  # session
33
33
  @session = ::FFI::MemoryPointer.new(:pointer)
34
- path_or_bytes = path_or_bytes.to_str
34
+ from_memory =
35
+ if path_or_bytes.respond_to?(:read)
36
+ path_or_bytes = path_or_bytes.read
37
+ true
38
+ else
39
+ path_or_bytes = path_or_bytes.to_str
40
+ path_or_bytes.encoding == Encoding::BINARY
41
+ end
35
42
 
36
43
  # fix for Windows "File doesn't exist"
37
- if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY
44
+ if Gem.win_platform? && !from_memory
38
45
  path_or_bytes = File.binread(path_or_bytes)
46
+ from_memory = true
39
47
  end
40
48
 
41
- if path_or_bytes.encoding == Encoding::BINARY
49
+ if from_memory
42
50
  check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
43
51
  else
44
52
  check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
@@ -100,6 +108,40 @@ module OnnxRuntime
100
108
  end
101
109
  end
102
110
 
111
+ def modelmeta
112
+ description = ::FFI::MemoryPointer.new(:string)
113
+ domain = ::FFI::MemoryPointer.new(:string)
114
+ graph_name = ::FFI::MemoryPointer.new(:string)
115
+ producer_name = ::FFI::MemoryPointer.new(:string)
116
+ version = ::FFI::MemoryPointer.new(:int64_t)
117
+
118
+ metadata = ::FFI::MemoryPointer.new(:pointer)
119
+ check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
120
+ check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
121
+ check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
122
+ check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
123
+ check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
124
+ check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
125
+ api[:ReleaseModelMetadata].call(metadata.read_pointer)
126
+
127
+ # TODO add custom_metadata_map
128
+ # need a way to get keys
129
+
130
+ {
131
+ description: description.read_pointer.read_string,
132
+ domain: domain.read_pointer.read_string,
133
+ graph_name: graph_name.read_pointer.read_string,
134
+ producer_name: producer_name.read_pointer.read_string,
135
+ version: version.read(:int64_t)
136
+ }
137
+ end
138
+
139
+ def end_profiling
140
+ out = ::FFI::MemoryPointer.new(:string)
141
+ check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
142
+ out.read_pointer.read_string
143
+ end
144
+
103
145
  private
104
146
 
105
147
  def create_input_tensor(input_feed)
@@ -261,15 +303,32 @@ module OnnxRuntime
261
303
  shape: shape
262
304
  }
263
305
  when :sequence
264
- # TODO show nested
306
+ sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
307
+ check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
308
+ nested_type_info = ::FFI::MemoryPointer.new(:pointer)
309
+ check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
310
+ v = node_info(nested_type_info)[:type]
311
+
265
312
  {
266
- type: "seq",
313
+ type: "seq(#{v})",
267
314
  shape: []
268
315
  }
269
316
  when :map
270
- # TODO show nested
317
+ map_type_info = ::FFI::MemoryPointer.new(:pointer)
318
+ check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
319
+
320
+ # key
321
+ key_type = ::FFI::MemoryPointer.new(:int)
322
+ check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
323
+ k = FFI::TensorElementDataType[key_type.read_int]
324
+
325
+ # value
326
+ value_type_info = ::FFI::MemoryPointer.new(:pointer)
327
+ check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
328
+ v = node_info(value_type_info)[:type]
329
+
271
330
  {
272
- type: "map",
331
+ type: "map(#{k},#{v})",
273
332
  shape: []
274
333
  }
275
334
  else
@@ -22,5 +22,9 @@ module OnnxRuntime
22
22
  def outputs
23
23
  @session.outputs
24
24
  end
25
+
26
+ def metadata
27
+ @session.modelmeta
28
+ end
25
29
  end
26
30
  end
@@ -1,3 +1,3 @@
1
1
  module OnnxRuntime
2
- VERSION = "0.2.3"
2
+ VERSION = "0.3.0"
3
3
  end
Binary file
Binary file
Binary file
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: onnxruntime
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-01-24 00:00:00.000000000 Z
11
+ date: 2020-03-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: ffi