tensor_stream 0.7.0 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. checksums.yaml +5 -5
  2. data/.rubocop.yml +6 -1
  3. data/CHANGELOG.md +10 -0
  4. data/README.md +35 -0
  5. data/lib/tensor_stream.rb +2 -2
  6. data/lib/tensor_stream/debugging/debugging.rb +2 -1
  7. data/lib/tensor_stream/dynamic_stitch.rb +23 -24
  8. data/lib/tensor_stream/evaluator/base_evaluator.rb +27 -18
  9. data/lib/tensor_stream/evaluator/opencl/kernels/apply_momentum.cl +16 -0
  10. data/lib/tensor_stream/evaluator/opencl/kernels/pack.cl +24 -0
  11. data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross.cl +6 -1
  12. data/lib/tensor_stream/evaluator/opencl/opencl_buffer.rb +6 -6
  13. data/lib/tensor_stream/evaluator/opencl/opencl_evaluator.rb +237 -107
  14. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +97 -7
  15. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +230 -123
  16. data/lib/tensor_stream/exceptions.rb +1 -0
  17. data/lib/tensor_stream/graph_builder.rb +2 -3
  18. data/lib/tensor_stream/graph_deserializers/protobuf.rb +22 -23
  19. data/lib/tensor_stream/graph_serializers/graphml.rb +26 -29
  20. data/lib/tensor_stream/graph_serializers/pbtext.rb +22 -19
  21. data/lib/tensor_stream/helpers/string_helper.rb +4 -5
  22. data/lib/tensor_stream/math_gradients.rb +141 -77
  23. data/lib/tensor_stream/nn/nn_ops.rb +4 -6
  24. data/lib/tensor_stream/operation.rb +139 -120
  25. data/lib/tensor_stream/ops.rb +36 -3
  26. data/lib/tensor_stream/session.rb +7 -11
  27. data/lib/tensor_stream/tensor.rb +3 -3
  28. data/lib/tensor_stream/tensor_shape.rb +5 -0
  29. data/lib/tensor_stream/train/gradient_descent_optimizer.rb +4 -37
  30. data/lib/tensor_stream/train/momentum_optimizer.rb +48 -0
  31. data/lib/tensor_stream/train/optimizer.rb +129 -0
  32. data/lib/tensor_stream/train/saver.rb +0 -1
  33. data/lib/tensor_stream/train/slot_creator.rb +62 -0
  34. data/lib/tensor_stream/train/utils.rb +11 -12
  35. data/lib/tensor_stream/trainer.rb +3 -0
  36. data/lib/tensor_stream/utils.rb +18 -11
  37. data/lib/tensor_stream/variable.rb +19 -12
  38. data/lib/tensor_stream/variable_scope.rb +1 -1
  39. data/lib/tensor_stream/version.rb +1 -1
  40. data/samples/iris.rb +2 -1
  41. data/samples/linear_regression.rb +3 -1
  42. data/samples/nearest_neighbor.rb +2 -0
  43. data/test_samples/neural_network_raw.py +101 -0
  44. data/test_samples/raw_neural_net_sample.rb +6 -4
  45. data/test_samples/test2.py +73 -27
  46. metadata +9 -3
@@ -3,4 +3,5 @@ module TensorStream
3
3
  class KeyError < TensorStreamError; end
4
4
  class ValueError < TensorStreamError; end
5
5
  class InvalidArgumentError < TensorStreamError; end
6
+ class NotImplementedError < TensorStreamError; end
6
7
  end
@@ -36,9 +36,8 @@ module TensorStream
36
36
  TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
37
37
  else
38
38
  op = underscore(node['op']).to_sym
39
- unless TensorStream::Evaluator::RubyEvaluator.ops.keys.include?(op)
40
- puts "warning unsupported op #{op}"
41
- end
39
+ puts "warning unsupported op #{op}" unless TensorStream::Evaluator::RubyEvaluator.ops.key?(op)
40
+
42
41
  # map input tensor
43
42
  inputs = node['input'].map do |input|
44
43
  input[0] = '' if input.start_with?('^')
@@ -23,15 +23,15 @@ module TensorStream
23
23
  end
24
24
 
25
25
  def parse_value(value_node)
26
- if value_node['tensor']
27
- evaluate_tensor_node(value_node['tensor'])
28
- end
26
+ return unless value_node['tensor']
27
+
28
+ evaluate_tensor_node(value_node['tensor'])
29
29
  end
30
30
 
31
31
  def evaluate_tensor_node(node)
