tensor_stream 1.0.0 → 1.0.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/.rubocop.yml +1 -0
  4. data/Gemfile +1 -1
  5. data/LICENSE.txt +1 -1
  6. data/README.md +34 -34
  7. data/Rakefile +3 -3
  8. data/USAGE_GUIDE.md +235 -0
  9. data/bin/stubgen +20 -0
  10. data/exe/model_utils +2 -2
  11. data/lib/tensor_stream.rb +45 -44
  12. data/lib/tensor_stream/constant.rb +2 -2
  13. data/lib/tensor_stream/control_flow.rb +1 -1
  14. data/lib/tensor_stream/debugging/debugging.rb +2 -2
  15. data/lib/tensor_stream/dynamic_stitch.rb +2 -2
  16. data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
  17. data/lib/tensor_stream/evaluator/buffer.rb +1 -1
  18. data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
  19. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
  20. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
  21. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
  22. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
  23. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
  24. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
  25. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
  26. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
  27. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
  28. data/lib/tensor_stream/exceptions.rb +1 -1
  29. data/lib/tensor_stream/generated_stub/ops.rb +691 -0
  30. data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
  31. data/lib/tensor_stream/graph.rb +18 -18
  32. data/lib/tensor_stream/graph_builder.rb +17 -17
  33. data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
  34. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
  35. data/lib/tensor_stream/graph_keys.rb +3 -3
  36. data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
  37. data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
  38. data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
  39. data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
  40. data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
  41. data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
  42. data/lib/tensor_stream/helpers/op_helper.rb +8 -9
  43. data/lib/tensor_stream/helpers/string_helper.rb +15 -15
  44. data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
  45. data/lib/tensor_stream/images.rb +1 -1
  46. data/lib/tensor_stream/initializer.rb +1 -1
  47. data/lib/tensor_stream/math_gradients.rb +28 -187
  48. data/lib/tensor_stream/monkey_patches/array.rb +1 -1
  49. data/lib/tensor_stream/monkey_patches/float.rb +1 -1
  50. data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
  51. data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
  52. data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
  53. data/lib/tensor_stream/nn/nn_ops.rb +17 -15
  54. data/lib/tensor_stream/op_maker.rb +180 -0
  55. data/lib/tensor_stream/operation.rb +17 -17
  56. data/lib/tensor_stream/ops.rb +95 -384
  57. data/lib/tensor_stream/ops/add.rb +23 -0
  58. data/lib/tensor_stream/ops/argmax.rb +14 -0
  59. data/lib/tensor_stream/ops/argmin.rb +14 -0
  60. data/lib/tensor_stream/ops/case.rb +17 -0
  61. data/lib/tensor_stream/ops/cast.rb +15 -0
  62. data/lib/tensor_stream/ops/ceil.rb +15 -0
  63. data/lib/tensor_stream/ops/const.rb +0 -0
  64. data/lib/tensor_stream/ops/cos.rb +10 -0
  65. data/lib/tensor_stream/ops/div.rb +21 -0
  66. data/lib/tensor_stream/ops/equal.rb +15 -0
  67. data/lib/tensor_stream/ops/expand_dims.rb +17 -0
  68. data/lib/tensor_stream/ops/fill.rb +19 -0
  69. data/lib/tensor_stream/ops/floor.rb +15 -0
  70. data/lib/tensor_stream/ops/floor_div.rb +15 -0
  71. data/lib/tensor_stream/ops/greater.rb +11 -0
  72. data/lib/tensor_stream/ops/greater_equal.rb +11 -0
  73. data/lib/tensor_stream/ops/less_equal.rb +15 -0
  74. data/lib/tensor_stream/ops/log.rb +14 -0
  75. data/lib/tensor_stream/ops/mat_mul.rb +60 -0
  76. data/lib/tensor_stream/ops/max.rb +15 -0
  77. data/lib/tensor_stream/ops/min.rb +15 -0
  78. data/lib/tensor_stream/ops/mod.rb +23 -0
  79. data/lib/tensor_stream/ops/mul.rb +21 -0
  80. data/lib/tensor_stream/ops/negate.rb +14 -0
  81. data/lib/tensor_stream/ops/ones_like.rb +19 -0
  82. data/lib/tensor_stream/ops/pow.rb +25 -0
  83. data/lib/tensor_stream/ops/prod.rb +60 -0
  84. data/lib/tensor_stream/ops/random_uniform.rb +18 -0
  85. data/lib/tensor_stream/ops/range.rb +20 -0
  86. data/lib/tensor_stream/ops/rank.rb +13 -0
  87. data/lib/tensor_stream/ops/reshape.rb +24 -0
  88. data/lib/tensor_stream/ops/round.rb +15 -0
  89. data/lib/tensor_stream/ops/shape.rb +14 -0
  90. data/lib/tensor_stream/ops/sigmoid.rb +10 -0
  91. data/lib/tensor_stream/ops/sign.rb +12 -0
  92. data/lib/tensor_stream/ops/sin.rb +10 -0
  93. data/lib/tensor_stream/ops/size.rb +16 -0
  94. data/lib/tensor_stream/ops/sub.rb +24 -0
  95. data/lib/tensor_stream/ops/sum.rb +27 -0
  96. data/lib/tensor_stream/ops/tan.rb +12 -0
  97. data/lib/tensor_stream/ops/tanh.rb +10 -0
  98. data/lib/tensor_stream/ops/tile.rb +19 -0
  99. data/lib/tensor_stream/ops/zeros.rb +15 -0
  100. data/lib/tensor_stream/placeholder.rb +2 -2
  101. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  102. data/lib/tensor_stream/session.rb +36 -38
  103. data/lib/tensor_stream/tensor.rb +2 -2
  104. data/lib/tensor_stream/tensor_shape.rb +4 -4
  105. data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
  106. data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
  107. data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
  108. data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
  109. data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
  110. data/lib/tensor_stream/train/optimizer.rb +9 -9
  111. data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
  112. data/lib/tensor_stream/train/saver.rb +14 -14
  113. data/lib/tensor_stream/train/slot_creator.rb +6 -6
  114. data/lib/tensor_stream/train/utils.rb +12 -12
  115. data/lib/tensor_stream/trainer.rb +10 -10
  116. data/lib/tensor_stream/types.rb +1 -1
  117. data/lib/tensor_stream/utils.rb +33 -32
  118. data/lib/tensor_stream/utils/freezer.rb +5 -5
  119. data/lib/tensor_stream/variable.rb +5 -5
  120. data/lib/tensor_stream/variable_scope.rb +1 -1
  121. data/lib/tensor_stream/version.rb +1 -1
  122. data/samples/{iris.data → datasets/iris.data} +0 -0
  123. data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
  124. data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
  125. data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
  126. data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
  127. data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
  128. data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
  129. data/samples/regression/linear_regression.rb +63 -0
  130. data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
  131. data/tensor_stream.gemspec +9 -8
  132. metadata +89 -19
  133. data/data_1.json +0 -4764
  134. data/data_2.json +0 -4764
  135. data/data_actual.json +0 -28
  136. data/data_expected.json +0 -28
  137. data/data_input.json +0 -28
  138. data/samples/error.graphml +0 -2755
  139. data/samples/gradient_sample.graphml +0 -1255
  140. data/samples/linear_regression.rb +0 -69
  141. data/samples/multigpu.rb +0 -73
  142. data/samples/raw_neural_net_sample.rb +0 -112
