tensor_stream 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/.rubocop.yml +1 -0
  4. data/Gemfile +1 -1
  5. data/LICENSE.txt +1 -1
  6. data/README.md +34 -34
  7. data/Rakefile +3 -3
  8. data/USAGE_GUIDE.md +235 -0
  9. data/bin/stubgen +20 -0
  10. data/exe/model_utils +2 -2
  11. data/lib/tensor_stream.rb +45 -44
  12. data/lib/tensor_stream/constant.rb +2 -2
  13. data/lib/tensor_stream/control_flow.rb +1 -1
  14. data/lib/tensor_stream/debugging/debugging.rb +2 -2
  15. data/lib/tensor_stream/dynamic_stitch.rb +2 -2
  16. data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
  17. data/lib/tensor_stream/evaluator/buffer.rb +1 -1
  18. data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
  19. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
  20. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
  21. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
  22. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
  23. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
  24. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
  25. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
  26. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
  27. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
  28. data/lib/tensor_stream/exceptions.rb +1 -1
  29. data/lib/tensor_stream/generated_stub/ops.rb +691 -0
  30. data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
  31. data/lib/tensor_stream/graph.rb +18 -18
  32. data/lib/tensor_stream/graph_builder.rb +17 -17
  33. data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
  34. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
  35. data/lib/tensor_stream/graph_keys.rb +3 -3
  36. data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
  37. data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
  38. data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
  39. data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
  40. data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
  41. data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
  42. data/lib/tensor_stream/helpers/op_helper.rb +8 -9
  43. data/lib/tensor_stream/helpers/string_helper.rb +15 -15
  44. data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
  45. data/lib/tensor_stream/images.rb +1 -1
  46. data/lib/tensor_stream/initializer.rb +1 -1
  47. data/lib/tensor_stream/math_gradients.rb +28 -187
  48. data/lib/tensor_stream/monkey_patches/array.rb +1 -1
  49. data/lib/tensor_stream/monkey_patches/float.rb +1 -1
  50. data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
  51. data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
  52. data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
  53. data/lib/tensor_stream/nn/nn_ops.rb +17 -15
  54. data/lib/tensor_stream/op_maker.rb +180 -0
  55. data/lib/tensor_stream/operation.rb +17 -17
  56. data/lib/tensor_stream/ops.rb +95 -384
  57. data/lib/tensor_stream/ops/add.rb +23 -0
  58. data/lib/tensor_stream/ops/argmax.rb +14 -0
  59. data/lib/tensor_stream/ops/argmin.rb +14 -0
  60. data/lib/tensor_stream/ops/case.rb +17 -0
  61. data/lib/tensor_stream/ops/cast.rb +15 -0
  62. data/lib/tensor_stream/ops/ceil.rb +15 -0
  63. data/lib/tensor_stream/ops/const.rb +0 -0
  64. data/lib/tensor_stream/ops/cos.rb +10 -0
  65. data/lib/tensor_stream/ops/div.rb +21 -0
  66. data/lib/tensor_stream/ops/equal.rb +15 -0
  67. data/lib/tensor_stream/ops/expand_dims.rb +17 -0
  68. data/lib/tensor_stream/ops/fill.rb +19 -0
  69. data/lib/tensor_stream/ops/floor.rb +15 -0
  70. data/lib/tensor_stream/ops/floor_div.rb +15 -0
  71. data/lib/tensor_stream/ops/greater.rb +11 -0
  72. data/lib/tensor_stream/ops/greater_equal.rb +11 -0
  73. data/lib/tensor_stream/ops/less_equal.rb +15 -0
  74. data/lib/tensor_stream/ops/log.rb +14 -0
  75. data/lib/tensor_stream/ops/mat_mul.rb +60 -0
  76. data/lib/tensor_stream/ops/max.rb +15 -0
  77. data/lib/tensor_stream/ops/min.rb +15 -0
  78. data/lib/tensor_stream/ops/mod.rb +23 -0
  79. data/lib/tensor_stream/ops/mul.rb +21 -0
  80. data/lib/tensor_stream/ops/negate.rb +14 -0
  81. data/lib/tensor_stream/ops/ones_like.rb +19 -0
  82. data/lib/tensor_stream/ops/pow.rb +25 -0
  83. data/lib/tensor_stream/ops/prod.rb +60 -0
  84. data/lib/tensor_stream/ops/random_uniform.rb +18 -0
  85. data/lib/tensor_stream/ops/range.rb +20 -0
  86. data/lib/tensor_stream/ops/rank.rb +13 -0
  87. data/lib/tensor_stream/ops/reshape.rb +24 -0
  88. data/lib/tensor_stream/ops/round.rb +15 -0
  89. data/lib/tensor_stream/ops/shape.rb +14 -0
  90. data/lib/tensor_stream/ops/sigmoid.rb +10 -0
  91. data/lib/tensor_stream/ops/sign.rb +12 -0
  92. data/lib/tensor_stream/ops/sin.rb +10 -0
  93. data/lib/tensor_stream/ops/size.rb +16 -0
  94. data/lib/tensor_stream/ops/sub.rb +24 -0
  95. data/lib/tensor_stream/ops/sum.rb +27 -0
  96. data/lib/tensor_stream/ops/tan.rb +12 -0
  97. data/lib/tensor_stream/ops/tanh.rb +10 -0
  98. data/lib/tensor_stream/ops/tile.rb +19 -0
  99. data/lib/tensor_stream/ops/zeros.rb +15 -0
  100. data/lib/tensor_stream/placeholder.rb +2 -2
  101. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  102. data/lib/tensor_stream/session.rb +36 -38
  103. data/lib/tensor_stream/tensor.rb +2 -2
  104. data/lib/tensor_stream/tensor_shape.rb +4 -4
  105. data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
  106. data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
  107. data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
  108. data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
  109. data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
  110. data/lib/tensor_stream/train/optimizer.rb +9 -9
  111. data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
  112. data/lib/tensor_stream/train/saver.rb +14 -14
  113. data/lib/tensor_stream/train/slot_creator.rb +6 -6
  114. data/lib/tensor_stream/train/utils.rb +12 -12
  115. data/lib/tensor_stream/trainer.rb +10 -10
  116. data/lib/tensor_stream/types.rb +1 -1
  117. data/lib/tensor_stream/utils.rb +33 -32
  118. data/lib/tensor_stream/utils/freezer.rb +5 -5
  119. data/lib/tensor_stream/variable.rb +5 -5
  120. data/lib/tensor_stream/variable_scope.rb +1 -1
  121. data/lib/tensor_stream/version.rb +1 -1
  122. data/samples/{iris.data → datasets/iris.data} +0 -0
  123. data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
  124. data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
  125. data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
  126. data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
  127. data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
  128. data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
  129. data/samples/regression/linear_regression.rb +63 -0
  130. data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
  131. data/tensor_stream.gemspec +9 -8
  132. metadata +89 -19
  133. data/data_1.json +0 -4764
  134. data/data_2.json +0 -4764
  135. data/data_actual.json +0 -28
  136. data/data_expected.json +0 -28
  137. data/data_input.json +0 -28
  138. data/samples/error.graphml +0 -2755
  139. data/samples/gradient_sample.graphml +0 -1255
  140. data/samples/linear_regression.rb +0 -69
  141. data/samples/multigpu.rb +0 -73
  142. data/samples/raw_neural_net_sample.rb +0 -112
