tensor_stream 0.7.0 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +5 -5
- data/.rubocop.yml +6 -1
- data/CHANGELOG.md +10 -0
- data/README.md +35 -0
- data/lib/tensor_stream.rb +2 -2
- data/lib/tensor_stream/debugging/debugging.rb +2 -1
- data/lib/tensor_stream/dynamic_stitch.rb +23 -24
- data/lib/tensor_stream/evaluator/base_evaluator.rb +27 -18
- data/lib/tensor_stream/evaluator/opencl/kernels/apply_momentum.cl +16 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/pack.cl +24 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross.cl +6 -1
- data/lib/tensor_stream/evaluator/opencl/opencl_buffer.rb +6 -6
- data/lib/tensor_stream/evaluator/opencl/opencl_evaluator.rb +237 -107
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +97 -7
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +230 -123
- data/lib/tensor_stream/exceptions.rb +1 -0
- data/lib/tensor_stream/graph_builder.rb +2 -3
- data/lib/tensor_stream/graph_deserializers/protobuf.rb +22 -23
- data/lib/tensor_stream/graph_serializers/graphml.rb +26 -29
- data/lib/tensor_stream/graph_serializers/pbtext.rb +22 -19
- data/lib/tensor_stream/helpers/string_helper.rb +4 -5
- data/lib/tensor_stream/math_gradients.rb +141 -77
- data/lib/tensor_stream/nn/nn_ops.rb +4 -6
- data/lib/tensor_stream/operation.rb +139 -120
- data/lib/tensor_stream/ops.rb +36 -3
- data/lib/tensor_stream/session.rb +7 -11
- data/lib/tensor_stream/tensor.rb +3 -3
- data/lib/tensor_stream/tensor_shape.rb +5 -0
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +4 -37
- data/lib/tensor_stream/train/momentum_optimizer.rb +48 -0
- data/lib/tensor_stream/train/optimizer.rb +129 -0
- data/lib/tensor_stream/train/saver.rb +0 -1
- data/lib/tensor_stream/train/slot_creator.rb +62 -0
- data/lib/tensor_stream/train/utils.rb +11 -12
- data/lib/tensor_stream/trainer.rb +3 -0
- data/lib/tensor_stream/utils.rb +18 -11
- data/lib/tensor_stream/variable.rb +19 -12
- data/lib/tensor_stream/variable_scope.rb +1 -1
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +2 -1
- data/samples/linear_regression.rb +3 -1
- data/samples/nearest_neighbor.rb +2 -0
- data/test_samples/neural_network_raw.py +101 -0
- data/test_samples/raw_neural_net_sample.rb +6 -4
- data/test_samples/test2.py +73 -27
- metadata +9 -3
@@ -36,9 +36,8 @@ module TensorStream
|
|
36
36
|
TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
|
37
37
|
else
|
38
38
|
op = underscore(node['op']).to_sym
|
39
|
-
unless TensorStream::Evaluator::RubyEvaluator.ops.
|
40
|
-
|
41
|
-
end
|
39
|
+
puts "warning unsupported op #{op}" unless TensorStream::Evaluator::RubyEvaluator.ops.key?(op)
|
40
|
+
|
42
41
|
# map input tensor
|
43
42
|
inputs = node['input'].map do |input|
|
44
43
|
input[0] = '' if input.start_with?('^')
|
@@ -23,15 +23,15 @@ module TensorStream
|
|
23
23
|
end
|
24
24
|
|
25
25
|
def parse_value(value_node)
|
26
|
-
|
27
|
-
|
28
|
-
|
26
|
+
return unless value_node['tensor']
|
27
|
+
|
28
|
+
evaluate_tensor_node(value_node['tensor'])
|
29
29
|
end
|
30
30
|
|
31
31
|
def evaluate_tensor_node(node)
|
32
32
|
if !node['shape'].empty? && node['tensor_content']
|
33
33
|
content = node['tensor_content']
|
34
|
-
unpacked = eval(%Q
|
34
|
+
unpacked = eval(%Q("#{content}"))
|
35
35
|
|
36
36
|
if node['dtype'] == 'DT_FLOAT'
|
37
37
|
TensorShape.reshape(unpacked.unpack('f*'), node['shape'])
|
@@ -45,14 +45,14 @@ module TensorStream
|
|
45
45
|
else
|
46
46
|
|
47
47
|
val = if node['dtype'] == 'DT_FLOAT'
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
48
|
+
node['float_val'] ? node['float_val'].to_f : []
|
49
|
+
elsif node['dtype'] == 'DT_INT32'
|
50
|
+
node['int_val'] ? node['int_val'].to_i : []
|
51
|
+
elsif node['dtype'] == 'DT_STRING'
|
52
|
+
node['string_val']
|
53
|
+
else
|
54
|
+
raise "unknown dtype #{node['dtype']}"
|
55
|
+
end
|
56
56
|
|
57
57
|
if node['shape'] == [1]
|
58
58
|
[val]
|
@@ -83,7 +83,7 @@ module TensorStream
|
|
83
83
|
return {} if node['attributes'].nil?
|
84
84
|
|
85
85
|
node['attributes'].map do |attribute|
|
86
|
-
attr_type, attr_value = attribute['value'].
|
86
|
+
attr_type, attr_value = attribute['value'].flat_map { |k, v| [k, v] }
|
87
87
|
|
88
88
|
if attr_type == 'tensor'
|
89
89
|
attr_value = evaluate_tensor_node(attr_value)
|
@@ -103,11 +103,10 @@ module TensorStream
|
|
103
103
|
block = []
|
104
104
|
node = {}
|
105
105
|
node_attr = {}
|
106
|
-
dim = []
|
107
106
|
state = :top
|
108
107
|
|
109
108
|
lines.each do |str|
|
110
|
-
case
|
109
|
+
case state
|
111
110
|
when :top
|
112
111
|
node['type'] = parse_node_name(str)
|
113
112
|
state = :node_context
|
@@ -177,7 +176,7 @@ module TensorStream
|
|
177
176
|
next
|
178
177
|
else
|
179
178
|
key, value = str.split(':', 2)
|
180
|
-
node_attr['value'] << { key => value}
|
179
|
+
node_attr['value'] << { key => value }
|
181
180
|
end
|
182
181
|
when :tensor_context
|
183
182
|
if str == 'tensor_shape {'
|
@@ -219,7 +218,7 @@ module TensorStream
|
|
219
218
|
state = :shape_context
|
220
219
|
next
|
221
220
|
else
|
222
|
-
|
221
|
+
_key, value = str.split(':', 2)
|
223
222
|
node_attr['value']['shape'] << value.strip.to_i
|
224
223
|
end
|
225
224
|
when :tensor_shape_dim_context
|
@@ -227,7 +226,7 @@ module TensorStream
|
|
227
226
|
state = :tensor_shape_context
|
228
227
|
next
|
229
228
|
else
|
230
|
-
|
229
|
+
_key, value = str.split(':', 2)
|
231
230
|
node_attr['value']['tensor']['shape'] << value.strip.to_i
|
232
231
|
end
|
233
232
|
end
|
@@ -237,7 +236,7 @@ module TensorStream
|
|
237
236
|
end
|
238
237
|
|
239
238
|
def parse_node_name(str)
|
240
|
-
|
239
|
+
str.split(' ')[0]
|
241
240
|
end
|
242
241
|
|
243
242
|
def process_value(value)
|
@@ -253,19 +252,19 @@ module TensorStream
|
|
253
252
|
'n' => "\x0a", 'v' => "\x0b", 'f' => "\x0c",
|
254
253
|
'r' => "\x0d", 'e' => "\x1b", "\\\\" => "\x5c",
|
255
254
|
"\"" => "\x22", "'" => "\x27"
|
256
|
-
}
|
255
|
+
}.freeze
|
257
256
|
|
258
257
|
def unescape(str)
|
259
258
|
# Escape all the things
|
260
|
-
str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/)
|
259
|
+
str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/) do
|
261
260
|
if $1
|
262
261
|
$1 == '\\' ? '\\' : UNESCAPES[$1]
|
263
262
|
elsif $2 # escape \u0000 unicode
|
264
|
-
["
|
263
|
+
["#{$2}".hex].pack('U*')
|
265
264
|
elsif $3 # escape \0xff or \xff
|
266
265
|
[$3].pack('H2')
|
267
266
|
end
|
268
|
-
|
267
|
+
end
|
269
268
|
end
|
270
269
|
end
|
271
270
|
end
|
@@ -83,7 +83,7 @@ module TensorStream
|
|
83
83
|
arr_buf << '<y:GroupNode>'
|
84
84
|
arr_buf << '<y:Fill color="#CAECFF84" transparent="false"/>'
|
85
85
|
arr_buf << '<y:BorderStyle color="#666699" type="dotted" width="1.0"/>'
|
86
|
-
arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">'+ title + '</y:NodeLabel>'
|
86
|
+
arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">' + title + '</y:NodeLabel>'
|
87
87
|
arr_buf << '<y:Shape type="roundrectangle"/>'
|
88
88
|
arr_buf << '</y:GroupNode>'
|
89
89
|
arr_buf << '</y:Realizers>'
|
@@ -146,9 +146,9 @@ module TensorStream
|
|
146
146
|
input_buf << "<node id=\"#{_gml_string(input.name)}\">"
|
147
147
|
input_buf << "<data key=\"d0\">#{input.name}</data>"
|
148
148
|
input_buf << "<data key=\"d2\">green</data>"
|
149
|
-
|
150
|
-
|
151
|
-
|
149
|
+
|
150
|
+
input_buf << "<data key=\"d3\">#{_val(tensor)}</data>" if @last_session_context[input.name]
|
151
|
+
|
152
152
|
input_buf << "<data key=\"d9\">"
|
153
153
|
input_buf << "<y:ShapeNode>"
|
154
154
|
input_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
|
@@ -164,9 +164,9 @@ module TensorStream
|
|
164
164
|
input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
|
165
165
|
input_buf << "</y:ShapeNode>"
|
166
166
|
input_buf << "</data>"
|
167
|
-
|
168
|
-
|
169
|
-
|
167
|
+
|
168
|
+
input_buf << "<data key=\"d3\">#{_val(tensor)}</data>" if @last_session_context[input.name]
|
169
|
+
\
|
170
170
|
input_buf << "</node>"
|
171
171
|
elsif input.is_a?(Tensor)
|
172
172
|
input_buf << "<node id=\"#{_gml_string(input.name)}\">"
|
@@ -175,12 +175,11 @@ module TensorStream
|
|
175
175
|
input_buf << "<data key=\"d9\">"
|
176
176
|
input_buf << "<y:ShapeNode>"
|
177
177
|
|
178
|
-
if input.internal?
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
178
|
+
input_buf << if input.internal?
|
179
|
+
" <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
|
180
|
+
else
|
181
|
+
" <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
|
182
|
+
end
|
184
183
|
|
185
184
|
input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
|
186
185
|
|
@@ -189,7 +188,7 @@ module TensorStream
|
|
189
188
|
input_buf << "</node>"
|
190
189
|
end
|
191
190
|
|
192
|
-
|
191
|
+
unless add_to_group(groups, input.name, input_buf)
|
193
192
|
if input.is_a?(Variable)
|
194
193
|
add_to_group(groups, "variable/#{input.name}", input_buf)
|
195
194
|
else
|
@@ -205,7 +204,7 @@ module TensorStream
|
|
205
204
|
end
|
206
205
|
|
207
206
|
def _gml_string(str)
|
208
|
-
str.
|
207
|
+
str.tr('/', '-')
|
209
208
|
end
|
210
209
|
|
211
210
|
def output_edge(input, tensor, arr_buf, index = 0)
|
@@ -215,22 +214,20 @@ module TensorStream
|
|
215
214
|
|
216
215
|
arr_buf << "<y:PolyLineEdge>"
|
217
216
|
arr_buf << "<y:EdgeLabel >"
|
218
|
-
if !@last_session_context.empty?
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
end
|
226
|
-
end
|
217
|
+
arr_buf << if !@last_session_context.empty?
|
218
|
+
"<![CDATA[ #{_val(input)} ]]>"
|
219
|
+
elsif input.shape.shape.nil?
|
220
|
+
"<![CDATA[ #{input.data_type} ? ]]>"
|
221
|
+
else
|
222
|
+
"<![CDATA[ #{input.data_type} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
|
223
|
+
end
|
227
224
|
arr_buf << "</y:EdgeLabel >"
|
228
225
|
arr_buf << "<y:Arrows source=\"none\" target=\"standard\"/>"
|
229
|
-
if index
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
226
|
+
arr_buf << if index.zero?
|
227
|
+
"<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
|
228
|
+
else
|
229
|
+
"<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
|
230
|
+
end
|
234
231
|
arr_buf << "</y:PolyLineEdge>"
|
235
232
|
arr_buf << "</data>"
|
236
233
|
arr_buf << "</edge>"
|
@@ -1,4 +1,5 @@
|
|
1
1
|
module TensorStream
|
2
|
+
# Parses pbtext files and loads it as a graph
|
2
3
|
class Pbtext < TensorStream::Serializer
|
3
4
|
include TensorStream::StringHelper
|
4
5
|
include TensorStream::OpHelper
|
@@ -47,11 +48,11 @@ module TensorStream
|
|
47
48
|
@lines << " attr {"
|
48
49
|
@lines << " key: \"#{k}\""
|
49
50
|
@lines << " value {"
|
50
|
-
if
|
51
|
-
@lines << " b: #{v
|
52
|
-
elsif
|
51
|
+
if v.is_a?(TrueClass) || v.is_a?(FalseClass)
|
52
|
+
@lines << " b: #{v}"
|
53
|
+
elsif v.is_a?(Integer)
|
53
54
|
@lines << " int_val: #{v}"
|
54
|
-
elsif
|
55
|
+
elsif v.is_a?(Float)
|
55
56
|
@lines << " float_val: #{v}"
|
56
57
|
end
|
57
58
|
@lines << " }"
|
@@ -60,21 +61,23 @@ module TensorStream
|
|
60
61
|
end
|
61
62
|
|
62
63
|
def pack_arr_float(float_arr)
|
63
|
-
float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr
|
64
|
+
float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
|
64
65
|
end
|
65
66
|
|
66
67
|
def pack_arr_int(int_arr)
|
67
|
-
int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr
|
68
|
+
int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
|
68
69
|
end
|
69
70
|
|
70
71
|
def shape_buf(tensor, shape_type = 'tensor_shape')
|
71
72
|
arr = []
|
72
73
|
arr << " #{shape_type} {"
|
73
|
-
tensor.shape.shape
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
74
|
+
if tensor.shape.shape
|
75
|
+
tensor.shape.shape.each do |dim|
|
76
|
+
arr << " dim {"
|
77
|
+
arr << " size: #{dim}"
|
78
|
+
arr << " }"
|
79
|
+
end
|
80
|
+
end
|
78
81
|
arr << " }"
|
79
82
|
arr
|
80
83
|
end
|
@@ -102,14 +105,14 @@ module TensorStream
|
|
102
105
|
end
|
103
106
|
else
|
104
107
|
val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
108
|
+
"int_val"
|
109
|
+
elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
|
110
|
+
"float_val"
|
111
|
+
elsif tensor.data_type == :string
|
112
|
+
"string_val"
|
113
|
+
else
|
114
|
+
"val"
|
115
|
+
end
|
113
116
|
arr << " #{val_type}: #{tensor.value.to_json}"
|
114
117
|
end
|
115
118
|
arr << "}"
|
@@ -12,11 +12,10 @@ module TensorStream
|
|
12
12
|
end
|
13
13
|
|
14
14
|
def underscore(string)
|
15
|
-
string.gsub(/::/, '/')
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
downcase
|
15
|
+
string.gsub(/::/, '/')
|
16
|
+
.gsub(/([A-Z]+)([A-Z][a-z])/, '\1_\2')
|
17
|
+
.gsub(/([a-z\d])([A-Z])/, '\1_\2')
|
18
|
+
.tr("-", "_").downcase
|
20
19
|
end
|
21
20
|
|
22
21
|
def symbolize_keys(hash)
|
@@ -1,9 +1,10 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# Class that provides auto-differentiation
|
3
|
+
# Most gradients are ported over from tensorflow's math_grad.py
|
3
4
|
class MathGradients
|
4
5
|
extend TensorStream::OpHelper
|
5
6
|
|
6
|
-
def self.
|
7
|
+
def self.ts
|
7
8
|
TensorStream
|
8
9
|
end
|
9
10
|
|
@@ -16,7 +17,7 @@ module TensorStream
|
|
16
17
|
node.consumers.include?(tensor.name) || node.equal?(tensor)
|
17
18
|
end.compact + [wrt_dx.name]
|
18
19
|
|
19
|
-
grad = i_op(:fill,
|
20
|
+
grad = i_op(:fill, ts.shape(tensor), ts.constant(1, dtype: wrt_dx.data_type))
|
20
21
|
|
21
22
|
_propagate(grad, tensor, wrt_dx, nodes_to_compute, options[:stop_gradients] || []) || i_op(:zeros_like, wrt_dx)
|
22
23
|
end
|
@@ -41,6 +42,7 @@ module TensorStream
|
|
41
42
|
end
|
42
43
|
end
|
43
44
|
|
45
|
+
#TODO: refactor and implement registerGradient
|
44
46
|
def self._compute_derivative(node, grad)
|
45
47
|
node.graph.name_scope("#{node.name}_grad") do
|
46
48
|
x = node.inputs[0] if node.inputs[0]
|
@@ -51,116 +53,161 @@ module TensorStream
|
|
51
53
|
return [grad] * node.inputs.size
|
52
54
|
when :add
|
53
55
|
return [grad, grad] if shapes_fully_specified_and_equal(x, y)
|
54
|
-
sx =
|
55
|
-
sy =
|
56
|
+
sx = ts.shape(x, name: 'add/shape_x')
|
57
|
+
sy = ts.shape(y, name: 'add/shape_y')
|
56
58
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
57
59
|
|
58
|
-
[
|
59
|
-
|
60
|
+
[ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sum_x'), sx),
|
61
|
+
ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sum_y'), sy)]
|
60
62
|
when :asin
|
61
|
-
|
62
|
-
x2 =
|
63
|
-
one =
|
64
|
-
den =
|
65
|
-
inv =
|
63
|
+
ts.control_dependencies([grad]) do
|
64
|
+
x2 = ts.square(x)
|
65
|
+
one = ts.constant(1, dtype: grad.data_type)
|
66
|
+
den = ts.sqrt(ts.subtract(one, x2))
|
67
|
+
inv = ts.reciprocal(den)
|
66
68
|
grad * inv
|
67
69
|
end
|
68
70
|
when :acos
|
69
|
-
|
70
|
-
x2 =
|
71
|
-
one =
|
72
|
-
den =
|
73
|
-
inv =
|
71
|
+
ts.control_dependencies([grad]) do
|
72
|
+
x2 = ts.square(x)
|
73
|
+
one = ts.constant(1, dtype: grad.data_type)
|
74
|
+
den = ts.sqrt(ts.subtract(one, x2))
|
75
|
+
inv = ts.reciprocal(den)
|
74
76
|
-grad * inv
|
75
77
|
end
|
78
|
+
when :atan
|
79
|
+
ts.control_dependencies([grad]) do
|
80
|
+
x2 = ts.square(x)
|
81
|
+
one = ts.constant(1, dtype: grad.data_type)
|
82
|
+
inv = ts.reciprocal(ts.add(one, x2))
|
83
|
+
grad * inv
|
84
|
+
end
|
85
|
+
when :fill
|
86
|
+
[nil, ts.reduce_sum(grad)]
|
76
87
|
when :sub
|
77
88
|
return [grad, -grad] if shapes_fully_specified_and_equal(x, y)
|
78
89
|
|
79
|
-
sx =
|
80
|
-
sy =
|
90
|
+
sx = ts.shape(x, name: 'sub/shape_x')
|
91
|
+
sy = ts.shape(y, name: 'sub/shape_y')
|
81
92
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
82
93
|
|
83
|
-
[
|
84
|
-
-
|
94
|
+
[ts.reshape(ts.reduce_sum(grad, rx, name: 'add/reduce_sub_x'), sx),
|
95
|
+
-ts.reshape(ts.reduce_sum(grad, ry, name: 'add/reduce_sub_y'), sy)]
|
85
96
|
when :mul
|
86
|
-
sx =
|
87
|
-
sy =
|
97
|
+
sx = ts.shape(x)
|
98
|
+
sy = ts.shape(y)
|
88
99
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
89
100
|
|
90
|
-
[
|
91
|
-
|
101
|
+
[ts.reshape(ts.reduce_sum(ts.mul(grad, y), rx), sx),
|
102
|
+
ts.reshape(ts.reduce_sum(ts.mul(x, grad), ry), sy)]
|
92
103
|
when :div
|
93
104
|
sx = i_op(:shape, x)
|
94
105
|
sy = i_op(:shape, y)
|
95
106
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
96
107
|
|
97
|
-
[
|
98
|
-
|
108
|
+
[ts.reshape(ts.reduce_sum(ts.div(grad, y), rx), sx),
|
109
|
+
ts.reshape(ts.reduce_sum(grad * ts.div(ts.div(-x, y), y), ry), sy)]
|
99
110
|
when :mod
|
100
|
-
sx =
|
101
|
-
sy =
|
111
|
+
sx = ts.shape(x)
|
112
|
+
sy = ts.shape(y)
|
102
113
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
103
|
-
floor_xy =
|
104
|
-
gx =
|
105
|
-
gy =
|
114
|
+
floor_xy = ts.floor_div(x, y)
|
115
|
+
gx = ts.reshape(ts.reduce_sum(grad, rx), sx)
|
116
|
+
gy = ts.reshape(ts.reduce_sum(grad * ts.negative(floor_xy), ry), sy)
|
106
117
|
|
107
118
|
[gx, gy]
|
119
|
+
when :prod
|
120
|
+
input_shape = ts.shape(x)
|
121
|
+
y = ts.range(0, ts.rank(x)) if y.nil?
|
122
|
+
reduction_indices = ts.reshape(y, [-1])
|
123
|
+
|
124
|
+
output_shape_kept_dims = ts.reduced_shape(input_shape, y)
|
125
|
+
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
126
|
+
grad = ts.reshape(grad, output_shape_kept_dims)
|
127
|
+
grad = ts.tile(grad, tile_scaling)
|
128
|
+
|
129
|
+
perm, reduced_num, other_num = ts.device("/cpu:0") do
|
130
|
+
rank = ts.rank(x)
|
131
|
+
reduction_indices = (reduction_indices + rank) % rank
|
132
|
+
reduced = ts.cast(reduction_indices, :int32)
|
133
|
+
idx = ts.range(0, rank)
|
134
|
+
other, = ts.setdiff1d(idx, reduced)
|
135
|
+
|
136
|
+
[ts.concat([reduced, other], 0),
|
137
|
+
ts.reduce_prod(ts.gather(input_shape, reduced)),
|
138
|
+
ts.reduce_prod(ts.gather(input_shape, other))]
|
139
|
+
end
|
140
|
+
|
141
|
+
permuted = ts.transpose(x, perm)
|
142
|
+
permuted_shape = ts.shape(permuted)
|
143
|
+
|
144
|
+
reshaped = ts.reshape(permuted, [reduced_num, other_num])
|
145
|
+
|
146
|
+
# Calculate product, leaving out the current entry
|
147
|
+
left = ts.cumprod(reshaped, axis: 0, exclusive: true)
|
148
|
+
right = ts.cumprod(reshaped, axis: 0, exclusive: true, reverse: true)
|
149
|
+
y = ts.reshape(left * right, permuted_shape)
|
150
|
+
|
151
|
+
# Invert the transpose and reshape operations.
|
152
|
+
# Make sure to set the statically known shape information through a reshape.
|
153
|
+
out = grad * ts.transpose(y, ts.invert_permutation(perm))
|
154
|
+
[ts.reshape(out, input_shape, name: 'prod'), nil]
|
108
155
|
when :squared_difference
|
109
156
|
sx = i_op(:shape, x)
|
110
157
|
sy = i_op(:shape, y)
|
111
158
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
112
159
|
|
113
|
-
x_grad =
|
160
|
+
x_grad = ts.mul(2.0, grad) * (x - y)
|
114
161
|
|
115
|
-
[
|
116
|
-
|
162
|
+
[ts.reshape(ts.reduce_sum(x_grad, rx), sx),
|
163
|
+
ts.reshape(-ts.reduce_sum(x_grad, ry), sy)]
|
117
164
|
when :mat_mul
|
118
165
|
t_a = node.options[:transpose_a]
|
119
166
|
t_b = node.options[:transpose_b]
|
120
167
|
|
121
168
|
if !t_a && !t_b
|
122
|
-
grad_a =
|
123
|
-
grad_b =
|
169
|
+
grad_a = ts.matmul(grad, y, transpose_b: true)
|
170
|
+
grad_b = ts.matmul(x, grad, transpose_a: true)
|
124
171
|
elsif !ta && tb
|
125
|
-
grad_a =
|
126
|
-
grad_b =
|
172
|
+
grad_a = ts.matmul(grad, y)
|
173
|
+
grad_b = ts.matmul(grad, x, transpose_a: true)
|
127
174
|
elsif t_a && !t_b
|
128
|
-
grad_a =
|
129
|
-
grad_b =
|
175
|
+
grad_a = ts.matmul(y, grad, transpose_b: true)
|
176
|
+
grad_b = ts.matmul(x, grad)
|
130
177
|
elsif t_a && t_b
|
131
|
-
grad_a =
|
132
|
-
grad_b =
|
178
|
+
grad_a = ts.matmul(y, grad, transpose_a: true, transpose_b: true)
|
179
|
+
grad_b = ts.matmul(grad, x, transpose_a: true, transpose_b: true)
|
133
180
|
end
|
134
181
|
|
135
182
|
[grad_a, grad_b]
|
136
183
|
when :sin
|
137
|
-
grad *
|
184
|
+
grad * ts.cos(x)
|
138
185
|
when :tanh
|
139
186
|
grad * i_op(:tanh_grad, x)
|
140
187
|
when :pow
|
141
188
|
z = node
|
142
|
-
sx =
|
143
|
-
sy =
|
189
|
+
sx = ts.shape(x)
|
190
|
+
sy = ts.shape(y)
|
144
191
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
145
|
-
gx =
|
192
|
+
gx = ts.reduce_sum(grad * y * ts.pow(x, y - 1), rx)
|
146
193
|
|
147
|
-
log_x =
|
148
|
-
gy =
|
194
|
+
log_x = ts.where(x > 0, ts.log(x), ts.zeros_like(x))
|
195
|
+
gy = ts.reduce_sum(grad * z * log_x, ry)
|
149
196
|
|
150
197
|
[gx, gy]
|
151
198
|
when :abs
|
152
|
-
grad *
|
199
|
+
grad * ts.sign(x)
|
153
200
|
when :log
|
154
|
-
grad *
|
201
|
+
grad * ts.reciprocal(x)
|
155
202
|
when :cos
|
156
|
-
-grad *
|
203
|
+
-grad * ts.sin(x)
|
157
204
|
when :max
|
158
|
-
_min_or_max_grad(node.inputs, grad, ->(
|
205
|
+
_min_or_max_grad(node.inputs, grad, ->(a, b) { ts.greater_equal(a, b) })
|
159
206
|
when :min
|
160
|
-
_min_or_max_grad(node.inputs, grad, ->(
|
207
|
+
_min_or_max_grad(node.inputs, grad, ->(a, b) { ts.less_equal(a, b) })
|
161
208
|
when :tan
|
162
|
-
secx =
|
163
|
-
secx2 =
|
209
|
+
secx = ts.reciprocal(ts.cos(x))
|
210
|
+
secx2 = ts.square(secx)
|
164
211
|
grad * secx2
|
165
212
|
when :negate
|
166
213
|
-grad
|
@@ -169,18 +216,25 @@ module TensorStream
|
|
169
216
|
when :identity, :print
|
170
217
|
grad
|
171
218
|
when :sign
|
172
|
-
|
219
|
+
ts.zeros(ts.shape(x), dtype: x.data_type)
|
220
|
+
when :tile
|
221
|
+
input_shape = ts.shape(x)
|
222
|
+
split_shape = ts.reshape(ts.transpose(ts.stack([y, input_shape])), [-1])
|
223
|
+
axes = ts.range(0, ts.size(split_shape), 2)
|
224
|
+
input_grad = ts.reduce_sum(ts.reshape(grad, split_shape), axes)
|
225
|
+
|
226
|
+
[input_grad, nil]
|
173
227
|
when :sum
|
174
228
|
_sum_grad(x, y, grad)
|
175
229
|
when :reciprocal
|
176
|
-
-grad * (
|
230
|
+
-grad * (ts.constant(1, dtype: x.dtype) / x**2)
|
177
231
|
when :sqrt
|
178
|
-
|
232
|
+
ts.constant(1, dtype: x.dtype) / (ts.constant(2, dtype: x.dtype) * ts.sqrt(x)) * grad
|
179
233
|
when :stop_gradient
|
180
|
-
|
234
|
+
ts.zeros_like(grad)
|
181
235
|
when :square
|
182
|
-
y =
|
183
|
-
|
236
|
+
y = ts.constant(2.0, dtype: x.dtype)
|
237
|
+
ts.multiply(grad, ts.multiply(x, y))
|
184
238
|
when :where
|
185
239
|
x_mask = i_op(:where, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
|
186
240
|
y_mask = i_op(:where, i_op(:zeros_like, x), i_op(:ones_like, y), pred: node.options[:pred])
|
@@ -191,12 +245,12 @@ module TensorStream
|
|
191
245
|
[x_cond * grad, y_cond * grad]
|
192
246
|
when :mean
|
193
247
|
sum_grad = _sum_grad(x, y, grad)[0]
|
194
|
-
input_shape =
|
195
|
-
output_shape =
|
196
|
-
factor = _safe_shape_div(
|
197
|
-
|
248
|
+
input_shape = ts.shape(x)
|
249
|
+
output_shape = ts.shape(node)
|
250
|
+
factor = _safe_shape_div(ts.reduce_prod(input_shape), ts.reduce_prod(output_shape))
|
251
|
+
[ts.div(sum_grad, ts.cast(factor, sum_grad.data_type)), nil]
|
198
252
|
when :log1p
|
199
|
-
grad *
|
253
|
+
grad * ts.reciprocal(i_cons(1, dtype: grad.data_type) + x)
|
200
254
|
when :sigmoid
|
201
255
|
i_op(:sigmoid_grad, x, grad)
|
202
256
|
when :sigmoid_grad
|
@@ -205,7 +259,8 @@ module TensorStream
|
|
205
259
|
when :softmax
|
206
260
|
i_op(:softmax_grad, x, grad)
|
207
261
|
when :softmax_cross_entropy_with_logits_v2
|
208
|
-
|
262
|
+
output = node
|
263
|
+
[_broadcast_mul(grad, output[1]), nil]
|
209
264
|
when :floor, :ceil
|
210
265
|
# non differentiable
|
211
266
|
nil
|
@@ -215,6 +270,10 @@ module TensorStream
|
|
215
270
|
when :argmin, :argmax, :floor_div
|
216
271
|
# non differentiable
|
217
272
|
[nil, nil]
|
273
|
+
when :transpose
|
274
|
+
return [ts.transpose(grad, ts.invert_permutation(y)), nil]
|
275
|
+
when :index
|
276
|
+
grad
|
218
277
|
else
|
219
278
|
raise "no derivative op for #{node.operation}"
|
220
279
|
end
|
@@ -231,12 +290,12 @@ module TensorStream
|
|
231
290
|
end
|
232
291
|
|
233
292
|
def self._safe_shape_div(arg_x, arg_y)
|
234
|
-
_op(:floor_div, arg_x,
|
293
|
+
_op(:floor_div, arg_x, ts.maximum(arg_y, 1))
|
235
294
|
end
|
236
295
|
|
237
296
|
def self._sum_grad(arg_x, arg_y, grad)
|
238
297
|
input_shape = _op(:shape, arg_x)
|
239
|
-
output_shape_kept_dims =
|
298
|
+
output_shape_kept_dims = ts.reduced_shape(input_shape, arg_y)
|
240
299
|
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
241
300
|
new_grad = _op(:reshape, grad, output_shape_kept_dims)
|
242
301
|
|
@@ -254,19 +313,24 @@ module TensorStream
|
|
254
313
|
x = inputs[0]
|
255
314
|
y = inputs[1]
|
256
315
|
gdtype = grad.data_type
|
257
|
-
sx =
|
258
|
-
sy =
|
259
|
-
gradshape =
|
260
|
-
zeros =
|
316
|
+
sx = ts.shape(x)
|
317
|
+
sy = ts.shape(y)
|
318
|
+
gradshape = ts.shape(grad)
|
319
|
+
zeros = ts.zeros(gradshape, dtype: gdtype)
|
261
320
|
xmask = selector_op.call(x, y)
|
262
321
|
rx, ry = _broadcast_gradient_args(sx, sy)
|
263
|
-
xgrad =
|
264
|
-
ygrad =
|
265
|
-
gx =
|
266
|
-
gy =
|
322
|
+
xgrad = ts.where(xmask, grad, zeros, name: 'x')
|
323
|
+
ygrad = ts.where(xmask, zeros, grad, name: 'y')
|
324
|
+
gx = ts.reshape(ts.reduce_sum(xgrad, rx), sx)
|
325
|
+
gy = ts.reshape(ts.reduce_sum(ygrad, ry), sy)
|
267
326
|
[gx, gy]
|
268
327
|
end
|
269
328
|
|
329
|
+
def self._broadcast_mul(vec, mat)
|
330
|
+
vec = ts.expand_dims(vec, -1)
|
331
|
+
vec * mat
|
332
|
+
end
|
333
|
+
|
270
334
|
def self._include?(arr, obj)
|
271
335
|
arr.each { |a| return true if a.equal?(obj) }
|
272
336
|
false
|