tensor_stream 1.0.0 → 1.0.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/.rubocop.yml +1 -0
  4. data/Gemfile +1 -1
  5. data/LICENSE.txt +1 -1
  6. data/README.md +34 -34
  7. data/Rakefile +3 -3
  8. data/USAGE_GUIDE.md +235 -0
  9. data/bin/stubgen +20 -0
  10. data/exe/model_utils +2 -2
  11. data/lib/tensor_stream.rb +45 -44
  12. data/lib/tensor_stream/constant.rb +2 -2
  13. data/lib/tensor_stream/control_flow.rb +1 -1
  14. data/lib/tensor_stream/debugging/debugging.rb +2 -2
  15. data/lib/tensor_stream/dynamic_stitch.rb +2 -2
  16. data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
  17. data/lib/tensor_stream/evaluator/buffer.rb +1 -1
  18. data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
  19. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
  20. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
  21. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
  22. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
  23. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
  24. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
  25. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
  26. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
  27. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
  28. data/lib/tensor_stream/exceptions.rb +1 -1
  29. data/lib/tensor_stream/generated_stub/ops.rb +691 -0
  30. data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
  31. data/lib/tensor_stream/graph.rb +18 -18
  32. data/lib/tensor_stream/graph_builder.rb +17 -17
  33. data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
  34. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
  35. data/lib/tensor_stream/graph_keys.rb +3 -3
  36. data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
  37. data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
  38. data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
  39. data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
  40. data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
  41. data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
  42. data/lib/tensor_stream/helpers/op_helper.rb +8 -9
  43. data/lib/tensor_stream/helpers/string_helper.rb +15 -15
  44. data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
  45. data/lib/tensor_stream/images.rb +1 -1
  46. data/lib/tensor_stream/initializer.rb +1 -1
  47. data/lib/tensor_stream/math_gradients.rb +28 -187
  48. data/lib/tensor_stream/monkey_patches/array.rb +1 -1
  49. data/lib/tensor_stream/monkey_patches/float.rb +1 -1
  50. data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
  51. data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
  52. data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
  53. data/lib/tensor_stream/nn/nn_ops.rb +17 -15
  54. data/lib/tensor_stream/op_maker.rb +180 -0
  55. data/lib/tensor_stream/operation.rb +17 -17
  56. data/lib/tensor_stream/ops.rb +95 -384
  57. data/lib/tensor_stream/ops/add.rb +23 -0
  58. data/lib/tensor_stream/ops/argmax.rb +14 -0
  59. data/lib/tensor_stream/ops/argmin.rb +14 -0
  60. data/lib/tensor_stream/ops/case.rb +17 -0
  61. data/lib/tensor_stream/ops/cast.rb +15 -0
  62. data/lib/tensor_stream/ops/ceil.rb +15 -0
  63. data/lib/tensor_stream/ops/const.rb +0 -0
  64. data/lib/tensor_stream/ops/cos.rb +10 -0
  65. data/lib/tensor_stream/ops/div.rb +21 -0
  66. data/lib/tensor_stream/ops/equal.rb +15 -0
  67. data/lib/tensor_stream/ops/expand_dims.rb +17 -0
  68. data/lib/tensor_stream/ops/fill.rb +19 -0
  69. data/lib/tensor_stream/ops/floor.rb +15 -0
  70. data/lib/tensor_stream/ops/floor_div.rb +15 -0
  71. data/lib/tensor_stream/ops/greater.rb +11 -0
  72. data/lib/tensor_stream/ops/greater_equal.rb +11 -0
  73. data/lib/tensor_stream/ops/less_equal.rb +15 -0
  74. data/lib/tensor_stream/ops/log.rb +14 -0
  75. data/lib/tensor_stream/ops/mat_mul.rb +60 -0
  76. data/lib/tensor_stream/ops/max.rb +15 -0
  77. data/lib/tensor_stream/ops/min.rb +15 -0
  78. data/lib/tensor_stream/ops/mod.rb +23 -0
  79. data/lib/tensor_stream/ops/mul.rb +21 -0
  80. data/lib/tensor_stream/ops/negate.rb +14 -0
  81. data/lib/tensor_stream/ops/ones_like.rb +19 -0
  82. data/lib/tensor_stream/ops/pow.rb +25 -0
  83. data/lib/tensor_stream/ops/prod.rb +60 -0
  84. data/lib/tensor_stream/ops/random_uniform.rb +18 -0
  85. data/lib/tensor_stream/ops/range.rb +20 -0
  86. data/lib/tensor_stream/ops/rank.rb +13 -0
  87. data/lib/tensor_stream/ops/reshape.rb +24 -0
  88. data/lib/tensor_stream/ops/round.rb +15 -0
  89. data/lib/tensor_stream/ops/shape.rb +14 -0
  90. data/lib/tensor_stream/ops/sigmoid.rb +10 -0
  91. data/lib/tensor_stream/ops/sign.rb +12 -0
  92. data/lib/tensor_stream/ops/sin.rb +10 -0
  93. data/lib/tensor_stream/ops/size.rb +16 -0
  94. data/lib/tensor_stream/ops/sub.rb +24 -0
  95. data/lib/tensor_stream/ops/sum.rb +27 -0
  96. data/lib/tensor_stream/ops/tan.rb +12 -0
  97. data/lib/tensor_stream/ops/tanh.rb +10 -0
  98. data/lib/tensor_stream/ops/tile.rb +19 -0
  99. data/lib/tensor_stream/ops/zeros.rb +15 -0
  100. data/lib/tensor_stream/placeholder.rb +2 -2
  101. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  102. data/lib/tensor_stream/session.rb +36 -38
  103. data/lib/tensor_stream/tensor.rb +2 -2
  104. data/lib/tensor_stream/tensor_shape.rb +4 -4
  105. data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
  106. data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
  107. data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
  108. data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
  109. data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
  110. data/lib/tensor_stream/train/optimizer.rb +9 -9
  111. data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
  112. data/lib/tensor_stream/train/saver.rb +14 -14
  113. data/lib/tensor_stream/train/slot_creator.rb +6 -6
  114. data/lib/tensor_stream/train/utils.rb +12 -12
  115. data/lib/tensor_stream/trainer.rb +10 -10
  116. data/lib/tensor_stream/types.rb +1 -1
  117. data/lib/tensor_stream/utils.rb +33 -32
  118. data/lib/tensor_stream/utils/freezer.rb +5 -5
  119. data/lib/tensor_stream/variable.rb +5 -5
  120. data/lib/tensor_stream/variable_scope.rb +1 -1
  121. data/lib/tensor_stream/version.rb +1 -1
  122. data/samples/{iris.data → datasets/iris.data} +0 -0
  123. data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
  124. data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
  125. data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
  126. data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
  127. data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
  128. data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
  129. data/samples/regression/linear_regression.rb +63 -0
  130. data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
  131. data/tensor_stream.gemspec +9 -8
  132. metadata +89 -19
  133. data/data_1.json +0 -4764
  134. data/data_2.json +0 -4764
  135. data/data_actual.json +0 -28
  136. data/data_expected.json +0 -28
  137. data/data_input.json +0 -28
  138. data/samples/error.graphml +0 -2755
  139. data/samples/gradient_sample.graphml +0 -1255
  140. data/samples/linear_regression.rb +0 -69
  141. data/samples/multigpu.rb +0 -73
  142. data/samples/raw_neural_net_sample.rb +0 -112
