tensor_stream 0.7.0 → 0.8.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +5 -5
- data/.rubocop.yml +6 -1
- data/CHANGELOG.md +10 -0
- data/README.md +35 -0
- data/lib/tensor_stream.rb +2 -2
- data/lib/tensor_stream/debugging/debugging.rb +2 -1
- data/lib/tensor_stream/dynamic_stitch.rb +23 -24
- data/lib/tensor_stream/evaluator/base_evaluator.rb +27 -18
- data/lib/tensor_stream/evaluator/opencl/kernels/apply_momentum.cl +16 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/pack.cl +24 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross.cl +6 -1
- data/lib/tensor_stream/evaluator/opencl/opencl_buffer.rb +6 -6
- data/lib/tensor_stream/evaluator/opencl/opencl_evaluator.rb +237 -107
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +97 -7
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +230 -123
- data/lib/tensor_stream/exceptions.rb +1 -0
- data/lib/tensor_stream/graph_builder.rb +2 -3
- data/lib/tensor_stream/graph_deserializers/protobuf.rb +22 -23
- data/lib/tensor_stream/graph_serializers/graphml.rb +26 -29
- data/lib/tensor_stream/graph_serializers/pbtext.rb +22 -19
- data/lib/tensor_stream/helpers/string_helper.rb +4 -5
- data/lib/tensor_stream/math_gradients.rb +141 -77
- data/lib/tensor_stream/nn/nn_ops.rb +4 -6
- data/lib/tensor_stream/operation.rb +139 -120
- data/lib/tensor_stream/ops.rb +36 -3
- data/lib/tensor_stream/session.rb +7 -11
- data/lib/tensor_stream/tensor.rb +3 -3
- data/lib/tensor_stream/tensor_shape.rb +5 -0
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +4 -37
- data/lib/tensor_stream/train/momentum_optimizer.rb +48 -0
- data/lib/tensor_stream/train/optimizer.rb +129 -0
- data/lib/tensor_stream/train/saver.rb +0 -1
- data/lib/tensor_stream/train/slot_creator.rb +62 -0
- data/lib/tensor_stream/train/utils.rb +11 -12
- data/lib/tensor_stream/trainer.rb +3 -0
- data/lib/tensor_stream/utils.rb +18 -11
- data/lib/tensor_stream/variable.rb +19 -12
- data/lib/tensor_stream/variable_scope.rb +1 -1
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +2 -1
- data/samples/linear_regression.rb +3 -1
- data/samples/nearest_neighbor.rb +2 -0
- data/test_samples/neural_network_raw.py +101 -0
- data/test_samples/raw_neural_net_sample.rb +6 -4
- data/test_samples/test2.py +73 -27
- metadata +9 -3
@@ -36,9 +36,8 @@ module TensorStream
|
|
36
36
|
TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
|
37
37
|
else
|
38
38
|
op = underscore(node['op']).to_sym
|
39
|
-
unless TensorStream::Evaluator::RubyEvaluator.ops.
|
40
|
-
|
41
|
-
end
|
39
|
+
puts "warning unsupported op #{op}" unless TensorStream::Evaluator::RubyEvaluator.ops.key?(op)
|
40
|
+
|
42
41
|
# map input tensor
|
43
42
|
inputs = node['input'].map do |input|
|
44
43
|
input[0] = '' if input.start_with?('^')
|
@@ -23,15 +23,15 @@ module TensorStream
|
|
23
23
|
end
|
24
24
|
|
25
25
|
def parse_value(value_node)
|
26
|
-
|
27
|
-
|
28
|
-
|
26
|
+
return unless value_node['tensor']
|
27
|
+
|
28
|
+
evaluate_tensor_node(value_node['tensor'])
|
29
29
|
end
|
30
30
|
|
31
31
|
def evaluate_tensor_node(node)
|
32
32
|
if !node['shape'].empty? && node['tensor_content']
|
33
33
|
content = node['tensor_content']
|
34
|
-
unpacked = eval(%Q
|
34
|
+
unpacked = eval(%Q("#{content}"))
|
35
35
|
|
36
36
|
if node['dtype'] == 'DT_FLOAT'
|
37
37
|
TensorShape.reshape(unpacked.unpack('f*'), node['shape'])
|
@@ -45,14 +45,14 @@ module TensorStream
|
|
45
45
|
else
|
46
46
|
|
47
47
|
val = if node['dtype'] == 'DT_FLOAT'
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
48
|
+
node['float_val'] ? node['float_val'].to_f : []
|
49
|
+
elsif node['dtype'] == 'DT_INT32'
|
50
|
+
node['int_val'] ? node['int_val'].to_i : []
|
51
|
+
elsif node['dtype'] == 'DT_STRING'
|
52
|
+
node['string_val']
|
53
|
+
else
|
54
|
+
raise "unknown dtype #{node['dtype']}"
|
55
|
+
end
|
56
56
|
|
57
57
|
if node['shape'] == [1]
|
58
58
|
[val]
|
@@ -83,7 +83,7 @@ module TensorStream
|
|
83
83
|
return {} if node['attributes'].nil?
|
84
84
|
|
85
85
|
node['attributes'].map do |attribute|
|
86
|
-
attr_type, attr_value = attribute['value'].
|
86
|
+
attr_type, attr_value = attribute['value'].flat_map { |k, v| [k, v] }
|
87
87
|
|
88
88
|
if attr_type == 'tensor'
|
89
89
|
attr_value = evaluate_tensor_node(attr_value)
|
@@ -103,11 +103,10 @@ module TensorStream
|
|
103
103
|
block = []
|
104
104
|
node = {}
|
105
105
|
node_attr = {}
|
106
|
-
dim = []
|
107
106
|
state = :top
|
108
107
|
|
109
108
|
lines.each do |str|
|
110
|
-
case
|
109
|
+
case state
|
111
110
|
when :top
|
112
111
|
node['type'] = parse_node_name(str)
|
113
112
|
state = :node_context
|
@@ -177,7 +176,7 @@ module TensorStream
|
|
177
176
|
next
|
178
177
|
else
|
179
178
|
key, value = str.split(':', 2)
|
180
|
-
node_attr['value'] << { key => value}
|
179
|
+
node_attr['value'] << { key => value }
|
181
180
|
end
|
182
181
|
when :tensor_context
|
183
182
|
if str == 'tensor_shape {'
|
@@ -219,7 +218,7 @@ module TensorStream
|
|
219
218
|
state = :shape_context
|
220
219
|
next
|
221
220
|
else
|
222
|
-
|
221
|
+
_key, value = str.split(':', 2)
|
223
222
|
node_attr['value']['shape'] << value.strip.to_i
|
224
223
|
end
|
225
224
|
when :tensor_shape_dim_context
|
@@ -227,7 +226,7 @@ module TensorStream
|
|
227
226
|
state = :tensor_shape_context
|
228
227
|
next
|
229
228
|
else
|
230
|
-
|
229
|
+
_key, value = str.split(':', 2)
|
231
230
|
node_attr['value']['tensor']['shape'] << value.strip.to_i
|
232
231
|
end
|
233
232
|
end
|
@@ -237,7 +236,7 @@ module TensorStream
|
|
237
236
|
end
|
238
237
|
|
239
238
|
def parse_node_name(str)
|
240
|
-
|
239
|
+
str.split(' ')[0]
|
241
240
|
end
|
242
241
|
|
243
242
|
def process_value(value)
|
@@ -253,19 +252,19 @@ module TensorStream
|
|
253
252
|
'n' => "\x0a", 'v' => "\x0b", 'f' => "\x0c",
|
254
253
|
'r' => "\x0d", 'e' => "\x1b", "\\\\" => "\x5c",
|
255
254
|
"\"" => "\x22", "'" => "\x27"
|
256
|
-
}
|
255
|
+
}.freeze
|
257
256
|
|
258
257
|
def unescape(str)
|
259
258
|
# Escape all the things
|
260
|
-
str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/)
|
259
|
+
str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/) do
|
261
260
|
if $1
|
262
261
|
$1 == '\\' ? '\\' : UNESCAPES[$1]
|
263
262
|
elsif $2 # escape \u0000 unicode
|
264
|
-
["
|
263
|
+
["#{$2}".hex].pack('U*')
|
265
264
|
elsif $3 # escape \0xff or \xff
|
266
265
|
[$3].pack('H2')
|
267
266
|
end
|
268
|
-
|
267
|
+
end
|
269
268
|
end
|
270
269
|
end
|
271
270
|
end
|
@@ -83,7 +83,7 @@ module TensorStream
|
|
83
83
|
arr_buf << '<y:GroupNode>'
|
84
84
|
arr_buf << '<y:Fill color="#CAECFF84" transparent="false"/>'
|
85
85
|
arr_buf << '<y:BorderStyle color="#666699" type="dotted" width="1.0"/>'
|
86
|
-
arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">'+ title + '</y:NodeLabel>'
|
86
|
+
arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">' + title + '</y:NodeLabel>'
|
87
87
|
arr_buf << '<y:Shape type="roundrectangle"/>'
|
88
88
|
arr_buf << '</y:GroupNode>'
|
89
89
|
arr_buf << '</y:Realizers>'
|
@@ -146,9 +146,9 @@ module TensorStream
|
|
146
146
|
input_buf << "<node id=\"#{_gml_string(input.name)}\">"
|
147
147
|
input_buf << "<data key=\"d0\">#{input.name}</data>"
|
148
148
|
input_buf << "<data key=\"d2\">green</data>"
|
149
|
-
|
150
|
-
|
151
|
-
|
149
|
+
|
150
|
+
input_buf << "<data key=\"d3\">#{_val(tensor)}</data>" if @last_session_context[input.name]
|
151
|
+
|
152
152
|
input_buf << "<data key=\"d9\">"
|
153
153
|
input_buf << "<y:ShapeNode>"
|
154
154
|
input_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
|
@@ -164,9 +164,9 @@ module TensorStream
|
|
164
164
|
input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
|
165
165
|
input_buf << "</y:ShapeNode>"
|
166
166
|
input_buf << "</data>"
|
167
|
-
|
168
|
-
|
169
|
-
|
167
|
+
|
168
|
+
input_buf << "<data key=\"d3\">#{_val(tensor)}</data>" if @last_session_context[input.name]
|
169
|
+
\
|
170
170
|
input_buf << "</node>"
|
171
171
|
elsif input.is_a?(Tensor)
|
172
172
|
input_buf << "<node id=\"#{_gml_string(input.name)}\">"
|
@@ -175,12 +175,11 @@ module TensorStream
|
|
175
175
|
input_buf << "<data key=\"d9\">"
|
176
176
|
input_buf << "<y:ShapeNode>"
|
177
177
|
|
178
|
-
if input.internal?
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
178
|
+
input_buf << if input.internal?
|
179
|
+
" <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
|
180
|
+
else
|
181
|
+
" <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
|
182
|
+
end
|
184
183
|
|
185
184
|
input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
|
186
185
|
|
@@ -189,7 +188,7 @@ module TensorStream
|
|
189
188
|
input_buf << "</node>"
|
190
189
|
end
|
191
190
|
|
192
|
-
|
191
|
+
unless add_to_group(groups, input.name, input_buf)
|
193
192
|
if input.is_a?(Variable)
|
194
193
|
add_to_group(groups, "variable/#{input.name}", input_buf)
|
195
194
|
else
|
@@ -205,7 +204,7 @@ module TensorStream
|
|
205
204
|
end
|
206
205
|
|
207
206
|
def _gml_string(str)
|
208
|
-
str.
|
207
|
+
str.tr('/', '-')
|
209
208
|
end
|
210
209
|
|
211
210
|
def output_edge(input, tensor, arr_buf, index = 0)
|
@@ -215,22 +214,20 @@ module TensorStream
|
|
215
214
|
|
216
215
|
arr_buf << "<y:PolyLineEdge>"
|
217
216
|
arr_buf << "<y:EdgeLabel >"
|
218
|
-
if !@last_session_context.empty?
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
end
|
226
|
-
end
|
217
|
+
arr_buf << if !@last_session_context.empty?
|
218
|
+
"<![CDATA[ #{_val(input)} ]]>"
|
219
|
+
elsif input.shape.shape.nil?
|
220
|
+
"<![CDATA[ #{input.data_type} ? ]]>"
|
221
|
+
else
|
222
|
+
"<![CDATA[ #{input.data_type} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
|
223
|
+
end
|
227
224
|
arr_buf << "</y:EdgeLabel >"
|
228
225
|
arr_buf << "<y:Arrows source=\"none\" target=\"standard\"/>"
|
229
|
-
if index
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
226
|
+
arr_buf << if index.zero?
|
227
|
+
"<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
|
228
|
+
else
|
229
|
+
"<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
|
230
|
+
end
|
234
231
|
arr_buf << "</y:PolyLineEdge>"
|
235
232
|
arr_buf << "</data>"
|
236
233
|
arr_buf << "</edge>"
|
@@ -1,4 +1,5 @@
|
|
1
1
|
module TensorStream
|
2
|
+
# Parses pbtext files and loads it as a graph
|
2
3
|
class Pbtext < TensorStream::Serializer
|
3
4
|
include TensorStream::StringHelper
|
4
5
|
include TensorStream::OpHelper
|
@@ -47,11 +48,11 @@ module TensorStream
|
|
47
48
|
@lines << " attr {"
|
48
49
|
@lines << " key: \"#{k}\""
|
49
50
|
@lines << " value {"
|
50
|
-
if
|
51
|
-
@lines << " b: #{v
|
52
|
-
elsif
|
51
|
+
if v.is_a?(TrueClass) || v.is_a?(FalseClass)
|
52
|
+
@lines << " b: #{v}"
|
53
|
+
elsif v.is_a?(Integer)
|
53
54
|
@lines << " int_val: #{v}"
|
54
|
-
elsif
|
55
|
+
elsif v.is_a?(Float)
|
55
56
|
@lines << " float_val: #{v}"
|
56
57
|
end
|
57
58
|
@lines << " }"
|
@@ -60,21 +61,23 @@ module TensorStream
|
|
60
61
|
end
|
61
62
|
|
62
63
|
def pack_arr_float(float_arr)
|
63
|
-
float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr
|
64
|
+
float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
|
64
65
|
end
|
65
66
|
|
66
67
|
def pack_arr_int(int_arr)
|
67
|
-
int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr
|
68
|
+
int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
|
68
69
|
end
|
69
70
|
|
70
71
|
def shape_buf(tensor, shape_type = 'tensor_shape')
|
71
72
|
arr = []
|
72
73
|
arr << " #{shape_type} {"
|
73
|
-
tensor.shape.shape
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
74
|
+
if tensor.shape.shape
|
75
|
+
tensor.shape.shape.each do |dim|
|
76
|
+
arr << " dim {"
|
77
|
+
arr << " size: #{dim}"
|
78
|
+
arr << " }"
|
79
|
+
end
|
80
|
+
end
|
78
81
|
arr << " }"
|
79
82
|
arr
|
80
83
|
end
|
@@ -102,14 +105,14 @@ module TensorStream
|
|
102
105
|
end
|
103
106
|
else
|
104
107
|
val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
108
|
+
"int_val"
|
109
|
+
elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
|
110
|
+
"float_val"
|
111
|
+
elsif tensor.data_type == :string
|
112
|
+
"string_val"
|
113
|
+
else
|
114
|
+
"val"
|
115
|
+
end
|
113
116
|
arr << " #{val_type}: #{tensor.value.to_json}"
|
114
117
|
end
|
115
118
|
arr << "}"
|
@@ -12,11 +12,10 @@ module TensorStream
|
|
12
12
|
end
|
13
13
|
|
14
14
|
def underscore(string)
|
15
|
-
string.gsub(/::/, '/')
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
downcase
|
15
|
+
string.gsub(/::/, '/')
|
16
|
+
.gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2')
|
17
|
+
.gsub(/([a-z\d])([A-Z])/, '\1_\2')
|
18
|
+
.tr("-", "_").downcase
|
20
19
|
end
|
21
20
|
|
22
21
|
def symbolize_keys(hash)
|
@@ -1,9 +1,10 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# Class that provides auto-differentiation
|
3
|
+
# Most gradients are ported over from tensorflow's math_grad.py
|
3
4
|
class MathGradients
|
4
5
|
extend TensorStream::OpHelper
|
5
6
|
|
6
|
-
def self.
|
7
|
+
def self.ts
|
7
8
|
TensorStream
|
8
9
|
end
|
9
10
|
|
@@ -16,7 +17,7 @@ module TensorStream
|
|
16
17
|
node.consumers.include?(tensor.name) || node.equal?(tensor)
|
17
18
|
end.compact + [wrt_dx.name]
|
18
19
|
|
19
|
-
grad = i_op(:fill,
|
20
|
+
grad = i_op(:fill, ts.shape(tensor), ts.constant(1, dtype: wrt_dx.data_type))
|
20
21
|
|
21
22
|
_propagate(grad, tensor, wrt_dx, nodes_to_compute, options[:stop_gradients] || []) || i_op(:zeros_like, wrt_dx)
|
22
23
|
end
|
@@ -41,6 +42,7 @@ module TensorStream
|
|
41
42
|
end
|
42
43
|
end
|
43
44
|
|
45
|
+
#TODO: refactor and implement registerGradient
|
44
46
|
def self._compute_derivative(node, grad)
|
45
47
|
node.graph.name_scope("#{node.name}_grad") do
|
46
48
|
x = node.inputs[0] if node.inputs[0]
|
@@ -51,116 +53,161 @@ module TensorStream
|
|
51
53
|
return [grad] * node.inputs.size
|
52
54
|
when :add
|
53
55
|
return [grad, grad] if shapes_fully_specified_and_equal(x, y)
|
54
|
-
sx =
|
55
|
-
sy =
|
56
|
+
sx = ts.shape(x, name: 'add/shape_x')
|
57
|
+
sy = ts.shape(y, name: 'add/shape_y')
|
56
58
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
57
59
|
|
58
|
-
[
|
59
|
-
|
60
|
+
[ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sum_x'), sx),
|
61
|
+
ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sum_y'), sy)]
|
60
62
|
when :asin
|
61
|
-
|
62
|
-
x2 =
|
63
|
-
one =
|
64
|
-
den =
|
65
|
-
inv =
|
63
|
+
ts.control_dependencies([grad]) do
|
64
|
+
x2 = ts.square(x)
|
65
|
+
one = ts.constant(1, dtype: grad.data_type)
|
66
|
+
den = ts.sqrt(ts.subtract(one, x2))
|
67
|
+
inv = ts.reciprocal(den)
|
66
68
|
grad * inv
|
67
69
|
end
|
68
70
|
when :acos
|
69
|
-
|
70
|
-
x2 =
|
71
|
-
one =
|
72
|
-
den =
|
73
|
-
inv =
|
71
|
+
ts.control_dependencies([grad]) do
|
72
|
+
x2 = ts.square(x)
|
73
|
+
one = ts.constant(1, dtype: grad.data_type)
|
74
|
+
den = ts.sqrt(ts.subtract(one, x2))
|
75
|
+
inv = ts.reciprocal(den)
|
74
76
|
-grad * inv
|
75
77
|
end
|
78
|
+
when :atan
|
79
|
+
ts.control_dependencies([grad]) do
|
80
|
+
x2 = ts.square(x)
|
81
|
+
one = ts.constant(1, dtype: grad.data_type)
|
82
|
+
inv = ts.reciprocal(ts.add(one, x2))
|
83
|
+
grad * inv
|
84
|
+
end
|
85
|
+
when :fill
|
86
|
+
[nil, ts.reduce_sum(grad)]
|
76
87
|
when :sub
|
77
88
|
return [grad, -grad] if shapes_fully_specified_and_equal(x, y)
|
78
89
|
|
79
|
-
sx =
|
80
|
-
sy =
|
90
|
+
sx = ts.shape(x, name: 'sub/shape_x')
|
91
|
+
sy = ts.shape(y, name: 'sub/shape_y')
|
81
92
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
82
93
|
|
83
|
-
[
|
84
|
-
-
|
94
|
+
[ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sub_x'), sx),
|
95
|
+
-ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sub_y'), sy)]
|
85
96
|
when :mul
|
86
|
-
sx =
|
87
|
-
sy =
|
97
|
+
sx = ts.shape(x)
|
98
|
+
sy = ts.shape(y)
|
88
99
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
89
100
|
|
90
|
-
[
|
91
|
-
|
101
|
+
[ts.reshape(ts.reduce_sum(ts.mul(grad, y), rx), sx),
|
102
|
+
ts.reshape(ts.reduce_sum(ts.mul(x, grad), ry), sy)]
|
92
103
|
when :div
|
93
104
|
sx = i_op(:shape, x)
|
94
105
|
sy = i_op(:shape, y)
|
95
106
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
96
107
|
|
97
|
-
[
|
98
|
-
|
108
|
+
[ts.reshape(ts.reduce_sum(ts.div(grad, y), rx), sx),
|
109
|
+
ts.reshape(ts.reduce_sum(grad * ts.div(ts.div(-x, y), y), ry), sy)]
|
99
110
|
when :mod
|
100
|
-
sx =
|
101
|
-
sy =
|
111
|
+
sx = ts.shape(x)
|
112
|
+
sy = ts.shape(y)
|
102
113
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
103
|
-
floor_xy =
|
104
|
-
gx =
|
105
|
-
gy =
|
114
|
+
floor_xy = ts.floor_div(x, y)
|
115
|
+
gx = ts.reshape(ts.reduce_sum(grad, rx), sx)
|
116
|
+
gy = ts.reshape(ts.reduce_sum(grad * ts.negative(floor_xy), ry), sy)
|
106
117
|
|
107
118
|
[gx, gy]
|
119
|
+
when :prod
|
120
|
+
input_shape = ts.shape(x)
|
121
|
+
y = ts.range(0, ts.rank(x)) if y.nil?
|
122
|
+
reduction_indices = ts.reshape(y, [-1])
|
123
|
+
|
124
|
+
output_shape_kept_dims = ts.reduced_shape(input_shape, y)
|
125
|
+
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
126
|
+
grad = ts.reshape(grad, output_shape_kept_dims)
|
127
|
+
grad = ts.tile(grad, tile_scaling)
|
128
|
+
|
129
|
+
perm, reduced_num, other_num = ts.device("/cpu:0") do
|
130
|
+
rank = ts.rank(x)
|
131
|
+
reduction_indices = (reduction_indices + rank) % rank
|
132
|
+
reduced = ts.cast(reduction_indices, :int32)
|
133
|
+
idx = ts.range(0, rank)
|
134
|
+
other, = ts.setdiff1d(idx, reduced)
|
135
|
+
|
136
|
+
[ts.concat([reduced, other], 0),
|
137
|
+
ts.reduce_prod(ts.gather(input_shape, reduced)),
|
138
|
+
ts.reduce_prod(ts.gather(input_shape, other))]
|
139
|
+
end
|
140
|
+
|
141
|
+
permuted = ts.transpose(x, perm)
|
142
|
+
permuted_shape = ts.shape(permuted)
|
143
|
+
|
144
|
+
reshaped = ts.reshape(permuted, [reduced_num, other_num])
|
145
|
+
|
146
|
+
# Calculate product, leaving out the current entry
|
147
|
+
left = ts.cumprod(reshaped, axis: 0, exclusive: true)
|
148
|
+
right = ts.cumprod(reshaped, axis: 0, exclusive: true, reverse: true)
|
149
|
+
y = ts.reshape(left * right, permuted_shape)
|
150
|
+
|
151
|
+
# Invert the transpose and reshape operations.
|
152
|
+
# Make sure to set the statically known shape information through a reshape.
|
153
|
+
out = grad * ts.transpose(y, ts.invert_permutation(perm))
|
154
|
+
[ts.reshape(out, input_shape, name: 'prod'), nil]
|
108
155
|
when :squared_difference
|
109
156
|
sx = i_op(:shape, x)
|
110
157
|
sy = i_op(:shape, y)
|
111
158
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
112
159
|
|
113
|
-
x_grad =
|
160
|
+
x_grad = ts.mul(2.0, grad) * (x - y)
|
114
161
|
|
115
|
-
[
|
116
|
-
|
162
|
+
[ts.reshape(ts.reduce_sum(x_grad, rx), sx),
|
163
|
+
ts.reshape(-ts.reduce_sum(x_grad, ry), sy)]
|
117
164
|
when :mat_mul
|
118
165
|
t_a = node.options[:transpose_a]
|
119
166
|
t_b = node.options[:transpose_b]
|
120
167
|
|
121
168
|
if !t_a && !t_b
|
122
|
-
grad_a =
|
123
|
-
grad_b =
|
169
|
+
grad_a = ts.matmul(grad, y, transpose_b: true)
|
170
|
+
grad_b = ts.matmul(x, grad, transpose_a: true)
|
124
171
|
elsif !ta && tb
|
125
|
-
grad_a =
|
126
|
-
grad_b =
|
172
|
+
grad_a = ts.matmul(grad, y)
|
173
|
+
grad_b = ts.matmul(grad, x, transpose_a: true)
|
127
174
|
elsif t_a && !t_b
|
128
|
-
grad_a =
|
129
|
-
grad_b =
|
175
|
+
grad_a = ts.matmul(y, grad, transpose_b: true)
|
176
|
+
grad_b = ts.matmul(x, grad)
|
130
177
|
elsif t_a && t_b
|
131
|
-
grad_a =
|
132
|
-
grad_b =
|
178
|
+
grad_a = ts.matmul(y, grad, transpose_a: true, transpose_b: true)
|
179
|
+
grad_b = ts.matmul(grad, x, transpose_a: true, transpose_b: true)
|
133
180
|
end
|
134
181
|
|
135
182
|
[grad_a, grad_b]
|
136
183
|
when :sin
|
137
|
-
grad *
|
184
|
+
grad * ts.cos(x)
|
138
185
|
when :tanh
|
139
186
|
grad * i_op(:tanh_grad, x)
|
140
187
|
when :pow
|
141
188
|
z = node
|
142
|
-
sx =
|
143
|
-
sy =
|
189
|
+
sx = ts.shape(x)
|
190
|
+
sy = ts.shape(y)
|
144
191
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
145
|
-
gx =
|
192
|
+
gx = ts.reduce_sum(grad * y * ts.pow(x, y - 1), rx)
|
146
193
|
|
147
|
-
log_x =
|
148
|
-
gy =
|
194
|
+
log_x = ts.where(x > 0, ts.log(x), ts.zeros_like(x))
|
195
|
+
gy = ts.reduce_sum(grad * z * log_x, ry)
|
149
196
|
|
150
197
|
[gx, gy]
|
151
198
|
when :abs
|
152
|
-
grad *
|
199
|
+
grad * ts.sign(x)
|
153
200
|
when :log
|
154
|
-
grad *
|
201
|
+
grad * ts.reciprocal(x)
|
155
202
|
when :cos
|
156
|
-
-grad *
|
203
|
+
-grad * ts.sin(x)
|
157
204
|
when :max
|
158
|
-
_min_or_max_grad(node.inputs, grad, ->(
|
205
|
+
_min_or_max_grad(node.inputs, grad, ->(a, b) { ts.greater_equal(a, b) })
|
159
206
|
when :min
|
160
|
-
_min_or_max_grad(node.inputs, grad, ->(
|
207
|
+
_min_or_max_grad(node.inputs, grad, ->(a, b) { ts.less_equal(a, b) })
|
161
208
|
when :tan
|
162
|
-
secx =
|
163
|
-
secx2 =
|
209
|
+
secx = ts.reciprocal(ts.cos(x))
|
210
|
+
secx2 = ts.square(secx)
|
164
211
|
grad * secx2
|
165
212
|
when :negate
|
166
213
|
-grad
|
@@ -169,18 +216,25 @@ module TensorStream
|
|
169
216
|
when :identity, :print
|
170
217
|
grad
|
171
218
|
when :sign
|
172
|
-
|
219
|
+
ts.zeros(ts.shape(x), dtype: x.data_type)
|
220
|
+
when :tile
|
221
|
+
input_shape = ts.shape(x)
|
222
|
+
split_shape = ts.reshape(ts.transpose(ts.stack([y, input_shape])), [-1])
|
223
|
+
axes = ts.range(0, ts.size(split_shape), 2)
|
224
|
+
input_grad = ts.reduce_sum(ts.reshape(grad, split_shape), axes)
|
225
|
+
|
226
|
+
[input_grad, nil]
|
173
227
|
when :sum
|
174
228
|
_sum_grad(x, y, grad)
|
175
229
|
when :reciprocal
|
176
|
-
-grad * (
|
230
|
+
-grad * (ts.constant(1, dtype: x.dtype) / x**2)
|
177
231
|
when :sqrt
|
178
|
-
|
232
|
+
ts.constant(1, dtype: x.dtype) / (ts.constant(2, dtype: x.dtype) * ts.sqrt(x)) * grad
|
179
233
|
when :stop_gradient
|
180
|
-
|
234
|
+
ts.zeros_like(grad)
|
181
235
|
when :square
|
182
|
-
y =
|
183
|
-
|
236
|
+
y = ts.constant(2.0, dtype: x.dtype)
|
237
|
+
ts.multiply(grad, ts.multiply(x, y))
|
184
238
|
when :where
|
185
239
|
x_mask = i_op(:where, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
|
186
240
|
y_mask = i_op(:where, i_op(:zeros_like, x), i_op(:ones_like, y), pred: node.options[:pred])
|
@@ -191,12 +245,12 @@ module TensorStream
|
|
191
245
|
[x_cond * grad, y_cond * grad]
|
192
246
|
when :mean
|
193
247
|
sum_grad = _sum_grad(x, y, grad)[0]
|
194
|
-
input_shape =
|
195
|
-
output_shape =
|
196
|
-
factor = _safe_shape_div(
|
197
|
-
|
248
|
+
input_shape = ts.shape(x)
|
249
|
+
output_shape = ts.shape(node)
|
250
|
+
factor = _safe_shape_div(ts.reduce_prod(input_shape), ts.reduce_prod(output_shape))
|
251
|
+
[ts.div(sum_grad, ts.cast(factor, sum_grad.data_type)), nil]
|
198
252
|
when :log1p
|
199
|
-
grad *
|
253
|
+
grad * ts.reciprocal(i_cons(1, dtype: grad.data_type) + x)
|
200
254
|
when :sigmoid
|
201
255
|
i_op(:sigmoid_grad, x, grad)
|
202
256
|
when :sigmoid_grad
|
@@ -205,7 +259,8 @@ module TensorStream
|
|
205
259
|
when :softmax
|
206
260
|
i_op(:softmax_grad, x, grad)
|
207
261
|
when :softmax_cross_entropy_with_logits_v2
|
208
|
-
|
262
|
+
output = node
|
263
|
+
[_broadcast_mul(grad, output[1]), nil]
|
209
264
|
when :floor, :ceil
|
210
265
|
# non differentiable
|
211
266
|
nil
|
@@ -215,6 +270,10 @@ module TensorStream
|
|
215
270
|
when :argmin, :argmax, :floor_div
|
216
271
|
# non differentiable
|
217
272
|
[nil, nil]
|
273
|
+
when :transpose
|
274
|
+
return [ts.transpose(grad, ts.invert_permutation(y)), nil]
|
275
|
+
when :index
|
276
|
+
grad
|
218
277
|
else
|
219
278
|
raise "no derivative op for #{node.operation}"
|
220
279
|
end
|
@@ -231,12 +290,12 @@ module TensorStream
|
|
231
290
|
end
|
232
291
|
|
233
292
|
def self._safe_shape_div(arg_x, arg_y)
|
234
|
-
_op(:floor_div, arg_x,
|
293
|
+
_op(:floor_div, arg_x, ts.maximum(arg_y, 1))
|
235
294
|
end
|
236
295
|
|
237
296
|
def self._sum_grad(arg_x, arg_y, grad)
|
238
297
|
input_shape = _op(:shape, arg_x)
|
239
|
-
output_shape_kept_dims =
|
298
|
+
output_shape_kept_dims = ts.reduced_shape(input_shape, arg_y)
|
240
299
|
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
241
300
|
new_grad = _op(:reshape, grad, output_shape_kept_dims)
|
242
301
|
|
@@ -254,19 +313,24 @@ module TensorStream
|
|
254
313
|
x = inputs[0]
|
255
314
|
y = inputs[1]
|
256
315
|
gdtype = grad.data_type
|
257
|
-
sx =
|
258
|
-
sy =
|
259
|
-
gradshape =
|
260
|
-
zeros =
|
316
|
+
sx = ts.shape(x)
|
317
|
+
sy = ts.shape(y)
|
318
|
+
gradshape = ts.shape(grad)
|
319
|
+
zeros = ts.zeros(gradshape, dtype: gdtype)
|
261
320
|
xmask = selector_op.call(x, y)
|
262
321
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
263
|
-
xgrad =
|
264
|
-
ygrad =
|
265
|
-
gx =
|
266
|
-
gy =
|
322
|
+
xgrad = ts.where(xmask, grad, zeros, name: 'x')
|
323
|
+
ygrad = ts.where(xmask, zeros, grad, name: 'y')
|
324
|
+
gx = ts.reshape(ts.reduce_sum(xgrad, rx), sx)
|
325
|
+
gy = ts.reshape(ts.reduce_sum(ygrad, ry), sy)
|
267
326
|
[gx, gy]
|
268
327
|
end
|
269
328
|
|
329
|
+
def self._broadcast_mul(vec, mat)
|
330
|
+
vec = ts.expand_dims(vec, -1)
|
331
|
+
vec * mat
|
332
|
+
end
|
333
|
+
|
270
334
|
def self._include?(arr, obj)
|
271
335
|
arr.each { |a| return true if a.equal?(obj) }
|
272
336
|
false
|