@@ -1,4 +1,3 @@
1
-
2
1
  module TensorStream
3
2
  # module that contains helper functions useful for ops
4
3
  module OpHelper
@@ -7,7 +6,7 @@ module TensorStream
7
6
 
8
7
  op = default_graph.add_op!(code.to_sym, *args)
9
8
  if !default_graph.get_dependency_scope.nil?
10
- i_op(:identity, op, default_graph.get_dependency_scope, name: [op.name, 'tuple', 'control_dependency'].join('/'))
9
+ i_op(:identity, op, default_graph.get_dependency_scope, name: [op.name, "tuple", "control_dependency"].join("/"))
11
10
  else
12
11
  op
13
12
  end
@@ -16,10 +15,10 @@ module TensorStream
16
15
  # same as op but with a marker that it was internal generated
17
16
  def i_op(code, *args)
18
17
  options = if args.last.is_a?(Hash)
19
- args.pop
20
- else
21
- {}
22
- end
18
+ args.pop
19
+ else
20
+ {}
21
+ end
23
22
 
24
23
  args << options.merge(internal: true)
25
24
  Graph.get_default_graph.add_op!(code.to_sym, *args)
@@ -65,8 +64,8 @@ module TensorStream
65
64
  end
66
65
 
67
66
  def format_source(trace)