32
32
  if !node['shape'].empty? && node['tensor_content']
33
33
  content = node['tensor_content']
34
- unpacked = eval(%Q{"#{content}"})
34
+ unpacked = eval(%Q("#{content}"))
35
35
 
36
36
  if node['dtype'] == 'DT_FLOAT'
37
37
  TensorShape.reshape(unpacked.unpack('f*'), node['shape'])
@@ -45,14 +45,14 @@ module TensorStream
45
45
  else
46
46
 
47
47
  val = if node['dtype'] == 'DT_FLOAT'
48
- node['float_val'] ? node['float_val'].to_f : []
49
- elsif node['dtype'] == 'DT_INT32'
50
- node['int_val'] ? node['int_val'].to_i : []
51
- elsif node['dtype'] == 'DT_STRING'
52
- node['string_val']
53
- else
54
- raise "unknown dtype #{node['dtype']}"
55
- end
48
+ node['float_val'] ? node['float_val'].to_f : []
49
+ elsif node['dtype'] == 'DT_INT32'
50
+ node['int_val'] ? node['int_val'].to_i : []
51
+ elsif node['dtype'] == 'DT_STRING'
52
+ node['string_val']
53
+ else
54
+ raise "unknown dtype #{node['dtype']}"
55
+ end
56
56
 
57
57
  if node['shape'] == [1]
58
58
  [val]
@@ -83,7 +83,7 @@ module TensorStream
83
83
  return {} if node['attributes'].nil?
84
84
 
85
85
  node['attributes'].map do |attribute|
86
- attr_type, attr_value = attribute['value'].collect { |k, v| [k, v] }.flatten(1)
86
+ attr_type, attr_value = attribute['value'].flat_map { |k, v| [k, v] }
87
87
 
88
88
  if attr_type == 'tensor'
89
89
  attr_value = evaluate_tensor_node(attr_value)
@@ -103,11 +103,10 @@ module TensorStream
103
103
  block = []
104
104
  node = {}
105
105
  node_attr = {}
106
- dim = []
107
106
  state = :top
108
107
 
109
108
  lines.each do |str|
110
- case(state)
109
+ case state
111
110
  when :top
112
111
  node['type'] = parse_node_name(str)
113
112
  state = :node_context
@@ -177,7 +176,7 @@ module TensorStream
177
176
  next
178
177
  else
179
178
  key, value = str.split(':', 2)
180
- node_attr['value'] << { key => value}
179
+ node_attr['value'] << { key => value }
181
180
  end
182
181
  when :tensor_context
183
182
  if str == 'tensor_shape {'
@@ -219,7 +218,7 @@ module TensorStream
219
218
  state = :shape_context
220
219
  next
221
220
  else
222
- key, value = str.split(':', 2)
221
+ _key, value = str.split(':', 2)
223
222
  node_attr['value']['shape'] << value.strip.to_i
224
223
  end
225
224
  when :tensor_shape_dim_context
@@ -227,7 +226,7 @@ module TensorStream
227
226
  state = :tensor_shape_context
228
227
  next
229
228
  else
230
- key, value = str.split(':', 2)
229
+ _key, value = str.split(':', 2)
231
230
  node_attr['value']['tensor']['shape'] << value.strip.to_i
232
231
  end
233
232
  end
@@ -237,7 +236,7 @@ module TensorStream
237
236
  end
238
237
 
239
238
  def parse_node_name(str)
240
- name = str.split(' ')[0]
239
+ str.split(' ')[0]
241
240
  end
242
241
 
243
242
  def process_value(value)
@@ -253,19 +252,19 @@ module TensorStream
253
252
  'n' => "\x0a", 'v' => "\x0b", 'f' => "\x0c",
254
253
  'r' => "\x0d", 'e' => "\x1b", "\\\\" => "\x5c",
255
254
  "\"" => "\x22", "'" => "\x27"
256
- }
255
+ }.freeze
257
256
 
258
257
  def unescape(str)
259
258
  # Escape all the things
260
- str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/) {
259
+ str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/) do
261
260
  if $1
262
261
  $1 == '\\' ? '\\' : UNESCAPES[$1]
263
262
  elsif $2 # escape \u0000 unicode
264
- ["#$2".hex].pack('U*')
263
+ ["#{$2}".hex].pack('U*')
265
264
  elsif $3 # escape \0xff or \xff
266
265
  [$3].pack('H2')
267
266
  end
268
- }
267
+ end
269
268
  end
270
269
  end
271
270
  end
@@ -83,7 +83,7 @@ module TensorStream
83
83
  arr_buf << '<y:GroupNode>'
