tensor_stream 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/.rubocop.yml +1 -0
  4. data/Gemfile +1 -1
  5. data/LICENSE.txt +1 -1
  6. data/README.md +34 -34
  7. data/Rakefile +3 -3
  8. data/USAGE_GUIDE.md +235 -0
  9. data/bin/stubgen +20 -0
  10. data/exe/model_utils +2 -2
  11. data/lib/tensor_stream.rb +45 -44
  12. data/lib/tensor_stream/constant.rb +2 -2
  13. data/lib/tensor_stream/control_flow.rb +1 -1
  14. data/lib/tensor_stream/debugging/debugging.rb +2 -2
  15. data/lib/tensor_stream/dynamic_stitch.rb +2 -2
  16. data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
  17. data/lib/tensor_stream/evaluator/buffer.rb +1 -1
  18. data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
  19. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
  20. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
  21. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
  22. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
  23. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
  24. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
  25. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
  26. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
  27. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
  28. data/lib/tensor_stream/exceptions.rb +1 -1
  29. data/lib/tensor_stream/generated_stub/ops.rb +691 -0
  30. data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
  31. data/lib/tensor_stream/graph.rb +18 -18
  32. data/lib/tensor_stream/graph_builder.rb +17 -17
  33. data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
  34. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
  35. data/lib/tensor_stream/graph_keys.rb +3 -3
  36. data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
  37. data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
  38. data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
  39. data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
  40. data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
  41. data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
  42. data/lib/tensor_stream/helpers/op_helper.rb +8 -9
  43. data/lib/tensor_stream/helpers/string_helper.rb +15 -15
  44. data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
  45. data/lib/tensor_stream/images.rb +1 -1
  46. data/lib/tensor_stream/initializer.rb +1 -1
  47. data/lib/tensor_stream/math_gradients.rb +28 -187
  48. data/lib/tensor_stream/monkey_patches/array.rb +1 -1
  49. data/lib/tensor_stream/monkey_patches/float.rb +1 -1
  50. data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
  51. data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
  52. data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
  53. data/lib/tensor_stream/nn/nn_ops.rb +17 -15
  54. data/lib/tensor_stream/op_maker.rb +180 -0
  55. data/lib/tensor_stream/operation.rb +17 -17
  56. data/lib/tensor_stream/ops.rb +95 -384
  57. data/lib/tensor_stream/ops/add.rb +23 -0
  58. data/lib/tensor_stream/ops/argmax.rb +14 -0
  59. data/lib/tensor_stream/ops/argmin.rb +14 -0
  60. data/lib/tensor_stream/ops/case.rb +17 -0
  61. data/lib/tensor_stream/ops/cast.rb +15 -0
  62. data/lib/tensor_stream/ops/ceil.rb +15 -0
  63. data/lib/tensor_stream/ops/const.rb +0 -0
  64. data/lib/tensor_stream/ops/cos.rb +10 -0
  65. data/lib/tensor_stream/ops/div.rb +21 -0
  66. data/lib/tensor_stream/ops/equal.rb +15 -0
  67. data/lib/tensor_stream/ops/expand_dims.rb +17 -0
  68. data/lib/tensor_stream/ops/fill.rb +19 -0
  69. data/lib/tensor_stream/ops/floor.rb +15 -0
  70. data/lib/tensor_stream/ops/floor_div.rb +15 -0
  71. data/lib/tensor_stream/ops/greater.rb +11 -0
  72. data/lib/tensor_stream/ops/greater_equal.rb +11 -0
  73. data/lib/tensor_stream/ops/less_equal.rb +15 -0
  74. data/lib/tensor_stream/ops/log.rb +14 -0
  75. data/lib/tensor_stream/ops/mat_mul.rb +60 -0
  76. data/lib/tensor_stream/ops/max.rb +15 -0
  77. data/lib/tensor_stream/ops/min.rb +15 -0
  78. data/lib/tensor_stream/ops/mod.rb +23 -0
  79. data/lib/tensor_stream/ops/mul.rb +21 -0
  80. data/lib/tensor_stream/ops/negate.rb +14 -0
  81. data/lib/tensor_stream/ops/ones_like.rb +19 -0
  82. data/lib/tensor_stream/ops/pow.rb +25 -0
  83. data/lib/tensor_stream/ops/prod.rb +60 -0
  84. data/lib/tensor_stream/ops/random_uniform.rb +18 -0
  85. data/lib/tensor_stream/ops/range.rb +20 -0
  86. data/lib/tensor_stream/ops/rank.rb +13 -0
  87. data/lib/tensor_stream/ops/reshape.rb +24 -0
  88. data/lib/tensor_stream/ops/round.rb +15 -0
  89. data/lib/tensor_stream/ops/shape.rb +14 -0
  90. data/lib/tensor_stream/ops/sigmoid.rb +10 -0
  91. data/lib/tensor_stream/ops/sign.rb +12 -0
  92. data/lib/tensor_stream/ops/sin.rb +10 -0
  93. data/lib/tensor_stream/ops/size.rb +16 -0
  94. data/lib/tensor_stream/ops/sub.rb +24 -0
  95. data/lib/tensor_stream/ops/sum.rb +27 -0
  96. data/lib/tensor_stream/ops/tan.rb +12 -0
  97. data/lib/tensor_stream/ops/tanh.rb +10 -0
  98. data/lib/tensor_stream/ops/tile.rb +19 -0
  99. data/lib/tensor_stream/ops/zeros.rb +15 -0
  100. data/lib/tensor_stream/placeholder.rb +2 -2
  101. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  102. data/lib/tensor_stream/session.rb +36 -38
  103. data/lib/tensor_stream/tensor.rb +2 -2
  104. data/lib/tensor_stream/tensor_shape.rb +4 -4
  105. data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
  106. data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
  107. data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
  108. data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
  109. data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
  110. data/lib/tensor_stream/train/optimizer.rb +9 -9
  111. data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
  112. data/lib/tensor_stream/train/saver.rb +14 -14
  113. data/lib/tensor_stream/train/slot_creator.rb +6 -6
  114. data/lib/tensor_stream/train/utils.rb +12 -12
  115. data/lib/tensor_stream/trainer.rb +10 -10
  116. data/lib/tensor_stream/types.rb +1 -1
  117. data/lib/tensor_stream/utils.rb +33 -32
  118. data/lib/tensor_stream/utils/freezer.rb +5 -5
  119. data/lib/tensor_stream/variable.rb +5 -5
  120. data/lib/tensor_stream/variable_scope.rb +1 -1
  121. data/lib/tensor_stream/version.rb +1 -1
  122. data/samples/{iris.data → datasets/iris.data} +0 -0
  123. data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
  124. data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
  125. data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
  126. data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
  127. data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
  128. data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
  129. data/samples/regression/linear_regression.rb +63 -0
  130. data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
  131. data/tensor_stream.gemspec +9 -8
  132. metadata +89 -19
  133. data/data_1.json +0 -4764
  134. data/data_2.json +0 -4764
  135. data/data_actual.json +0 -28
  136. data/data_expected.json +0 -28
  137. data/data_input.json +0 -28
  138. data/samples/error.graphml +0 -2755
  139. data/samples/gradient_sample.graphml +0 -1255
  140. data/samples/linear_regression.rb +0 -69
  141. data/samples/multigpu.rb +0 -73
  142. data/samples/raw_neural_net_sample.rb +0 -112
