tensor_stream 0.7.0 → 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. checksums.yaml +5 -5
  2. data/.rubocop.yml +6 -1
  3. data/CHANGELOG.md +10 -0
  4. data/README.md +35 -0
  5. data/lib/tensor_stream.rb +2 -2
  6. data/lib/tensor_stream/debugging/debugging.rb +2 -1
  7. data/lib/tensor_stream/dynamic_stitch.rb +23 -24
  8. data/lib/tensor_stream/evaluator/base_evaluator.rb +27 -18
  9. data/lib/tensor_stream/evaluator/opencl/kernels/apply_momentum.cl +16 -0
  10. data/lib/tensor_stream/evaluator/opencl/kernels/pack.cl +24 -0
  11. data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross.cl +6 -1
  12. data/lib/tensor_stream/evaluator/opencl/opencl_buffer.rb +6 -6
  13. data/lib/tensor_stream/evaluator/opencl/opencl_evaluator.rb +237 -107
  14. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +97 -7
  15. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +230 -123
  16. data/lib/tensor_stream/exceptions.rb +1 -0
  17. data/lib/tensor_stream/graph_builder.rb +2 -3
  18. data/lib/tensor_stream/graph_deserializers/protobuf.rb +22 -23
  19. data/lib/tensor_stream/graph_serializers/graphml.rb +26 -29
  20. data/lib/tensor_stream/graph_serializers/pbtext.rb +22 -19
  21. data/lib/tensor_stream/helpers/string_helper.rb +4 -5
  22. data/lib/tensor_stream/math_gradients.rb +141 -77
  23. data/lib/tensor_stream/nn/nn_ops.rb +4 -6
  24. data/lib/tensor_stream/operation.rb +139 -120
  25. data/lib/tensor_stream/ops.rb +36 -3
  26. data/lib/tensor_stream/session.rb +7 -11
  27. data/lib/tensor_stream/tensor.rb +3 -3
  28. data/lib/tensor_stream/tensor_shape.rb +5 -0
  29. data/lib/tensor_stream/train/gradient_descent_optimizer.rb +4 -37
  30. data/lib/tensor_stream/train/momentum_optimizer.rb +48 -0
  31. data/lib/tensor_stream/train/optimizer.rb +129 -0
  32. data/lib/tensor_stream/train/saver.rb +0 -1
  33. data/lib/tensor_stream/train/slot_creator.rb +62 -0
  34. data/lib/tensor_stream/train/utils.rb +11 -12
  35. data/lib/tensor_stream/trainer.rb +3 -0
  36. data/lib/tensor_stream/utils.rb +18 -11
  37. data/lib/tensor_stream/variable.rb +19 -12
  38. data/lib/tensor_stream/variable_scope.rb +1 -1
  39. data/lib/tensor_stream/version.rb +1 -1
  40. data/samples/iris.rb +2 -1
  41. data/samples/linear_regression.rb +3 -1
  42. data/samples/nearest_neighbor.rb +2 -0
  43. data/test_samples/neural_network_raw.py +101 -0
  44. data/test_samples/raw_neural_net_sample.rb +6 -4
  45. data/test_samples/test2.py +73 -27
  46. metadata +9 -3
@@ -3,4 +3,5 @@ module TensorStream
3
3
  class KeyError < TensorStreamError; end
4
4
  class ValueError < TensorStreamError; end
5
5
  class InvalidArgumentError < TensorStreamError; end
6
+ class NotImplementedError < TensorStreamError; end
6
7
  end
@@ -36,9 +36,8 @@ module TensorStream
36
36
  TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
37
37
  else
38
38
  op = underscore(node['op']).to_sym
39
- unless TensorStream::Evaluator::RubyEvaluator.ops.keys.include?(op)
40
- puts "warning unsupported op #{op}"
41
- end
39
+ puts "warning unsupported op #{op}" unless TensorStream::Evaluator::RubyEvaluator.ops.key?(op)
40
+
42
41
  # map input tensor
43
42
  inputs = node['input'].map do |input|
44
43
  input[0] = '' if input.start_with?('^')
@@ -23,15 +23,15 @@ module TensorStream
23
23
  end
24
24
 
25
25
  def parse_value(value_node)
26
- if value_node['tensor']
27
- evaluate_tensor_node(value_node['tensor'])
28
- end
26
+ return unless value_node['tensor']
27
+
28
+ evaluate_tensor_node(value_node['tensor'])
29
29
  end
30
30
 
31
31
  def evaluate_tensor_node(node)
