tensor_stream 0.1.5 → 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +5 -5
- data/CHANGELOG.md +13 -0
- data/README.md +34 -0
- data/lib/tensor_stream.rb +7 -3
- data/lib/tensor_stream/control_flow.rb +1 -2
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +44 -3
- data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +9 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +70 -36
- data/lib/tensor_stream/graph.rb +15 -7
- data/lib/tensor_stream/graph_serializers/graphml.rb +183 -35
- data/lib/tensor_stream/graph_serializers/pbtext.rb +81 -14
- data/lib/tensor_stream/graph_serializers/serializer.rb +13 -0
- data/lib/tensor_stream/helpers/string_helper.rb +12 -0
- data/lib/tensor_stream/math_gradients.rb +203 -161
- data/lib/tensor_stream/operation.rb +30 -16
- data/lib/tensor_stream/ops.rb +29 -19
- data/lib/tensor_stream/placeholder.rb +2 -3
- data/lib/tensor_stream/session.rb +7 -13
- data/lib/tensor_stream/tensor.rb +22 -5
- data/lib/tensor_stream/tensor_shape.rb +2 -0
- data/lib/tensor_stream/trainer.rb +6 -1
- data/lib/tensor_stream/variable.rb +4 -3
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/gradient_sample.graphml +1255 -0
- data/samples/linear_regression.rb +1 -1
- data/samples/logistic_regression.rb +9 -2
- data/tensor_stream.gemspec +1 -1
- data/test_samples/error.graphml +120 -0
- data/test_samples/gradient_sample.graphml +1255 -0
- data/{samples → test_samples}/iris.rb +0 -0
- data/{samples → test_samples}/raw_neural_net_sample.rb +0 -0
- data/{samples → test_samples}/test.py +2 -0
- data/test_samples/test2.py +41 -0
- metadata +41 -47
@@ -1,91 +1,239 @@
|
|
1
1
|
module TensorStream
|
2
|
-
class Graphml
|
2
|
+
class Graphml < Serializer
|
3
3
|
def initialize
|
4
4
|
end
|
5
5
|
|
6
|
-
def
|
6
|
+
def get_string(tensor, session = nil)
|
7
|
+
tensor = TensorStream.convert_to_tensor(tensor) unless tensor.is_a?(Tensor)
|
7
8
|
@session = session
|
8
|
-
@
|
9
|
+
@name = tensor.name
|
10
|
+
@last_session_context = session ? session.last_session_context : {}
|
11
|
+
groups = {}
|
9
12
|
|
10
13
|
arr_buf = []
|
11
14
|
arr_buf << '<?xml version="1.0" encoding="UTF-8"?>'
|
12
|
-
arr_buf << '<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
15
|
+
arr_buf << '<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:y="http://www.yworks.com/xml/graphml"
|
13
16
|
xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">'
|
14
17
|
arr_buf << '<key id="d0" for="node" attr.name="label" attr.type="string"/>'
|
15
18
|
arr_buf << '<key id="d1" for="node" attr.name="formula" attr.type="string"/>'
|
16
19
|
arr_buf << '<key id="d2" for="node" attr.name="color" attr.type="string"/>'
|
17
20
|
arr_buf << '<key id="d3" for="node" attr.name="value" attr.type="string"/>'
|
21
|
+
arr_buf << '<key attr.name="description" attr.type="string" for="edge" id="d12"/>'
|
22
|
+
arr_buf << '<key for="edge" id="d13" yfiles.type="edgegraphics"/>'
|
23
|
+
arr_buf << '<key for="node" id="d9" yfiles.type="nodegraphics"/>'
|
18
24
|
arr_buf << "<graph id=\"g_#{_gml_string(tensor.name)}\" edgedefault=\"directed\">"
|
19
25
|
arr_buf << "<node id=\"out\">"
|
20
26
|
arr_buf << "<data key=\"d0\">out</data>"
|
21
27
|
arr_buf << "<data key=\"d2\">red</data>"
|
28
|
+
arr_buf << "<data key=\"d9\">"
|
29
|
+
arr_buf << "<y:ShapeNode>"
|
30
|
+
arr_buf << " <y:Fill color=\"#FF0000\" transparent=\"false\"/>"
|
31
|
+
arr_buf << " <y:NodeLabel alignment=\"center\">out</y:NodeLabel>"
|
32
|
+
arr_buf << "</y:ShapeNode>"
|
33
|
+
arr_buf << "</data>"
|
22
34
|
arr_buf << "</node>"
|
23
|
-
|
24
|
-
arr_buf
|
35
|
+
|
36
|
+
to_graph_ml(tensor, arr_buf, {}, groups)
|
37
|
+
#dump groups
|
38
|
+
groups.each do |k, g|
|
39
|
+
arr_buf << create_group(k, k, g)
|
40
|
+
end
|
41
|
+
|
42
|
+
output_edge(tensor, "out", arr_buf)
|
25
43
|
arr_buf << "</graph>"
|
26
44
|
arr_buf << "</graphml>"
|
27
|
-
|
45
|
+
arr_buf.flatten.join("\n")
|
28
46
|
end
|
29
47
|
|
30
48
|
private
|
31
49
|
|
50
|
+
def add_to_group(groups, name, arr_buf)
|
51
|
+
name_parts = name.split('/')
|
52
|
+
return false if name_parts.size < 2
|
53
|
+
|
54
|
+
prefix = name_parts.shift
|
55
|
+
|
56
|
+
ptr = find_or_create_group(prefix, groups)
|
57
|
+
|
58
|
+
Kernel.loop do
|
59
|
+
next_group = ptr[:group]
|
60
|
+
ptr = find_or_create_group(prefix, next_group)
|
61
|
+
break if name_parts.size < 2
|
62
|
+
prefix = name_parts.shift
|
63
|
+
end
|
64
|
+
|
65
|
+
ptr[:buf] << arr_buf
|
66
|
+
true
|
67
|
+
end
|
68
|
+
|
69
|
+
def find_or_create_group(prefix, groups)
|
70
|
+
if !groups[prefix]
|
71
|
+
groups[prefix] = { buf: [], group: {} }
|
72
|
+
end
|
73
|
+
|
74
|
+
return groups[prefix]
|
75
|
+
end
|
76
|
+
|
77
|
+
def create_group(id, title, group)
|
78
|
+
arr_buf = []
|
79
|
+
arr_buf << "<node id=\"#{id}\" yfiles.foldertype=\"group\">"
|
80
|
+
arr_buf << '<data key="d9">'
|
81
|
+
arr_buf << '<y:ProxyAutoBoundsNode>'
|
82
|
+
arr_buf << '<y:Realizers active="0">'
|
83
|
+
arr_buf << '<y:GroupNode>'
|
84
|
+
arr_buf << '<y:Fill color="#CAECFF84" transparent="false"/>'
|
85
|
+
arr_buf << '<y:BorderStyle color="#666699" type="dotted" width="1.0"/>'
|
86
|
+
arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">'+ title + '</y:NodeLabel>'
|
87
|
+
arr_buf << '<y:Shape type="roundrectangle"/>'
|
88
|
+
arr_buf << '</y:GroupNode>'
|
89
|
+
arr_buf << '</y:Realizers>'
|
90
|
+
arr_buf << '</y:ProxyAutoBoundsNode>'
|
91
|
+
arr_buf << '</data>'
|
92
|
+
arr_buf << '<graph edgedefault="directed" id="n105:">'
|
93
|
+
arr_buf << group[:buf]
|
94
|
+
group[:group].each do |k, g|
|
95
|
+
arr_buf << create_group(k, k, g)
|
96
|
+
end
|
97
|
+
arr_buf << '</graph>'
|
98
|
+
arr_buf << '</node>'
|
99
|
+
arr_buf
|
100
|
+
end
|
101
|
+
|
32
102
|
def _val(tensor)
|
33
|
-
JSON.pretty_generate(@last_session_context[tensor.name])
|
103
|
+
# JSON.pretty_generate(@last_session_context[tensor.name])
|
104
|
+
@last_session_context[tensor.name]
|
34
105
|
end
|
35
106
|
|
36
|
-
def to_graph_ml(tensor, arr_buf = [], added = {}, _id = 0)
|
107
|
+
def to_graph_ml(tensor, arr_buf = [], added = {}, groups = {}, _id = 0)
|
37
108
|
puts tensor.name
|
109
|
+
return unless tensor.is_a?(Operation)
|
110
|
+
|
38
111
|
added[tensor.name] = true
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
112
|
+
node_buf = []
|
113
|
+
node_buf << "<node id=\"#{_gml_string(tensor.name)}\">"
|
114
|
+
node_buf << "<data key=\"d0\">#{tensor.operation}</data>"
|
115
|
+
node_buf << "<data key=\"d1\">#{tensor.to_math(true, 1)}</data>"
|
116
|
+
node_buf << "<data key=\"d2\">blue</data>"
|
117
|
+
|
43
118
|
if @last_session_context[tensor.name]
|
44
119
|
arr_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
|
45
120
|
end
|
46
|
-
|
121
|
+
node_buf << "<data key=\"d9\">"
|
122
|
+
node_buf << "<y:ShapeNode>"
|
123
|
+
if tensor.internal?
|
124
|
+
node_buf << " <y:Fill color=\"#FFFF99\" transparent=\"false\"/>"
|
125
|
+
else
|
126
|
+
node_buf << " <y:Fill color=\"#99CC00\" transparent=\"false\"/>"
|
127
|
+
end
|
128
|
+
node_buf << " <y:NodeLabel alignment=\"center\">#{tensor.operation}</y:NodeLabel>"
|
129
|
+
node_buf << "</y:ShapeNode>"
|
130
|
+
node_buf << "</data>"
|
131
|
+
node_buf << "</node>"
|
132
|
+
|
133
|
+
if !add_to_group(groups, tensor.name, node_buf)
|
134
|
+
add_to_group(groups, "program/#{tensor.name}", node_buf)
|
135
|
+
end
|
47
136
|
|
48
137
|
tensor.items.each do |item|
|
49
138
|
next unless item
|
50
|
-
next if
|
139
|
+
next if added[item.name]
|
140
|
+
|
141
|
+
next to_graph_ml(item, arr_buf, added, groups) if item.is_a?(Operation)
|
51
142
|
|
52
|
-
next to_graph_ml(item, arr_buf, added) if item.is_a?(Operation)
|
53
143
|
added[item.name] = true
|
144
|
+
item_buf = []
|
54
145
|
if item.is_a?(Variable)
|
55
|
-
|
56
|
-
|
57
|
-
|
146
|
+
item_buf << "<node id=\"#{_gml_string(item.name)}\">"
|
147
|
+
item_buf << "<data key=\"d0\">#{item.name}</data>"
|
148
|
+
item_buf << "<data key=\"d2\">green</data>"
|
58
149
|
if @last_session_context[item.name]
|
59
|
-
|
150
|
+
item_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
|
60
151
|
end
|
61
|
-
|
152
|
+
item_buf << "<data key=\"d9\">"
|
153
|
+
item_buf << "<y:ShapeNode>"
|
154
|
+
item_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
|
155
|
+
item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
|
156
|
+
item_buf << "</y:ShapeNode>"
|
157
|
+
item_buf << "</data>"
|
158
|
+
item_buf << "</node>"
|
62
159
|
elsif item.is_a?(Placeholder)
|
63
|
-
|
64
|
-
|
65
|
-
|
160
|
+
item_buf << "<node id=\"#{_gml_string(item.name)}\">"
|
161
|
+
item_buf << "<data key=\"d9\">"
|
162
|
+
item_buf << "<y:ShapeNode>"
|
163
|
+
item_buf << " <y:Fill color=\"#FFCC00\" transparent=\"false\"/>"
|
164
|
+
item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
|
165
|
+
item_buf << "</y:ShapeNode>"
|
166
|
+
item_buf << "</data>"
|
66
167
|
if @last_session_context[item.name]
|
67
|
-
|
168
|
+
item_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
|
68
169
|
end
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
170
|
+
item_buf << "</node>"
|
171
|
+
elsif item.is_a?(Tensor)
|
172
|
+
item_buf << "<node id=\"#{_gml_string(item.name)}\">"
|
173
|
+
item_buf << "<data key=\"d0\">#{item.name}</data>"
|
174
|
+
item_buf << "<data key=\"d2\">black</data>"
|
175
|
+
item_buf << "<data key=\"d9\">"
|
176
|
+
item_buf << "<y:ShapeNode>"
|
177
|
+
|
178
|
+
if item.internal?
|
179
|
+
item_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
|
180
|
+
else
|
181
|
+
item_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
|
182
|
+
end
|
183
|
+
|
184
|
+
|
185
|
+
item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
|
186
|
+
|
187
|
+
item_buf << "</y:ShapeNode>"
|
188
|
+
item_buf << "</data>"
|
189
|
+
item_buf << "</node>"
|
190
|
+
end
|
191
|
+
|
192
|
+
if !add_to_group(groups, item.name, item_buf)
|
193
|
+
if item.is_a?(Variable)
|
194
|
+
add_to_group(groups, "variable/#{item.name}", item_buf)
|
195
|
+
else
|
196
|
+
add_to_group(groups, "program/#{item.name}", item_buf)
|
76
197
|
end
|
77
|
-
arr_buf << "</node>"
|
78
198
|
end
|
79
199
|
end
|
80
200
|
|
81
|
-
tensor.items.
|
201
|
+
tensor.items.each_with_index do |item, index|
|
82
202
|
next unless item
|
83
|
-
|
203
|
+
output_edge(item, tensor, arr_buf, index)
|
84
204
|
end
|
85
205
|
end
|
86
206
|
|
87
207
|
def _gml_string(str)
|
88
208
|
str.gsub('/','-')
|
89
209
|
end
|
210
|
+
|
211
|
+
def output_edge(item, tensor, arr_buf, index = 0)
|
212
|
+
target_name = tensor.is_a?(Tensor) ? tensor.name : tensor
|
213
|
+
arr_buf << "<edge source=\"#{_gml_string(item.name)}\" target=\"#{_gml_string(target_name)}\">"
|
214
|
+
arr_buf << "<data key=\"d13\">"
|
215
|
+
|
216
|
+
arr_buf << "<y:PolyLineEdge>"
|
217
|
+
arr_buf << "<y:EdgeLabel >"
|
218
|
+
if !@last_session_context.empty?
|
219
|
+
arr_buf << "<![CDATA[ #{_val(item)} ]]>"
|
220
|
+
else
|
221
|
+
if item.shape.shape.nil?
|
222
|
+
arr_buf << "<![CDATA[ #{item.data_type.to_s} ? ]]>"
|
223
|
+
else
|
224
|
+
arr_buf << "<![CDATA[ #{item.data_type.to_s} #{item.shape.shape.empty? ? 'scalar' : item.shape.shape.to_json} ]]>"
|
225
|
+
end
|
226
|
+
end
|
227
|
+
arr_buf << "</y:EdgeLabel >"
|
228
|
+
arr_buf << "<y:Arrows source=\"none\" target=\"standard\"/>"
|
229
|
+
if index == 0
|
230
|
+
arr_buf << "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
|
231
|
+
else
|
232
|
+
arr_buf << "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
|
233
|
+
end
|
234
|
+
arr_buf << "</y:PolyLineEdge>"
|
235
|
+
arr_buf << "</data>"
|
236
|
+
arr_buf << "</edge>"
|
237
|
+
end
|
90
238
|
end
|
91
239
|
end
|
@@ -1,54 +1,121 @@
|
|
1
1
|
module TensorStream
|
2
|
-
class Pbtext
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
def serialize(session, filename, tensor)
|
7
|
-
end
|
2
|
+
class Pbtext < TensorStream::Serializer
|
3
|
+
include TensorStream::StringHelper
|
4
|
+
include TensorStream::OpHelper
|
8
5
|
|
9
|
-
def get_string(
|
6
|
+
def get_string(tensor_or_graph, session = nil)
|
7
|
+
graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
|
10
8
|
@lines = []
|
11
9
|
graph.nodes.each do |k, node|
|
12
10
|
@lines << "node {"
|
13
11
|
@lines << " name: #{node.name.to_json}"
|
14
12
|
if node.is_a?(TensorStream::Operation)
|
15
|
-
@lines << " op: #{node.operation.to_json}"
|
13
|
+
@lines << " op: #{camelize(node.operation.to_s).to_json}"
|
16
14
|
node.items.each do |input|
|
17
15
|
next unless input
|
18
16
|
@lines << " input: #{input.name.to_json}"
|
19
17
|
end
|
20
18
|
# type
|
21
|
-
pb_attr('T', sym_to_protobuf_type(node.data_type))
|
19
|
+
pb_attr('T', "dtype: #{sym_to_protobuf_type(node.data_type)}")
|
20
|
+
process_options(node)
|
22
21
|
elsif node.is_a?(TensorStream::Tensor) && node.is_const
|
23
22
|
@lines << " op: \"Const\""
|
24
23
|
# type
|
25
|
-
pb_attr('T', sym_to_protobuf_type(node.data_type))
|
24
|
+
pb_attr('T', "dtype: #{sym_to_protobuf_type(node.data_type)}")
|
26
25
|
pb_attr('value', tensor_value(node))
|
26
|
+
elsif node.is_a?(TensorStream::Variable)
|
27
|
+
@lines << " op: \"VariableV2\""
|
28
|
+
pb_attr('T', "dtype: #{sym_to_protobuf_type(node.data_type)}")
|
29
|
+
pb_attr('shape', shape_buf(node, 'shape'))
|
30
|
+
process_options(node)
|
27
31
|
end
|
28
32
|
@lines << "}"
|
29
33
|
end
|
30
|
-
@lines
|
34
|
+
@lines << "versions {"
|
35
|
+
@lines << " producer: 26"
|
36
|
+
@lines << "}"
|
37
|
+
@lines.flatten.join("\n")
|
31
38
|
end
|
32
39
|
|
33
40
|
private
|
34
41
|
|
42
|
+
def process_options(node)
|
43
|
+
node.options.each do |k, v|
|
44
|
+
next if %w[name].include?(k.to_s)
|
45
|
+
@lines << " attr {"
|
46
|
+
@lines << " key: \"#{k}\""
|
47
|
+
@lines << " value {"
|
48
|
+
@lines << " }"
|
49
|
+
@lines << " }"
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
def pack_arr_float(float_arr)
|
54
|
+
float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
|
55
|
+
end
|
56
|
+
|
57
|
+
def pack_arr_int(int_arr)
|
58
|
+
int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
|
59
|
+
end
|
60
|
+
|
61
|
+
def shape_buf(tensor, shape_type = 'tensor_shape')
|
62
|
+
arr = []
|
63
|
+
arr << " #{shape_type} {"
|
64
|
+
tensor.shape.shape.each do |dim|
|
65
|
+
arr << " dim {"
|
66
|
+
arr << " size: #{dim}"
|
67
|
+
arr << " }"
|
68
|
+
end if tensor.shape.shape
|
69
|
+
arr << " }"
|
70
|
+
arr
|
71
|
+
end
|
35
72
|
def tensor_value(tensor)
|
36
73
|
arr = []
|
37
74
|
arr << "tensor {"
|
38
75
|
arr << " dtype: #{sym_to_protobuf_type(tensor.data_type)}"
|
39
|
-
|
76
|
+
|
77
|
+
arr += shape_buf(tensor)
|
78
|
+
|
79
|
+
if tensor.rank > 0
|
80
|
+
if TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
|
81
|
+
packed = pack_arr_float(tensor.value)
|
82
|
+
arr << " tensor_content: \"#{packed}\""
|
83
|
+
elsif TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
|
84
|
+
packed = pack_arr_int(tensor.value)
|
85
|
+
arr << " tensor_content: \"#{packed}\""
|
86
|
+
elsif tensor.data_type == :string
|
87
|
+
tensor.value.each do |v|
|
88
|
+
arr << " string_val: #{v.to_json}"
|
89
|
+
end
|
90
|
+
else
|
91
|
+
arr << " tensor_content: #{tensor.value.flatten}"
|
92
|
+
end
|
93
|
+
else
|
94
|
+
val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
|
95
|
+
"int_val"
|
96
|
+
elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
|
97
|
+
"float_val"
|
98
|
+
elsif tensor.data_type == :string
|
99
|
+
"string_val"
|
100
|
+
else
|
101
|
+
"val"
|
102
|
+
end
|
103
|
+
arr << " #{val_type}: #{tensor.value.to_json}"
|
104
|
+
end
|
40
105
|
arr << "}"
|
41
106
|
arr
|
42
107
|
end
|
43
108
|
|
44
109
|
def sym_to_protobuf_type(type)
|
45
110
|
case type
|
46
|
-
when :int32
|
111
|
+
when :int32, :int
|
47
112
|
"DT_INT32"
|
48
113
|
when :float, :float32
|
49
114
|
"DT_FLOAT"
|
115
|
+
when :string
|
116
|
+
"DT_STRING"
|
50
117
|
else
|
51
|
-
"
|
118
|
+
"UKNOWN"
|
52
119
|
end
|
53
120
|
end
|
54
121
|
|
@@ -0,0 +1,12 @@
|
|
1
|
+
module TensorStream
|
2
|
+
module StringHelper
|
3
|
+
def camelize(string, uppercase_first_letter = true)
|
4
|
+
string = if uppercase_first_letter
|
5
|
+
string.sub(/^[a-z\d]*/) { $&.capitalize }
|
6
|
+
else
|
7
|
+
string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
|
8
|
+
end
|
9
|
+
string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::')
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
@@ -3,201 +3,243 @@ module TensorStream
|
|
3
3
|
class MathGradients
|
4
4
|
extend TensorStream::OpHelper
|
5
5
|
|
6
|
+
def self.tf
|
7
|
+
TensorStream
|
8
|
+
end
|
9
|
+
|
6
10
|
def self.derivative(tensor, wrt_dx, options = {})
|
7
|
-
|
8
|
-
|
9
|
-
return options[:graph].get_node(gradient_program_name) if options[:graph] && options[:graph].node_added?(gradient_program_name)
|
11
|
+
return i_op(:ones_like, tensor) if tensor.equal?(wrt_dx)
|
12
|
+
return i_op(:zeros_like, tensor) unless wrt_dx.consumers.include?(tensor.name)
|
10
13
|
|
11
|
-
|
12
|
-
|
14
|
+
nodes_to_compute = wrt_dx.consumers.select do |t|
|
15
|
+
node = tensor.graph.nodes[t]
|
16
|
+
node.consumers.include?(tensor.name) || node.equal?(tensor)
|
17
|
+
end.compact + [wrt_dx.name]
|
13
18
|
|
14
|
-
|
15
|
-
return i_cons(0, constant_options) if options[:stop_gradients] && _include?(options[:stop_gradients], tensor)
|
19
|
+
grad = i_op(:ones_like, wrt_dx)
|
16
20
|
|
17
|
-
|
18
|
-
|
19
|
-
|
21
|
+
result = _propagate(grad, tensor, wrt_dx, nodes_to_compute, options[:stop_gradients] || [])
|
22
|
+
i_op(:truncate, result, tf.shape(wrt_dx))
|
23
|
+
end
|
20
24
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
grad * i_op(:sign, _ds(tensor.items[0]))
|
42
|
-
when :square
|
43
|
-
i_cons(2, constant_options_1) * _ds(tensor.items[0]) * grad
|
44
|
-
when :exp
|
45
|
-
i_op(:exp, tensor.items[0]) * grad
|
46
|
-
when :log
|
47
|
-
(i_cons(1, constant_options_1) / _ds(tensor.items[0])) * grad
|
48
|
-
when :tanh
|
49
|
-
i_op(:mul, (i_cons(1, constant_options_1) - (i_op(:tanh, _ds(tensor.items[0]))**2)), grad, name: 'grad_tanh')
|
50
|
-
when :tan
|
51
|
-
(i_cons(1, constant_options_1) / (i_op(:cos, _ds(tensor.items[0]))**2)) * grad
|
52
|
-
when :sin
|
53
|
-
i_op(:mul, i_op(:cos, tensor.items[0]), grad, name: 'grad_sin')
|
54
|
-
when :sqrt
|
55
|
-
i_cons(1, constant_options_1) / (i_cons(2, constant_options_1) * i_op(:sqrt, _ds(tensor.items[0]))) * grad
|
56
|
-
when :cos
|
57
|
-
-i_op(:sin, tensor.items[0]) * grad
|
58
|
-
when :add
|
59
|
-
# rx = _op(:shape, tensor.items[0])
|
60
|
-
# ry = _op(:shape, tensor.items[1])
|
25
|
+
def self._propagate(grad, tensor, stop_tensor, nodes_to_compute, stop_gradients = [])
|
26
|
+
return grad * i_op(:ones_like, stop_tensor) if stop_tensor.equal?(tensor)
|
27
|
+
return i_op(:zeros_like, stop_tensor) if stop_gradients && _include?(stop_gradients, tensor)
|
28
|
+
return i_op(:zeros_like, stop_tensor) unless tensor.is_a?(Operation)
|
29
|
+
|
30
|
+
computed_op = if _op_supports_broadcast?(tensor)
|
31
|
+
_compute_derivative(tensor, _broadcast_transform(tensor, grad)[1])
|
32
|
+
else
|
33
|
+
_compute_derivative(tensor, grad)
|
34
|
+
end
|
35
|
+
|
36
|
+
if computed_op.is_a?(Array)
|
37
|
+
partials = []
|
38
|
+
computed_op.each_with_index do |op_grad, index|
|
39
|
+
next if op_grad.nil?
|
40
|
+
|
41
|
+
if nodes_to_compute.include?(tensor.items[index].name)
|
42
|
+
partials << _propagate(op_grad, tensor.items[index], stop_tensor, nodes_to_compute, stop_gradients)
|
43
|
+
end
|
44
|
+
end
|
61
45
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
46
|
+
partials.reduce(:+)
|
47
|
+
else
|
48
|
+
return tf.zeros_like(stop_tensor) if computed_op.nil?
|
49
|
+
_propagate(computed_op, tensor.items[0], stop_tensor, nodes_to_compute, stop_gradients)
|
50
|
+
end
|
51
|
+
end
|
66
52
|
|
67
|
-
|
68
|
-
|
53
|
+
def self._compute_derivative(node, grad)
|
54
|
+
node.graph.name_scope("#{node.name}_grad") do
|
55
|
+
x = node.items[0] if node.items[0]
|
56
|
+
y = node.items[1] if node.items[1]
|
69
57
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
_grad_with_broadcast(tensor, wrt_dx, ->(a, b) { i_op(:add, a, b, name: 'grad_add') }, options)
|
74
|
-
when :sub
|
75
|
-
_grad_with_broadcast(tensor, wrt_dx, ->(a, b) { i_op(:sub, a, b, name: 'grad_sub') }, options)
|
76
|
-
when :pow
|
77
|
-
gx = _ds(tensor.items[1]) * (_ds(tensor.items[0])**(_ds(tensor.items[1]) - 1)) * grad
|
58
|
+
case node.operation
|
59
|
+
when :add
|
60
|
+
return [grad, grad] if _shapes_fully_specified_and_equal(x, y)
|
78
61
|
|
79
|
-
|
80
|
-
|
62
|
+
sx = tf.shape(x, name: 'add/shape_x')
|
63
|
+
sy = tf.shape(y, name: 'add/shape_y')
|
64
|
+
rx, ry = _broadcast_gradient_args(sx, sy)
|
65
|
+
keep_dims_x = tf.rank(x) == tf.rank(grad)
|
66
|
+
keep_dims_y = tf.rank(y) == tf.rank(grad)
|
81
67
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
gy = grad2 * i_op(:div, i_op(:div, -_ds(tensor.items[0]), _ds(tensor.items[1])), _ds(tensor.items[1]))
|
68
|
+
[tf.reduce_sum(grad, rx, name: 'add/reduce_sum_x', keepdims: keep_dims_x),
|
69
|
+
tf.reduce_sum(grad, ry, name: 'add/reduce_sum_y', keepdims: keep_dims_y)]
|
70
|
+
when :sub
|
71
|
+
return [grad, -grad] if _shapes_fully_specified_and_equal(x, y)
|
87
72
|
|
88
|
-
|
73
|
+
sx = tf.shape(x, name: 'sub/shape_x')
|
74
|
+
sy = tf.shape(y, name: 'sub/shape_y')
|
75
|
+
rx, ry = _broadcast_gradient_args(sx, sy)
|
76
|
+
[tf.reduce_sum(grad, rx), -tf.reduce_sum(grad, ry)]
|
89
77
|
when :mul
|
90
|
-
|
91
|
-
|
92
|
-
ry =
|
93
|
-
sx, sy = _broadcast_gradient_args(rx, ry)
|
94
|
-
inputs = _broadcast_transform(tensor.items[0], tensor.items[1])
|
95
|
-
keep_dims_x = _op(:rank, inputs[0]) == _op(:rank, tensor.items[0])
|
96
|
-
keep_dims_y = _op(:rank, inputs[1]) == _op(:rank, tensor.items[1])
|
97
|
-
|
98
|
-
_filtered_sum(_op(:reduce_sum, grad * _ds(inputs[1]), nil, axis: sy, keepdims: keep_dims_x),
|
99
|
-
_op(:reduce_sum, _ds(inputs[0]) * grad2, nil, axis: sx, keepdims: keep_dims_y), wrt_dx)
|
100
|
-
when :reduce_mean
|
101
|
-
input_size = i_op(:reduce_prod, i_op(:shape, tensor.items[0]))
|
102
|
-
output_size = i_op(:reduce_prod, i_op(:shape, tensor))
|
103
|
-
factor = input_size / output_size
|
104
|
-
|
105
|
-
(grad / i_op(:cast, factor, data_type: grad.dtype))
|
106
|
-
when :reduce_sum
|
107
|
-
grad
|
108
|
-
when :reciprocal
|
109
|
-
-grad * (i_cons(1, constant_options_1) / _ds(tensor.items[0])**2)
|
110
|
-
when :stop_gradient
|
111
|
-
return i_cons(0, constant_options)
|
112
|
-
when :matmul
|
113
|
-
derivative_a = derivative(tensor.items[0], wrt_dx)
|
114
|
-
derivative_b = derivative(tensor.items[1], wrt_dx)
|
115
|
-
|
116
|
-
s0 = i_op(:shape, tensor.items[0])
|
117
|
-
s1 = i_op(:shape, tensor.items[1])
|
78
|
+
sx = tf.shape(x)
|
79
|
+
sy = tf.shape(y)
|
80
|
+
rx, ry = _broadcast_gradient_args(sx, sy)
|
118
81
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
matmul_db = i_op(:matmul, tensor.items[0], identity_1, transpose_a: true,
|
126
|
-
pad_zeros: true,
|
127
|
-
name: 'matrix_dy')
|
128
|
-
# matmul_db = _op(:transpose, matmul_db, nil).first
|
129
|
-
|
130
|
-
# begin_a = _op(:zeros, _op(:rank, matmul_db), nil, data_type: :int32, name: 'begin_a')
|
131
|
-
# matmul_b_shape = _op(:shape, matmul_db)
|
132
|
-
# end_a = [matmul_b_shape[0], 1]
|
82
|
+
[ tf.reduce_sum(tf.mul(grad, y), rx),
|
83
|
+
tf.reduce_sum(tf.mul(x, grad), ry)]
|
84
|
+
when :div
|
85
|
+
sx = i_op(:shape, x)
|
86
|
+
sy = i_op(:shape, y)
|
87
|
+
rx, ry = _broadcast_gradient_args(sx, sy)
|
133
88
|
|
134
|
-
|
89
|
+
[tf.reduce_sum(tf.div(grad, y), rx),
|
90
|
+
tf.reduce_sum(grad * tf.div(tf.div(-x, y), y),
|
91
|
+
ry)]
|
92
|
+
when :matmul
|
93
|
+
t_a = node.options[:transpose_a]
|
94
|
+
t_b = node.options[:transpose_b]
|
95
|
+
|
96
|
+
s0 = tf.shape(x)
|
97
|
+
s1 = tf.shape(y)
|
98
|
+
|
99
|
+
identity_0 = tf.ones([ s0[0], s1[1] ], dtype: x.data_type, name: 'matmul/identity0')
|
100
|
+
identity_1 = tf.ones([ s0[0], s1[1] ], dtype: y.data_type, name: 'matmul/identity1')
|
101
|
+
|
102
|
+
grad_a, grad_b = nil
|
103
|
+
if !t_a && !t_b
|
104
|
+
grad_a = tf.matmul(identity_0, y, transpose_b: true)
|
105
|
+
grad_b = tf.matmul(x, identity_1, transpose_a: true)
|
106
|
+
elsif !ta && tb
|
107
|
+
grad_a = tf.matmul(identity_0, y)
|
108
|
+
grad_b = tf.matmul(identity_1, x, transpose_a: true)
|
109
|
+
elsif t_a && !t_b
|
110
|
+
grad_a = tf.matmul(y, identity_0, transpose_b: true)
|
111
|
+
grad_b = tf.matmul(x, identity_1)
|
112
|
+
elsif t_a && t_b
|
113
|
+
grad_a = tf.matmul(y, identity_0, transpose_a: true, transpose_b: true)
|
114
|
+
grad_b = tf.matmul(identity_1, x, transpose_a: true, transpose_b: true)
|
115
|
+
end
|
116
|
+
|
117
|
+
grad_a = i_op(:mul, grad, grad_a, name: 'matmul/grad_a_norm_mul_da')
|
118
|
+
grad_b = i_op(:mul, grad, grad_b, name: 'matmul/grad_b_norm_mul_db')
|
119
|
+
|
120
|
+
[grad_a, grad_b]
|
121
|
+
when :sin
|
122
|
+
grad * tf.cos(x)
|
123
|
+
when :tanh
|
124
|
+
grad * i_op(:tanh_grad, x)
|
125
|
+
when :pow
|
126
|
+
z = node
|
127
|
+
sx = tf.shape(x)
|
128
|
+
sy = tf.shape(y)
|
129
|
+
rx, ry = _broadcast_gradient_args(sx, sy)
|
130
|
+
gx = tf.reshape(
|
131
|
+
tf.reduce_sum(grad * y * tf.pow(x, y - 1), rx), sx)
|
135
132
|
|
136
|
-
|
137
|
-
|
138
|
-
norm_b = i_op(:mul, derivative_b, matmul_db, name: 'grad_b_norm_mul_db')
|
133
|
+
log_x = tf.where(x > 0, tf.log(x), tf.zeros_like(x))
|
134
|
+
gy = tf.reshape(tf.reduce_sum(grad * z * log_x, ry), sy)
|
139
135
|
|
140
|
-
|
141
|
-
|
142
|
-
|
136
|
+
[gx, gy]
|
137
|
+
when :abs
|
138
|
+
grad * tf.sign(x)
|
139
|
+
when :log
|
140
|
+
grad * tf.reciprocal(x)
|
141
|
+
when :tanh
|
142
|
+
i_op(:tanh_grad, x) * grad
|
143
|
+
when :cos
|
144
|
+
-grad * tf.sin(x)
|
145
|
+
when :max
|
146
|
+
x_mask = tf.where(x > y, tf.ones_like(x), tf.zeros_like(y))
|
147
|
+
y_mask = tf.where(x < y, tf.zeros_like(x), tf.ones_like(y))
|
148
|
+
[x_mask * grad, y_mask * grad]
|
149
|
+
when :tan
|
150
|
+
secx = tf.reciprocal(tf.cos(x))
|
151
|
+
secx2 = tf.square(secx)
|
152
|
+
grad * secx2
|
153
|
+
when :negate
|
154
|
+
-grad
|
155
|
+
when :exp
|
156
|
+
grad * node
|
157
|
+
when :identity
|
158
|
+
grad
|
159
|
+
when :sum
|
160
|
+
_sum_grad(x, y, grad)
|
161
|
+
when :reciprocal
|
162
|
+
-grad * (tf.constant(1, dtype: x.dtype) / x**2)
|
163
|
+
when :sqrt
|
164
|
+
tf.constant(1, dtype: x.dtype) / (tf.constant(2, dtype: x.dtype) * tf.sqrt(x)) * grad
|
165
|
+
when :stop_gradient
|
166
|
+
tf.zeros_like(grad)
|
167
|
+
when :square
|
168
|
+
y = tf.constant(2.0, dtype: x.dtype)
|
169
|
+
tf.multiply(grad, tf.multiply(x, y))
|
170
|
+
when :where
|
171
|
+
x_mask = i_op(:where, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
|
172
|
+
y_mask = i_op(:where, i_op(:zeros_like, x), i_op(:ones_like, y), pred: node.options[:pred])
|
173
|
+
[x_mask * grad, y_mask * grad]
|
174
|
+
when :cond
|
175
|
+
x_cond = i_op(:cond, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
|
176
|
+
y_cond = i_op(:cond, i_op(:zeros_like, x), i_op(:ones_like, x), pred: node.options[:pred])
|
177
|
+
[x_cond * grad, y_cond * grad]
|
178
|
+
when :mean
|
179
|
+
sum_grad = _sum_grad(x, y, grad)
|
180
|
+
input_shape = tf.shape(x)
|
181
|
+
output_shape = tf.shape(node)
|
182
|
+
factor = _safe_shape_div(tf.reduce_prod(input_shape), tf.reduce_prod(output_shape))
|
183
|
+
tf.div(sum_grad, tf.cast(factor, sum_grad.data_type))
|
184
|
+
when :log1p
|
185
|
+
grad * tf.reciprocal(i_cons(1, data_type: grad.data_type) + x)
|
186
|
+
when :sigmoid
|
187
|
+
i_op(:sigmoid_grad, x, grad)
|
188
|
+
when :zeros_like
|
189
|
+
# non differentiable
|
190
|
+
nil
|
143
191
|
else
|
144
|
-
raise "no derivative
|
192
|
+
raise "no derivative op for #{node.operation}"
|
145
193
|
end
|
146
|
-
elsif tensor.is_a?(TensorStream::Variable)
|
147
|
-
i_cons(0, constant_options)
|
148
|
-
elsif tensor.is_a?(TensorStream::Placeholder)
|
149
|
-
i_cons(0, constant_options)
|
150
|
-
else
|
151
|
-
i_cons(0, constant_options)
|
152
|
-
end.tap do |ops|
|
153
|
-
options[:graph].add_node!(gradient_program_name, ops) if options[:graph]
|
154
194
|
end
|
155
195
|
end
|
156
196
|
|
157
|
-
def self.
|
158
|
-
|
197
|
+
def self._broadcast_gradient_args(input_a, input_b)
|
198
|
+
[_op(:broadcast_gradient_args, input_b, input_a), _op(:broadcast_gradient_args, input_a, input_b)]
|
199
|
+
end
|
159
200
|
|
160
|
-
|
161
|
-
|
162
|
-
tensor.items[0]
|
163
|
-
else
|
164
|
-
tensor
|
165
|
-
end
|
201
|
+
def self._broadcast_transform(input_a, input_b)
|
202
|
+
_op(:broadcast_transform, input_a, input_b)
|
166
203
|
end
|
167
204
|
|
168
|
-
def self.
|
169
|
-
|
170
|
-
grad2 = derivative(tensor.items[1], wrt_dx, options)
|
171
|
-
elements1 = i_op(:reduce_prod, i_op(:shape, tensor.items[0]), data_type: :float32)
|
172
|
-
elements2 = i_op(:reduce_prod, i_op(:shape, tensor.items[1]), data_type: :float32)
|
173
|
-
multiplier = elements1 / elements2
|
174
|
-
_reduce_when_necessary(func.call(grad, grad2 * multiplier), wrt_dx)
|
205
|
+
def self._safe_shape_div(x, y)
|
206
|
+
x / tf.maximum(y, 1)
|
175
207
|
end
|
176
208
|
|
177
|
-
def self.
|
178
|
-
|
209
|
+
def self._sum_grad(x, y, grad)
|
210
|
+
tf.ones_like(x) * grad
|
211
|
+
end
|
212
|
+
|
213
|
+
def self._op_supports_broadcast?(node)
|
214
|
+
return true if %i[add sub div mul pow].include?(node.operation)
|
179
215
|
false
|
180
216
|
end
|
181
217
|
|
182
|
-
def self.
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
218
|
+
def self._min_or_max_grad(op, grad)
|
219
|
+
y = op
|
220
|
+
indicators = tf.cast(tf.equal(y, op.items[0]), grad.data_type)
|
221
|
+
num_selected = tf.reduce_sum(indicators, op.items[1])
|
222
|
+
_safe_shape_div(indicators, num_selected) * grad
|
187
223
|
end
|
188
224
|
|
189
|
-
def self.
|
190
|
-
|
225
|
+
def self._include?(arr, obj)
|
226
|
+
arr.each { |a| return true if a.equal?(obj) }
|
227
|
+
false
|
191
228
|
end
|
192
229
|
|
193
|
-
def self.
|
194
|
-
|
230
|
+
def self._shapes_fully_specified_and_equal(x, y)
|
231
|
+
return false if !_shape_full_specified(x) || !_shape_full_specified(y)
|
232
|
+
return false if x.shape.shape != y.shape.shape
|
233
|
+
|
234
|
+
true
|
195
235
|
end
|
196
236
|
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
237
|
+
def self._shape_full_specified(tensor)
|
238
|
+
return false if tensor.shape.nil?
|
239
|
+
return false if tensor.shape.shape.nil?
|
240
|
+
|
241
|
+
tensor.shape.shape.each { |s| return false if s.nil? }
|
242
|
+
true
|
201
243
|
end
|
202
244
|
end
|
203
|
-
end
|
245
|
+
end
|