@@ -1,4 +1,3 @@
1
-
2
1
  module TensorStream
3
2
  # module that contains helper functions useful for ops
4
3
  module OpHelper
@@ -7,7 +6,7 @@ module TensorStream
7
6
 
8
7
  op = default_graph.add_op!(code.to_sym, *args)
9
8
  if !default_graph.get_dependency_scope.nil?
10
- i_op(:identity, op, default_graph.get_dependency_scope, name: [op.name, 'tuple', 'control_dependency'].join('/'))
9
+ i_op(:identity, op, default_graph.get_dependency_scope, name: [op.name, "tuple", "control_dependency"].join("/"))
11
10
  else
12
11
  op
13
12
  end
@@ -16,10 +15,10 @@ module TensorStream
16
15
  # same as op but with a marker that it was internal generated
17
16
  def i_op(code, *args)
18
17
  options = if args.last.is_a?(Hash)
19
- args.pop
20
- else
21
- {}
22
- end
18
+ args.pop
19
+ else
20
+ {}
21
+ end
23
22
 
24
23
  args << options.merge(internal: true)
25
24
  Graph.get_default_graph.add_op!(code.to_sym, *args)
@@ -65,8 +64,8 @@ module TensorStream
65
64
  end
66
65
 
67
66
  def format_source(trace)