32
32
  if !node['shape'].empty? && node['tensor_content']
33
33
  content = node['tensor_content']
34
- unpacked = eval(%Q{"#{content}"})
34
+ unpacked = eval(%Q("#{content}"))
35
35
 
36
36
  if node['dtype'] == 'DT_FLOAT'
37
37
  TensorShape.reshape(unpacked.unpack('f*'), node['shape'])
@@ -45,14 +45,14 @@ module TensorStream
45
45
  else
46
46
 
47
47
  val = if node['dtype'] == 'DT_FLOAT'
48
- node['float_val'] ? node['float_val'].to_f : []
49
- elsif node['dtype'] == 'DT_INT32'
50
- node['int_val'] ? node['int_val'].to_i : []
51
- elsif node['dtype'] == 'DT_STRING'
52
- node['string_val']
53
- else
54
- raise "unknown dtype #{node['dtype']}"
55
- end
48
+ node['float_val'] ? node['float_val'].to_f : []
49
+ elsif node['dtype'] == 'DT_INT32'
50
+ node['int_val'] ? node['int_val'].to_i : []
51
+ elsif node['dtype'] == 'DT_STRING'
52
+ node['string_val']
53
+ else
54
+ raise "unknown dtype #{node['dtype']}"
55
+ end
56
56
 
57
57
  if node['shape'] == [1]
58
58
  [val]
@@ -83,7 +83,7 @@ module TensorStream
83
83
  return {} if node['attributes'].nil?
84
84
 
85
85
  node['attributes'].map do |attribute|
86
- attr_type, attr_value = attribute['value'].collect { |k, v| [k, v] }.flatten(1)
86
+ attr_type, attr_value = attribute['value'].flat_map { |k, v| [k, v] }
87
87
 
88
88
  if attr_type == 'tensor'
89
89
  attr_value = evaluate_tensor_node(attr_value)
@@ -103,11 +103,10 @@ module TensorStream
103
103
  block = []
104
104
  node = {}
105
105
  node_attr = {}
106
- dim = []
107
106
  state = :top
108
107
 
109
108
  lines.each do |str|
110
- case(state)
109
+ case state
111
110
  when :top
112
111
  node['type'] = parse_node_name(str)
113
112
  state = :node_context
@@ -177,7 +176,7 @@ module TensorStream
177
176
  next
178
177
  else
179
178
  key, value = str.split(':', 2)
180
- node_attr['value'] << { key => value}
179
+ node_attr['value'] << { key => value }
181
180
  end
182
181
  when :tensor_context
183
182
  if str == 'tensor_shape {'
@@ -219,7 +218,7 @@ module TensorStream
219
218
  state = :shape_context
220
219
  next
221
220
  else
222
- key, value = str.split(':', 2)
221
+ _key, value = str.split(':', 2)
223
222
  node_attr['value']['shape'] << value.strip.to_i
224
223
  end
225
224
  when :tensor_shape_dim_context
@@ -227,7 +226,7 @@ module TensorStream
227
226
  state = :tensor_shape_context
228
227
  next
229
228
  else
230
- key, value = str.split(':', 2)
229
+ _key, value = str.split(':', 2)
231
230
  node_attr['value']['tensor']['shape'] << value.strip.to_i
232
231
  end
233
232
  end
@@ -237,7 +236,7 @@ module TensorStream
237
236
  end
238
237
 
239
238
  def parse_node_name(str)
240
- name = str.split(' ')[0]
239
+ str.split(' ')[0]
241
240
  end
242
241
 
243
242
  def process_value(value)
@@ -253,19 +252,19 @@ module TensorStream
253
252
  'n' => "\x0a", 'v' => "\x0b", 'f' => "\x0c",
254
253
  'r' => "\x0d", 'e' => "\x1b", "\\\\" => "\x5c",
255
254
  "\"" => "\x22", "'" => "\x27"
256
- }
255
+ }.freeze
257
256
 
258
257
  def unescape(str)
259
258
  # Escape all the things
260
- str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/) {
259
+ str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/) do
261
260
  if $1
262
261
  $1 == '\\' ? '\\' : UNESCAPES[$1]
263
262
  elsif $2 # escape \u0000 unicode
264
- ["#$2".hex].pack('U*')
263
+ ["#{$2}".hex].pack('U*')
265
264
  elsif $3 # escape \0xff or \xff
266
265
  [$3].pack('H2')
267
266
  end
268
- }
267
+ end
269
268
  end
270
269
  end
271
270
  end
@@ -83,7 +83,7 @@ module TensorStream
83
83
  arr_buf << '<y:GroupNode>'