@@ -58,4 +58,4 @@ module TensorStream
58
58
  @graph
59
59
  end
60
60
  end
61
- end
61
+ end
@@ -1,7 +1,7 @@
1
1
  module TensorStream
2
2
  class GraphKeys
3
- GLOBAL_VARIABLES = 'variables'.freeze
4
- TRAINABLE_VARIABLES = 'trainable_variables'.freeze
5
- GLOBAL_STEP = 'global_step'.freeze
3
+ GLOBAL_VARIABLES = "variables".freeze
4
+ TRAINABLE_VARIABLES = "trainable_variables".freeze
5
+ GLOBAL_STEP = "global_step".freeze
6
6
  end
7
7
  end
@@ -48,7 +48,7 @@ module TensorStream
48
48
  private
49
49
 
50
50
  def add_to_group(groups, name, arr_buf)
51
- name_parts = name.split('/')
51
+ name_parts = name.split("/")
52
52
  return false if name_parts.size < 2
53
53
 
54
54
  prefix = name_parts.shift
@@ -67,35 +67,35 @@ module TensorStream
67
67
  end
68
68
 
69
69
  def find_or_create_group(prefix, groups)
70
- if !groups[prefix]
71
- groups[prefix] = { buf: [], group: {} }
70
+ unless groups[prefix]
71
+ groups[prefix] = {buf: [], group: {}}
72
72
  end