@@ -58,4 +58,4 @@ module TensorStream
58
58
  @graph
59
59
  end
60
60
  end
61
- end
61
+ end
@@ -1,7 +1,7 @@
1
1
  module TensorStream
2
2
  class GraphKeys
3
- GLOBAL_VARIABLES = 'variables'.freeze
4
- TRAINABLE_VARIABLES = 'trainable_variables'.freeze
5
- GLOBAL_STEP = 'global_step'.freeze
3
+ GLOBAL_VARIABLES = "variables".freeze
4
+ TRAINABLE_VARIABLES = "trainable_variables".freeze
5
+ GLOBAL_STEP = "global_step".freeze
6
6
  end
7
7
  end
@@ -48,7 +48,7 @@ module TensorStream
48
48
  private
49
49
 
50
50
  def add_to_group(groups, name, arr_buf)
51
- name_parts = name.split('/')
51
+ name_parts = name.split("/")
52
52
  return false if name_parts.size < 2
53
53
 
54
54
  prefix = name_parts.shift
@@ -67,35 +67,35 @@ module TensorStream
67
67
  end
68
68
 
69
69
  def find_or_create_group(prefix, groups)
70
- if !groups[prefix]
71
- groups[prefix] = { buf: [], group: {} }
70
+ unless groups[prefix]
71
+ groups[prefix] = {buf: [], group: {}}
72
72
  end