84
84
  arr_buf << '<y:Fill color="#CAECFF84" transparent="false"/>'
85
85
  arr_buf << '<y:BorderStyle color="#666699" type="dotted" width="1.0"/>'
86
- arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">'+ title + '</y:NodeLabel>'
86
+ arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">' + title + '</y:NodeLabel>'
87
87
  arr_buf << '<y:Shape type="roundrectangle"/>'
88
88
  arr_buf << '</y:GroupNode>'
89
89
  arr_buf << '</y:Realizers>'
@@ -146,9 +146,9 @@ module TensorStream
146
146
  input_buf << "<node id=\"#{_gml_string(input.name)}\">"
147
147
  input_buf << "<data key=\"d0\">#{input.name}</data>"
148
148
  input_buf << "<data key=\"d2\">green</data>"
149
- if @last_session_context[input.name]
150
- input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
151
- end
149
+
150
+ input_buf << "<data key=\"d3\">#{_val(tensor)}</data>" if @last_session_context[input.name]
151
+
152
152
  input_buf << "<data key=\"d9\">"
153
153
  input_buf << "<y:ShapeNode>"
154
154
  input_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
@@ -164,9 +164,9 @@ module TensorStream
164
164
  input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
165
165
  input_buf << "</y:ShapeNode>"
166
166
  input_buf << "</data>"
167
- if @last_session_context[input.name]
168
- input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
169
- end
167
+
168
+ input_buf << "<data key=\"d3\">#{_val(tensor)}</data>" if @last_session_context[input.name]
169
+ \
170
170
  input_buf << "</node>"
171
171
  elsif input.is_a?(Tensor)
172
172
  input_buf << "<node id=\"#{_gml_string(input.name)}\">"
@@ -175,12 +175,11 @@ module TensorStream
175
175
  input_buf << "<data key=\"d9\">"
176
176
  input_buf << "<y:ShapeNode>"
177
177
 
178
- if input.internal?
179
- input_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
- else
181
- input_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
- end
183
-
178
+ input_buf << if input.internal?
179
+ " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
+ else
181
+ " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
+ end
184
183
 
185
184
  input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
186
185
 
@@ -189,7 +188,7 @@ module TensorStream
189
188
  input_buf << "</node>"
190
189
  end
191
190
 
192
- if !add_to_group(groups, input.name, input_buf)
191
+ unless add_to_group(groups, input.name, input_buf)
193
192
  if input.is_a?(Variable)
194
193
  add_to_group(groups, "variable/#{input.name}", input_buf)
195
194
  else
@@ -205,7 +204,7 @@ module TensorStream
205
204
  end
206
205
 
207
206
  def _gml_string(str)
208
- str.gsub('/','-')
207
+ str.tr('/', '-')
209
208
  end
210
209
 
211
210
  def output_edge(input, tensor, arr_buf, index = 0)
@@ -215,22 +214,20 @@ module TensorStream
215
214
 
216
215
  arr_buf << "<y:PolyLineEdge>"
217
216
  arr_buf << "<y:EdgeLabel >"
218
- if !@last_session_context.empty?
219
- arr_buf << "<![CDATA[ #{_val(input)} ]]>"
220
- else
221
- if input.shape.shape.nil?
222
- arr_buf << "<![CDATA[ #{input.data_type.to_s} ? ]]>"
223
- else
224
- arr_buf << "<![CDATA[ #{input.data_type.to_s} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
225
- end
226
- end
217
+ arr_buf << if !@last_session_context.empty?
218
+ "<![CDATA[ #{_val(input)} ]]>"
219
+ elsif input.shape.shape.nil?
220
+ "<![CDATA[ #{input.data_type} ? ]]>"
221
+ else
222
+ "<![CDATA[ #{input.data_type} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
223
+ end
227
224
  arr_buf << "</y:EdgeLabel >"
228
225
  arr_buf << "<y:Arrows source=\"none\" target=\"standard\"/>"
229
- if index == 0
230
- arr_buf << "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
231
- else
232
- arr_buf << "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
233
- end
226
+ arr_buf << if index.zero?
227
+ "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
228
+ else
229
+ "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
230
+ end
234
231
  arr_buf << "</y:PolyLineEdge>"
235
232
  arr_buf << "</data>"
236
233
  arr_buf << "</edge>"
@@ -1,4 +1,5 @@
1
1
  module TensorStream
2
+ # Parses pbtext files and loads it as a graph
2
3
  class Pbtext < TensorStream::Serializer
