tensor_stream 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/.rubocop.yml +1 -0
  4. data/Gemfile +1 -1
  5. data/LICENSE.txt +1 -1
  6. data/README.md +34 -34
  7. data/Rakefile +3 -3
  8. data/USAGE_GUIDE.md +235 -0
  9. data/bin/stubgen +20 -0
  10. data/exe/model_utils +2 -2
  11. data/lib/tensor_stream.rb +45 -44
  12. data/lib/tensor_stream/constant.rb +2 -2
  13. data/lib/tensor_stream/control_flow.rb +1 -1
  14. data/lib/tensor_stream/debugging/debugging.rb +2 -2
  15. data/lib/tensor_stream/dynamic_stitch.rb +2 -2
  16. data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
  17. data/lib/tensor_stream/evaluator/buffer.rb +1 -1
  18. data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
  19. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
  20. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
  21. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
  22. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
  23. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
  24. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
  25. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
  26. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
  27. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
  28. data/lib/tensor_stream/exceptions.rb +1 -1
  29. data/lib/tensor_stream/generated_stub/ops.rb +691 -0
  30. data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
  31. data/lib/tensor_stream/graph.rb +18 -18
  32. data/lib/tensor_stream/graph_builder.rb +17 -17
  33. data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
  34. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
  35. data/lib/tensor_stream/graph_keys.rb +3 -3
  36. data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
  37. data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
  38. data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
  39. data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
  40. data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
  41. data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
  42. data/lib/tensor_stream/helpers/op_helper.rb +8 -9
  43. data/lib/tensor_stream/helpers/string_helper.rb +15 -15
  44. data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
  45. data/lib/tensor_stream/images.rb +1 -1
  46. data/lib/tensor_stream/initializer.rb +1 -1
  47. data/lib/tensor_stream/math_gradients.rb +28 -187
  48. data/lib/tensor_stream/monkey_patches/array.rb +1 -1
  49. data/lib/tensor_stream/monkey_patches/float.rb +1 -1
  50. data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
  51. data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
  52. data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
  53. data/lib/tensor_stream/nn/nn_ops.rb +17 -15
  54. data/lib/tensor_stream/op_maker.rb +180 -0
  55. data/lib/tensor_stream/operation.rb +17 -17
  56. data/lib/tensor_stream/ops.rb +95 -384
  57. data/lib/tensor_stream/ops/add.rb +23 -0
  58. data/lib/tensor_stream/ops/argmax.rb +14 -0
  59. data/lib/tensor_stream/ops/argmin.rb +14 -0
  60. data/lib/tensor_stream/ops/case.rb +17 -0
  61. data/lib/tensor_stream/ops/cast.rb +15 -0
  62. data/lib/tensor_stream/ops/ceil.rb +15 -0
  63. data/lib/tensor_stream/ops/const.rb +0 -0
  64. data/lib/tensor_stream/ops/cos.rb +10 -0
  65. data/lib/tensor_stream/ops/div.rb +21 -0
  66. data/lib/tensor_stream/ops/equal.rb +15 -0
  67. data/lib/tensor_stream/ops/expand_dims.rb +17 -0
  68. data/lib/tensor_stream/ops/fill.rb +19 -0
  69. data/lib/tensor_stream/ops/floor.rb +15 -0
  70. data/lib/tensor_stream/ops/floor_div.rb +15 -0
  71. data/lib/tensor_stream/ops/greater.rb +11 -0
  72. data/lib/tensor_stream/ops/greater_equal.rb +11 -0
  73. data/lib/tensor_stream/ops/less_equal.rb +15 -0
  74. data/lib/tensor_stream/ops/log.rb +14 -0
  75. data/lib/tensor_stream/ops/mat_mul.rb +60 -0
  76. data/lib/tensor_stream/ops/max.rb +15 -0
  77. data/lib/tensor_stream/ops/min.rb +15 -0
  78. data/lib/tensor_stream/ops/mod.rb +23 -0
  79. data/lib/tensor_stream/ops/mul.rb +21 -0
  80. data/lib/tensor_stream/ops/negate.rb +14 -0
  81. data/lib/tensor_stream/ops/ones_like.rb +19 -0
  82. data/lib/tensor_stream/ops/pow.rb +25 -0
  83. data/lib/tensor_stream/ops/prod.rb +60 -0
  84. data/lib/tensor_stream/ops/random_uniform.rb +18 -0
  85. data/lib/tensor_stream/ops/range.rb +20 -0
  86. data/lib/tensor_stream/ops/rank.rb +13 -0
  87. data/lib/tensor_stream/ops/reshape.rb +24 -0
  88. data/lib/tensor_stream/ops/round.rb +15 -0
  89. data/lib/tensor_stream/ops/shape.rb +14 -0
  90. data/lib/tensor_stream/ops/sigmoid.rb +10 -0
  91. data/lib/tensor_stream/ops/sign.rb +12 -0
  92. data/lib/tensor_stream/ops/sin.rb +10 -0
  93. data/lib/tensor_stream/ops/size.rb +16 -0
  94. data/lib/tensor_stream/ops/sub.rb +24 -0
  95. data/lib/tensor_stream/ops/sum.rb +27 -0
  96. data/lib/tensor_stream/ops/tan.rb +12 -0
  97. data/lib/tensor_stream/ops/tanh.rb +10 -0
  98. data/lib/tensor_stream/ops/tile.rb +19 -0
  99. data/lib/tensor_stream/ops/zeros.rb +15 -0
  100. data/lib/tensor_stream/placeholder.rb +2 -2
  101. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  102. data/lib/tensor_stream/session.rb +36 -38
  103. data/lib/tensor_stream/tensor.rb +2 -2
  104. data/lib/tensor_stream/tensor_shape.rb +4 -4
  105. data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
  106. data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
  107. data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
  108. data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
  109. data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
  110. data/lib/tensor_stream/train/optimizer.rb +9 -9
  111. data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
  112. data/lib/tensor_stream/train/saver.rb +14 -14
  113. data/lib/tensor_stream/train/slot_creator.rb +6 -6
  114. data/lib/tensor_stream/train/utils.rb +12 -12
  115. data/lib/tensor_stream/trainer.rb +10 -10
  116. data/lib/tensor_stream/types.rb +1 -1
  117. data/lib/tensor_stream/utils.rb +33 -32
  118. data/lib/tensor_stream/utils/freezer.rb +5 -5
  119. data/lib/tensor_stream/variable.rb +5 -5
  120. data/lib/tensor_stream/variable_scope.rb +1 -1
  121. data/lib/tensor_stream/version.rb +1 -1
  122. data/samples/{iris.data → datasets/iris.data} +0 -0
  123. data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
  124. data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
  125. data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
  126. data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
  127. data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
  128. data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
  129. data/samples/regression/linear_regression.rb +63 -0
  130. data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
  131. data/tensor_stream.gemspec +9 -8
  132. metadata +89 -19
  133. data/data_1.json +0 -4764
  134. data/data_2.json +0 -4764
  135. data/data_actual.json +0 -28
  136. data/data_expected.json +0 -28
  137. data/data_input.json +0 -28
  138. data/samples/error.graphml +0 -2755
  139. data/samples/gradient_sample.graphml +0 -1255
  140. data/samples/linear_regression.rb +0 -69
  141. data/samples/multigpu.rb +0 -73
  142. data/samples/raw_neural_net_sample.rb +0 -112