73
73
 
74
- return groups[prefix]
74
+ groups[prefix]
75
75
  end
76
76
 
77
77
  def create_group(id, title, group)
78
78
  arr_buf = []
79
79
  arr_buf << "<node id=\"#{id}\" yfiles.foldertype=\"group\">"
80
80
  arr_buf << '<data key="d9">'
81
- arr_buf << '<y:ProxyAutoBoundsNode>'
81
+ arr_buf << "<y:ProxyAutoBoundsNode>"
82
82
  arr_buf << '<y:Realizers active="0">'
83
- arr_buf << '<y:GroupNode>'
83
+ arr_buf << "<y:GroupNode>"
84
84
  arr_buf << '<y:Fill color="#CAECFF84" transparent="false"/>'
85
85
  arr_buf << '<y:BorderStyle color="#666699" type="dotted" width="1.0"/>'
86
- arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">' + title + '</y:NodeLabel>'
86
+ arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">' + title + "</y:NodeLabel>"
87
87
  arr_buf << '<y:Shape type="roundrectangle"/>'
88
- arr_buf << '</y:GroupNode>'
89
- arr_buf << '</y:Realizers>'
90
- arr_buf << '</y:ProxyAutoBoundsNode>'
91
- arr_buf << '</data>'
88
+ arr_buf << "</y:GroupNode>"
89
+ arr_buf << "</y:Realizers>"
90
+ arr_buf << "</y:ProxyAutoBoundsNode>"
91
+ arr_buf << "</data>"
92
92
  arr_buf << '<graph edgedefault="directed" id="n105:">'
93
93
  arr_buf << group[:buf]
94
94
  group[:group].each do |k, g|
95
95
  arr_buf << create_group(k, k, g)
96
96
  end
97
- arr_buf << '</graph>'
98
- arr_buf << '</node>'
97
+ arr_buf << "</graph>"
98
+ arr_buf << "</node>"
99
99
  arr_buf
100
100
  end
101
101
 
@@ -120,17 +120,17 @@ module TensorStream
120
120
  end
121
121
  node_buf << "<data key=\"d9\">"
122
122
  node_buf << "<y:ShapeNode>"
123
- if tensor.internal?
124
- node_buf << " <y:Fill color=\"#FFFF99\" transparent=\"false\"/>"
123
+ node_buf << if tensor.internal?
124
+ " <y:Fill color=\"#FFFF99\" transparent=\"false\"/>"
125
125
  else
126
- node_buf << " <y:Fill color=\"#99CC00\" transparent=\"false\"/>"
126
+ " <y:Fill color=\"#99CC00\" transparent=\"false\"/>"
127
127
  end
128
128
  node_buf << " <y:NodeLabel alignment=\"center\">#{tensor.operation}</y:NodeLabel>"
129
129
  node_buf << "</y:ShapeNode>"
130
130
  node_buf << "</data>"
131
131
  node_buf << "</node>"
132
132
 
133
- if !add_to_group(groups, tensor.name, node_buf)
133
+ unless add_to_group(groups, tensor.name, node_buf)
134
134
  add_to_group(groups, "program/#{tensor.name}", node_buf)
