tensor_stream 1.0.0 → 1.0.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/.rubocop.yml +1 -0
  4. data/Gemfile +1 -1
  5. data/LICENSE.txt +1 -1
  6. data/README.md +34 -34
  7. data/Rakefile +3 -3
  8. data/USAGE_GUIDE.md +235 -0
  9. data/bin/stubgen +20 -0
  10. data/exe/model_utils +2 -2
  11. data/lib/tensor_stream.rb +45 -44
  12. data/lib/tensor_stream/constant.rb +2 -2
  13. data/lib/tensor_stream/control_flow.rb +1 -1
  14. data/lib/tensor_stream/debugging/debugging.rb +2 -2
  15. data/lib/tensor_stream/dynamic_stitch.rb +2 -2
  16. data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
  17. data/lib/tensor_stream/evaluator/buffer.rb +1 -1
  18. data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
  19. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
  20. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
  21. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
  22. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
  23. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
  24. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
  25. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
  26. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
  27. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
  28. data/lib/tensor_stream/exceptions.rb +1 -1
  29. data/lib/tensor_stream/generated_stub/ops.rb +691 -0
  30. data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
  31. data/lib/tensor_stream/graph.rb +18 -18
  32. data/lib/tensor_stream/graph_builder.rb +17 -17
  33. data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
  34. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
  35. data/lib/tensor_stream/graph_keys.rb +3 -3
  36. data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
  37. data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
  38. data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
  39. data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
  40. data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
  41. data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
  42. data/lib/tensor_stream/helpers/op_helper.rb +8 -9
  43. data/lib/tensor_stream/helpers/string_helper.rb +15 -15
  44. data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
  45. data/lib/tensor_stream/images.rb +1 -1
  46. data/lib/tensor_stream/initializer.rb +1 -1
  47. data/lib/tensor_stream/math_gradients.rb +28 -187
  48. data/lib/tensor_stream/monkey_patches/array.rb +1 -1
  49. data/lib/tensor_stream/monkey_patches/float.rb +1 -1
  50. data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
  51. data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
  52. data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
  53. data/lib/tensor_stream/nn/nn_ops.rb +17 -15
  54. data/lib/tensor_stream/op_maker.rb +180 -0
  55. data/lib/tensor_stream/operation.rb +17 -17
  56. data/lib/tensor_stream/ops.rb +95 -384
  57. data/lib/tensor_stream/ops/add.rb +23 -0
  58. data/lib/tensor_stream/ops/argmax.rb +14 -0
  59. data/lib/tensor_stream/ops/argmin.rb +14 -0
  60. data/lib/tensor_stream/ops/case.rb +17 -0
  61. data/lib/tensor_stream/ops/cast.rb +15 -0
  62. data/lib/tensor_stream/ops/ceil.rb +15 -0
  63. data/lib/tensor_stream/ops/const.rb +0 -0
  64. data/lib/tensor_stream/ops/cos.rb +10 -0
  65. data/lib/tensor_stream/ops/div.rb +21 -0
  66. data/lib/tensor_stream/ops/equal.rb +15 -0
  67. data/lib/tensor_stream/ops/expand_dims.rb +17 -0
  68. data/lib/tensor_stream/ops/fill.rb +19 -0
  69. data/lib/tensor_stream/ops/floor.rb +15 -0
  70. data/lib/tensor_stream/ops/floor_div.rb +15 -0
  71. data/lib/tensor_stream/ops/greater.rb +11 -0
  72. data/lib/tensor_stream/ops/greater_equal.rb +11 -0
  73. data/lib/tensor_stream/ops/less_equal.rb +15 -0
  74. data/lib/tensor_stream/ops/log.rb +14 -0
  75. data/lib/tensor_stream/ops/mat_mul.rb +60 -0
  76. data/lib/tensor_stream/ops/max.rb +15 -0
  77. data/lib/tensor_stream/ops/min.rb +15 -0
  78. data/lib/tensor_stream/ops/mod.rb +23 -0
  79. data/lib/tensor_stream/ops/mul.rb +21 -0
  80. data/lib/tensor_stream/ops/negate.rb +14 -0
  81. data/lib/tensor_stream/ops/ones_like.rb +19 -0
  82. data/lib/tensor_stream/ops/pow.rb +25 -0
  83. data/lib/tensor_stream/ops/prod.rb +60 -0
  84. data/lib/tensor_stream/ops/random_uniform.rb +18 -0
  85. data/lib/tensor_stream/ops/range.rb +20 -0
  86. data/lib/tensor_stream/ops/rank.rb +13 -0
  87. data/lib/tensor_stream/ops/reshape.rb +24 -0
  88. data/lib/tensor_stream/ops/round.rb +15 -0
  89. data/lib/tensor_stream/ops/shape.rb +14 -0
  90. data/lib/tensor_stream/ops/sigmoid.rb +10 -0
  91. data/lib/tensor_stream/ops/sign.rb +12 -0
  92. data/lib/tensor_stream/ops/sin.rb +10 -0
  93. data/lib/tensor_stream/ops/size.rb +16 -0
  94. data/lib/tensor_stream/ops/sub.rb +24 -0
  95. data/lib/tensor_stream/ops/sum.rb +27 -0
  96. data/lib/tensor_stream/ops/tan.rb +12 -0
  97. data/lib/tensor_stream/ops/tanh.rb +10 -0
  98. data/lib/tensor_stream/ops/tile.rb +19 -0
  99. data/lib/tensor_stream/ops/zeros.rb +15 -0
  100. data/lib/tensor_stream/placeholder.rb +2 -2
  101. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  102. data/lib/tensor_stream/session.rb +36 -38
  103. data/lib/tensor_stream/tensor.rb +2 -2
  104. data/lib/tensor_stream/tensor_shape.rb +4 -4
  105. data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
  106. data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
  107. data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
  108. data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
  109. data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
  110. data/lib/tensor_stream/train/optimizer.rb +9 -9
  111. data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
  112. data/lib/tensor_stream/train/saver.rb +14 -14
  113. data/lib/tensor_stream/train/slot_creator.rb +6 -6
  114. data/lib/tensor_stream/train/utils.rb +12 -12
  115. data/lib/tensor_stream/trainer.rb +10 -10
  116. data/lib/tensor_stream/types.rb +1 -1
  117. data/lib/tensor_stream/utils.rb +33 -32
  118. data/lib/tensor_stream/utils/freezer.rb +5 -5
  119. data/lib/tensor_stream/variable.rb +5 -5
  120. data/lib/tensor_stream/variable_scope.rb +1 -1
  121. data/lib/tensor_stream/version.rb +1 -1
  122. data/samples/{iris.data → datasets/iris.data} +0 -0
  123. data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
  124. data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
  125. data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
  126. data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
  127. data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
  128. data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
  129. data/samples/regression/linear_regression.rb +63 -0
  130. data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
  131. data/tensor_stream.gemspec +9 -8
  132. metadata +89 -19
  133. data/data_1.json +0 -4764
  134. data/data_2.json +0 -4764
  135. data/data_actual.json +0 -28
  136. data/data_expected.json +0 -28
  137. data/data_input.json +0 -28
  138. data/samples/error.graphml +0 -2755
  139. data/samples/gradient_sample.graphml +0 -1255
  140. data/samples/linear_regression.rb +0 -69
  141. data/samples/multigpu.rb +0 -73
  142. data/samples/raw_neural_net_sample.rb +0 -112