84
84
  arr_buf << '<y:Fill color="#CAECFF84" transparent="false"/>'
85
85
  arr_buf << '<y:BorderStyle color="#666699" type="dotted" width="1.0"/>'
86
- arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">'+ title + '</y:NodeLabel>'
86
+ arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">' + title + '</y:NodeLabel>'
87
87
  arr_buf << '<y:Shape type="roundrectangle"/>'
88
88
  arr_buf << '</y:GroupNode>'
89
89
  arr_buf << '</y:Realizers>'
@@ -146,9 +146,9 @@ module TensorStream
146
146
  input_buf << "<node id=\"#{_gml_string(input.name)}\">"
147
147
  input_buf << "<data key=\"d0\">#{input.name}</data>"
148
148
  input_buf << "<data key=\"d2\">green</data>"
149
- if @last_session_context[input.name]
150
- input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
151
- end
149
+
150
+ input_buf << "<data key=\"d3\">#{_val(tensor)}</data>" if @last_session_context[input.name]
151
+
152
152
  input_buf << "<data key=\"d9\">"
153
153
  input_buf << "<y:ShapeNode>"
154
154
  input_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
@@ -164,9 +164,9 @@ module TensorStream
164
164
  input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
165
165
  input_buf << "</y:ShapeNode>"
166
166
  input_buf << "</data>"
167
- if @last_session_context[input.name]
168
- input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
169
- end
167
+
168
+ input_buf << "<data key=\"d3\">#{_val(tensor)}</data>" if @last_session_context[input.name]
169
+ \
170
170
  input_buf << "</node>"
171
171
  elsif input.is_a?(Tensor)
172
172
  input_buf << "<node id=\"#{_gml_string(input.name)}\">"
@@ -175,12 +175,11 @@ module TensorStream
175
175
  input_buf << "<data key=\"d9\">"
176
176
  input_buf << "<y:ShapeNode>"
177
177
 
178
- if input.internal?
179
- input_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
- else
181
- input_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
- end
183
-
178
+ input_buf << if input.internal?
179
+ " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
+ else
181
+ " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
+ end
184
183
 
185
184
  input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
186
185
 
@@ -189,7 +188,7 @@ module TensorStream
189
188
  input_buf << "</node>"
190
189
  end
191
190
 
192
- if !add_to_group(groups, input.name, input_buf)
191
+ unless add_to_group(groups, input.name, input_buf)
193
192
  if input.is_a?(Variable)
194
193
  add_to_group(groups, "variable/#{input.name}", input_buf)
195
194
  else
@@ -205,7 +204,7 @@ module TensorStream
205
204
  end
206
205
 
207
206
  def _gml_string(str)
208
- str.gsub('/','-')
207
+ str.tr('/', '-')
209
208
  end
210
209
 
211
210
  def output_edge(input, tensor, arr_buf, index = 0)
@@ -215,22 +214,20 @@ module TensorStream
215
214
 
216
215
  arr_buf << "<y:PolyLineEdge>"
217
216
  arr_buf << "<y:EdgeLabel >"
218
- if !@last_session_context.empty?
219
- arr_buf << "<![CDATA[ #{_val(input)} ]]>"
220
- else
221
- if input.shape.shape.nil?
222
- arr_buf << "<![CDATA[ #{input.data_type.to_s} ? ]]>"
223
- else
224
- arr_buf << "<![CDATA[ #{input.data_type.to_s} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
225
- end
226
- end
217
+ arr_buf << if !@last_session_context.empty?
218
+ "<![CDATA[ #{_val(input)} ]]>"
219
+ elsif input.shape.shape.nil?
220
+ "<![CDATA[ #{input.data_type} ? ]]>"
221
+ else
222
+ "<![CDATA[ #{input.data_type} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
223
+ end
227
224
  arr_buf << "</y:EdgeLabel >"
228
225
  arr_buf << "<y:Arrows source=\"none\" target=\"standard\"/>"
229
- if index == 0
230
- arr_buf << "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
231
- else
232
- arr_buf << "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
233
- end
226
+ arr_buf << if index.zero?
227
+ "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
228
+ else
229
+ "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
230
+ end
234
231
  arr_buf << "</y:PolyLineEdge>"
235
232
  arr_buf << "</data>"
236
233
  arr_buf << "</edge>"
@@ -1,4 +1,5 @@
1
1
  module TensorStream
2
+ # Parses pbtext files and loads it as a graph
2
3
  class Pbtext < TensorStream::Serializer