135
135
  end
136
136
 
@@ -176,10 +176,10 @@ module TensorStream
176
176
  input_buf << "<y:ShapeNode>"
177
177
 
178
178
  input_buf << if input.internal?
179
- " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
- else
181
- " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
- end
179
+ " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
+ else
181
+ " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
+ end
183
183
 
184
184
  input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
185
185
 
@@ -204,7 +204,7 @@ module TensorStream
204
204
  end
205
205
 
206
206
  def _gml_string(str)
207
- str.tr('/', '-')
207
+ str.tr("/", "-")
208
208
  end
209
209
 
210
210
  def output_edge(input, tensor, arr_buf, index = 0)
@@ -215,22 +215,22 @@ module TensorStream
215
215
  arr_buf << "<y:PolyLineEdge>"
216
216
  arr_buf << "<y:EdgeLabel >"
217
217
  arr_buf << if !@last_session_context.empty?
218
- "<![CDATA[ #{_val(input)} ]]>"
219
- elsif input.shape.shape.nil?
220
- "<![CDATA[ #{input.data_type} ? ]]>"
221
- else
222
- "<![CDATA[ #{input.data_type} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
223
- end
218
+ "<![CDATA[ #{_val(input)} ]]>"
219
+ elsif input.shape.shape.nil?
220
+ "<![CDATA[ #{input.data_type} ? ]]>"
221
+ else
222
+ "<![CDATA[ #{input.data_type} #{input.shape.shape.empty? ? "scalar" : input.shape.shape.to_json} ]]>"
223
+ end
224
224
  arr_buf << "</y:EdgeLabel >"
225
225
  arr_buf << "<y:Arrows source=\"none\" target=\"standard\"/>"
226
226
  arr_buf << if index.zero?
227
- "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
228
- else
229
- "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
230
- end
227
+ "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
228
+ else
229
+ "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
230
+ end
231
231
  arr_buf << "</y:PolyLineEdge>"
232
232
  arr_buf << "</data>"
233
233
  arr_buf << "</edge>"
234
234
  end
235
235
  end
236
- end
236
+ end
@@ -1,4 +1,4 @@
1
- require 'base64'
1
+ require "base64"
2
2
 
3
3
  module TensorStream
4
4
  # Utility class to handle data type serialization
@@ -7,36 +7,36 @@ module TensorStream
7
7
  value = value.is_a?(Array) ? value.flatten : [value]
8
8
  byte_value = case data_type
9
9
  when :float64
10
- value.pack('d*')
10
+ value.pack("d*")
11
11
  when :float32, :float16, :float
12
- value.pack('f*')
12
+ value.pack("f*")
13
13
  when :uint32
14
- value.pack('L*')
14
+ value.pack("L*")
15
15
  when :int32, :int
16
- value.pack('l*')
16
+ value.pack("l*")
17
17
  when :int64
18
- value.pack('q*')
18
+ value.pack("q*")
19
19
  when :uint64
20
- value.pack('Q*')
20
+ value.pack("Q*")
21
21
  when :uint8
22
- value.pack('C*')
22
+ value.pack("C*")
23
23
  when :boolean
24
- value.map { |v| v ? 1 : 0 }.pack('C*')
24
+ value.map { |v| v ? 1 : 0 }.pack("C*")
25
25
  when :string
26
26
  if value.is_a?(Array)
27
- value.to_yaml
27
+ value.to_yaml
28
28
  else
29
- value
29
+ value
30
30
  end
31
31
  else
32
- raise "unknown type #{data_type}"
33
- end
32
+ raise "unknown type #{data_type}"
33
+ end
34
34
 
35
35
  byte_value
36
36
  end
37
37
 
38
38
  def self.pack_to_str(value, data_type)
39
- pack(value, data_type).bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
39
+ pack(value, data_type).bytes.map { |b| /[^[:print:]]/.match?(b.chr) ? "\\#{sprintf("%o", b).rjust(3, "0")}" : b.chr }.join
40
40
  end
41
41
 