3
4
  include TensorStream::StringHelper
4
5
  include TensorStream::OpHelper
@@ -47,11 +48,11 @@ module TensorStream
47
48
  @lines << " attr {"
48
49
  @lines << " key: \"#{k}\""
49
50
  @lines << " value {"
50
- if (v.is_a?(TrueClass) || v.is_a?(FalseClass))
51
- @lines << " b: #{v.to_s}"
52
- elsif (v.is_a?(Integer))
51
+ if v.is_a?(TrueClass) || v.is_a?(FalseClass)
52
+ @lines << " b: #{v}"
53
+ elsif v.is_a?(Integer)
53
54
  @lines << " int_val: #{v}"
54
- elsif (v.is_a?(Float))
55
+ elsif v.is_a?(Float)
55
56
  @lines << " float_val: #{v}"
56
57
  end
57
58
  @lines << " }"
@@ -60,21 +61,23 @@ module TensorStream
60
61
  end
61
62
 
62
63
  def pack_arr_float(float_arr)
63
- float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
64
+ float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
64
65
  end
65
66
 
66
67
  def pack_arr_int(int_arr)
67
- int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
68
+ int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
68
69
  end
69
70
 
70
71
  def shape_buf(tensor, shape_type = 'tensor_shape')
71
72
  arr = []
72
73
  arr << " #{shape_type} {"
73
- tensor.shape.shape.each do |dim|
74
- arr << " dim {"
75
- arr << " size: #{dim}"
76
- arr << " }"
77
- end if tensor.shape.shape
74
+ if tensor.shape.shape
75
+ tensor.shape.shape.each do |dim|
76
+ arr << " dim {"
77
+ arr << " size: #{dim}"
78
+ arr << " }"
79
+ end
80
+ end
78
81
  arr << " }"
79
82
  arr
80
83
  end
@@ -102,14 +105,14 @@ module TensorStream
102
105
  end
103
106
  else
104
107
  val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
105
- "int_val"
106
- elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
107
- "float_val"
108
- elsif tensor.data_type == :string
109
- "string_val"
110
- else
111
- "val"
112
- end
108
+ "int_val"
109
+ elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
110
+ "float_val"
111
+ elsif tensor.data_type == :string
112
+ "string_val"
113
+ else
114
+ "val"
115
+ end
113
116
  arr << " #{val_type}: #{tensor.value.to_json}"
114
117
  end
115
118
  arr << "}"
@@ -12,11 +12,10 @@ module TensorStream
12
12
  end
13
13
 
14
14
  def underscore(string)
15
- string.gsub(/::/, '/').
16
- gsub(/([A-Z]+)([A-Z][a-z])/,'\1_\2').
17
- gsub(/([a-z\d])([A-Z])/,'\1_\2').
18
- tr("-", "_").
19
- downcase
15
+ string.gsub(/::/, '/')
16
+ .gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2')
17
+ .gsub(/([a-z\d])([A-Z])/, '\1_\2')
18
+ .tr("-", "_").downcase
20
19
  end
21
20
 
22
21
  def symbolize_keys(hash)
@@ -1,9 +1,10 @@
1
1
  module TensorStream
2
2
  # Class that provides auto-differentiation
3
+ # Most gradients are ported over from tensorflow's math_grad.py
3
4
  class MathGradients
4
5
  extend TensorStream::OpHelper
5
6
 
6
- def self.tf
7
+ def self.ts
7
8
  TensorStream
8
9
  end
9
10
 
@@ -16,7 +17,7 @@ module TensorStream
16
17
  node.consumers.include?(tensor.name) || node.equal?(tensor)
17
18
  end.compact + [wrt_dx.name]
18
19
 
19
- grad = i_op(:fill, tf.shape(tensor), tf.constant(1, dtype: wrt_dx.data_type))
20
+ grad = i_op(:fill, ts.shape(tensor), ts.constant(1, dtype: wrt_dx.data_type))
20
21
 
21
22
  _propagate(grad, tensor, wrt_dx, nodes_to_compute, options[:stop_gradients] || []) || i_op(:zeros_like, wrt_dx)
22
23
  end
@@ -41,6 +42,7 @@ module TensorStream
41
42
  end
42
43
  end
43
44
 
45
+ #TODO: refactor and implement registerGradient
44
46
  def self._compute_derivative(node, grad)
45
47
  node.graph.name_scope("#{node.name}_grad") do
46
48
  x = node.inputs[0] if node.inputs[0]
@@ -51,116 +53,161 @@ module TensorStream
51
53
  return [grad] * node.inputs.size