@@ -0,0 +1,24 @@
1
+ # This file has ben automatically generated by stubgen
2
+ # DO NOT EDIT
3
+ #
4
+ module TensorStream
5
+ module OpStub
6
+ <% TensorStream::OpMaker.each_op do |op|%>
7
+ ##
8
+ <% op.description_lines.each do |line|%> # <%= line %>
9
+ <%end%> #
10
+ #<% if op.supports_broadcasting? %> This operation supports broadcasting
11
+ #<% end %>
12
+ # Params:
13
+ <% op.parameters.each do |param| %> # +<%= param[:name] %>+:: <%= param[:description]%><%if param[:validate]%> (of type <%= param[:validate] %>)<%end%>
14
+ <% end %> #
15
+ # Options:
16
+ <% op.options.each do |k, v| %> # +:<%= k %>+:: <%= v[:description]%><% if v[:default_value] != :nil %> default (<%= v[:default_value] %>)<%end%>
17
+ <%end%> def <%= op.operation.to_s %>(<%= (op.expand_params(true) + op.expand_options(true)).join(', ') %>)
18
+ <%= op.generate_body %>
19
+ end
20
+ <% op.aliases.each do |a|%>
21
+ alias_method :<%= a %>, :<%= op.operation %><%end%>
22
+ <% end %>
23
+ end
24
+ end
@@ -12,7 +12,7 @@ module TensorStream
12
12
  @node_keys = []