@@ -0,0 +1,24 @@
1
+ # This file has ben automatically generated by stubgen
2
+ # DO NOT EDIT
3
+ #
4
+ module TensorStream
5
+ module OpStub
6
+ <% TensorStream::OpMaker.each_op do |op|%>
7
+ ##
8
+ <% op.description_lines.each do |line|%> # <%= line %>
9
+ <%end%> #
10
+ #<% if op.supports_broadcasting? %> This operation supports broadcasting
11
+ #<% end %>
12
+ # Params:
13
+ <% op.parameters.each do |param| %> # +<%= param[:name] %>+:: <%= param[:description]%><%if param[:validate]%> (of type <%= param[:validate] %>)<%end%>
14
+ <% end %> #
15
+ # Options:
16
+ <% op.options.each do |k, v| %> # +:<%= k %>+:: <%= v[:description]%><% if v[:default_value] != :nil %> default (<%= v[:default_value] %>)<%end%>
17
+ <%end%> def <%= op.operation.to_s %>(<%= (op.expand_params(true) + op.expand_options(true)).join(', ') %>)
18
+ <%= op.generate_body %>
19
+ end
20
+ <% op.aliases.each do |a|%>
21
+ alias_method :<%= a %>, :<%= op.operation %><%end%>
22
+ <% end %>
23
+ end
24
+ end
@@ -12,7 +12,7 @@ module TensorStream
12
12
  @node_keys = []