68
- grad_source = trace.detect { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }
69
- source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
67
+ grad_source = trace.detect { |c| c.to_s.include?(File.join("lib", "tensor_stream", "math_gradients")) }
68
+ source = trace.reject { |c| c.to_s.include?(File.join("lib", "tensor_stream")) }.first
70
69
  [grad_source, trace].compact.join("\n")
71
70
  end
72
71
 
@@ -94,7 +93,7 @@ module TensorStream
94
93
  axes_shape = i_op(:shape, axes)
95
94
 
96
95
  TensorStream.dynamic_stitch([TensorStream.range(0, input_rank), axes],
97
- [input_shape, i_op(:fill, axes_shape, 1)])
96
+ [input_shape, i_op(:fill, axes_shape, 1)])
98
97
  end
99
98
  end
100
99
  end
@@ -4,28 +4,28 @@ module TensorStream
4
4
  module StringHelper
5
5
  def camelize(string, uppercase_first_letter = true)
6
6
  string = if uppercase_first_letter
7
- string.sub(/^[a-z\d]*/) { $&.capitalize }
8
- else
9
- string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
10
- end
11
- string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::')
7
+ string.sub(/^[a-z\d]*/) { $&.capitalize }
8
+ else
9
+ string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
10
+ end
11
+ string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub("/", "::")
12
12
  end
13
13
 
14
14
  def underscore(string)
15
- string.gsub(/::/, '/')
16
- .gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2')
17
- .gsub(/([a-z\d])([A-Z])/, '\1_\2')
18
- .tr("-", "_").downcase
15
+ string.gsub(/::/, "/").
16
+ gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2').
17
+ gsub(/([a-z\d])([A-Z])/, '\1_\2').
18
+ tr("-", "_").downcase
19
19
  end
20
20
 
21
21
  def symbolize_keys(hash)
22
- hash.map do |k, v|
22
+ hash.map { |k, v|
23
23
  [k.to_sym, v]
24
- end.to_h
24
+ }.to_h
25
25
  end
26
26
 
27
27
  def constantize(camel_cased_word)
28
- names = camel_cased_word.split('::')
28
+ names = camel_cased_word.split("::")
29
29
 
30
30
  # Trigger a built-in NameError exception including the ill-formed constant in the message.
31
31
  Object.const_get(camel_cased_word) if names.empty?
@@ -43,11 +43,11 @@ module TensorStream
43
43
 
44
44
  # Go down the ancestors to check if it is owned directly. The check
45
45
  # stops when we reach Object or the end of ancestors tree.
46
- constant = constant.ancestors.inject do |const, ancestor|
46
+ constant = constant.ancestors.inject { |const, ancestor|
47
47
  break const if ancestor == Object
48
48
  break ancestor if ancestor.const_defined?(name, false)
49
49
  const
50
- end
50
+ }
51
51
 
52
52
  # owner is in Object, so raise
53
53
  constant.const_get(name, false)
@@ -55,4 +55,4 @@ module TensorStream
55
55
  end
56
56
  end
57
57
  end
58
- end
58
+ end
@@ -1,7 +1,7 @@
1
1
  module TensorStream
2
2
  module TensorMixins
3
3
  def +(other)
4
- _a, other = TensorStream.check_data_types(self, other)
4
+ TensorStream.check_data_types(self, other)
5
5
  _op(:add, self, other)
6
6
  end
7
7
 