73
73
 
74
- return groups[prefix]
74
+ groups[prefix]
75
75
  end
76
76
 
77
77
  def create_group(id, title, group)
78
78
  arr_buf = []
79
79
  arr_buf << "<node id=\"#{id}\" yfiles.foldertype=\"group\">"
80
80
  arr_buf << '<data key="d9">'
81
- arr_buf << '<y:ProxyAutoBoundsNode>'
81
+ arr_buf << "<y:ProxyAutoBoundsNode>"
82
82
  arr_buf << '<y:Realizers active="0">'
83
- arr_buf << '<y:GroupNode>'
83
+ arr_buf << "<y:GroupNode>"
84
84
  arr_buf << '<y:Fill color="#CAECFF84" transparent="false"/>'
85
85
  arr_buf << '<y:BorderStyle color="#666699" type="dotted" width="1.0"/>'
86
- arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">' + title + '</y:NodeLabel>'
86
+ arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">' + title + "</y:NodeLabel>"
87
87
  arr_buf << '<y:Shape type="roundrectangle"/>'
88
- arr_buf << '</y:GroupNode>'
89
- arr_buf << '</y:Realizers>'
90
- arr_buf << '</y:ProxyAutoBoundsNode>'
91
- arr_buf << '</data>'
88
+ arr_buf << "</y:GroupNode>"
89
+ arr_buf << "</y:Realizers>"
90
+ arr_buf << "</y:ProxyAutoBoundsNode>"
91
+ arr_buf << "</data>"
92
92
  arr_buf << '<graph edgedefault="directed" id="n105:">'
93
93
  arr_buf << group[:buf]
94
94
  group[:group].each do |k, g|
95
95
  arr_buf << create_group(k, k, g)
96
96
  end
97
- arr_buf << '</graph>'
98
- arr_buf << '</node>'
97
+ arr_buf << "</graph>"
98
+ arr_buf << "</node>"
99
99
  arr_buf
100
100
  end
101
101
 
@@ -120,17 +120,17 @@ module TensorStream
120
120
  end
121
121
  node_buf << "<data key=\"d9\">"
122
122
  node_buf << "<y:ShapeNode>"
123
- if tensor.internal?
124
- node_buf << " <y:Fill color=\"#FFFF99\" transparent=\"false\"/>"
123
+ node_buf << if tensor.internal?
124
+ " <y:Fill color=\"#FFFF99\" transparent=\"false\"/>"
125
125
  else
126
- node_buf << " <y:Fill color=\"#99CC00\" transparent=\"false\"/>"
126
+ " <y:Fill color=\"#99CC00\" transparent=\"false\"/>"
127
127
  end
128
128
  node_buf << " <y:NodeLabel alignment=\"center\">#{tensor.operation}</y:NodeLabel>"
129
129
  node_buf << "</y:ShapeNode>"
130
130
  node_buf << "</data>"
131
131
  node_buf << "</node>"
132
132
 
133
- if !add_to_group(groups, tensor.name, node_buf)
133
+ unless add_to_group(groups, tensor.name, node_buf)
134
134
  add_to_group(groups, "program/#{tensor.name}", node_buf)