13
13
  @collections = {
14
14
  :"#{GraphKeys::GLOBAL_VARIABLES}" => [],
15
- :"#{GraphKeys::TRAINABLE_VARIABLES}" => []
15
+ :"#{GraphKeys::TRAINABLE_VARIABLES}" => [],
16
16
  }
17
17
  @constants = {}
18
18
  end
@@ -27,7 +27,7 @@ module TensorStream
27
27
  @node_keys = []
28
28
  @collections = {
29
29
  :"#{GraphKeys::GLOBAL_VARIABLES}" => [],
30
- :"#{GraphKeys::TRAINABLE_VARIABLES}" => []
30
+ :"#{GraphKeys::TRAINABLE_VARIABLES}" => [],
31
31
  }
32
32
  @constants = {}
33
33
  end
@@ -85,14 +85,14 @@ module TensorStream
85
85
  end
86
86
 
87
87
  def add_node(node, name = nil)
88
- raise 'Placeholder cannot be used when eager_execution is enabled' if @eager_execution && node.is_a?(Placeholder)
88
+ raise "Placeholder cannot be used when eager_execution is enabled" if @eager_execution && node.is_a?(Placeholder)
89
89
 
90
90
  if name.nil?
91
91
  node.name = if @nodes[node.name]
92
- uniqunify(node.name)
93
- else
94
- node.name
95
- end
92
+ uniqunify(node.name)
93
+ else
94
+ node.name
95
+ end
96
96
  end
97
97
 
98
98
  node.device = get_device_scope
@@ -129,10 +129,10 @@ module TensorStream
129
129
 
130
130
  def add_op(operation, *args)
131
131
  options = if args.last.is_a?(Hash)
132
- args.pop
133
- else
134
- {}
135
- end
132
+ args.pop
133
+ else
134
+ {}
135
+ end
136
136
 
137
137
  inputs = args.map { |i| TensorStream.convert_to_tensor(i) }.map { |i| i ? i.op : nil }
138
138
 
@@ -141,7 +141,7 @@ module TensorStream
141
141
  new_op.operation = operation
142
142
  new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
143
143
  new_op.rank = new_op.shape.rank
144
- new_op.name = options[:internal_name] || [get_name_scope, options[:name] || set_operation_name(new_op)].compact.reject(&:empty?).join('/')
144
+ new_op.name = options[:internal_name] || [get_name_scope, options[:name] || set_operation_name(new_op)].compact.reject(&:empty?).join("/")
145
145
  new_op.internal = options[:internal]
146
146
 
147
147
  new_op.data_type = new_op.set_data_type(options[:data_type])
@@ -211,7 +211,7 @@ module TensorStream
211
211
  def get_operation_counter
212
212
  @op_counter ||= 0
213
213
 
214
- name = @op_counter.zero? ? '' : "_#{@op_counter}"
214
+ name = @op_counter.zero? ? "" : "_#{@op_counter}"
215
215
 
216
216
  @op_counter += 1
217
217
 
@@ -222,7 +222,7 @@ module TensorStream
222
222
  @placeholder_counter ||= 0
223
223
  @placeholder_counter += 1
224
224
 
225
- return '' if @placeholder_counter == 1
225
+ return "" if @placeholder_counter == 1
226
226
 
227
227
  "_#{@placeholder_counter}"
228
228
  end
@@ -231,14 +231,14 @@ module TensorStream
231
231
  @var_counter ||= 0
232
232
  @var_counter += 1
233
233
 
234
- return '' if @var_counter == 1
234
+ return "" if @var_counter == 1
235
235
  "_#{@var_counter}"
236
236
  end
237
237
 
238
238
  def get_const_counter
239
239
  @const_counter ||= 0
240
240
 
241
- name = @const_counter.zero? ? '' : "_#{@const_counter}"
241
+ name = @const_counter.zero? ? "" : "_#{@const_counter}"
242
242
 
243
243
  @const_counter += 1
244
244
  name
@@ -248,7 +248,7 @@ module TensorStream
248
248
  graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
249
249
  return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil?