68
- grad_source = trace.detect { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }
69
- source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
67
+ grad_source = trace.detect { |c| c.to_s.include?(File.join("lib", "tensor_stream", "math_gradients")) }
68
+ source = trace.reject { |c| c.to_s.include?(File.join("lib", "tensor_stream")) }.first
70
69
  [grad_source, trace].compact.join("\n")
71
70
  end
72
71
 
@@ -94,7 +93,7 @@ module TensorStream
94
93
  axes_shape = i_op(:shape, axes)
95
94
 
96
95
  TensorStream.dynamic_stitch([TensorStream.range(0, input_rank), axes],
97
- [input_shape, i_op(:fill, axes_shape, 1)])
96
+ [input_shape, i_op(:fill, axes_shape, 1)])
98
97
  end
99
98
  end
100
99
  end
@@ -4,28 +4,28 @@ module TensorStream
4
4
  module StringHelper
5
5
  def camelize(string, uppercase_first_letter = true)
6
6
  string = if uppercase_first_letter
7
- string.sub(/^[a-z\d]*/) { $&.capitalize }
8
- else
9
- string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
10
- end
11
- string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::')
7
+ string.sub(/^[a-z\d]*/) { $&.capitalize }
8
+ else
9
+ string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
10
+ end
11
+ string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub("/", "::")
12
12
  end
13
13
 
14
14
  def underscore(string)
15
- string.gsub(/::/, '/')
16
- .gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2')
17
- .gsub(/([a-z\d])([A-Z])/, '\1_\2')
18
- .tr("-", "_").downcase
15
+ string.gsub(/::/, "/").
16
+ gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2').
17
+ gsub(/([a-z\d])([A-Z])/, '\1_\2').
18
+ tr("-", "_").downcase
19
19
  end
20
20
 
21
21
  def symbolize_keys(hash)
22
- hash.map do |k, v|
22
+ hash.map { |k, v|
23
23
  [k.to_sym, v]
24
- end.to_h
24
+ }.to_h
25
25
  end
26
26
 
27
27
  def constantize(camel_cased_word)
28
- names = camel_cased_word.split('::')
28
+ names = camel_cased_word.split("::")
29
29
 
30
30
  # Trigger a built-in NameError exception including the ill-formed constant in the message.
31
31
  Object.const_get(camel_cased_word) if names.empty?
@@ -43,11 +43,11 @@ module TensorStream
43
43
 
44
44
  # Go down the ancestors to check if it is owned directly. The check
45
45
  # stops when we reach Object or the end of ancestors tree.
46
- constant = constant.ancestors.inject do |const, ancestor|
46
+ constant = constant.ancestors.inject { |const, ancestor|
47
47
  break const if ancestor == Object
48
48
  break ancestor if ancestor.const_defined?(name, false)
49
49
  const
50
- end
50
+ }
51
51
 
52
52
  # owner is in Object, so raise
53
53
  constant.const_get(name, false)
@@ -55,4 +55,4 @@ module TensorStream
55
55
  end
56
56
  end
57
57
  end
58
- end
58
+ end
@@ -1,7 +1,7 @@
1
1
  module TensorStream
2
2
  module TensorMixins
3
3
  def +(other)
4
- _a, other = TensorStream.check_data_types(self, other)
4
+ TensorStream.check_data_types(self, other)
5
5
  _op(:add, self, other)
6
6
  end