@@ -10,22 +10,22 @@ module TensorStream
10
10
  end
11
11
 
12
12
  def *(other)
13
- _a, other = TensorStream.check_data_types(self, other)
13
+ TensorStream.check_data_types(self, other)
14
14
  _op(:mul, self, TensorStream.convert_to_tensor(other, dtype: data_type))
15
15
  end
16
16
 
17
17
  def **(other)
18
- _a, other = TensorStream.check_data_types(self, other)
18
+ TensorStream.check_data_types(self, other)
19
19
  _op(:pow, self, TensorStream.convert_to_tensor(other, dtype: data_type))
20
20
  end
21
21
 
22
22
  def /(other)
23
- _a, other = TensorStream.check_data_types(self, other)
23
+ TensorStream.check_data_types(self, other)
24
24
  _op(:div, self, TensorStream.convert_to_tensor(other, dtype: data_type))
25
25
  end
26
26
 
27
27
  def -(other)
28
- _a, other = TensorStream.check_data_types(self, other)
28
+ TensorStream.check_data_types(self, other)
29
29
  _op(:sub, self, TensorStream.convert_to_tensor(other, dtype: data_type))
30
30
  end
31
31
 
@@ -58,51 +58,51 @@ module TensorStream
58
58
  end
59
59
 
60
60
  def zero?
61
- _op(:equal, self, TensorStream.constant(0, dtype: data_type, name: 'equal/is_zero?'))
61
+ _op(:equal, self, TensorStream.constant(0, dtype: data_type, name: "equal/is_zero?"))
62
62
  end
63
63
 
64
64
  def ==(other)
65
- _a, other = TensorStream.check_data_types(self, other)
65
+ TensorStream.check_data_types(self, other)
66
66
  _op(:equal, self, other)
67
67
  end
68
68
 
69
69
  def <(other)
70
- _a, other = TensorStream.check_data_types(self, other)
70
+ TensorStream.check_data_types(self, other)
71
71
  _op(:less, self, other)
72
72
  end
73
73
 
74
74
  def !=(other)
75
- _a, other = TensorStream.check_data_types(self, other)
75
+ TensorStream.check_data_types(self, other)
76
76
  _op(:not_equal, self, other)
77
77
  end
78
78
 
79
79
  def >(other)
80
- _a, other = TensorStream.check_data_types(self, other)
80
+ TensorStream.check_data_types(self, other)
81
81
  _op(:greater, self, other)
82
82
  end
83
83
 
84
84
  def >=(other)
85
- _a, other = TensorStream.check_data_types(self, other)
85
+ TensorStream.check_data_types(self, other)
86
86
  _op(:greater_equal, self, other)
87
87
  end
88
88
 
89
89
  def <=(other)
90
- _a, other = TensorStream.check_data_types(self, other)
90
+ TensorStream.check_data_types(self, other)
91
91
  _op(:less_equal, self, other)
92
92
  end
93
93
 
94
94
  def and(other)
95
- _a, other = TensorStream.check_data_types(self, other)
95
+ TensorStream.check_data_types(self, other)
96
96
  _op(:logical_and, self, other)
97
97
  end
98
98
 
99
99
  def matmul(other)
100
- _a, other = TensorStream.check_data_types(self, other)
100
+ TensorStream.check_data_types(self, other)
101
101
  _op(:mat_mul, self, other)
102
102
  end
103
103
 
104
104
  def dot(other)
105
- _a, other = TensorStream.check_data_types(self, other)
105
+ TensorStream.check_data_types(self, other)
106
106
  _op(:mat_mul, self, other)
107
107
  end
108
108
 
@@ -126,10 +126,10 @@ module TensorStream
126
126
  :mean
127
127
  else
128
128
  raise "unsupported reduce op type #{op_type} valid values are :+, :*, :prod, :mean"
129
- end
129
+ end
130
130
  raise "blocks are not supported for tensors" if block_given?
131
131
 
132
132
  TensorStream.reduce(reduce_op, self, axis, keepdims: keepdims, name: name)
133
133
  end
134
134
  end