250
250
 
251
- graph_thread_storage[:current_scope].join('/')
251
+ graph_thread_storage[:current_scope].join("/")
252
252
  end
253
253
 
254
254
  def get_dependency_scope
@@ -279,7 +279,7 @@ module TensorStream
279
279
  protected
280
280
 
281
281
  def _variable_scope
282
- return VariableScope.new(name: '', reuse: false, initializer: nil) if Thread.current[:tensor_stream_variable_scope].nil? || Thread.current[:tensor_stream_variable_scope].empty?
282
+ return VariableScope.new(name: "", reuse: false, initializer: nil) if Thread.current[:tensor_stream_variable_scope].nil? || Thread.current[:tensor_stream_variable_scope].empty?
283
283
  scope = Thread.current[:tensor_stream_variable_scope].last
284
284
  scope
285
285
  end
@@ -13,48 +13,48 @@ module TensorStream
13
13
  protobuf = TensorStream::Protobuf.new
14
14
  parsed_tree = protobuf.load_from_string(buffer)
15
15
  parsed_tree.each do |node|
16
- next unless node['type'] == 'node'
16
+ next unless node["type"] == "node"
17
17
 
18
18
  # puts "build #{node['name']}"
19
19
  options = protobuf.options_evaluator(node)
20
- options[:name] = node['name']
20
+ options[:name] = node["name"]
21
21
  options[:__graph] = @graph
22
- value = options.delete('value')
22
+ value = options.delete("value")
23
23
  options = symbolize_keys(options)
24
- case node['op']
25
- when 'Const'
24
+ case node["op"]
25
+ when "Const"
26
26
  dimension = shape_eval(value)
27
27
  rank = dimension.size
28
28
  options[:value] = value
29
29
  options[:const] = true
30
30
  TensorStream::Constant.new(options[:dtype] || options[:T], rank, dimension, options)
31
- when 'VariableV2'
31
+ when "VariableV2"
32
32
  # evaluate options
33
33
  shape = options[:shape]
34
34
  i_var(options[:dtype] || options[:T], nil, shape, nil, options)
35
- when 'Placeholder'
35
+ when "Placeholder"
36
36
  shape = options[:shape]
37
37
  TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
38
38
  else
39
- op = underscore(node['op']).to_sym
39
+ op = underscore(node["op"]).to_sym
40
40
  puts "warning unsupported op #{op}" unless TensorStream::Evaluator::RubyEvaluator.ops.key?(op)
41
41
 
42
42
  # map input tensor
43
- inputs = node['input'].map do |input|
44
- input[0] = '' if input.start_with?('^')
43
+ inputs = node["input"].map { |input|
44
+ input[0] = "" if input.start_with?("^")
45
45
 
46
- input_indexed, index = input.split(':')
46
+ input_indexed, index = input.split(":")
47
47
 
48
48
  tensor = if index && index.to_i > 0
49
- @graph.get_tensor_by_name(input_indexed)[index.to_i]
50
- else
51
- @graph.get_tensor_by_name(input)
52
- end
49
+ @graph.get_tensor_by_name(input_indexed)[index.to_i]
50
+ else
51
+ @graph.get_tensor_by_name(input)
52
+ end
53
53
 
54
54
  raise "tensor not found by name #{input}" if tensor.nil?
55
55
 
56
56
  tensor
57
- end
57
+ }
58
58
 
59
59
  options[:data_type] = options.delete(:T)
60
60
  Graph.get_default_graph.add_op!(op, *inputs, options)
@@ -64,4 +64,4 @@ module TensorStream
64
64
  @graph
65
65
  end
66
66
  end
67
- end
67
+ end
@@ -1,4 +1,4 @@
1
- require 'yaml'
1
+ require "yaml"
2
2
 
3
3
  module TensorStream
4
4
  # A .pb graph deserializer
@@ -14,7 +14,7 @@ module TensorStream
14
14
  # parsers a protobuf file and spits out
15
15
  # a ruby hash
16
16
  def load(pbfile)
17
- f = File.new(pbfile, 'r')
17
+ f = File.new(pbfile, "r")
18
18
  lines = []
19
19
  while !f.eof? && (str = f.readline.strip)
20
20
  lines << str
@@ -23,38 +23,38 @@ module TensorStream
23
23
  end
24
24
 
25
25
  def parse_value(value_node)