135
135
  end
136
136
 
@@ -176,10 +176,10 @@ module TensorStream
176
176
  input_buf << "<y:ShapeNode>"
177
177
 
178
178
  input_buf << if input.internal?
179
- " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
- else
181
- " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
- end
179
+ " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
+ else
181
+ " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
+ end
183
183
 
184
184
  input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
185
185
 
@@ -204,7 +204,7 @@ module TensorStream
204
204
  end
205
205
 
206
206
  def _gml_string(str)
207
- str.tr('/', '-')
207
+ str.tr("/", "-")
208
208
  end
209
209
 
210
210
  def output_edge(input, tensor, arr_buf, index = 0)
@@ -215,22 +215,22 @@ module TensorStream
215
215
  arr_buf << "<y:PolyLineEdge>"
216
216
  arr_buf << "<y:EdgeLabel >"
217
217
  arr_buf << if !@last_session_context.empty?
218
- "<![CDATA[ #{_val(input)} ]]>"
219
- elsif input.shape.shape.nil?
220
- "<![CDATA[ #{input.data_type} ? ]]>"
221
- else
222
- "<![CDATA[ #{input.data_type} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
223
- end
218
+ "<![CDATA[ #{_val(input)} ]]>"
219
+ elsif input.shape.shape.nil?
220
+ "<![CDATA[ #{input.data_type} ? ]]>"
221
+ else
222
+ "<![CDATA[ #{input.data_type} #{input.shape.shape.empty? ? "scalar" : input.shape.shape.to_json} ]]>"
223
+ end
224
224
  arr_buf << "</y:EdgeLabel >"
225
225
  arr_buf << "<y:Arrows source=\"none\" target=\"standard\"/>"
226
226
  arr_buf << if index.zero?
227
- "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
228
- else
229
- "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
230
- end
227
+ "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
228
+ else
229
+ "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
230
+ end
231
231
  arr_buf << "</y:PolyLineEdge>"
232
232
  arr_buf << "</data>"
233
233
  arr_buf << "</edge>"
234
234
  end
235
235
  end
236
- end
236
+ end
@@ -1,4 +1,4 @@
1
- require 'base64'
1
+ require "base64"
2
2
 
3
3
  module TensorStream
4
4
  # Utility class to handle data type serialization
@@ -7,36 +7,36 @@ module TensorStream
7
7
  value = value.is_a?(Array) ? value.flatten : [value]
8
8
  byte_value = case data_type
9
9
  when :float64
10
- value.pack('d*')
10
+ value.pack("d*")
11
11
  when :float32, :float16, :float
12
- value.pack('f*')
12
+ value.pack("f*")
13
13
  when :uint32
14
- value.pack('L*')
14
+ value.pack("L*")
15
15
  when :int32, :int
16
- value.pack('l*')
16
+ value.pack("l*")
17
17
  when :int64
18
- value.pack('q*')
18
+ value.pack("q*")
19
19
  when :uint64
20
- value.pack('Q*')
20
+ value.pack("Q*")
21
21
  when :uint8
22
- value.pack('C*')
22
+ value.pack("C*")
23
23
  when :boolean
24
- value.map { |v| v ? 1 : 0 }.pack('C*')
24
+ value.map { |v| v ? 1 : 0 }.pack("C*")
25
25
  when :string
26
26
  if value.is_a?(Array)
27
- value.to_yaml
27
+ value.to_yaml
28
28
  else
29
- value
29
+ value
30
30
  end
31
31
  else
32
- raise "unknown type #{data_type}"
33
- end
32
+ raise "unknown type #{data_type}"
33
+ end
34
34
 
35
35
  byte_value
36
36
  end
37
37
 
38
38
  def self.pack_to_str(value, data_type)
39
- pack(value, data_type).bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
39
+ pack(value, data_type).bytes.map { |b| /[^[:print:]]/.match?(b.chr) ? "\\#{sprintf("%o", b).rjust(3, "0")}" : b.chr }.join
40
40
  end