42
42
  def self.unpack_from_str(content, data_type)
@@ -47,22 +47,22 @@ module TensorStream
47
47
  def self.unpack(unpacked, data_type)
48
48
  case data_type
49
49
  when :float32, :float, :float16
50
- unpacked.unpack('f*')
50
+ unpacked.unpack("f*")
51
51
  when :float64
52
- unpacked.unpack('d*')
52
+ unpacked.unpack("d*")
53
53
  when :int32, :int
54
- unpacked.unpack('L*')
54
+ unpacked.unpack("L*")
55
55
  when :uint32
56
- unpacked.unpack('l*')
56
+ unpacked.unpack("l*")
57
57
  when :int64
58
- unpacked.unpack('q*')
58
+ unpacked.unpack("q*")
59
59
  when :uint64
60
- unpacked.unpack('Q*')
60
+ unpacked.unpack("Q*")
61
61
  when :uint8
62
- unpacked.unpack('C*')
62
+ unpacked.unpack("C*")
63
63
  when :boolean
64
- unpacked.unpack('C*').map { |v| v == 1 }
64
+ unpacked.unpack("C*").map { |v| v == 1 }
65
65
  end
66
66
  end
67
67
  end
68
- end
68
+ end
@@ -12,10 +12,10 @@ module TensorStream
12
12
 
13
13
  node_keys.each do |k|
14
14
  node = if block_given?
15
- yield graph, k
16
- else
17
- graph.get_tensor_by_name(k)
18
- end
15
+ yield graph, k
16
+ else
17
+ graph.get_tensor_by_name(k)
18
+ end
19
19
 
20
20
  @lines << "node {"
21
21
  @lines << " name: #{node.name.to_json}"
@@ -27,13 +27,13 @@ module TensorStream
27
27
  @lines << " input: #{input.name.to_json}"
28
28
  end
29
29
  # type
30
- pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
30
+ pb_attr("T", "type: #{sym_to_protobuf_type(node.data_type)}")
31
31
 
32
32
  case node.operation.to_s
33
- when 'const'
34
- pb_attr('value', tensor_value(node))
35
- when 'variable_v2'
36
- pb_attr('shape', shape_buf(node, 'shape'))
33
+ when "const"
34
+ pb_attr("value", tensor_value(node))
35
+ when "variable_v2"
36
+ pb_attr("shape", shape_buf(node, "shape"))
37
37
  end
38
38
  process_options(node)
39
39
  end
@@ -50,7 +50,7 @@ module TensorStream
50
50
  def process_options(node)
51
51
  return if node.options.nil?
52
52
  node.options.reject { |_k, v| v.nil? }.each do |k, v|
53
- next if %w[name internal_name data_type].include?(k.to_s) || k.to_s.start_with?('__')
53
+ next if %w[name internal_name data_type].include?(k.to_s) || k.to_s.start_with?("__")
54
54
  @lines << " attr {"
55
55
  @lines << " key: \"#{k}\""
56
56
  @lines << " value {"
@@ -63,55 +63,51 @@ module TensorStream
63
63
  def attr_value(val, indent = 0)
64
64
  spaces = " " * indent
65
65
  case val.class.to_s
66
- when 'TrueClass', 'FalseClass'
66
+ when "TrueClass", "FalseClass"
67
67
  @lines << "#{spaces}b: #{val}"
68
- when 'Integer'
68
+ when "Integer"
69
69
  @lines << "#{spaces}i: #{val}"
70
- when 'String',
70
+ when "String",
71
71
  @lines << "#{spaces}s: #{val}"
72
- when 'Float'
72
+ when "Float"
73
73
  @lines << "#{spaces}f: #{val}"
74
- when 'Symbol'
74
+ when "Symbol"
75
75
  @lines << "#{spaces}sym: #{val}"
76
- when 'Array'
76
+ when "Array"
77
77
  @lines << "#{spaces}list {"
78
78
  val.each do |v_item|
79
79
  attr_value(v_item, indent + 2)
80
80
  end