13
13
  @collections = {
14
14
  :"#{GraphKeys::GLOBAL_VARIABLES}" => [],
15
- :"#{GraphKeys::TRAINABLE_VARIABLES}" => []
15
+ :"#{GraphKeys::TRAINABLE_VARIABLES}" => [],
16
16
  }
17
17
  @constants = {}
18
18
  end
@@ -27,7 +27,7 @@ module TensorStream
27
27
  @node_keys = []
28
28
  @collections = {
29
29
  :"#{GraphKeys::GLOBAL_VARIABLES}" => [],
30
- :"#{GraphKeys::TRAINABLE_VARIABLES}" => []
30
+ :"#{GraphKeys::TRAINABLE_VARIABLES}" => [],
31
31
  }
32
32
  @constants = {}
33
33
  end
@@ -85,14 +85,14 @@ module TensorStream
85
85
  end
86
86
 
87
87
  def add_node(node, name = nil)
88
- raise 'Placeholder cannot be used when eager_execution is enabled' if @eager_execution && node.is_a?(Placeholder)
88
+ raise "Placeholder cannot be used when eager_execution is enabled" if @eager_execution && node.is_a?(Placeholder)
89
89
 
90
90
  if name.nil?
91
91
  node.name = if @nodes[node.name]
92
- uniqunify(node.name)
93
- else
94
- node.name
95
- end
92
+ uniqunify(node.name)
93
+ else
94
+ node.name
95
+ end
96
96
  end
97
97
 
98
98
  node.device = get_device_scope
@@ -129,10 +129,10 @@ module TensorStream
129
129
 
130
130
  def add_op(operation, *args)
131
131
  options = if args.last.is_a?(Hash)
132
- args.pop
133
- else
134
- {}
135
- end
132
+ args.pop
133
+ else
134
+ {}
135
+ end
136
136
 
137
137
  inputs = args.map { |i| TensorStream.convert_to_tensor(i) }.map { |i| i ? i.op : nil }
138
138
 
@@ -141,7 +141,7 @@ module TensorStream
141
141
  new_op.operation = operation
142
142
  new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
143
143
  new_op.rank = new_op.shape.rank
144
- new_op.name = options[:internal_name] || [get_name_scope, options[:name] || set_operation_name(new_op)].compact.reject(&:empty?).join('/')
144
+ new_op.name = options[:internal_name] || [get_name_scope, options[:name] || set_operation_name(new_op)].compact.reject(&:empty?).join("/")
145
145
  new_op.internal = options[:internal]
146
146
 
147
147
  new_op.data_type = new_op.set_data_type(options[:data_type])
@@ -211,7 +211,7 @@ module TensorStream
211
211
  def get_operation_counter
212
212
  @op_counter ||= 0
213
213
 
214
- name = @op_counter.zero? ? '' : "_#{@op_counter}"
214
+ name = @op_counter.zero? ? "" : "_#{@op_counter}"
215
215
 
216
216
  @op_counter += 1
217
217
 
@@ -222,7 +222,7 @@ module TensorStream
222
222
  @placeholder_counter ||= 0
223
223
  @placeholder_counter += 1
224
224
 
225
- return '' if @placeholder_counter == 1
225
+ return "" if @placeholder_counter == 1
226
226
 
227
227
  "_#{@placeholder_counter}"
228
228
  end
@@ -231,14 +231,14 @@ module TensorStream
231
231
  @var_counter ||= 0
232
232
  @var_counter += 1
233
233
 
234
- return '' if @var_counter == 1
234
+ return "" if @var_counter == 1
235
235
  "_#{@var_counter}"
236
236
  end
237
237
 
238
238
  def get_const_counter
239
239
  @const_counter ||= 0
240
240
 
241
- name = @const_counter.zero? ? '' : "_#{@const_counter}"
241
+ name = @const_counter.zero? ? "" : "_#{@const_counter}"
242
242
 
243
243
  @const_counter += 1
244
244
  name
@@ -248,7 +248,7 @@ module TensorStream
248
248
  graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
249
249
  return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil?
250
250
 