41
41
 
42
42
  def self.unpack_from_str(content, data_type)
@@ -47,22 +47,22 @@ module TensorStream
47
47
  def self.unpack(unpacked, data_type)
48
48
  case data_type
49
49
  when :float32, :float, :float16
50
- unpacked.unpack('f*')
50
+ unpacked.unpack("f*")
51
51
  when :float64
52
- unpacked.unpack('d*')
52
+ unpacked.unpack("d*")
53
53
  when :int32, :int
54
- unpacked.unpack('L*')
54
+ unpacked.unpack("L*")
55
55
  when :uint32
56
- unpacked.unpack('l*')
56
+ unpacked.unpack("l*")
57
57
  when :int64
58
- unpacked.unpack('q*')
58
+ unpacked.unpack("q*")
59
59
  when :uint64
60
- unpacked.unpack('Q*')
60
+ unpacked.unpack("Q*")
61
61
  when :uint8
62
- unpacked.unpack('C*')
62
+ unpacked.unpack("C*")
63
63
  when :boolean
64
- unpacked.unpack('C*').map { |v| v == 1 }
64
+ unpacked.unpack("C*").map { |v| v == 1 }
65
65
  end
66
66
  end
67
67
  end
68
- end
68
+ end
@@ -12,10 +12,10 @@ module TensorStream
12
12
 
13
13
  node_keys.each do |k|
14
14
  node = if block_given?
15
- yield graph, k
16
- else
17
- graph.get_tensor_by_name(k)
18
- end
15
+ yield graph, k
16
+ else
17
+ graph.get_tensor_by_name(k)
18
+ end
19
19
 
20
20
  @lines << "node {"
21
21
  @lines << " name: #{node.name.to_json}"
@@ -27,13 +27,13 @@ module TensorStream
27
27
  @lines << " input: #{input.name.to_json}"
28
28
  end
29
29
  # type
30
- pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
30
+ pb_attr("T", "type: #{sym_to_protobuf_type(node.data_type)}")
31
31
 
32
32
  case node.operation.to_s
33
- when 'const'
34
- pb_attr('value', tensor_value(node))
35
- when 'variable_v2'
36
- pb_attr('shape', shape_buf(node, 'shape'))
33
+ when "const"
34
+ pb_attr("value", tensor_value(node))
35
+ when "variable_v2"
36
+ pb_attr("shape", shape_buf(node, "shape"))
37
37
  end
38
38
  process_options(node)
39
39
  end
@@ -50,7 +50,7 @@ module TensorStream
50
50
  def process_options(node)
51
51
  return if node.options.nil?
52
52
  node.options.reject { |_k, v| v.nil? }.each do |k, v|
53
- next if %w[name internal_name data_type].include?(k.to_s) || k.to_s.start_with?('__')
53
+ next if %w[name internal_name data_type].include?(k.to_s) || k.to_s.start_with?("__")
54
54
  @lines << " attr {"
55
55
  @lines << " key: \"#{k}\""
56
56
  @lines << " value {"
@@ -63,55 +63,51 @@ module TensorStream
63
63
  def attr_value(val, indent = 0)
64
64
  spaces = " " * indent
65
65
  case val.class.to_s
66
- when 'TrueClass', 'FalseClass'
66
+ when "TrueClass", "FalseClass"
67
67
  @lines << "#{spaces}b: #{val}"
68
- when 'Integer'
68
+ when "Integer"
69
69
  @lines << "#{spaces}i: #{val}"
70
- when 'String',
70
+ when "String",
71
71
  @lines << "#{spaces}s: #{val}"
72
- when 'Float'
72
+ when "Float"
73
73
  @lines << "#{spaces}f: #{val}"
74
- when 'Symbol'
74
+ when "Symbol"
75
75
  @lines << "#{spaces}sym: #{val}"
76
- when 'Array'
76
+ when "Array"
77
77
  @lines << "#{spaces}list {"
78
78
  val.each do |v_item|
79
79
  attr_value(v_item, indent + 2)