81
81
  @lines << "#{spaces}}"
82
- when 'TensorStream::TensorShape'
82
+ when "TensorStream::TensorShape"
83
83
  @lines << "#{spaces}shape {"
84
- if val.shape
85
- val.shape.each do |dim|
86
- @lines << "#{spaces} dim {"
87
- @lines << "#{spaces} size: #{dim}"
88
- @lines << "#{spaces} }"
89
- end
84
+ val.shape&.each do |dim|
85
+ @lines << "#{spaces} dim {"
86
+ @lines << "#{spaces} size: #{dim}"
87
+ @lines << "#{spaces} }"
90
88
  end
91
89
  @lines << "#{spaces}}"
92
- when 'TensorStream::Variable'
90
+ when "TensorStream::Variable"
93
91
  else
94
92
  raise "unknown type #{val.class}"
95
93
  end
96
94
  end
97
95
 
98
96
  def pack_arr_float(float_arr)
99
- float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
97
+ float_arr.flatten.pack("f*").bytes.map { |b| /[^[:print:]]/.match?(b.chr) ? "\\#{sprintf("%o", b).rjust(3, "0")}" : b.chr }.join
100
98
  end
101
99
 
102
100
  def pack_arr_int(int_arr)
103
- int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
101
+ int_arr.flatten.pack("l*").bytes.map { |b| /[^[:print:]]/.match?(b.chr) ? "\\#{sprintf("%o", b).rjust(3, "0")}" : b.chr }.join
104
102
  end
105
103
 
106
- def shape_buf(tensor, shape_type = 'tensor_shape')
104
+ def shape_buf(tensor, shape_type = "tensor_shape")
107
105
  arr = []
108
106
  arr << " #{shape_type} {"
109
- if tensor.shape.shape
110
- tensor.shape.shape.each do |dim|
111
- arr << " dim {"
112
- arr << " size: #{dim}"
113
- arr << " }"
114
- end
107
+ tensor.shape.shape&.each do |dim|
108
+ arr << " dim {"
109
+ arr << " size: #{dim}"
110
+ arr << " }"
115
111
  end
116
112
  arr << " }"
117
113
  arr
@@ -140,14 +136,14 @@ module TensorStream
140
136
  end
141
137
  else
142
138
  val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
143
- "int_val"
144
- elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
145
- "float_val"
146
- elsif tensor.data_type == :string
147
- "string_val"
148
- else
149
- "val"
150
- end
139
+ "int_val"
140
+ elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
141
+ "float_val"
142
+ elsif tensor.data_type == :string
143
+ "string_val"
144
+ else
145
+ "val"
146
+ end
151
147
  arr << " #{val_type}: #{tensor.const_value.to_json}"
152
148
  end
153
149
  arr << "}"
@@ -186,4 +182,4 @@ module TensorStream
186
182
  @lines << " }"
187
183
  end
188
184
  end
189
- end
185
+ end
@@ -4,6 +4,7 @@ module TensorStream
4
4
  File.write(filename, get_string(tensor, session, graph_keys = nil))
5
5
  end
6
6
 
7
- def get_string(tensor, session = nil); end
7
+ def get_string(tensor, session = nil)
8
+ end
8
9
  end
9
- end
10
+ end
@@ -12,10 +12,10 @@ module TensorStream
12
12
 
13
13
  node_keys.each do |k|
14
14
  node = if block_given?
15
- yield graph, k
16
- else
17
- graph.get_tensor_by_name(k)
18
- end
15
+ yield graph, k
16
+ else
17
+ graph.get_tensor_by_name(k)
18
+ end
19
19
  next unless node.is_a?(Operation)
20
20
 
21
21
  serialized_arr << node.to_h
@@ -24,4 +24,4 @@ module TensorStream
24
24
  serialized_arr.to_yaml
25
25
  end
26
26
  end
27
- end
27
+ end
@@ -1,4 +1,4 @@
1
- require 'tensor_stream/evaluator/operation_helpers/array_ops_helper'
1
+ require "tensor_stream/evaluator/operation_helpers/array_ops_helper"
2
2
  module TensorStream