3
4
  include TensorStream::StringHelper
4
5
  include TensorStream::OpHelper
@@ -47,11 +48,11 @@ module TensorStream
47
48
  @lines << " attr {"
48
49
  @lines << " key: \"#{k}\""
49
50
  @lines << " value {"
50
- if (v.is_a?(TrueClass) || v.is_a?(FalseClass))
51
- @lines << " b: #{v.to_s}"
52
- elsif (v.is_a?(Integer))
51
+ if v.is_a?(TrueClass) || v.is_a?(FalseClass)
52
+ @lines << " b: #{v}"
53
+ elsif v.is_a?(Integer)
53
54
  @lines << " int_val: #{v}"
54
- elsif (v.is_a?(Float))
55
+ elsif v.is_a?(Float)
55
56
  @lines << " float_val: #{v}"
56
57
  end
57
58
  @lines << " }"
@@ -60,21 +61,23 @@ module TensorStream
60
61
  end
61
62
 
62
63
  def pack_arr_float(float_arr)
63
- float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
64
+ float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
64
65
  end
65
66
 
66
67
  def pack_arr_int(int_arr)
67
- int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
68
+ int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
68
69
  end
69
70
 
70
71
  def shape_buf(tensor, shape_type = 'tensor_shape')
71
72
  arr = []
72
73
  arr << " #{shape_type} {"
73
- tensor.shape.shape.each do |dim|
74
- arr << " dim {"
75
- arr << " size: #{dim}"
76
- arr << " }"
77
- end if tensor.shape.shape
74
+ if tensor.shape.shape
75
+ tensor.shape.shape.each do |dim|
76
+ arr << " dim {"
77
+ arr << " size: #{dim}"
78
+ arr << " }"
79
+ end
80
+ end
78
81
  arr << " }"
79
82
  arr
80
83
  end
@@ -102,14 +105,14 @@ module TensorStream
102
105
  end
103
106
  else
104
107
  val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
105
- "int_val"
106
- elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
107
- "float_val"
108
- elsif tensor.data_type == :string
109
- "string_val"
110
- else
111
- "val"
112
- end
108
+ "int_val"
109
+ elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
110
+ "float_val"
111
+ elsif tensor.data_type == :string
112
+ "string_val"
113
+ else
114
+ "val"
115
+ end
113
116
  arr << " #{val_type}: #{tensor.value.to_json}"
114
117
  end
115
118
  arr << "}"
@@ -12,11 +12,10 @@ module TensorStream
12
12
  end
13
13
 
14
14
  def underscore(string)
15
- string.gsub(/::/, '/').
16
- gsub(/([A-Z]+)([A-Z][a-z])/,'\1_\2').
17
- gsub(/([a-z\d])([A-Z])/,'\1_\2').
18
- tr("-", "_").
19
- downcase
15
+ string.gsub(/::/, '/')
16
+ .gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2')
17
+ .gsub(/([a-z\d])([A-Z])/, '\1_\2')
18
+ .tr("-", "_").downcase
20
19
  end
21
20
 
22
21
  def symbolize_keys(hash)
@@ -1,9 +1,10 @@
1
1
  module TensorStream
2
2
  # Class that provides auto-differentiation
3
+ # Most gradients are ported over from tensorflow's math_grad.py
3
4
  class MathGradients
4
5
  extend TensorStream::OpHelper
5
6
 
6
- def self.tf
7
+ def self.ts
7
8
  TensorStream
8
9
  end
9
10
 
@@ -16,7 +17,7 @@ module TensorStream
16
17
  node.consumers.include?(tensor.name) || node.equal?(tensor)
17
18
  end.compact + [wrt_dx.name]
18
19
 
19
- grad = i_op(:fill, tf.shape(tensor), tf.constant(1, dtype: wrt_dx.data_type))
20
+ grad = i_op(:fill, ts.shape(tensor), ts.constant(1, dtype: wrt_dx.data_type))
20
21
 
21
22
  _propagate(grad, tensor, wrt_dx, nodes_to_compute, options[:stop_gradients] || []) || i_op(:zeros_like, wrt_dx)
22
23
  end
@@ -41,6 +42,7 @@ module TensorStream
41
42
  end
42
43
  end
43
44
 
45
+ #TODO: refactor and implement registerGradient
44
46
  def self._compute_derivative(node, grad)
45
47
  node.graph.name_scope("#{node.name}_grad") do
46
48
  x = node.inputs[0] if node.inputs[0]
@@ -51,116 +53,161 @@ module TensorStream
51
53
  return [grad] * node.inputs.size