52
54
  when :add
53
55
  return [grad, grad] if shapes_fully_specified_and_equal(x, y)
54
- sx = tf.shape(x, name: 'add/shape_x')
55
- sy = tf.shape(y, name: 'add/shape_y')
56
+ sx = ts.shape(x, name: 'add/shape_x')
57
+ sy = ts.shape(y, name: 'add/shape_y')
56
58
  rx, ry = _broadcast_gradient_args(sx, sy)
57
59
 
58
- [tf.reshape(tf.reduce_sum(grad, rx, name: 'add/reduce_sum_x'), sx),
59
- tf.reshape(tf.reduce_sum(grad, ry, name: 'add/reduce_sum_y'), sy)]
60
+ [ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sum_x'), sx),
61
+ ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sum_y'), sy)]
60
62
  when :asin
61
- tf.control_dependencies([grad]) do
62
- x2 = tf.square(x)
63
- one = tf.constant(1, dtype: grad.data_type)
64
- den = tf.sqrt(tf.subtract(one, x2))
65
- inv = tf.reciprocal(den)
63
+ ts.control_dependencies([grad]) do
64
+ x2 = ts.square(x)
65
+ one = ts.constant(1, dtype: grad.data_type)
66
+ den = ts.sqrt(ts.subtract(one, x2))
67
+ inv = ts.reciprocal(den)
66
68
  grad * inv
67
69
  end
68
70
  when :acos
69
- tf.control_dependencies([grad]) do
70
- x2 = tf.square(x)
71
- one = tf.constant(1, dtype: grad.data_type)
72
- den = tf.sqrt(tf.subtract(one, x2))
73
- inv = tf.reciprocal(den)
71
+ ts.control_dependencies([grad]) do
72
+ x2 = ts.square(x)
73
+ one = ts.constant(1, dtype: grad.data_type)
74
+ den = ts.sqrt(ts.subtract(one, x2))
75
+ inv = ts.reciprocal(den)
74
76
  -grad * inv
75
77
  end
78
+ when :atan
79
+ ts.control_dependencies([grad]) do
80
+ x2 = ts.square(x)
81
+ one = ts.constant(1, dtype: grad.data_type)
82
+ inv = ts.reciprocal(ts.add(one, x2))
83
+ grad * inv
84
+ end
85
+ when :fill
86
+ [nil, ts.reduce_sum(grad)]
76
87
  when :sub
77
88
  return [grad, -grad] if shapes_fully_specified_and_equal(x, y)
78
89
 
79
- sx = tf.shape(x, name: 'sub/shape_x')
80
- sy = tf.shape(y, name: 'sub/shape_y')
90
+ sx = ts.shape(x, name: 'sub/shape_x')
91
+ sy = ts.shape(y, name: 'sub/shape_y')
81
92
  rx, ry = _broadcast_gradient_args(sx, sy)
82
93
 
83
- [tf.reshape(tf.reduce_sum(grad, rx, name: 'add/reduce_sub_x'), sx),
84
- -tf.reshape(tf.reduce_sum(grad, ry, name: 'add/reduce_sub_y'), sy)]
94
+ [ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sub_x'), sx),
95
+ -ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sub_y'), sy)]
85
96
  when :mul
86
- sx = tf.shape(x)
87
- sy = tf.shape(y)
97
+ sx = ts.shape(x)
98
+ sy = ts.shape(y)
88
99
  rx, ry = _broadcast_gradient_args(sx, sy)
89
100
 
90
- [tf.reshape(tf.reduce_sum(tf.mul(grad, y), rx), sx),
91
- tf.reshape(tf.reduce_sum(tf.mul(x, grad), ry), sy)]
101
+ [ts.reshape(ts.reduce_sum(ts.mul(grad, y), rx), sx),
102
+ ts.reshape(ts.reduce_sum(ts.mul(x, grad), ry), sy)]
92
103
  when :div
93
104
  sx = i_op(:shape, x)
94
105
  sy = i_op(:shape, y)
95
106
  rx, ry = _broadcast_gradient_args(sx, sy)
96
107
 
97
- [tf.reshape(tf.reduce_sum(tf.div(grad, y), rx), sx),
98
- tf.reshape(tf.reduce_sum(grad * tf.div(tf.div(-x, y), y), ry), sy)]
108
+ [ts.reshape(ts.reduce_sum(ts.div(grad, y), rx), sx),
109
+ ts.reshape(ts.reduce_sum(grad * ts.div(ts.div(-x, y), y), ry), sy)]
99
110
  when :mod
100
- sx = tf.shape(x)
101
- sy = tf.shape(y)
111
+ sx = ts.shape(x)
112
+ sy = ts.shape(y)
102
113
  rx, ry = _broadcast_gradient_args(sx, sy)
103
- floor_xy = tf.floor_div(x, y)
104
- gx = tf.reshape(tf.reduce_sum(grad, rx), sx)
105
- gy = tf.reshape(tf.reduce_sum(grad * tf.negative(floor_xy), ry), sy)
114
+ floor_xy = ts.floor_div(x, y)
115
+ gx = ts.reshape(ts.reduce_sum(grad, rx), sx)
116
+ gy = ts.reshape(ts.reduce_sum(grad * ts.negative(floor_xy), ry), sy)
106
117
 
107
118
  [gx, gy]
119
+ when :prod
120
+ input_shape = ts.shape(x)
121
+ y = ts.range(0, ts.rank(x)) if y.nil?
122
+ reduction_indices = ts.reshape(y, [-1])
123
+
124
+ output_shape_kept_dims = ts.reduced_shape(input_shape, y)
125
+ tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
126
+ grad = ts.reshape(grad, output_shape_kept_dims)
127
+ grad = ts.tile(grad, tile_scaling)
128
+
129
+ perm, reduced_num, other_num = ts.device("/cpu:0") do
130
+ rank = ts.rank(x)
131
+ reduction_indices = (reduction_indices + rank) % rank
132
+ reduced = ts.cast(reduction_indices, :int32)
133
+ idx = ts.range(0, rank)
134
+ other, = ts.setdiff1d(idx, reduced)
135
+
136
+ [ts.concat([reduced, other], 0),
137
+ ts.reduce_prod(ts.gather(input_shape, reduced)),
138
+ ts.reduce_prod(ts.gather(input_shape, other))]
139
+ end
140
+
141
+ permuted = ts.transpose(x, perm)
142
+ permuted_shape = ts.shape(permuted)
143
+
144
+ reshaped = ts.reshape(permuted, [reduced_num, other_num])
145
+
146
+ # Calculate product, leaving out the current entry
147
+ left = ts.cumprod(reshaped, axis: 0, exclusive: true)
148
+ right = ts.cumprod(reshaped, axis: 0, exclusive: true, reverse: true)
149
+ y = ts.reshape(left * right, permuted_shape)
150
+
151
+ # Invert the transpose and reshape operations.
152
+ # Make sure to set the statically known shape information through a reshape.
153
+ out = grad * ts.transpose(y, ts.invert_permutation(perm))
154
+ [ts.reshape(out, input_shape, name: 'prod'), nil]
108
155
  when :squared_difference
109
156
  sx = i_op(:shape, x)
110
157
  sy = i_op(:shape, y)
111
158
  rx, ry = _broadcast_gradient_args(sx, sy)
112
159
 
113
- x_grad = tf.mul(2.0, grad) * (x - y)
160
+ x_grad = ts.mul(2.0, grad) * (x - y)
114
161
 
115
- [tf.reshape(tf.reduce_sum(x_grad, rx), sx),
116
- tf.reshape(-tf.reduce_sum(x_grad, ry), sy)]
162
+ [ts.reshape(ts.reduce_sum(x_grad, rx), sx),
163
+ ts.reshape(-ts.reduce_sum(x_grad, ry), sy)]
117
164
  when :mat_mul
118
165
  t_a = node.options[:transpose_a]
119
166
  t_b = node.options[:transpose_b]
120
167
 
121
168
  if !t_a && !t_b
122
- grad_a = tf.matmul(grad, y, transpose_b: true)
123
- grad_b = tf.matmul(x, grad, transpose_a: true)
169
+ grad_a = ts.matmul(grad, y, transpose_b: true)
170
+ grad_b = ts.matmul(x, grad, transpose_a: true)
124
171
  elsif !ta && tb
125
- grad_a = tf.matmul(grad, y)
126
- grad_b = tf.matmul(grad, x, transpose_a: true)
172
+ grad_a = ts.matmul(grad, y)
173
+ grad_b = ts.matmul(grad, x, transpose_a: true)
127
174
  elsif t_a && !t_b
128
- grad_a = tf.matmul(y, grad, transpose_b: true)
129
- grad_b = tf.matmul(x, grad)
175
+ grad_a = ts.matmul(y, grad, transpose_b: true)
176
+ grad_b = ts.matmul(x, grad)
130
177
  elsif t_a && t_b
131
- grad_a = tf.matmul(y, grad, transpose_a: true, transpose_b: true)
132
- grad_b = tf.matmul(grad, x, transpose_a: true, transpose_b: true)
178
+ grad_a = ts.matmul(y, grad, transpose_a: true, transpose_b: true)
179
+ grad_b = ts.matmul(grad, x, transpose_a: true, transpose_b: true)
133
180
  end
134
181
 
135
182
  [grad_a, grad_b]
136
183
  when :sin
137
- grad * tf.cos(x)
184
+ grad * ts.cos(x)
138
185
  when :tanh
139
186
  grad * i_op(:tanh_grad, x)
140
187
  when :pow
141
188
  z = node
142
- sx = tf.shape(x)
143
- sy = tf.shape(y)
189
+ sx = ts.shape(x)
190
+ sy = ts.shape(y)
144
191
  rx, ry = _broadcast_gradient_args(sx, sy)
145
- gx = tf.reduce_sum(grad * y * tf.pow(x, y - 1), rx)
192
+ gx = ts.reduce_sum(grad * y * ts.pow(x, y - 1), rx)
146
193
 
147
- log_x = tf.where(x > 0, tf.log(x), tf.zeros_like(x))
148
- gy = tf.reduce_sum(grad * z * log_x, ry)
194
+ log_x = ts.where(x > 0, ts.log(x), ts.zeros_like(x))
195
+ gy = ts.reduce_sum(grad * z * log_x, ry)
149
196
 
150
197
  [gx, gy]
151
198
  when :abs
152
- grad * tf.sign(x)
199
+ grad * ts.sign(x)
153
200
  when :log
154
- grad * tf.reciprocal(x)
201
+ grad * ts.reciprocal(x)
155
202
  when :cos
156
- -grad * tf.sin(x)
203
+ -grad * ts.sin(x)
157
204
  when :max
158
- _min_or_max_grad(node.inputs, grad, ->(x, y) { tf.greater_equal(x, y) } )
205
+ _min_or_max_grad(node.inputs, grad, ->(a, b) { ts.greater_equal(a, b) })
159
206
  when :min
160
- _min_or_max_grad(node.inputs, grad, ->(x, y) { tf.less_equal(x, y) } )
207
+ _min_or_max_grad(node.inputs, grad, ->(a, b) { ts.less_equal(a, b) })
161
208
  when :tan
162
- secx = tf.reciprocal(tf.cos(x))
163
- secx2 = tf.square(secx)
209
+ secx = ts.reciprocal(ts.cos(x))
210
+ secx2 = ts.square(secx)
164
211
  grad * secx2
165
212
  when :negate
166
213
  -grad
@@ -169,18 +216,25 @@ module TensorStream
169
216
  when :identity, :print
170
217
  grad
171
218
  when :sign
172
- tf.zeros(tf.shape(x), dtype: x.data_type)
219
+ ts.zeros(ts.shape(x), dtype: x.data_type)
220
+ when :tile
221
+ input_shape = ts.shape(x)
222
+ split_shape = ts.reshape(ts.transpose(ts.stack([y, input_shape])), [-1])
223
+ axes = ts.range(0, ts.size(split_shape), 2)
224
+ input_grad = ts.reduce_sum(ts.reshape(grad, split_shape), axes)
225
+
226
+ [input_grad, nil]
173
227
  when :sum
174
228
  _sum_grad(x, y, grad)
175
229
  when :reciprocal
176
- -grad * (tf.constant(1, dtype: x.dtype) / x**2)
230
+ -grad * (ts.constant(1, dtype: x.dtype) / x**2)
177
231
  when :sqrt
178
- tf.constant(1, dtype: x.dtype) / (tf.constant(2, dtype: x.dtype) * tf.sqrt(x)) * grad
232
+ ts.constant(1, dtype: x.dtype) / (ts.constant(2, dtype: x.dtype) * ts.sqrt(x)) * grad
179
233
  when :stop_gradient
180
- tf.zeros_like(grad)
234
+ ts.zeros_like(grad)
181
235
  when :square
182
- y = tf.constant(2.0, dtype: x.dtype)
183
- tf.multiply(grad, tf.multiply(x, y))
236
+ y = ts.constant(2.0, dtype: x.dtype)
237
+ ts.multiply(grad, ts.multiply(x, y))
184
238
  when :where
185
239
  x_mask = i_op(:where, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
186
240
  y_mask = i_op(:where, i_op(:zeros_like, x), i_op(:ones_like, y), pred: node.options[:pred])
@@ -191,12 +245,12 @@ module TensorStream
191
245
  [x_cond * grad, y_cond * grad]
192
246
  when :mean
193
247
  sum_grad = _sum_grad(x, y, grad)[0]
194
- input_shape = tf.shape(x)
195
- output_shape = tf.shape(node)
196
- factor = _safe_shape_div(tf.reduce_prod(input_shape), tf.reduce_prod(output_shape))
197
- tf.div(sum_grad, tf.cast(factor, sum_grad.data_type))
248
+ input_shape = ts.shape(x)
249
+ output_shape = ts.shape(node)
250
+ factor = _safe_shape_div(ts.reduce_prod(input_shape), ts.reduce_prod(output_shape))
251
+ [ts.div(sum_grad, ts.cast(factor, sum_grad.data_type)), nil]
198
252
  when :log1p
199
- grad * tf.reciprocal(i_cons(1, dtype: grad.data_type) + x)
253
+ grad * ts.reciprocal(i_cons(1, dtype: grad.data_type) + x)
200
254
  when :sigmoid
201
255
  i_op(:sigmoid_grad, x, grad)
202
256
  when :sigmoid_grad
@@ -205,7 +259,8 @@ module TensorStream
205
259
  when :softmax
206
260
  i_op(:softmax_grad, x, grad)
207
261
  when :softmax_cross_entropy_with_logits_v2
208
- [i_op(:softmax_cross_entropy_with_logits_v2_grad, x, y, grad), nil]
262
+ output = node
263
+ [_broadcast_mul(grad, output[1]), nil]
209
264
  when :floor, :ceil
210
265
  # non differentiable
211
266
  nil
@@ -215,6 +270,10 @@ module TensorStream
215
270
  when :argmin, :argmax, :floor_div
216
271
  # non differentiable
217
272
  [nil, nil]
273
+ when :transpose
274
+ return [ts.transpose(grad, ts.invert_permutation(y)), nil]
275
+ when :index
276
+ grad
218
277
  else
219
278
  raise "no derivative op for #{node.operation}"
220
279
  end
@@ -231,12 +290,12 @@ module TensorStream
231
290
  end
232
291
 
233
292
  def self._safe_shape_div(arg_x, arg_y)
234
- _op(:floor_div, arg_x, tf.maximum(arg_y, 1))
293
+ _op(:floor_div, arg_x, ts.maximum(arg_y, 1))
235
294
  end
236
295
 
237
296
  def self._sum_grad(arg_x, arg_y, grad)
238
297
  input_shape = _op(:shape, arg_x)
239
- output_shape_kept_dims = tf.reduced_shape(input_shape, arg_y)
298
+ output_shape_kept_dims = ts.reduced_shape(input_shape, arg_y)
240
299
  tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
241
300
  new_grad = _op(:reshape, grad, output_shape_kept_dims)
242
301
 
@@ -254,19 +313,24 @@ module TensorStream
254
313
  x = inputs[0]
255
314
  y = inputs[1]
256
315
  gdtype = grad.data_type
257
- sx = tf.shape(x)
258
- sy = tf.shape(y)
259
- gradshape = tf.shape(grad)
260
- zeros = tf.zeros(gradshape, dtype: gdtype)
316
+ sx = ts.shape(x)
317
+ sy = ts.shape(y)
318
+ gradshape = ts.shape(grad)
319
+ zeros = ts.zeros(gradshape, dtype: gdtype)
261
320
  xmask = selector_op.call(x, y)
262
321
  rx, ry = _broadcast_gradient_args(sx, sy)
263
- xgrad = tf.where(xmask, grad, zeros, name: 'x')
264
- ygrad = tf.where(xmask, zeros, grad, name: 'y')
265
- gx = tf.reshape(tf.reduce_sum(xgrad, rx), sx)
266
- gy = tf.reshape(tf.reduce_sum(ygrad, ry), sy)
322
+ xgrad = ts.where(xmask, grad, zeros, name: 'x')
323
+ ygrad = ts.where(xmask, zeros, grad, name: 'y')
324
+ gx = ts.reshape(ts.reduce_sum(xgrad, rx), sx)
325
+ gy = ts.reshape(ts.reduce_sum(ygrad, ry), sy)
267
326
  [gx, gy]
268
327
  end
269
328
 
329
+ def self._broadcast_mul(vec, mat)
330
+ vec = ts.expand_dims(vec, -1)
331
+ vec * mat
332
+ end
333
+
270
334
  def self._include?(arr, obj)
271
335
  arr.each { |a| return true if a.equal?(obj) }
272
336
  false