3
3
  ##
4
4
  # Convenience class for guessing the shape of a tensor
@@ -9,18 +9,35 @@ module TensorStream
9
9
 
10
10
  def self.infer_shape(tensor)
11
11
  case tensor.operation
12
+ when :assign
13
+ possible_shape = if tensor.inputs[0]&.shape&.shape
14
+ tensor.inputs[0].shape.shape
15
+ else
16
+ tensor.inputs[1].shape.shape
17
+ end
18
+
19
+ possible_shape
20
+ when :const
21
+ shape_eval(tensor.options[:value])
22
+ when :variable_v2
23
+ tensor.shape ? tensor.shape.shape : nil
24
+ when :placeholder
25
+ return nil if tensor.inputs[0].nil?
26
+ return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
27
+
28
+ TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1]
12
29
  when :case, :case_grad
13
- tensor.inputs[2].shape.shape if tensor.inputs[2]
30
+ tensor.inputs[2]&.shape&.shape
14
31
  when :const
15
32
  shape_eval(tensor.options[:value])
16
33
  when :variable_v2
17
34
  tensor.shape ? tensor.shape.shape : nil
18
35
  when :assign
19
- possible_shape = if tensor.inputs[0] && tensor.inputs[0].shape.shape
20
- tensor.inputs[0].shape.shape
21
- else
22
- tensor.inputs[1].shape.shape
23
- end
36
+ possible_shape = if tensor.inputs[0]&.shape&.shape
37
+ tensor.inputs[0].shape.shape
38
+ else
39
+ tensor.inputs[1].shape.shape
40
+ end
24
41
 
25
42
  possible_shape
26
43
  when :index
@@ -39,11 +56,11 @@ module TensorStream
39
56
 
40
57
  axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].const_value
41
58
  new_shape = tensor.inputs[0].shape.shape
42
- new_shape.each_with_index.collect do |shape, index|
59
+ new_shape.each_with_index.collect { |shape, index|
43
60
  next nil if index == axis
44
61
 
45
62
  shape
46
- end.compact
63
+ }.compact
47
64
  when :mean, :prod, :sum, :arg_max
48
65
  return [] if tensor.inputs[1].nil?
49
66
  return nil if tensor.inputs[0].nil?
@@ -58,24 +75,14 @@ module TensorStream
58
75
  axis = [axis] unless axis.is_a?(Array)
59
76
  axis = axis.map { |a| a < 0 ? rank - a.abs : a }
60
77
 
61
- input_shape.each_with_index.map do |item, index|
78
+ input_shape.each_with_index.map { |item, index|
62
79
  if axis.include?(index)
63
80
  next 1 if tensor.options[:keepdims]
64
81
 
65
82
  next nil
66
83
  end
67
84
  item
68
- end.compact
69
- when :reshape
70
- new_shape = tensor.inputs[1] && tensor.inputs[1].const_value ? tensor.inputs[1].const_value : nil
71
- return nil if new_shape.nil?
72
- return nil if tensor.inputs[0].shape.nil?
73
-
74
- input_shape = tensor.inputs[0].shape.shape
75
- return new_shape if input_shape.nil? && !new_shape.include?(-1) && !new_shape.include?(nil)
76
- return nil if input_shape.nil? || input_shape.include?(nil)
77
-
78
- TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
85
+ }.compact
79
86
  when :flow_group
80
87
  []
81
88
  when :zeros, :ones, :fill, :random_standard_normal, :random_uniform, :truncated_normal
@@ -94,28 +101,6 @@ module TensorStream
94
101
  size = tensor.inputs[0].shape.shape.reduce(:*) || 1
95
102
  dummy_tensor_for_shape = TensorShape.reshape(Array.new(size), tensor.inputs[0].shape)
96
103
  shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].const_value))