80
80
  end
81
81
  @lines << "#{spaces}}"
82
- when 'TensorStream::TensorShape'
82
+ when "TensorStream::TensorShape"
83
83
  @lines << "#{spaces}shape {"
84
- if val.shape
85
- val.shape.each do |dim|
86
- @lines << "#{spaces} dim {"
87
- @lines << "#{spaces} size: #{dim}"
88
- @lines << "#{spaces} }"
89
- end
84
+ val.shape&.each do |dim|
85
+ @lines << "#{spaces} dim {"
86
+ @lines << "#{spaces} size: #{dim}"
87
+ @lines << "#{spaces} }"
90
88
  end
91
89
  @lines << "#{spaces}}"
92
- when 'TensorStream::Variable'
90
+ when "TensorStream::Variable"
93
91
  else
94
92
  raise "unknown type #{val.class}"
95
93
  end
96
94
  end
97
95
 
98
96
  def pack_arr_float(float_arr)
99
- float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
97
+ float_arr.flatten.pack("f*").bytes.map { |b| /[^[:print:]]/.match?(b.chr) ? "\\#{sprintf("%o", b).rjust(3, "0")}" : b.chr }.join
100
98
  end
101
99
 
102
100
  def pack_arr_int(int_arr)
103
- int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
101
+ int_arr.flatten.pack("l*").bytes.map { |b| /[^[:print:]]/.match?(b.chr) ? "\\#{sprintf("%o", b).rjust(3, "0")}" : b.chr }.join
104
102
  end
105
103
 
106
- def shape_buf(tensor, shape_type = 'tensor_shape')
104
+ def shape_buf(tensor, shape_type = "tensor_shape")
107
105
  arr = []
108
106
  arr << " #{shape_type} {"
109
- if tensor.shape.shape
110
- tensor.shape.shape.each do |dim|
111
- arr << " dim {"
112
- arr << " size: #{dim}"
113
- arr << " }"
114
- end
107
+ tensor.shape.shape&.each do |dim|
108
+ arr << " dim {"
109
+ arr << " size: #{dim}"
110
+ arr << " }"
115
111
  end
116
112
  arr << " }"
117
113
  arr
@@ -140,14 +136,14 @@ module TensorStream
140
136
  end
141
137
  else
142
138
  val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
143
- "int_val"
144
- elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
145
- "float_val"
146
- elsif tensor.data_type == :string
147
- "string_val"
148
- else
149
- "val"
150
- end
139
+ "int_val"
140
+ elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
141
+ "float_val"
142
+ elsif tensor.data_type == :string
143
+ "string_val"
144
+ else
145
+ "val"
146
+ end
151
147
  arr << " #{val_type}: #{tensor.const_value.to_json}"
152
148
  end
153
149
  arr << "}"
@@ -186,4 +182,4 @@ module TensorStream
186
182
  @lines << " }"
187
183
  end
188
184
  end
189
- end
185
+ end
@@ -4,6 +4,7 @@ module TensorStream
4
4
  File.write(filename, get_string(tensor, session, graph_keys = nil))
5
5
  end
6
6
 
7
- def get_string(tensor, session = nil); end
7
+ def get_string(tensor, session = nil)
8
+ end
8
9
  end
9
- end
10
+ end
@@ -12,10 +12,10 @@ module TensorStream
12
12
 
13
13
  node_keys.each do |k|
14
14
  node = if block_given?
15
- yield graph, k
16
- else
17
- graph.get_tensor_by_name(k)
18
- end
15
+ yield graph, k
16
+ else
17
+ graph.get_tensor_by_name(k)
18
+ end
19
19
  next unless node.is_a?(Operation)
20
20
 
21
21
  serialized_arr << node.to_h
@@ -24,4 +24,4 @@ module TensorStream
24
24
  serialized_arr.to_yaml
25
25
  end
26
26
  end
27
- end
27
+ end
@@ -1,4 +1,4 @@
1
- require 'tensor_stream/evaluator/operation_helpers/array_ops_helper'
1
+ require "tensor_stream/evaluator/operation_helpers/array_ops_helper"
2
2
  module TensorStream