26
- return unless value_node['tensor']
26
+ return unless value_node["tensor"]
27
27
 
28
- evaluate_tensor_node(value_node['tensor'])
28
+ evaluate_tensor_node(value_node["tensor"])
29
29
  end
30
30
 
31
31
  def evaluate_tensor_node(node)
32
- if !node['shape'].empty? && node['tensor_content']
33
- content = node['tensor_content']
34
- unpacked = eval(%Q("#{content}"))
32
+ if !node["shape"].empty? && node["tensor_content"]
33
+ content = node["tensor_content"]
34
+ unpacked = eval(%("#{content}"))
35
35
 
36
- if node['dtype'] == 'DT_FLOAT'
37
- TensorShape.reshape(unpacked.unpack('f*'), node['shape'])
38
- elsif node['dtype'] == 'DT_INT32'
39
- TensorShape.reshape(unpacked.unpack('l*'), node['shape'])
40
- elsif node['dtype'] == 'DT_STRING'
41
- node['string_val']
36
+ if node["dtype"] == "DT_FLOAT"
37
+ TensorShape.reshape(unpacked.unpack("f*"), node["shape"])
38
+ elsif node["dtype"] == "DT_INT32"
39
+ TensorShape.reshape(unpacked.unpack("l*"), node["shape"])
40
+ elsif node["dtype"] == "DT_STRING"
41
+ node["string_val"]
42
42
  else
43
- raise "unknown dtype #{node['dtype']}"
43
+ raise "unknown dtype #{node["dtype"]}"
44
44
  end
45
45
  else
46
46
 
47
- val = if node['dtype'] == 'DT_FLOAT'
48
- node['float_val'] ? node['float_val'].to_f : []
49
- elsif node['dtype'] == 'DT_INT32'
50
- node['int_val'] ? node['int_val'].to_i : []
51
- elsif node['dtype'] == 'DT_STRING'
52
- node['string_val']
53
- else
54
- raise "unknown dtype #{node['dtype']}"
55
- end
47
+ val = if node["dtype"] == "DT_FLOAT"
48
+ node["float_val"] ? node["float_val"].to_f : []
49
+ elsif node["dtype"] == "DT_INT32"
50
+ node["int_val"] ? node["int_val"].to_i : []
51
+ elsif node["dtype"] == "DT_STRING"
52
+ node["string_val"]
53
+ else
54
+ raise "unknown dtype #{node["dtype"]}"
55
+ end
56
56
 
57
- if node['shape'] == [1]
57
+ if node["shape"] == [1]
58
58
  [val]
59
59
  else
60
60
  val
@@ -63,16 +63,16 @@ module TensorStream
63
63
  end
64
64
 
65
65
  def map_type_to_ts(attr_value)
66
- case(attr_value)
67
- when 'DT_FLOAT'
66
+ case attr_value
67
+ when "DT_FLOAT"
68
68
  :float32
69
- when 'DT_INT32'
69
+ when "DT_INT32"
70
70
  :int32
71
- when 'DT_INT64'
71
+ when "DT_INT64"
72
72
  :int64
73
- when 'DT_STRING'
73
+ when "DT_STRING"
74
74
  :string
75
- when 'DT_BOOL'
75
+ when "DT_BOOL"
76
76
  :boolean
77
77
  else
78
78
  raise "unknown type #{attr_value}"
@@ -80,21 +80,21 @@ module TensorStream
80
80
  end
81
81
 
82
82
  def options_evaluator(node)
83
- return {} if node['attributes'].nil?
83
+ return {} if node["attributes"].nil?
84
84
 
85
- node['attributes'].map do |attribute|
86
- attr_type, attr_value = attribute['value'].flat_map { |k, v| [k, v] }
85
+ node["attributes"].map { |attribute|
86
+ attr_type, attr_value = attribute["value"].flat_map { |k, v| [k, v] }
87
87
 
88
- if attr_type == 'tensor'
88
+ if attr_type == "tensor"
89
89
  attr_value = evaluate_tensor_node(attr_value)
90
- elsif attr_type == 'type'
90
+ elsif attr_type == "type"
91
91
  attr_value = map_type_to_ts(attr_value)
92
- elsif attr_type == 'b'
93
- attr_value = attr_value == 'true'
92
+ elsif attr_type == "b"
93
+ attr_value = attr_value == "true"
94
94
  end
95
95
 
96
- [attribute['key'], attr_value]
97
- end.to_h
96
+ [attribute["key"], attr_value]
97
+ }.to_h
98
98
  end