97
- when :mat_mul
98
- return nil if tensor.inputs[0].shape.shape.nil? || tensor.inputs[1].shape.shape.nil?
99
- return [] if tensor.inputs[0].shape.shape.empty? || tensor.inputs[1].shape.shape.empty?
100
- return nil if tensor.inputs[0].shape.shape.size != 2 || tensor.inputs[1].shape.shape.size != 2
101
-
102
- shape1, m = if tensor.options[:transpose_a]
103
- [tensor.inputs[0].shape.shape[0], tensor.inputs[0].shape.shape[1]]
104
- else
105
- [tensor.inputs[0].shape.shape[1], tensor.inputs[0].shape.shape[0]]
106
- end
107
-
108
- shape2, n = if tensor.options[:transpose_b]
109
- [tensor.inputs[1].shape.shape[1], tensor.inputs[1].shape.shape[0]]
110
- else
111
- [tensor.inputs[1].shape.shape[0], tensor.inputs[1].shape.shape[1]]
112
- end
113
-
114
- return nil if shape1.nil? || shape2.nil? || shape1 < 0 || shape2 < 0
115
-
116
- raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{shape1} != #{shape2}) #{tensor.inputs[0].shape.shape} vs #{tensor.inputs[1].shape.shape}" if shape1 != shape2
117
-
118
- [m, n]
119
104
  when :transpose
120
105
  return nil unless shape_full_specified(tensor.inputs[0])
121
106
  return nil if tensor.inputs[1].is_a?(Tensor)
@@ -152,10 +137,6 @@ module TensorStream
152
137
  new_shape
153
138
  when :slice, :squeeze
154
139
  nil
155
- when :tile
156
- nil
157
- when :expand_dims
158
- nil
159
140
  when :broadcast_gradient_args
160
141
  nil
161
142
  when :no_op
@@ -168,8 +149,6 @@ module TensorStream
168
149
  return [tensor.inputs[0].const_value, tensor.inputs[1].const_value] if tensor.inputs[0].const_value && tensor.inputs[1].const_value
169
150
 
170
151
  nil
171
- when :size
172
- []
173
152
  when :unstack
174
153
  return nil unless tensor.inputs[0].shape.known?
175
154
 
@@ -190,10 +169,10 @@ module TensorStream
190
169
  strides = tensor.options[:strides]
191
170
 
192
171
  case tensor.options[:padding]
193
- when 'SAME'
172
+ when "SAME"
194
173
  new_shape[1] /= strides[1]
195
174
  new_shape[2] /= strides[2]
196
- when 'VALID'
175
+ when "VALID"
197
176
  new_shape[1] = (new_shape[1] - tensor.inputs[1].shape.shape[0]) / strides[1] + 1
198
177
  new_shape[2] = (new_shape[2] - tensor.inputs[1].shape.shape[1]) / strides[2] + 1
199
178
  else
@@ -206,11 +185,32 @@ module TensorStream
206
185
 
207
186
  tensor.inputs[0].const_value
208
187
  else
209
- return nil if tensor.inputs[0].nil?
210
- return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
211
-
212
- TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1]
188
+ TensorStream::OpMaker.infer_shape(self, tensor)
213
189
  end
214
190
  end
191
+
192
+ def self._infer_reduction_op_shape(tensor)
193
+ return [] if tensor.inputs[1].nil?
194
+ return nil if tensor.inputs[0].nil?
195
+ return nil unless tensor.inputs[0].shape.known?
196
+
197
+ input_shape = tensor.inputs[0].shape.shape
198
+ rank = input_shape.size
199
+
200
+ axis = tensor.inputs[1].const_value
201
+ return nil if axis.nil?
202
+
203
+ axis = [axis] unless axis.is_a?(Array)
204
+ axis = axis.map { |a| a < 0 ? rank - a.abs : a }
205
+
206
+ input_shape.each_with_index.map { |item, index|
207
+ if axis.include?(index)
208
+ next 1 if tensor.options[:keepdims]
209
+
210
+ next nil
211
+ end
212
+ item
213
+ }.compact
214
+ end
215
215
  end
216
- end
216
+ end