135
- end
135
+ end
@@ -13,4 +13,4 @@ module TensorStream
13
13
  _op(:encode_png, contents, compression: compression, name: name, new_shape: new_shape, resample_method: resample_method)
14
14
  end
15
15
  end
16
- end
16
+ end
@@ -13,4 +13,4 @@ module TensorStream
13
13
  nil
14
14
  end
15
15
  end
16
- end
16
+ end
@@ -12,10 +12,10 @@ module TensorStream
12
12
  return i_op(:ones_like, tensor) if tensor.equal?(wrt_dx)
13
13
  return i_op(:zeros_like, wrt_dx) unless wrt_dx.consumers.include?(tensor.name)
14
14
 
15
- nodes_to_compute = wrt_dx.consumers.select do |t|
15
+ nodes_to_compute = wrt_dx.consumers.select { |t|
16
16
  node = tensor.graph.nodes[t]
17
17
  node.consumers.include?(tensor.name) || node.equal?(tensor)
18
- end.compact + [wrt_dx.name]
18
+ }.compact + [wrt_dx.name]
19
19
 
20
20
  grad = i_op(:fill, ts.shape(tensor), ts.constant(1, dtype: wrt_dx.data_type))
21
21
 
@@ -30,12 +30,12 @@ module TensorStream
30
30
  computed_op = _compute_derivative(tensor, grad)
31
31
 
32
32
  if computed_op.is_a?(Array)
33
- grads = computed_op.each_with_index.collect do |op_grad, index|
33
+ grads = computed_op.each_with_index.collect { |op_grad, index|
34
34
  next if op_grad.nil?
35
35
  next unless nodes_to_compute.include?(tensor.inputs[index].name)
36
36
 
37
37
  _propagate(op_grad, tensor.inputs[index], stop_tensor, nodes_to_compute, stop_gradients)
38
- end.compact
38
+ }.compact
39
39
 
40
40
  return nil if grads.empty?
41
41
  grads.size > 1 ? ts.add_n(grads) : grads[0]
@@ -48,7 +48,7 @@ module TensorStream
48
48
  end
49
49
  end
50
50
 
51
- #TODO: refactor and implement registerGradient
51
+ # TODO: refactor and implement registerGradient
52
52
  def self._compute_derivative(node, grad)
53
53
  node.graph.name_scope("#{node.name}_grad") do
54
54
  x = node.inputs[0] if node.inputs[0]
@@ -58,14 +58,6 @@ module TensorStream
58
58
  case node.operation
59
59
  when :add_n
60
60
  return [grad] * node.inputs.size