99
99
 
100
100
  protected
@@ -108,126 +108,126 @@ module TensorStream
108
108
  lines.each do |str|
109
109
  case state
110
110
  when :top
111
- node['type'] = parse_node_name(str)
111
+ node["type"] = parse_node_name(str)
112
112
  state = :node_context
113
113
  next
114
114
  when :node_context
115
- if str == 'attr {'
115
+ if str == "attr {"
116
116
  state = :attr_context
117
117
  node_attr = {}
118
- node['attributes'] ||= []
119
- node['attributes'] << node_attr
118
+ node["attributes"] ||= []
119
+ node["attributes"] << node_attr
120
120
  next
121
- elsif str == '}'
121
+ elsif str == "}"
122
122
  state = :top
123
123
  block << node
124
124
  node = {}
125
125
  next
126
126
  else
127
- key, value = str.split(':', 2)
128
- if key == 'input'
129
- node['input'] ||= []
130
- node['input'] << process_value(value.strip)
127
+ key, value = str.split(":", 2)
128
+ if key == "input"
129
+ node["input"] ||= []
130
+ node["input"] << process_value(value.strip)
131
131
  else
132
132
  node[key] = process_value(value.strip)
133
133
  end
134
134
  end
135
135
  when :attr_context
136
- if str == 'value {'
136
+ if str == "value {"
137
137
  state = :value_context
138
- node_attr['value'] = {}
138
+ node_attr["value"] = {}
139
139
  next
140
- elsif str == '}'
140
+ elsif str == "}"
141
141
  state = :node_context
142
142
  next
143
143
  else
144
- key, value = str.split(':', 2)
144
+ key, value = str.split(":", 2)
145
145
  node_attr[key] = process_value(value.strip)
146
146
  end
147
147
  when :value_context
148
- if str == 'list {'
148
+ if str == "list {"
149
149
  state = :list_context
150
- node_attr['value'] = []
150
+ node_attr["value"] = []
151
151
  next
152
- elsif str == 'shape {'
152
+ elsif str == "shape {"
153
153
  state = :shape_context
154
- node_attr['value']['shape'] = []
154
+ node_attr["value"]["shape"] = []
155
155
  next
156
- elsif str == 'tensor {'
156
+ elsif str == "tensor {"
157
157
  state = :tensor_context
158
- node_attr['value']['tensor'] = {}
158
+ node_attr["value"]["tensor"] = {}
159
159
  next
160
- elsif str == '}'
160
+ elsif str == "}"
161
161
  state = :attr_context
162
162
  next
163
163
  else
164
- key, value = str.split(':', 2)
165
- if key == 'dtype'
166
- node_attr['value']['dtype'] = value.strip
167
- elsif key == 'type'
168
- node_attr['value']['type'] = value.strip
164
+ key, value = str.split(":", 2)
165
+ if key == "dtype"
166
+ node_attr["value"]["dtype"] = value.strip
167
+ elsif key == "type"
168
+ node_attr["value"]["type"] = value.strip
169
169
  else
170
- node_attr['value'][key] = process_value(value.strip)
170
+ node_attr["value"][key] = process_value(value.strip)
171
171
  end
172
172
  end
173
173
  when :list_context
174
- if str == '}'
174
+ if str == "}"
175
175
  state = :value_context
176
176
  next
177
177
  else
178
- key, value = str.split(':', 2)
179
- node_attr['value'] << { key => value }
178
+ key, value = str.split(":", 2)
179
+ node_attr["value"] << {key => value}
180
180
  end
181
181
  when :tensor_context
182
- if str == 'tensor_shape {'
182
+ if str == "tensor_shape {"
183
183
  state = :tensor_shape_context
184
- node_attr['value']['tensor']['shape'] = []
184
+ node_attr["value"]["tensor"]["shape"] = []
185
185
  next
186
- elsif str == '}'
186
+ elsif str == "}"
187
187
  state = :value_context
188
188
  next
189
189
  else