251
- graph_thread_storage[:current_scope].join('/')
251
+ graph_thread_storage[:current_scope].join("/")
252
252
  end
253
253
 
254
254
  def get_dependency_scope
@@ -279,7 +279,7 @@ module TensorStream
279
279
  protected
280
280
 
281
281
  def _variable_scope
282
- return VariableScope.new(name: '', reuse: false, initializer: nil) if Thread.current[:tensor_stream_variable_scope].nil? || Thread.current[:tensor_stream_variable_scope].empty?
282
+ return VariableScope.new(name: "", reuse: false, initializer: nil) if Thread.current[:tensor_stream_variable_scope].nil? || Thread.current[:tensor_stream_variable_scope].empty?
283
283
  scope = Thread.current[:tensor_stream_variable_scope].last
284
284
  scope
285
285
  end
@@ -13,48 +13,48 @@ module TensorStream
13
13
  protobuf = TensorStream::Protobuf.new
14
14
  parsed_tree = protobuf.load_from_string(buffer)
15
15
  parsed_tree.each do |node|
16
- next unless node['type'] == 'node'
16
+ next unless node["type"] == "node"
17
17
 
18
18
  # puts "build #{node['name']}"
19
19
  options = protobuf.options_evaluator(node)
20
- options[:name] = node['name']
20
+ options[:name] = node["name"]
21
21
  options[:__graph] = @graph
22
- value = options.delete('value')
22
+ value = options.delete("value")
23
23
  options = symbolize_keys(options)
24
- case node['op']
25
- when 'Const'
24
+ case node["op"]
25
+ when "Const"
26
26
  dimension = shape_eval(value)
27
27
  rank = dimension.size
28
28
  options[:value] = value
29
29
  options[:const] = true
30
30
  TensorStream::Constant.new(options[:dtype] || options[:T], rank, dimension, options)
31
- when 'VariableV2'
31
+ when "VariableV2"
32
32
  # evaluate options
33
33
  shape = options[:shape]
34
34
  i_var(options[:dtype] || options[:T], nil, shape, nil, options)
35
- when 'Placeholder'
35
+ when "Placeholder"
36
36
  shape = options[:shape]
37
37
  TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
38
38
  else
39
- op = underscore(node['op']).to_sym
39
+ op = underscore(node["op"]).to_sym
40
40
  puts "warning unsupported op #{op}" unless TensorStream::Evaluator::RubyEvaluator.ops.key?(op)
41
41
 
42
42
  # map input tensor
43
- inputs = node['input'].map do |input|
44
- input[0] = '' if input.start_with?('^')
43
+ inputs = node["input"].map { |input|
44
+ input[0] = "" if input.start_with?("^")
45
45
 
46
- input_indexed, index = input.split(':')
46
+ input_indexed, index = input.split(":")
47
47
 
48
48
  tensor = if index && index.to_i > 0
49
- @graph.get_tensor_by_name(input_indexed)[index.to_i]
50
- else
51
- @graph.get_tensor_by_name(input)
52
- end
49
+ @graph.get_tensor_by_name(input_indexed)[index.to_i]
50
+ else
51
+ @graph.get_tensor_by_name(input)
52
+ end
53
53
 
54
54
  raise "tensor not found by name #{input}" if tensor.nil?
55
55
 
56
56
  tensor
57
- end
57
+ }
58
58
 
59
59
  options[:data_type] = options.delete(:T)
60
60
  Graph.get_default_graph.add_op!(op, *inputs, options)
@@ -64,4 +64,4 @@ module TensorStream
64
64
  @graph
65
65
  end
66
66
  end
67
- end
67
+ end
@@ -1,4 +1,4 @@
1
- require 'yaml'
1
+ require "yaml"
2
2
 
3
3
  module TensorStream
4
4
  # A .pb graph deserializer
@@ -14,7 +14,7 @@ module TensorStream
14
14
  # parsers a protobuf file and spits out
15
15
  # a ruby hash
16
16
  def load(pbfile)
17
- f = File.new(pbfile, 'r')
17
+ f = File.new(pbfile, "r")
18
18
  lines = []
19
19
  while !f.eof? && (str = f.readline.strip)
20
20
  lines << str
@@ -23,38 +23,38 @@ module TensorStream
23
23
  end
24
24
 
25
25
  def parse_value(value_node)