3
3
  ##
4
4
  # Convenience class for guessing the shape of a tensor
@@ -9,18 +9,35 @@ module TensorStream
9
9
 
10
10
  def self.infer_shape(tensor)
11
11
  case tensor.operation
12
+ when :assign
13
+ possible_shape = if tensor.inputs[0]&.shape&.shape
14
+ tensor.inputs[0].shape.shape
15
+ else
16
+ tensor.inputs[1].shape.shape
17
+ end
18
+
19
+ possible_shape
20
+ when :const
21
+ shape_eval(tensor.options[:value])
22
+ when :variable_v2
23
+ tensor.shape ? tensor.shape.shape : nil
24
+ when :placeholder
25
+ return nil if tensor.inputs[0].nil?
26
+ return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
27
+
28
+ TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1]
12
29
  when :case, :case_grad
13
- tensor.inputs[2].shape.shape if tensor.inputs[2]
30
+ tensor.inputs[2]&.shape&.shape
14
31
  when :const
15
32
  shape_eval(tensor.options[:value])
16
33
  when :variable_v2
17
34
  tensor.shape ? tensor.shape.shape : nil
18
35
  when :assign
19
- possible_shape = if tensor.inputs[0] && tensor.inputs[0].shape.shape
20
- tensor.inputs[0].shape.shape
21
- else
22
- tensor.inputs[1].shape.shape
23
- end
36
+ possible_shape = if tensor.inputs[0]&.shape&.shape
37
+ tensor.inputs[0].shape.shape
38
+ else
39
+ tensor.inputs[1].shape.shape
40
+ end
24
41
 
25
42
  possible_shape
26
43
  when :index
@@ -39,11 +56,11 @@ module TensorStream
39
56
 
40
57
  axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].const_value
41
58
  new_shape = tensor.inputs[0].shape.shape
42
- new_shape.each_with_index.collect do |shape, index|
59
+ new_shape.each_with_index.collect { |shape, index|
43
60
  next nil if index == axis
44
61
 
45
62
  shape
46
- end.compact
63
+ }.compact
47
64
  when :mean, :prod, :sum, :arg_max
48
65
  return [] if tensor.inputs[1].nil?
49
66
  return nil if tensor.inputs[0].nil?
@@ -58,24 +75,14 @@ module TensorStream
58
75
  axis = [axis] unless axis.is_a?(Array)
59
76
  axis = axis.map { |a| a < 0 ? rank - a.abs : a }
60
77
 
61
- input_shape.each_with_index.map do |item, index|
78
+ input_shape.each_with_index.map { |item, index|
62
79
  if axis.include?(index)
63
80
  next 1 if tensor.options[:keepdims]
64
81
 
65
82
  next nil
66
83
  end
67
84
  item
68
- end.compact
69
- when :reshape
70
- new_shape = tensor.inputs[1] && tensor.inputs[1].const_value ? tensor.inputs[1].const_value : nil
71
- return nil if new_shape.nil?
72
- return nil if tensor.inputs[0].shape.nil?
73
-
74
- input_shape = tensor.inputs[0].shape.shape
75
- return new_shape if input_shape.nil? && !new_shape.include?(-1) && !new_shape.include?(nil)
76
- return nil if input_shape.nil? || input_shape.include?(nil)
77
-
78
- TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
85
+ }.compact
79
86
  when :flow_group
80
87
  []
81
88
  when :zeros, :ones, :fill, :random_standard_normal, :random_uniform, :truncated_normal
@@ -94,28 +101,6 @@ module TensorStream
94
101
  size = tensor.inputs[0].shape.shape.reduce(:*) || 1
95
102
  dummy_tensor_for_shape = TensorShape.reshape(Array.new(size), tensor.inputs[0].shape)
96
103
  shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].const_value))