7
7
 
@@ -10,22 +10,22 @@ module TensorStream
10
10
  end
11
11
 
12
12
  def *(other)
13
- _a, other = TensorStream.check_data_types(self, other)
13
+ TensorStream.check_data_types(self, other)
14
14
  _op(:mul, self, TensorStream.convert_to_tensor(other, dtype: data_type))
15
15
  end
16
16
 
17
17
  def **(other)
18
- _a, other = TensorStream.check_data_types(self, other)
18
+ TensorStream.check_data_types(self, other)
19
19
  _op(:pow, self, TensorStream.convert_to_tensor(other, dtype: data_type))
20
20
  end
21
21
 
22
22
  def /(other)
23
- _a, other = TensorStream.check_data_types(self, other)
23
+ TensorStream.check_data_types(self, other)
24
24
  _op(:div, self, TensorStream.convert_to_tensor(other, dtype: data_type))
25
25
  end
26
26
 
27
27
  def -(other)
28
- _a, other = TensorStream.check_data_types(self, other)
28
+ TensorStream.check_data_types(self, other)
29
29
  _op(:sub, self, TensorStream.convert_to_tensor(other, dtype: data_type))
30
30
  end
31
31
 
@@ -58,51 +58,51 @@ module TensorStream
58
58
  end
59
59
 
60
60
  def zero?
61
- _op(:equal, self, TensorStream.constant(0, dtype: data_type, name: 'equal/is_zero?'))
61
+ _op(:equal, self, TensorStream.constant(0, dtype: data_type, name: "equal/is_zero?"))
62
62
  end
63
63
 
64
64
  def ==(other)
65
- _a, other = TensorStream.check_data_types(self, other)
65
+ TensorStream.check_data_types(self, other)
66
66
  _op(:equal, self, other)
67
67
  end
68
68
 
69
69
  def <(other)
70
- _a, other = TensorStream.check_data_types(self, other)
70
+ TensorStream.check_data_types(self, other)
71
71
  _op(:less, self, other)
72
72
  end
73
73
 
74
74
  def !=(other)
75
- _a, other = TensorStream.check_data_types(self, other)
75
+ TensorStream.check_data_types(self, other)
76
76
  _op(:not_equal, self, other)
77
77
  end
78
78
 
79
79
  def >(other)
80
- _a, other = TensorStream.check_data_types(self, other)
80
+ TensorStream.check_data_types(self, other)
81
81
  _op(:greater, self, other)
82
82
  end
83
83
 
84
84
  def >=(other)
85
- _a, other = TensorStream.check_data_types(self, other)
85
+ TensorStream.check_data_types(self, other)
86
86
  _op(:greater_equal, self, other)
87
87
  end
88
88
 
89
89
  def <=(other)
90
- _a, other = TensorStream.check_data_types(self, other)
90
+ TensorStream.check_data_types(self, other)
91
91
  _op(:less_equal, self, other)
92
92
  end
93
93
 
94
94
  def and(other)
95
- _a, other = TensorStream.check_data_types(self, other)
95
+ TensorStream.check_data_types(self, other)
96
96
  _op(:logical_and, self, other)
97
97
  end
98
98
 
99
99
  def matmul(other)
100
- _a, other = TensorStream.check_data_types(self, other)
100
+ TensorStream.check_data_types(self, other)
101
101
  _op(:mat_mul, self, other)
102
102
  end
103
103
 
104
104
  def dot(other)
105
- _a, other = TensorStream.check_data_types(self, other)
105
+ TensorStream.check_data_types(self, other)
106
106
  _op(:mat_mul, self, other)
107
107
  end
108
108
 
@@ -126,10 +126,10 @@ module TensorStream
126
126
  :mean
127
127
  else
128
128
  raise "unsupported reduce op type #{op_type} valid values are :+, :*, :prod, :mean"
129
- end
129
+ end
130
130
  raise "blocks are not supported for tensors" if block_given?
131
131
 
132
132
  TensorStream.reduce(reduce_op, self, axis, keepdims: keepdims, name: name)
133
133
  end