26
- return unless value_node['tensor']
26
+ return unless value_node["tensor"]
27
27
 
28
- evaluate_tensor_node(value_node['tensor'])
28
+ evaluate_tensor_node(value_node["tensor"])
29
29
  end
30
30
 
31
31
  def evaluate_tensor_node(node)
32
- if !node['shape'].empty? && node['tensor_content']
33
- content = node['tensor_content']
34
- unpacked = eval(%Q("#{content}"))
32
+ if !node["shape"].empty? && node["tensor_content"]
33
+ content = node["tensor_content"]
34
+ unpacked = eval(%("#{content}"))
35
35
 
36
- if node['dtype'] == 'DT_FLOAT'
37
- TensorShape.reshape(unpacked.unpack('f*'), node['shape'])
38
- elsif node['dtype'] == 'DT_INT32'
39
- TensorShape.reshape(unpacked.unpack('l*'), node['shape'])
40
- elsif node['dtype'] == 'DT_STRING'
41
- node['string_val']
36
+ if node["dtype"] == "DT_FLOAT"
37
+ TensorShape.reshape(unpacked.unpack("f*"), node["shape"])
38
+ elsif node["dtype"] == "DT_INT32"
39
+ TensorShape.reshape(unpacked.unpack("l*"), node["shape"])
40
+ elsif node["dtype"] == "DT_STRING"
41
+ node["string_val"]
42
42
  else
43
- raise "unknown dtype #{node['dtype']}"
43
+ raise "unknown dtype #{node["dtype"]}"
44
44
  end
45
45
  else
46
46
 
47
- val = if node['dtype'] == 'DT_FLOAT'
48
- node['float_val'] ? node['float_val'].to_f : []
49
- elsif node['dtype'] == 'DT_INT32'
50
- node['int_val'] ? node['int_val'].to_i : []
51
- elsif node['dtype'] == 'DT_STRING'
52
- node['string_val']
53
- else
54
- raise "unknown dtype #{node['dtype']}"
55
- end
47
+ val = if node["dtype"] == "DT_FLOAT"
48
+ node["float_val"] ? node["float_val"].to_f : []
49
+ elsif node["dtype"] == "DT_INT32"
50
+ node["int_val"] ? node["int_val"].to_i : []
51
+ elsif node["dtype"] == "DT_STRING"
52
+ node["string_val"]
53
+ else
54
+ raise "unknown dtype #{node["dtype"]}"
55
+ end
56
56
 
57
- if node['shape'] == [1]
57
+ if node["shape"] == [1]
58
58
  [val]
59
59
  else
60
60
  val
@@ -63,16 +63,16 @@ module TensorStream
63
63
  end
64
64
 
65
65
  def map_type_to_ts(attr_value)
66
- case(attr_value)
67
- when 'DT_FLOAT'
66
+ case attr_value
67
+ when "DT_FLOAT"
68
68
  :float32
69
- when 'DT_INT32'
69
+ when "DT_INT32"
70
70
  :int32
71
- when 'DT_INT64'
71
+ when "DT_INT64"
72
72
  :int64
73
- when 'DT_STRING'
73
+ when "DT_STRING"
74
74
  :string
75
- when 'DT_BOOL'
75
+ when "DT_BOOL"
76
76
  :boolean
77
77
  else
78
78
  raise "unknown type #{attr_value}"
@@ -80,21 +80,21 @@ module TensorStream
80
80
  end
81
81
 
82
82
  def options_evaluator(node)
83
- return {} if node['attributes'].nil?
83
+ return {} if node["attributes"].nil?
84
84
 
85
- node['attributes'].map do |attribute|
86
- attr_type, attr_value = attribute['value'].flat_map { |k, v| [k, v] }
85
+ node["attributes"].map { |attribute|
86
+ attr_type, attr_value = attribute["value"].flat_map { |k, v| [k, v] }
87
87
 
88
- if attr_type == 'tensor'
88
+ if attr_type == "tensor"
89
89
  attr_value = evaluate_tensor_node(attr_value)
90
- elsif attr_type == 'type'
90
+ elsif attr_type == "type"
91
91
  attr_value = map_type_to_ts(attr_value)
92
- elsif attr_type == 'b'
93
- attr_value = attr_value == 'true'
92
+ elsif attr_type == "b"
93
+ attr_value = attr_value == "true"
94
94
  end
95
95
 
96
- [attribute['key'], attr_value]
97
- end.to_h
96
+ [attribute["key"], attr_value]
97
+ }.to_h
98
98
  end