190
- key, value = str.split(':', 2)
191
- if node_attr['value']['tensor'][key] && !node_attr['value']['tensor'][key].is_a?(Array)
192
- node_attr['value']['tensor'][key] = [node_attr['value']['tensor'][key]]
193
- node_attr['value']['tensor'][key] << process_value(value.strip)
194
- elsif node_attr['value']['tensor'][key]
195
- node_attr['value']['tensor'][key] << process_value(value.strip)
190
+ key, value = str.split(":", 2)
191
+ if node_attr["value"]["tensor"][key] && !node_attr["value"]["tensor"][key].is_a?(Array)
192
+ node_attr["value"]["tensor"][key] = [node_attr["value"]["tensor"][key]]
193
+ node_attr["value"]["tensor"][key] << process_value(value.strip)
194
+ elsif node_attr["value"]["tensor"][key]
195
+ node_attr["value"]["tensor"][key] << process_value(value.strip)
196
196
  else
197
- node_attr['value']['tensor'][key] = process_value(value.strip)
197
+ node_attr["value"]["tensor"][key] = process_value(value.strip)
198
198
  end
199
199
  end
200
200
  when :tensor_shape_context
201
- if str == 'dim {'
201
+ if str == "dim {"
202
202
  state = :tensor_shape_dim_context
203
203
  next
204
- elsif str == '}'
204
+ elsif str == "}"
205
205
  state = :tensor_context
206
206
  next
207
207
  end
208
208
  when :shape_context
209
- if str == '}'
209
+ if str == "}"
210
210
  state = :value_context
211
211
  next
212
- elsif str == 'dim {'
212
+ elsif str == "dim {"
213
213
  state = :shape_dim_context
214
214
  next
215
215
  end
216
216
  when :shape_dim_context
217
- if str == '}'
217
+ if str == "}"
218
218
  state = :shape_context
219
219
  next
220
220
  else
221
- _key, value = str.split(':', 2)
222
- node_attr['value']['shape'] << value.strip.to_i
221
+ _key, value = str.split(":", 2)
222
+ node_attr["value"]["shape"] << value.strip.to_i
223
223
  end
224
224
  when :tensor_shape_dim_context
225
- if str == '}'
225
+ if str == "}"
226
226
  state = :tensor_shape_context
227
227
  next
228
228
  else
229
- _key, value = str.split(':', 2)
230
- node_attr['value']['tensor']['shape'] << value.strip.to_i
229
+ _key, value = str.split(":", 2)
230
+ node_attr["value"]["tensor"]["shape"] << value.strip.to_i
231
231
  end
232
232
  end
233
233
  end
@@ -236,22 +236,22 @@ module TensorStream
236
236
  end
237
237
 
238
238
  def parse_node_name(str)
239
- str.split(' ')[0]
239
+ str.split(" ")[0]
240
240
  end
241
241
 
242
242
  def process_value(value)
243
243
  if value.start_with?('"')
244
- unescape(value.gsub!(/\A"|"\Z/, ''))
244
+ unescape(value.gsub!(/\A"|"\Z/, ""))
245
245
  else
246
246
  unescape(value)
247
247
  end
248
248
  end
249
249
 
250
250
  UNESCAPES = {
251
- 'a' => "\x07", 'b' => "\x08", 't' => "\x09",
252
- 'n' => "\x0a", 'v' => "\x0b", 'f' => "\x0c",
253
- 'r' => "\x0d", 'e' => "\x1b", "\\\\" => "\x5c",
254
- "\"" => "\x22", "'" => "\x27"
251
+ "a" => "\x07", "b" => "\x08", "t" => "\x09",
252
+ "n" => "\x0a", "v" => "\x0b", "f" => "\x0c",
253
+ "r" => "\x0d", "e" => "\x1b", "\\\\" => "\x5c",
254
+ "\"" => "\x22", "'" => "\x27",
255
255
  }.freeze
256
256
 
257
257
  def unescape(str)
@@ -260,11 +260,11 @@ module TensorStream
260
260
  if $1
261
261
  $1 == '\\' ? '\\' : UNESCAPES[$1]
262
262
  elsif $2 # escape \u0000 unicode
263
- ["#{$2}".hex].pack('U*')
263
+ [$2.to_s.hex].pack("U*")
264
264
  elsif $3 # escape \0xff or \xff
265
- [$3].pack('H2')
265
+ [$3].pack("H2")
266
266
  end
267
267
  end
268
268
  end
269
269
  end
270
- end
270
+ end