52
54
  when :add
53
55
  return [grad, grad] if shapes_fully_specified_and_equal(x, y)
54
- sx = tf.shape(x, name: 'add/shape_x')
55
- sy = tf.shape(y, name: 'add/shape_y')
56
+ sx = ts.shape(x, name: 'add/shape_x')
57
+ sy = ts.shape(y, name: 'add/shape_y')
56
58
  rx, ry = _broadcast_gradient_args(sx, sy)
57
59
 
58
- [tf.reshape(tf.reduce_sum(grad, rx, name: 'add/reduce_sum_x'), sx),
59
- tf.reshape(tf.reduce_sum(grad, ry, name: 'add/reduce_sum_y'), sy)]
60
+ [ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sum_x'), sx),
61
+ ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sum_y'), sy)]
60
62
  when :asin
61
- tf.control_dependencies([grad]) do
62
- x2 = tf.square(x)
63
- one = tf.constant(1, dtype: grad.data_type)
64
- den = tf.sqrt(tf.subtract(one, x2))
65
- inv = tf.reciprocal(den)
63
+ ts.control_dependencies([grad]) do
64
+ x2 = ts.square(x)
65
+ one = ts.constant(1, dtype: grad.data_type)
66
+ den = ts.sqrt(ts.subtract(one, x2))
67
+ inv = ts.reciprocal(den)
66
68
  grad * inv
67
69
  end
68
70
  when :acos
69
- tf.control_dependencies([grad]) do
70
- x2 = tf.square(x)
71
- one = tf.constant(1, dtype: grad.data_type)
72
- den = tf.sqrt(tf.subtract(one, x2))
73
- inv = tf.reciprocal(den)
71
+ ts.control_dependencies([grad]) do
72
+ x2 = ts.square(x)
73
+ one = ts.constant(1, dtype: grad.data_type)
74
+ den = ts.sqrt(ts.subtract(one, x2))
75
+ inv = ts.reciprocal(den)
74
76
  -grad * inv
75
77
  end
78
+ when :atan
79
+ ts.control_dependencies([grad]) do
80
+ x2 = ts.square(x)
81
+ one = ts.constant(1, dtype: grad.data_type)
82
+ inv = ts.reciprocal(ts.add(one, x2))
83
+ grad * inv
84
+ end
85
+ when :fill
86
+ [nil, ts.reduce_sum(grad)]
76
87
  when :sub
77
88
  return [grad, -grad] if shapes_fully_specified_and_equal(x, y)
78
89
 
79
- sx = tf.shape(x, name: 'sub/shape_x')
80
- sy = tf.shape(y, name: 'sub/shape_y')
90
+ sx = ts.shape(x, name: 'sub/shape_x')
91
+ sy = ts.shape(y, name: 'sub/shape_y')
81
92
  rx, ry = _broadcast_gradient_args(sx, sy)
82
93
 
83
- [tf.reshape(tf.reduce_sum(grad, rx, name: 'add/reduce_sub_x'), sx),
84
- -tf.reshape(tf.reduce_sum(grad, ry, name: 'add/reduce_sub_y'), sy)]
94
+ [ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sub_x'), sx),
95
+ -ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sub_y'), sy)]
85
96
  when :mul
86
- sx = tf.shape(x)
87
- sy = tf.shape(y)
97
+ sx = ts.shape(x)
98
+ sy = ts.shape(y)
88
99
  rx, ry = _broadcast_gradient_args(sx, sy)
89
100
 
90
- [tf.reshape(tf.reduce_sum(tf.mul(grad, y), rx), sx),
91
- tf.reshape(tf.reduce_sum(tf.mul(x, grad), ry), sy)]
101
+ [ts.reshape(ts.reduce_sum(ts.mul(grad, y), rx), sx),
102
+ ts.reshape(ts.reduce_sum(ts.mul(x, grad), ry), sy)]
92
103
  when :div
93
104
  sx = i_op(:shape, x)
94
105
  sy = i_op(:shape, y)
95
106
  rx, ry = _broadcast_gradient_args(sx, sy)
96
107
 
97
- [tf.reshape(tf.reduce_sum(tf.div(grad, y), rx), sx),
98
- tf.reshape(tf.reduce_sum(grad * tf.div(tf.div(-x, y), y), ry), sy)]
108
+ [ts.reshape(ts.reduce_sum(ts.div(grad, y), rx), sx),
109
+ ts.reshape(ts.reduce_sum(grad * ts.div(ts.div(-x, y), y), ry), sy)]
99
110
  when :mod
100
- sx = tf.shape(x)
101
- sy = tf.shape(y)
111
+ sx = ts.shape(x)
112
+ sy = ts.shape(y)
102
113
  rx, ry = _broadcast_gradient_args(sx, sy)
103
- floor_xy = tf.floor_div(x, y)
104
- gx = tf.reshape(tf.reduce_sum(grad, rx), sx)
105
- gy = tf.reshape(tf.reduce_sum(grad * tf.negative(floor_xy), ry), sy)
114
+ floor_xy = ts.floor_div(x, y)
115
+ gx = ts.reshape(ts.reduce_sum(grad, rx), sx)
116
+ gy = ts.reshape(ts.reduce_sum(grad * ts.negative(floor_xy), ry), sy)
106
117
 
107
118
  [gx, gy]
119
+ when :prod
120
+ input_shape = ts.shape(x)
121
+ y = ts.range(0, ts.rank(x)) if y.nil?
122
+ reduction_indices = ts.reshape(y, [-1])
123
+
124
+ output_shape_kept_dims = ts.reduced_shape(input_shape, y)
125
+ tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
126
+ grad = ts.reshape(grad, output_shape_kept_dims)
127
+ grad = ts.tile(grad, tile_scaling)
128
+
129
+ perm, reduced_num, other_num = ts.device("/cpu:0") do
130
+ rank = ts.rank(x)
131
+ reduction_indices = (reduction_indices + rank) % rank
132
+ reduced = ts.cast(reduction_indices, :int32)
133
+ idx = ts.range(0, rank)
134
+ other, = ts.setdiff1d(idx, reduced)
135
+
136
+ [ts.concat([reduced, other], 0),
137
+ ts.reduce_prod(ts.gather(input_shape, reduced)),
138
+ ts.reduce_prod(ts.gather(input_shape, other))]
139
+ end
140
+
141
+ permuted = ts.transpose(x, perm)
142
+ permuted_shape = ts.shape(permuted)
143
+
144
+ reshaped = ts.reshape(permuted, [reduced_num, other_num])
145
+
146
+ # Calculate product, leaving out the current entry
147
+ left = ts.cumprod(reshaped, axis: 0, exclusive: true)
148
+ right = ts.cumprod(reshaped, axis: 0, exclusive: true, reverse: true)
149
+ y = ts.reshape(left * right, permuted_shape)
150
+
151
+ # Invert the transpose and reshape operations.
152
+ # Make sure to set the statically known shape information through a reshape.
153
+ out = grad * ts.transpose(y, ts.invert_permutation(perm))
154
+ [ts.reshape(out, input_shape, name: 'prod'), nil]
108
155
  when :squared_difference
109
156
  sx = i_op(:shape, x)
110
157
  sy = i_op(:shape, y)
111
158
  rx, ry = _broadcast_gradient_args(sx, sy)
112
159
 
113
- x_grad = tf.mul(2.0, grad) * (x - y)
160
+ x_grad = ts.mul(2.0, grad) * (x - y)
114
161
 
115
- [tf.reshape(tf.reduce_sum(x_grad, rx), sx),
116
- tf.reshape(-tf.reduce_sum(x_grad, ry), sy)]
162
+ [ts.reshape(ts.reduce_sum(x_grad, rx), sx),
163
+ ts.reshape(-ts.reduce_sum(x_grad, ry), sy)]
117
164
  when :mat_mul
118
165
  t_a = node.options[:transpose_a]
119
166
  t_b = node.options[:transpose_b]
120
167
 
121
168
  if !t_a && !t_b
122
- grad_a = tf.matmul(grad, y, transpose_b: true)
123
- grad_b = tf.matmul(x, grad, transpose_a: true)
169
+ grad_a = ts.matmul(grad, y, transpose_b: true)
170
+ grad_b = ts.matmul(x, grad, transpose_a: true)
124
171
  elsif !ta && tb
125
- grad_a = tf.matmul(grad, y)
126
- grad_b = tf.matmul(grad, x, transpose_a: true)
172
+ grad_a = ts.matmul(grad, y)
173
+ grad_b = ts.matmul(grad, x, transpose_a: true)
127
174
  elsif t_a && !t_b
128
- grad_a = tf.matmul(y, grad, transpose_b: true)
129
- grad_b = tf.matmul(x, grad)
175
+ grad_a = ts.matmul(y, grad, transpose_b: true)
176
+ grad_b = ts.matmul(x, grad)
130
177
  elsif t_a && t_b
131
- grad_a = tf.matmul(y, grad, transpose_a: true, transpose_b: true)
132
- grad_b = tf.matmul(grad, x, transpose_a: true, transpose_b: true)
178
+ grad_a = ts.matmul(y, grad, transpose_a: true, transpose_b: true)
179
+ grad_b = ts.matmul(grad, x, transpose_a: true, transpose_b: true)
133
180
  end
134
181
 
135
182
  [grad_a, grad_b]
136
183
  when :sin
137
- grad * tf.cos(x)
184
+ grad * ts.cos(x)
138
185
  when :tanh
139
186
  grad * i_op(:tanh_grad, x)
140
187
  when :pow
141
188
  z = node
142
- sx = tf.shape(x)
143
- sy = tf.shape(y)
189
+ sx = ts.shape(x)
190
+ sy = ts.shape(y)
144
191
  rx, ry = _broadcast_gradient_args(sx, sy)
145
- gx = tf.reduce_sum(grad * y * tf.pow(x, y - 1), rx)
192
+ gx = ts.reduce_sum(grad * y * ts.pow(x, y - 1), rx)
146
193
 
147
- log_x = tf.where(x > 0, tf.log(x), tf.zeros_like(x))
148
- gy = tf.reduce_sum(grad * z * log_x, ry)
194
+ log_x = ts.where(x > 0, ts.log(x), ts.zeros_like(x))
195
+ gy = ts.reduce_sum(grad * z * log_x, ry)
149
196
 
150
197
  [gx, gy]
151
198
  when :abs
152
- grad * tf.sign(x)
199
+ grad * ts.sign(x)
153
200
  when :log
154
- grad * tf.reciprocal(x)
201
+ grad * ts.reciprocal(x)
155
202
  when :cos
156
- -grad * tf.sin(x)
203
+ -grad * ts.sin(x)
157
204
  when :max
158
- _min_or_max_grad(node.inputs, grad, ->(x, y) { tf.greater_equal(x, y) } )
205
+ _min_or_max_grad(node.inputs, grad, ->(a, b) { ts.greater_equal(a, b) })
159
206
  when :min
160
- _min_or_max_grad(node.inputs, grad, ->(x, y) { tf.less_equal(x, y) } )
207
+ _min_or_max_grad(node.inputs, grad, ->(a, b) { ts.less_equal(a, b) })
161
208
  when :tan
162
- secx = tf.reciprocal(tf.cos(x))
163
- secx2 = tf.square(secx)
209
+ secx = ts.reciprocal(ts.cos(x))
210
+ secx2 = ts.square(secx)
164
211
  grad * secx2
165
212
  when :negate
166
213
  -grad
@@ -169,18 +216,25 @@ module TensorStream
169
216
  when :identity, :print
170
217
  grad
171
218
  when :sign
172
- tf.zeros(tf.shape(x), dtype: x.data_type)
219
+ ts.zeros(ts.shape(x), dtype: x.data_type)
220
+ when :tile
221
+ input_shape = ts.shape(x)
222
+ split_shape = ts.reshape(ts.transpose(ts.stack([y, input_shape])), [-1])
223
+ axes = ts.range(0, ts.size(split_shape), 2)
224
+ input_grad = ts.reduce_sum(ts.reshape(grad, split_shape), axes)
225
+
226
+ [input_grad, nil]
173
227
  when :sum
174
228
  _sum_grad(x, y, grad)
175
229
  when :reciprocal
176
- -grad * (tf.constant(1, dtype: x.dtype) / x**2)
230
+ -grad * (ts.constant(1, dtype: x.dtype) / x**2)
177
231
  when :sqrt
178
- tf.constant(1, dtype: x.dtype) / (tf.constant(2, dtype: x.dtype) * tf.sqrt(x)) * grad
232
+ ts.constant(1, dtype: x.dtype) / (ts.constant(2, dtype: x.dtype) * ts.sqrt(x)) * grad
179
233
  when :stop_gradient
180
- tf.zeros_like(grad)
234
+ ts.zeros_like(grad)
181
235
  when :square
182
- y = tf.constant(2.0, dtype: x.dtype)
183
- tf.multiply(grad, tf.multiply(x, y))
236
+ y = ts.constant(2.0, dtype: x.dtype)
237
+ ts.multiply(grad, ts.multiply(x, y))
184
238
  when :where
185
239
  x_mask = i_op(:where, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
186
240
  y_mask = i_op(:where, i_op(:zeros_like, x), i_op(:ones_like, y), pred: node.options[:pred])
@@ -191,12 +245,12 @@ module TensorStream
191
245
  [x_cond * grad, y_cond * grad]
192
246
  when :mean
193
247
  sum_grad = _sum_grad(x, y, grad)[0]
194
- input_shape = tf.shape(x)
195
- output_shape = tf.shape(node)
196
- factor = _safe_shape_div(tf.reduce_prod(input_shape), tf.reduce_prod(output_shape))
197
- tf.div(sum_grad, tf.cast(factor, sum_grad.data_type))
248
+ input_shape = ts.shape(x)
249
+ output_shape = ts.shape(node)
250
+ factor = _safe_shape_div(ts.reduce_prod(input_shape), ts.reduce_prod(output_shape))
251
+ [ts.div(sum_grad, ts.cast(factor, sum_grad.data_type)), nil]
198
252
  when :log1p
199
- grad * tf.reciprocal(i_cons(1, dtype: grad.data_type) + x)
253
+ grad * ts.reciprocal(i_cons(1, dtype: grad.data_type) + x)
200
254
  when :sigmoid
201
255
  i_op(:sigmoid_grad, x, grad)
202
256
  when :sigmoid_grad
@@ -205,7 +259,8 @@ module TensorStream
205
259
  when :softmax
206
260
  i_op(:softmax_grad, x, grad)
207
261
  when :softmax_cross_entropy_with_logits_v2
208
- [i_op(:softmax_cross_entropy_with_logits_v2_grad, x, y, grad), nil]
262
+ output = node
263
+ [_broadcast_mul(grad, output[1]), nil]
209
264
  when :floor, :ceil
210
265
  # non differentiable
211
266
  nil
@@ -215,6 +270,10 @@ module TensorStream
215
270
  when :argmin, :argmax, :floor_div
216
271
  # non differentiable
217
272
  [nil, nil]
273
+ when :transpose
274
+ return [ts.transpose(grad, ts.invert_permutation(y)), nil]
275
+ when :index
276
+ grad
218
277
  else
219
278
  raise "no derivative op for #{node.operation}"
220
279
  end
@@ -231,12 +290,12 @@ module TensorStream
231
290
  end
232
291
 
233
292
  def self._safe_shape_div(arg_x, arg_y)
234
- _op(:floor_div, arg_x, tf.maximum(arg_y, 1))
293
+ _op(:floor_div, arg_x, ts.maximum(arg_y, 1))
235
294
  end
236
295
 
237
296
  def self._sum_grad(arg_x, arg_y, grad)
238
297
  input_shape = _op(:shape, arg_x)
239
- output_shape_kept_dims = tf.reduced_shape(input_shape, arg_y)
298
+ output_shape_kept_dims = ts.reduced_shape(input_shape, arg_y)
240
299
  tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
241
300
  new_grad = _op(:reshape, grad, output_shape_kept_dims)
242
301
 
@@ -254,19 +313,24 @@ module TensorStream
254
313
  x = inputs[0]
255
314
  y = inputs[1]
256
315
  gdtype = grad.data_type
257
- sx = tf.shape(x)
258
- sy = tf.shape(y)
259
- gradshape = tf.shape(grad)
260
- zeros = tf.zeros(gradshape, dtype: gdtype)
316
+ sx = ts.shape(x)
317
+ sy = ts.shape(y)
318
+ gradshape = ts.shape(grad)
319
+ zeros = ts.zeros(gradshape, dtype: gdtype)
261
320
  xmask = selector_op.call(x, y)
262
321
  rx, ry = _broadcast_gradient_args(sx, sy)
263
- xgrad = tf.where(xmask, grad, zeros, name: 'x')
264
- ygrad = tf.where(xmask, zeros, grad, name: 'y')
265
- gx = tf.reshape(tf.reduce_sum(xgrad, rx), sx)
266
- gy = tf.reshape(tf.reduce_sum(ygrad, ry), sy)
322
+ xgrad = ts.where(xmask, grad, zeros, name: 'x')
323
+ ygrad = ts.where(xmask, zeros, grad, name: 'y')
324
+ gx = ts.reshape(ts.reduce_sum(xgrad, rx), sx)
325
+ gy = ts.reshape(ts.reduce_sum(ygrad, ry), sy)
267
326
  [gx, gy]
268
327
  end
269
328
 
329
+ def self._broadcast_mul(vec, mat)
330
+ vec = ts.expand_dims(vec, -1)
331
+ vec * mat
332
+ end
333
+
270
334
  def self._include?(arr, obj)
271
335
  arr.each { |a| return true if a.equal?(obj) }
272
336
  false