99
99
 
100
100
  protected
@@ -108,126 +108,126 @@ module TensorStream
108
108
  lines.each do |str|
109
109
  case state
110
110
  when :top
111
- node['type'] = parse_node_name(str)
111
+ node["type"] = parse_node_name(str)
112
112
  state = :node_context
113
113
  next
114
114
  when :node_context
115
- if str == 'attr {'
115
+ if str == "attr {"
116
116
  state = :attr_context
117
117
  node_attr = {}
118
- node['attributes'] ||= []
119
- node['attributes'] << node_attr
118
+ node["attributes"] ||= []
119
+ node["attributes"] << node_attr
120
120
  next
121
- elsif str == '}'
121
+ elsif str == "}"
122
122
  state = :top
123
123
  block << node
124
124
  node = {}
125
125
  next
126
126
  else
127
- key, value = str.split(':', 2)
128
- if key == 'input'
129
- node['input'] ||= []
130
- node['input'] << process_value(value.strip)
127
+ key, value = str.split(":", 2)
128
+ if key == "input"
129
+ node["input"] ||= []
130
+ node["input"] << process_value(value.strip)
131
131
  else
132
132
  node[key] = process_value(value.strip)
133
133
  end
134
134
  end
135
135
  when :attr_context
136
- if str == 'value {'
136
+ if str == "value {"
137
137
  state = :value_context
138
- node_attr['value'] = {}
138
+ node_attr["value"] = {}
139
139
  next
140
- elsif str == '}'
140
+ elsif str == "}"
141
141
  state = :node_context
142
142
  next
143
143
  else
144
- key, value = str.split(':', 2)
144
+ key, value = str.split(":", 2)
145
145
  node_attr[key] = process_value(value.strip)
146
146
  end
147
147
  when :value_context
148
- if str == 'list {'
148
+ if str == "list {"
149
149
  state = :list_context
150
- node_attr['value'] = []
150
+ node_attr["value"] = []
151
151
  next
152
- elsif str == 'shape {'
152
+ elsif str == "shape {"
153
153
  state = :shape_context
154
- node_attr['value']['shape'] = []
154
+ node_attr["value"]["shape"] = []
155
155
  next
156
- elsif str == 'tensor {'
156
+ elsif str == "tensor {"
157
157
  state = :tensor_context
158
- node_attr['value']['tensor'] = {}
158
+ node_attr["value"]["tensor"] = {}
159
159
  next
160
- elsif str == '}'
160
+ elsif str == "}"
161
161
  state = :attr_context
162
162
  next
163
163
  else
164
- key, value = str.split(':', 2)
165
- if key == 'dtype'
166
- node_attr['value']['dtype'] = value.strip
167
- elsif key == 'type'
168
- node_attr['value']['type'] = value.strip
164
+ key, value = str.split(":", 2)
165
+ if key == "dtype"
166
+ node_attr["value"]["dtype"] = value.strip
167
+ elsif key == "type"
168
+ node_attr["value"]["type"] = value.strip
169
169
  else
170
- node_attr['value'][key] = process_value(value.strip)
170
+ node_attr["value"][key] = process_value(value.strip)
171
171
  end
172
172
  end
173
173
  when :list_context
174
- if str == '}'
174
+ if str == "}"
175
175
  state = :value_context
176
176
  next
177
177
  else
178
- key, value = str.split(':', 2)
179
- node_attr['value'] << { key => value }
178
+ key, value = str.split(":", 2)
179
+ node_attr["value"] << {key => value}
180
180
  end
181
181
  when :tensor_context
182
- if str == 'tensor_shape {'
182
+ if str == "tensor_shape {"
183
183
  state = :tensor_shape_context
184
- node_attr['value']['tensor']['shape'] = []
184
+ node_attr["value"]["tensor"]["shape"] = []
185
185
  next
186
- elsif str == '}'
186
+ elsif str == "}"
187
187
  state = :value_context
188
188
  next
189
189
  else