134
134
  end
135
- end
135
+ end
@@ -13,4 +13,4 @@ module TensorStream
13
13
  _op(:encode_png, contents, compression: compression, name: name, new_shape: new_shape, resample_method: resample_method)
14
14
  end
15
15
  end
16
- end
16
+ end
@@ -13,4 +13,4 @@ module TensorStream
13
13
  nil
14
14
  end
15
15
  end
16
- end
16
+ end
@@ -12,10 +12,10 @@ module TensorStream
12
12
  return i_op(:ones_like, tensor) if tensor.equal?(wrt_dx)
13
13
  return i_op(:zeros_like, wrt_dx) unless wrt_dx.consumers.include?(tensor.name)
14
14
 
15
- nodes_to_compute = wrt_dx.consumers.select do |t|
15
+ nodes_to_compute = wrt_dx.consumers.select { |t|
16
16
  node = tensor.graph.nodes[t]
17
17
  node.consumers.include?(tensor.name) || node.equal?(tensor)
18
- end.compact + [wrt_dx.name]
18
+ }.compact + [wrt_dx.name]
19
19
 
20
20
  grad = i_op(:fill, ts.shape(tensor), ts.constant(1, dtype: wrt_dx.data_type))
21
21
 
@@ -30,12 +30,12 @@ module TensorStream
30
30
  computed_op = _compute_derivative(tensor, grad)
31
31
 
32
32
  if computed_op.is_a?(Array)
33
- grads = computed_op.each_with_index.collect do |op_grad, index|
33
+ grads = computed_op.each_with_index.collect { |op_grad, index|
34
34
  next if op_grad.nil?
35
35
  next unless nodes_to_compute.include?(tensor.inputs[index].name)
36
36
 
37
37
  _propagate(op_grad, tensor.inputs[index], stop_tensor, nodes_to_compute, stop_gradients)
38
- end.compact
38
+ }.compact
39
39
 
40
40
  return nil if grads.empty?
41
41
  grads.size > 1 ? ts.add_n(grads) : grads[0]
@@ -48,7 +48,7 @@ module TensorStream
48
48
  end
49
49
  end
50
50
 
51
- #TODO: refactor and implement registerGradient
51
+ # TODO: refactor and implement registerGradient
52
52
  def self._compute_derivative(node, grad)
53
53
  node.graph.name_scope("#{node.name}_grad") do
54
54
  x = node.inputs[0] if node.inputs[0]
@@ -58,14 +58,6 @@ module TensorStream
58
58
  case node.operation
59
59
  when :add_n
60
60
  return [grad] * node.inputs.size