61
- when :add
62
- return [grad, grad] if shapes_fully_specified_and_equal(x, y)
63
- sx = ts.shape(x, name: 'add/shape_x')
64
- sy = ts.shape(y, name: 'add/shape_y')
65
- rx, ry = _broadcast_gradient_args(sx, sy)
66
-
67
- [ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sum_x'), sx),
68
- ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sum_y'), sy)]
69
61
  when :asin
70
62
  ts.control_dependencies([grad]) do
71
63
  x2 = ts.square(x)
@@ -89,75 +81,6 @@ module TensorStream
89
81
  inv = ts.reciprocal(ts.add(one, x2))
90
82
  grad * inv
91
83
  end
92
- when :fill
93
- [nil, ts.reduce_sum(grad)]
94
- when :sub
95
- return [grad, -grad] if shapes_fully_specified_and_equal(x, y)
96
-
97
- sx = ts.shape(x, name: 'sub/shape_x')
98
- sy = ts.shape(y, name: 'sub/shape_y')
99
- rx, ry = _broadcast_gradient_args(sx, sy)
100
-
101
- [ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sub_x'), sx),
102
- -ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sub_y'), sy)]
103
- when :mul
104
- sx = ts.shape(x)
105
- sy = ts.shape(y)
106
- rx, ry = _broadcast_gradient_args(sx, sy)
107
-
108
- [ts.reshape(ts.reduce_sum(ts.mul(grad, y), rx), sx),
109
- ts.reshape(ts.reduce_sum(ts.mul(x, grad), ry), sy)]
110
- when :div
111
- sx = i_op(:shape, x)
112
- sy = i_op(:shape, y)
113
- rx, ry = _broadcast_gradient_args(sx, sy)
114
-
115
- [ts.reshape(ts.reduce_sum(ts.div(grad, y), rx), sx),
116
- ts.reshape(ts.reduce_sum(grad * ts.div(ts.div(-x, y), y), ry), sy)]
117
- when :mod
118
- sx = ts.shape(x)
119
- sy = ts.shape(y)
120
- rx, ry = _broadcast_gradient_args(sx, sy)
121
- floor_xy = ts.floor_div(x, y)
122
- gx = ts.reshape(ts.reduce_sum(grad, rx), sx)
123
- gy = ts.reshape(ts.reduce_sum(grad * ts.negative(floor_xy), ry), sy)
124
-
125
- [gx, gy]
126
- when :prod
127
- input_shape = ts.shape(x)
128
- y = ts.range(0, ts.rank(x)) if y.nil?
129
- reduction_indices = ts.reshape(y, [-1])
130
-
131
- output_shape_kept_dims = ts.reduced_shape(input_shape, y)
132
- tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
133
- grad = ts.reshape(grad, output_shape_kept_dims)
134
- grad = ts.tile(grad, tile_scaling)
135
-
136
- perm, reduced_num, other_num = ts.device("/cpu:0") do
137
- rank = ts.rank(x)
138
- reduction_indices = (reduction_indices + rank) % rank
139
- reduced = ts.cast(reduction_indices, :int32)
140
- idx = ts.range(0, rank)
141
- other, = ts.setdiff1d(idx, reduced)
142
- [ts.concat([reduced, other], 0),
143
- ts.reduce_prod(ts.gather(input_shape, reduced)),
144
- ts.reduce_prod(ts.gather(input_shape, other))]
145
- end
146
-
147
- permuted = ts.transpose(x, perm)
148
- permuted_shape = ts.shape(permuted)
149
-
150
- reshaped = ts.reshape(permuted, [reduced_num, other_num])
151
-
152
- # Calculate product, leaving out the current entry
153
- left = ts.cumprod(reshaped, axis: 0, exclusive: true)
154
- right = ts.cumprod(reshaped, axis: 0, exclusive: true, reverse: true)
155
- y = ts.reshape(left * right, permuted_shape)
156
-
157
- # Invert the transpose and reshape operations.
158
- # Make sure to set the statically known shape information through a reshape.
159
- out = grad * ts.transpose(y, ts.invert_permutation(perm))
160
- [ts.reshape(out, input_shape, name: 'prod'), nil]
161
84
  when :squared_difference
162
85
  sx = i_op(:shape, x)
163
86
  sy = i_op(:shape, y)
@@ -166,63 +89,13 @@ module TensorStream
166
89
  x_grad = ts.mul(2.0, grad) * (x - y)
167
90
 
168
91
  [ts.reshape(ts.reduce_sum(x_grad, rx), sx),
169
- ts.reshape(-ts.reduce_sum(x_grad, ry), sy)]
170
- when :mat_mul
171
- t_a = node.options[:transpose_a]
172
- t_b = node.options[:transpose_b]
173
-
174
- if !t_a && !t_b
175
- grad_a = ts.matmul(grad, y, transpose_b: true)
176
- grad_b = ts.matmul(x, grad, transpose_a: true)
177
- elsif !ta && tb
178
- grad_a = ts.matmul(grad, y)
179
- grad_b = ts.matmul(grad, x, transpose_a: true)
180
- elsif t_a && !t_b
181
- grad_a = ts.matmul(y, grad, transpose_b: true)
182
- grad_b = ts.matmul(x, grad)
183
- elsif t_a && t_b
184
- grad_a = ts.matmul(y, grad, transpose_a: true, transpose_b: true)
185
- grad_b = ts.matmul(grad, x, transpose_a: true, transpose_b: true)
186
- end
187
-
188
- [grad_a, grad_b]
189
- when :sin
190
- grad * ts.cos(x)
191
- when :tanh
192
- grad * i_op(:tanh_grad, x)
193
- when :pow
194
- z = node
195
- sx = ts.shape(x)
196
- sy = ts.shape(y)
197
- rx, ry = _broadcast_gradient_args(sx, sy)
198
- gx = ts.reduce_sum(grad * y * ts.pow(x, y - 1), rx)
199
-
200
- log_x = ts.where(x > 0, ts.log(x), ts.zeros_like(x))
201
- gy = ts.reduce_sum(grad * z * log_x, ry)
202
-
203
- [gx, gy]
92
+ ts.reshape(-ts.reduce_sum(x_grad, ry), sy),]
204
93
  when :abs
205
94
  grad * ts.sign(x)
206
- when :log
207
- grad * ts.reciprocal(x)
208
- when :cos
209
- -grad * ts.sin(x)
210
- when :max
211
- _min_or_max_grad(node.inputs, grad, ->(a, b) { ts.greater_equal(a, b) })
212
- when :min
213
- _min_or_max_grad(node.inputs, grad, ->(a, b) { ts.less_equal(a, b) })
214
- when :tan
215
- secx = ts.reciprocal(ts.cos(x))
216
- secx2 = ts.square(secx)
217
- grad * secx2
218
- when :negate
219
- -grad
220
95
  when :exp
221
96
  grad * node
222
97
  when :identity, :print
223
98
  grad
224
- when :sign
225
- ts.zeros(ts.shape(x), dtype: x.data_type)
226
99
  when :tile
227
100
  input_shape = ts.shape(x)
228
101
  split_shape = ts.reshape(ts.transpose(ts.stack([y, input_shape])), [-1])
@@ -230,8 +103,6 @@ module TensorStream
230
103
  input_grad = ts.reduce_sum(ts.reshape(grad, split_shape), axes)
231
104
 
232
105
  [input_grad, nil]
233
- when :sum
234
- _sum_grad(x, y, grad)
235
106
  when :reciprocal
236
107
  -grad * (ts.constant(1, dtype: x.dtype) / x**2)
237
108
  when :sqrt
@@ -245,14 +116,6 @@ module TensorStream
245
116
  x_mask = i_op(:where, x, i_op(:ones_like, y), i_op(:zeros_like, z))
246
117
  y_mask = i_op(:where, x, i_op(:zeros_like, y), i_op(:ones_like, z))
247
118
  [nil, x_mask * grad, y_mask * grad]
248
- when :case
249
- n_preds = node.inputs.size - 2
250
-
251
- case_grads = Array.new(n_preds) do |index|
252
- i_op(:case_grad, index, node.inputs[0], node.inputs[2 + index], grad)
253
- end
254
-
255
- [nil, i_op(:case_grad, -1, node.inputs[0], node.inputs[1], grad)] + case_grads
256
119
  when :mean
257
120
  sum_grad = _sum_grad(x, y, grad)[0]
258
121
  input_shape = ts.shape(x)
@@ -261,8 +124,6 @@ module TensorStream
261
124
  [ts.div(sum_grad, ts.cast(factor, sum_grad.data_type)), nil]
262
125
  when :log1p
263
126
  grad * ts.reciprocal(i_cons(1, dtype: grad.data_type) + x)
264
- when :sigmoid
265
- i_op(:sigmoid_grad, x, grad)
266
127
  when :sigmoid_grad
267
128
  gb = grad * y
268
129
  [gb - 2.0 * gb * x, i_op(:sigmoid_grad, x, grad)]
@@ -275,15 +136,9 @@ module TensorStream
275
136
  when :sparse_softmax_cross_entropy_with_logits
276
137
  output = node
277
138
  [_broadcast_mul(grad, output[1]), nil]
278
- when :floor, :ceil, :round
279
- # non differentiable
280
- nil
281
- when :zeros_like
139
+ when :zeros_like
282
140
  # non differentiable
283
141
  nil
284
- when :argmin, :argmax, :floor_div
285
- # non differentiable
286
- [nil, nil]
287
142
  when :transpose
288
143
  return [ts.transpose(grad, ts.invert_permutation(y)), nil]
289
144
  when :index
@@ -294,19 +149,15 @@ module TensorStream
294
149
  multiplier = node.inputs[0].shape.shape[0]
295
150
  filler = ts.zeros_like(grad)
296
151
 
297
- res = Array.new(multiplier) do |index|
152
+ res = Array.new(multiplier) { |index|
298
153
  index == node.inputs[1].const_value ? grad : filler
299
- end
154
+ }
300
155
  [res]
301
156
  end
302
157
  when :squeeze
303
158
  _reshape_to_input(node, grad)
304
- when :expand_dims
305
- [_reshape_to_input(node, grad), nil]
306
159
  when :concat
307
160
  _concat_grad_helper(node, grad, 1, node.inputs.size, 0)
308
- when :reshape
309
- [ts.reshape(grad, ts.shape(node.inputs[0])), nil]
310
161
  when :stack
311
162
  res = ts.unstack(grad, num: node.inputs.size, axis: node.options[:axis])
312
163
  Array.new(node.inputs.size) { |i| res[i] }
@@ -314,18 +165,8 @@ module TensorStream
314
165
  ts.stack(grad, axis: node.options[:axis])
315
166
  when :conv2d
316
167
  _Conv2DGrad(node, grad)
317
- when :cast
318
- t = %i[float16 float32 float64]
319
- src_type = node.inputs[0].data_type
320
- dst_type = grad.data_type
321
-
322
- if t.key?(src_type) && t.key?(dst_type)
323
- ts.cast(grad, src_type)
324
- end
325
-
326
- nil
327
168
  else
328
- raise "no derivative op for #{node.operation}"
169
+ TensorStream::OpMaker.gradient_op(self, node, grad)
329
170
  end
330
171
  end
331
172
  end
@@ -373,8 +214,8 @@ module TensorStream
373
214
  zeros = ts.zeros(gradshape, dtype: gdtype)
374
215
  xmask = selector_op.call(x, y)
375
216
  rx, ry = _broadcast_gradient_args(sx, sy)
376
- xgrad = ts.where(xmask, grad, zeros, name: 'x')
377
- ygrad = ts.where(xmask, zeros, grad, name: 'y')
217
+ xgrad = ts.where(xmask, grad, zeros, name: "x")
218
+ ygrad = ts.where(xmask, zeros, grad, name: "y")
378
219
  gx = ts.reshape(ts.reduce_sum(xgrad, rx), sx)
379
220
  gy = ts.reshape(ts.reduce_sum(ygrad, ry), sy)
380
221
  [gx, gy]
@@ -435,22 +276,22 @@ module TensorStream
435
276
 
436
277
  shape_0, shape_1 = ts.shape_n([op.inputs[0], op.inputs[1]])
437
278
  [
438
- _op(:conv2d_backprop_input,
439
- shape_0,
440
- op.inputs[1],
441
- grad,
442
- strides: strides,
443
- padding: padding,
444
- use_cudnn_on_gpu: use_cudnn_on_gpu,
445
- data_format: data_format),
446
- _op(:conv2d_backprop_filter,
447
- op.inputs[0],
448
- shape_1,
449
- grad,
450
- strides: strides,
451
- padding: padding,
452
- use_cudnn_on_gpu: use_cudnn_on_gpu,
453
- data_format: data_format)
279
+ _op(:conv2d_backprop_input,
280
+ shape_0,
281
+ op.inputs[1],
282
+ grad,
283
+ strides: strides,
284
+ padding: padding,
285
+ use_cudnn_on_gpu: use_cudnn_on_gpu,
286
+ data_format: data_format),
287
+ _op(:conv2d_backprop_filter,
288
+ op.inputs[0],
289
+ shape_1,
290
+ grad,
291
+ strides: strides,
292
+ padding: padding,
293
+ use_cudnn_on_gpu: use_cudnn_on_gpu,
294
+ data_format: data_format),
454
295
  ]
455
296
  end
456
297
  end