190
- key, value = str.split(':', 2)
191
- if node_attr['value']['tensor'][key] && !node_attr['value']['tensor'][key].is_a?(Array)
192
- node_attr['value']['tensor'][key] = [node_attr['value']['tensor'][key]]
193
- node_attr['value']['tensor'][key] << process_value(value.strip)
194
- elsif node_attr['value']['tensor'][key]
195
- node_attr['value']['tensor'][key] << process_value(value.strip)
190
+ key, value = str.split(":", 2)
191
+ if node_attr["value"]["tensor"][key] && !node_attr["value"]["tensor"][key].is_a?(Array)
192
+ node_attr["value"]["tensor"][key] = [node_attr["value"]["tensor"][key]]
193
+ node_attr["value"]["tensor"][key] << process_value(value.strip)
194
+ elsif node_attr["value"]["tensor"][key]
195
+ node_attr["value"]["tensor"][key] << process_value(value.strip)
196
196
  else
197
- node_attr['value']['tensor'][key] = process_value(value.strip)
197
+ node_attr["value"]["tensor"][key] = process_value(value.strip)
198
198
  end
199
199
  end
200
200
  when :tensor_shape_context
201
- if str == 'dim {'
201
+ if str == "dim {"
202
202
  state = :tensor_shape_dim_context
203
203
  next
204
- elsif str == '}'
204
+ elsif str == "}"
205
205
  state = :tensor_context
206
206
  next
207
207
  end
208
208
  when :shape_context
209
- if str == '}'
209
+ if str == "}"
210
210
  state = :value_context
211
211
  next
212
- elsif str == 'dim {'
212
+ elsif str == "dim {"
213
213
  state = :shape_dim_context
214
214
  next
215
215
  end
216
216
  when :shape_dim_context
217
- if str == '}'
217
+ if str == "}"
218
218
  state = :shape_context
219
219
  next
220
220
  else
221
- _key, value = str.split(':', 2)
222
- node_attr['value']['shape'] << value.strip.to_i
221
+ _key, value = str.split(":", 2)
222
+ node_attr["value"]["shape"] << value.strip.to_i
223
223
  end
224
224
  when :tensor_shape_dim_context
225
- if str == '}'
225
+ if str == "}"
226
226
  state = :tensor_shape_context
227
227
  next
228
228
  else
229
- _key, value = str.split(':', 2)
230
- node_attr['value']['tensor']['shape'] << value.strip.to_i
229
+ _key, value = str.split(":", 2)
230
+ node_attr["value"]["tensor"]["shape"] << value.strip.to_i
231
231
  end
232
232
  end
233
233
  end
@@ -236,22 +236,22 @@ module TensorStream
236
236
  end
237
237
 
238
238
  def parse_node_name(str)
239
- str.split(' ')[0]
239
+ str.split(" ")[0]
240
240
  end
241
241
 
242
242
  def process_value(value)
243
243
  if value.start_with?('"')
244
- unescape(value.gsub!(/\A"|"\Z/, ''))
244
+ unescape(value.gsub!(/\A"|"\Z/, ""))
245
245
  else
246
246
  unescape(value)
247
247
  end
248
248
  end
249
249
 
250
250
  UNESCAPES = {
251
- 'a' => "\x07", 'b' => "\x08", 't' => "\x09",
252
- 'n' => "\x0a", 'v' => "\x0b", 'f' => "\x0c",
253
- 'r' => "\x0d", 'e' => "\x1b", "\\\\" => "\x5c",
254
- "\"" => "\x22", "'" => "\x27"
251
+ "a" => "\x07", "b" => "\x08", "t" => "\x09",
252
+ "n" => "\x0a", "v" => "\x0b", "f" => "\x0c",
253
+ "r" => "\x0d", "e" => "\x1b", "\\\\" => "\x5c",
254
+ "\"" => "\x22", "'" => "\x27",
255
255
  }.freeze
256
256
 
257
257
  def unescape(str)
@@ -260,11 +260,11 @@ module TensorStream
260
260
  if $1
261
261
  $1 == '\\' ? '\\' : UNESCAPES[$1]
262
262
  elsif $2 # escape \u0000 unicode
263
- ["#{$2}".hex].pack('U*')
263
+ [$2.to_s.hex].pack("U*")
264
264
  elsif $3 # escape \0xff or \xff
265
- [$3].pack('H2')
265
+ [$3].pack("H2")
266
266
  end
267
267
  end
268
268
  end
269
269
  end
270
- end
270
+ end