97
- when :mat_mul
98
- return nil if tensor.inputs[0].shape.shape.nil? || tensor.inputs[1].shape.shape.nil?
99
- return [] if tensor.inputs[0].shape.shape.empty? || tensor.inputs[1].shape.shape.empty?
100
- return nil if tensor.inputs[0].shape.shape.size != 2 || tensor.inputs[1].shape.shape.size != 2
101
-
102
- shape1, m = if tensor.options[:transpose_a]
103
- [tensor.inputs[0].shape.shape[0], tensor.inputs[0].shape.shape[1]]
104
- else
105
- [tensor.inputs[0].shape.shape[1], tensor.inputs[0].shape.shape[0]]
106
- end
107
-
108
- shape2, n = if tensor.options[:transpose_b]
109
- [tensor.inputs[1].shape.shape[1], tensor.inputs[1].shape.shape[0]]
110
- else
111
- [tensor.inputs[1].shape.shape[0], tensor.inputs[1].shape.shape[1]]
112
- end
113
-
114
- return nil if shape1.nil? || shape2.nil? || shape1 < 0 || shape2 < 0
115
-
116
- raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{shape1} != #{shape2}) #{tensor.inputs[0].shape.shape} vs #{tensor.inputs[1].shape.shape}" if shape1 != shape2
117
-
118
- [m, n]
119
104
  when :transpose
120
105
  return nil unless shape_full_specified(tensor.inputs[0])
121
106
  return nil if tensor.inputs[1].is_a?(Tensor)
@@ -152,10 +137,6 @@ module TensorStream
152
137
  new_shape
153
138
  when :slice, :squeeze
154
139
  nil
155
- when :tile
156
- nil
157
- when :expand_dims
158
- nil
159
140
  when :broadcast_gradient_args
160
141
  nil
161
142
  when :no_op
@@ -168,8 +149,6 @@ module TensorStream
168
149
  return [tensor.inputs[0].const_value, tensor.inputs[1].const_value] if tensor.inputs[0].const_value && tensor.inputs[1].const_value
169
150
 
170
151
  nil
171
- when :size
172
- []
173
152
  when :unstack
174
153
  return nil unless tensor.inputs[0].shape.known?
175
154
 
@@ -190,10 +169,10 @@ module TensorStream
190
169
  strides = tensor.options[:strides]
191
170
 
192
171
  case tensor.options[:padding]
193
- when 'SAME'
172
+ when "SAME"
194
173
  new_shape[1] /= strides[1]
195
174
  new_shape[2] /= strides[2]
196
- when 'VALID'
175
+ when "VALID"
197
176
  new_shape[1] = (new_shape[1] - tensor.inputs[1].shape.shape[0]) / strides[1] + 1
198
177
  new_shape[2] = (new_shape[2] - tensor.inputs[1].shape.shape[1]) / strides[2] + 1
199
178
  else
@@ -206,11 +185,32 @@ module TensorStream
206
185
 
207
186
  tensor.inputs[0].const_value
208
187
  else
209
- return nil if tensor.inputs[0].nil?
210
- return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
211
-
212
- TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1]
188
+ TensorStream::OpMaker.infer_shape(self, tensor)
213
189
  end
214
190
  end
191
+
192
+ def self._infer_reduction_op_shape(tensor)
193
+ return [] if tensor.inputs[1].nil?
194
+ return nil if tensor.inputs[0].nil?
195
+ return nil unless tensor.inputs[0].shape.known?
196
+
197
+ input_shape = tensor.inputs[0].shape.shape
198
+ rank = input_shape.size
199
+
200
+ axis = tensor.inputs[1].const_value
201
+ return nil if axis.nil?
202
+
203
+ axis = [axis] unless axis.is_a?(Array)
204
+ axis = axis.map { |a| a < 0 ? rank - a.abs : a }
205
+
206
+ input_shape.each_with_index.map { |item, index|
207
+ if axis.include?(index)
208
+ next 1 if tensor.options[:keepdims]
209
+
210
+ next nil
211
+ end
212
+ item
213
+ }.compact
214
+ end
215
215
  end
216
- end
216
+ end