61
- when :add
62
- return [grad, grad] if shapes_fully_specified_and_equal(x, y)
63
- sx = ts.shape(x, name: 'add/shape_x')
64
- sy = ts.shape(y, name: 'add/shape_y')
65
- rx, ry = _broadcast_gradient_args(sx, sy)
66
-
67
- [ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sum_x'), sx),
68
- ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sum_y'), sy)]
69
61
  when :asin
70
62
  ts.control_dependencies([grad]) do
71
63
  x2 = ts.square(x)
@@ -89,75 +81,6 @@ module TensorStream
89
81
  inv = ts.reciprocal(ts.add(one, x2))
90
82
  grad * inv
91
83
  end
92
- when :fill
93
- [nil, ts.reduce_sum(grad)]
94
- when :sub
95
- return [grad, -grad] if shapes_fully_specified_and_equal(x, y)
96
-
97
- sx = ts.shape(x, name: 'sub/shape_x')
98
- sy = ts.shape(y, name: 'sub/shape_y')
99
- rx, ry = _broadcast_gradient_args(sx, sy)
100
-
101
- [ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sub_x'), sx),
102
- -ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sub_y'), sy)]
103
- when :mul
104
- sx = ts.shape(x)
105
- sy = ts.shape(y)
106
- rx, ry = _broadcast_gradient_args(sx, sy)
107
-
108
- [ts.reshape(ts.reduce_sum(ts.mul(grad, y), rx), sx),
109
- ts.reshape(ts.reduce_sum(ts.mul(x, grad), ry), sy)]
110
- when :div
111
- sx = i_op(:shape, x)
112
- sy = i_op(:shape, y)
113
- rx, ry = _broadcast_gradient_args(sx, sy)
114
-
115
- [ts.reshape(ts.reduce_sum(ts.div(grad, y), rx), sx),
116
- ts.reshape(ts.reduce_sum(grad * ts.div(ts.div(-x, y), y), ry), sy)]
117
- when :mod
118
- sx = ts.shape(x)
119
- sy = ts.shape(y)
120
- rx, ry = _broadcast_gradient_args(sx, sy)
121
- floor_xy = ts.floor_div(x, y)
122
- gx = ts.reshape(ts.reduce_sum(grad, rx), sx)
123
- gy = ts.reshape(ts.reduce_sum(grad * ts.negative(floor_xy), ry), sy)
124
-
125
- [gx, gy]
126
- when :prod
127
- input_shape = ts.shape(x)
128
- y = ts.range(0, ts.rank(x)) if y.nil?
129
- reduction_indices = ts.reshape(y, [-1])
130
-
131
- output_shape_kept_dims = ts.reduced_shape(input_shape, y)
132
- tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
133
- grad = ts.reshape(grad, output_shape_kept_dims)
134
- grad = ts.tile(grad, tile_scaling)
135
-
136
- perm, reduced_num, other_num = ts.device("/cpu:0") do
137
- rank = ts.rank(x)
138
- reduction_indices = (reduction_indices + rank) % rank
139
- reduced = ts.cast(reduction_indices, :int32)
140
- idx = ts.range(0, rank)
141
- other, = ts.setdiff1d(idx, reduced)
142
- [ts.concat([reduced, other], 0),
143
- ts.reduce_prod(ts.gather(input_shape, reduced)),
144
- ts.reduce_prod(ts.gather(input_shape, other))]
145
- end
146
-
147
- permuted = ts.transpose(x, perm)
148
- permuted_shape = ts.shape(permuted)
149
-
150
- reshaped = ts.reshape(permuted, [reduced_num, other_num])
151
-
152
- # Calculate product, leaving out the current entry
153
- left = ts.cumprod(reshaped, axis: 0, exclusive: true)
154
- right = ts.cumprod(reshaped, axis: 0, exclusive: true, reverse: true)
155
- y = ts.reshape(left * right, permuted_shape)
156
-
157
- # Invert the transpose and reshape operations.
158
- # Make sure to set the statically known shape information through a reshape.
159
- out = grad * ts.transpose(y, ts.invert_permutation(perm))
160
- [ts.reshape(out, input_shape, name: 'prod'), nil]
161
84
  when :squared_difference
162
85
  sx = i_op(:shape, x)
163
86
  sy = i_op(:shape, y)
@@ -166,63 +89,13 @@ module TensorStream
166
89
  x_grad = ts.mul(2.0, grad) * (x - y)
167
90
 
168
91
  [ts.reshape(ts.reduce_sum(x_grad, rx), sx),
169
- ts.reshape(-ts.reduce_sum(x_grad, ry), sy)]
170
- when :mat_mul
171
- t_a = node.options[:transpose_a]
172
- t_b = node.options[:transpose_b]
173
-
174
- if !t_a && !t_b
175
- grad_a = ts.matmul(grad, y, transpose_b: true)
176
- grad_b = ts.matmul(x, grad, transpose_a: true)
177
- elsif !ta && tb
178
- grad_a = ts.matmul(grad, y)
179
- grad_b = ts.matmul(grad, x, transpose_a: true)
180
- elsif t_a && !t_b
181
- grad_a = ts.matmul(y, grad, transpose_b: true)
182
- grad_b = ts.matmul(x, grad)
183
- elsif t_a && t_b
184
- grad_a = ts.matmul(y, grad, transpose_a: true, transpose_b: true)
185
- grad_b = ts.matmul(grad, x, transpose_a: true, transpose_b: true)
186
- end
187
-
188
- [grad_a, grad_b]
189
- when :sin
190
- grad * ts.cos(x)
191
- when :tanh
192
- grad * i_op(:tanh_grad, x)
193
- when :pow
194
- z = node
195
- sx = ts.shape(x)
196
- sy = ts.shape(y)
197
- rx, ry = _broadcast_gradient_args(sx, sy)
198
- gx = ts.reduce_sum(grad * y * ts.pow(x, y - 1), rx)
199
-
200
- log_x = ts.where(x > 0, ts.log(x), ts.zeros_like(x))
201
- gy = ts.reduce_sum(grad * z * log_x, ry)
202
-
203
- [gx, gy]
92
+ ts.reshape(-ts.reduce_sum(x_grad, ry), sy),]
204
93
  when :abs
205
94
  grad * ts.sign(x)
206
- when :log
207
- grad * ts.reciprocal(x)
208
- when :cos
209
- -grad * ts.sin(x)
210
- when :max
211
- _min_or_max_grad(node.inputs, grad, ->(a, b) { ts.greater_equal(a, b) })
212
- when :min
213
- _min_or_max_grad(node.inputs, grad, ->(a, b) { ts.less_equal(a, b) })
214
- when :tan
215
- secx = ts.reciprocal(ts.cos(x))
216
- secx2 = ts.square(secx)
217
- grad * secx2
218
- when :negate
219
- -grad
220
95
  when :exp
221
96
  grad * node
222
97
  when :identity, :print
223
98
  grad
224
- when :sign
225
- ts.zeros(ts.shape(x), dtype: x.data_type)
226
99
  when :tile
227
100
  input_shape = ts.shape(x)
228
101
  split_shape = ts.reshape(ts.transpose(ts.stack([y, input_shape])), [-1])
@@ -230,8 +103,6 @@ module TensorStream
230
103
  input_grad = ts.reduce_sum(ts.reshape(grad, split_shape), axes)
231
104
 
232
105
  [input_grad, nil]
233
- when :sum
234
- _sum_grad(x, y, grad)
235
106
  when :reciprocal
236
107
  -grad * (ts.constant(1, dtype: x.dtype) / x**2)
237
108
  when :sqrt
@@ -245,14 +116,6 @@ module TensorStream
245
116
  x_mask = i_op(:where, x, i_op(:ones_like, y), i_op(:zeros_like, z))
246
117
  y_mask = i_op(:where, x, i_op(:zeros_like, y), i_op(:ones_like, z))
247
118
  [nil, x_mask * grad, y_mask * grad]
248
- when :case
249
- n_preds = node.inputs.size - 2
250
-
251
- case_grads = Array.new(n_preds) do |index|
252
- i_op(:case_grad, index, node.inputs[0], node.inputs[2 + index], grad)
253
- end
254
-
255
- [nil, i_op(:case_grad, -1, node.inputs[0], node.inputs[1], grad)] + case_grads
256
119
  when :mean
257
120
  sum_grad = _sum_grad(x, y, grad)[0]
258
121
  input_shape = ts.shape(x)
@@ -261,8 +124,6 @@ module TensorStream
261
124
  [ts.div(sum_grad, ts.cast(factor, sum_grad.data_type)), nil]
262
125
  when :log1p
263
126
  grad * ts.reciprocal(i_cons(1, dtype: grad.data_type) + x)
264
- when :sigmoid
265
- i_op(:sigmoid_grad, x, grad)
266
127
  when :sigmoid_grad
267
128
  gb = grad * y
268
129
  [gb - 2.0 * gb * x, i_op(:sigmoid_grad, x, grad)]
@@ -275,15 +136,9 @@ module TensorStream
275
136
  when :sparse_softmax_cross_entropy_with_logits
276
137
  output = node
277
138
  [_broadcast_mul(grad, output[1]), nil]
278
- when :floor, :ceil, :round
279
- # non differentiable
280
- nil
281
- when :zeros_like
139
+ when :zeros_like
282
140
  # non differentiable
283
141
  nil
284
- when :argmin, :argmax, :floor_div
285
- # non differentiable
286
- [nil, nil]
287
142
  when :transpose
288
143
  return [ts.transpose(grad, ts.invert_permutation(y)), nil]
289
144
  when :index
@@ -294,19 +149,15 @@ module TensorStream
294
149
  multiplier = node.inputs[0].shape.shape[0]
295
150
  filler = ts.zeros_like(grad)
296
151
 
297
- res = Array.new(multiplier) do |index|
152
+ res = Array.new(multiplier) { |index|
298
153
  index == node.inputs[1].const_value ? grad : filler
299
- end
154
+ }
300
155
  [res]
301
156
  end
302
157
  when :squeeze
303
158
  _reshape_to_input(node, grad)
304
- when :expand_dims
305
- [_reshape_to_input(node, grad), nil]
306
159
  when :concat
307
160
  _concat_grad_helper(node, grad, 1, node.inputs.size, 0)
308
- when :reshape
309
- [ts.reshape(grad, ts.shape(node.inputs[0])), nil]
310
161
  when :stack
311
162
  res = ts.unstack(grad, num: node.inputs.size, axis: node.options[:axis])
312
163
  Array.new(node.inputs.size) { |i| res[i] }
@@ -314,18 +165,8 @@ module TensorStream
314
165
  ts.stack(grad, axis: node.options[:axis])
315
166
  when :conv2d
316
167
  _Conv2DGrad(node, grad)
317
- when :cast
318
- t = %i[float16 float32 float64]
319
- src_type = node.inputs[0].data_type
320
- dst_type = grad.data_type
321
-
322
- if t.key?(src_type) && t.key?(dst_type)
323
- ts.cast(grad, src_type)
324
- end
325
-
326
- nil
327
168
  else
328
- raise "no derivative op for #{node.operation}"
169
+ TensorStream::OpMaker.gradient_op(self, node, grad)
329
170
  end
330
171
  end
331
172
  end
@@ -373,8 +214,8 @@ module TensorStream
373
214
  zeros = ts.zeros(gradshape, dtype: gdtype)
374
215
  xmask = selector_op.call(x, y)
375
216
  rx, ry = _broadcast_gradient_args(sx, sy)
376
- xgrad = ts.where(xmask, grad, zeros, name: 'x')
377
- ygrad = ts.where(xmask, zeros, grad, name: 'y')
217
+ xgrad = ts.where(xmask, grad, zeros, name: "x")
218
+ ygrad = ts.where(xmask, zeros, grad, name: "y")
378
219
  gx = ts.reshape(ts.reduce_sum(xgrad, rx), sx)
379
220
  gy = ts.reshape(ts.reduce_sum(ygrad, ry), sy)
380
221
  [gx, gy]
@@ -435,22 +276,22 @@ module TensorStream
435
276
 
436
277
  shape_0, shape_1 = ts.shape_n([op.inputs[0], op.inputs[1]])
437
278
  [
438
- _op(:conv2d_backprop_input,
439
- shape_0,
440
- op.inputs[1],
441
- grad,
442
- strides: strides,
443
- padding: padding,
444
- use_cudnn_on_gpu: use_cudnn_on_gpu,
445
- data_format: data_format),
446
- _op(:conv2d_backprop_filter,
447
- op.inputs[0],
448
- shape_1,
449
- grad,
450
- strides: strides,
451
- padding: padding,
452
- use_cudnn_on_gpu: use_cudnn_on_gpu,
453
- data_format: data_format)
279
+ _op(:conv2d_backprop_input,
280
+ shape_0,
281
+ op.inputs[1],
282
+ grad,
283
+ strides: strides,
284
+ padding: padding,
285
+ use_cudnn_on_gpu: use_cudnn_on_gpu,
286
+ data_format: data_format),
287
+ _op(:conv2d_backprop_filter,
288
+ op.inputs[0],
289
+ shape_1,
290
+ grad,
291
+ strides: strides,
292
+ padding: padding,
293
+ use_cudnn_on_gpu: use_cudnn_on_gpu,
294
+ data_format: data_format),
454
295
  ]
455
296
  end
456
297
  end