tensor_stream 0.1.5 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (34) hide show
  1. checksums.yaml +5 -5
  2. data/CHANGELOG.md +13 -0
  3. data/README.md +34 -0
  4. data/lib/tensor_stream.rb +7 -3
  5. data/lib/tensor_stream/control_flow.rb +1 -2
  6. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +44 -3
  7. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +9 -0
  8. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +70 -36
  9. data/lib/tensor_stream/graph.rb +15 -7
  10. data/lib/tensor_stream/graph_serializers/graphml.rb +183 -35
  11. data/lib/tensor_stream/graph_serializers/pbtext.rb +81 -14
  12. data/lib/tensor_stream/graph_serializers/serializer.rb +13 -0
  13. data/lib/tensor_stream/helpers/string_helper.rb +12 -0
  14. data/lib/tensor_stream/math_gradients.rb +203 -161
  15. data/lib/tensor_stream/operation.rb +30 -16
  16. data/lib/tensor_stream/ops.rb +29 -19
  17. data/lib/tensor_stream/placeholder.rb +2 -3
  18. data/lib/tensor_stream/session.rb +7 -13
  19. data/lib/tensor_stream/tensor.rb +22 -5
  20. data/lib/tensor_stream/tensor_shape.rb +2 -0
  21. data/lib/tensor_stream/trainer.rb +6 -1
  22. data/lib/tensor_stream/variable.rb +4 -3
  23. data/lib/tensor_stream/version.rb +1 -1
  24. data/samples/gradient_sample.graphml +1255 -0
  25. data/samples/linear_regression.rb +1 -1
  26. data/samples/logistic_regression.rb +9 -2
  27. data/tensor_stream.gemspec +1 -1
  28. data/test_samples/error.graphml +120 -0
  29. data/test_samples/gradient_sample.graphml +1255 -0
  30. data/{samples → test_samples}/iris.rb +0 -0
  31. data/{samples → test_samples}/raw_neural_net_sample.rb +0 -0
  32. data/{samples → test_samples}/test.py +2 -0
  33. data/test_samples/test2.py +41 -0
  34. metadata +41 -47
@@ -1,91 +1,239 @@
1
1
  module TensorStream
2
- class Graphml
2
+ class Graphml < Serializer
3
3
  def initialize
4
4
  end
5
5
 
6
- def serialize(session, tensor, filename)
6
+ def get_string(tensor, session = nil)
7
+ tensor = TensorStream.convert_to_tensor(tensor) unless tensor.is_a?(Tensor)
7
8
  @session = session
8
- @last_session_context = session.last_session_context
9
+ @name = tensor.name
10
+ @last_session_context = session ? session.last_session_context : {}
11
+ groups = {}
9
12
 
10
13
  arr_buf = []
11
14
  arr_buf << '<?xml version="1.0" encoding="UTF-8"?>'
12
- arr_buf << '<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
15
+ arr_buf << '<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:y="http://www.yworks.com/xml/graphml"
13
16
  xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">'
14
17
  arr_buf << '<key id="d0" for="node" attr.name="label" attr.type="string"/>'
15
18
  arr_buf << '<key id="d1" for="node" attr.name="formula" attr.type="string"/>'
16
19
  arr_buf << '<key id="d2" for="node" attr.name="color" attr.type="string"/>'
17
20
  arr_buf << '<key id="d3" for="node" attr.name="value" attr.type="string"/>'
21
+ arr_buf << '<key attr.name="description" attr.type="string" for="edge" id="d12"/>'
22
+ arr_buf << '<key for="edge" id="d13" yfiles.type="edgegraphics"/>'
23
+ arr_buf << '<key for="node" id="d9" yfiles.type="nodegraphics"/>'
18
24
  arr_buf << "<graph id=\"g_#{_gml_string(tensor.name)}\" edgedefault=\"directed\">"
19
25
  arr_buf << "<node id=\"out\">"
20
26
  arr_buf << "<data key=\"d0\">out</data>"
21
27
  arr_buf << "<data key=\"d2\">red</data>"
28
+ arr_buf << "<data key=\"d9\">"
29
+ arr_buf << "<y:ShapeNode>"
30
+ arr_buf << " <y:Fill color=\"#FF0000\" transparent=\"false\"/>"
31
+ arr_buf << " <y:NodeLabel alignment=\"center\">out</y:NodeLabel>"
32
+ arr_buf << "</y:ShapeNode>"
33
+ arr_buf << "</data>"
22
34
  arr_buf << "</node>"
23
- to_graph_ml(tensor, arr_buf)
24
- arr_buf << "<edge source=\"#{_gml_string(tensor.name)}\" target=\"out\"/>"
35
+
36
+ to_graph_ml(tensor, arr_buf, {}, groups)
37
+ #dump groups
38
+ groups.each do |k, g|
39
+ arr_buf << create_group(k, k, g)
40
+ end
41
+
42
+ output_edge(tensor, "out", arr_buf)
25
43
  arr_buf << "</graph>"
26
44
  arr_buf << "</graphml>"
27
- File.write(filename, arr_buf.join("\n"))
45
+ arr_buf.flatten.join("\n")
28
46
  end
29
47
 
30
48
  private
31
49
 
50
+ def add_to_group(groups, name, arr_buf)
51
+ name_parts = name.split('/')
52
+ return false if name_parts.size < 2
53
+
54
+ prefix = name_parts.shift
55
+
56
+ ptr = find_or_create_group(prefix, groups)
57
+
58
+ Kernel.loop do
59
+ next_group = ptr[:group]
60
+ ptr = find_or_create_group(prefix, next_group)
61
+ break if name_parts.size < 2
62
+ prefix = name_parts.shift
63
+ end
64
+
65
+ ptr[:buf] << arr_buf
66
+ true
67
+ end
68
+
69
+ def find_or_create_group(prefix, groups)
70
+ if !groups[prefix]
71
+ groups[prefix] = { buf: [], group: {} }
72
+ end
73
+
74
+ return groups[prefix]
75
+ end
76
+
77
+ def create_group(id, title, group)
78
+ arr_buf = []
79
+ arr_buf << "<node id=\"#{id}\" yfiles.foldertype=\"group\">"
80
+ arr_buf << '<data key="d9">'
81
+ arr_buf << '<y:ProxyAutoBoundsNode>'
82
+ arr_buf << '<y:Realizers active="0">'
83
+ arr_buf << '<y:GroupNode>'
84
+ arr_buf << '<y:Fill color="#CAECFF84" transparent="false"/>'
85
+ arr_buf << '<y:BorderStyle color="#666699" type="dotted" width="1.0"/>'
86
+ arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">'+ title + '</y:NodeLabel>'
87
+ arr_buf << '<y:Shape type="roundrectangle"/>'
88
+ arr_buf << '</y:GroupNode>'
89
+ arr_buf << '</y:Realizers>'
90
+ arr_buf << '</y:ProxyAutoBoundsNode>'
91
+ arr_buf << '</data>'
92
+ arr_buf << '<graph edgedefault="directed" id="n105:">'
93
+ arr_buf << group[:buf]
94
+ group[:group].each do |k, g|
95
+ arr_buf << create_group(k, k, g)
96
+ end
97
+ arr_buf << '</graph>'
98
+ arr_buf << '</node>'
99
+ arr_buf
100
+ end
101
+
32
102
  def _val(tensor)
33
- JSON.pretty_generate(@last_session_context[tensor.name])
103
+ # JSON.pretty_generate(@last_session_context[tensor.name])
104
+ @last_session_context[tensor.name]
34
105
  end
35
106
 
36
- def to_graph_ml(tensor, arr_buf = [], added = {}, _id = 0)
107
+ def to_graph_ml(tensor, arr_buf = [], added = {}, groups = {}, _id = 0)
37
108
  puts tensor.name
109
+ return unless tensor.is_a?(Operation)
110
+
38
111
  added[tensor.name] = true
39
- arr_buf << "<node id=\"#{_gml_string(tensor.name)}\">"
40
- arr_buf << "<data key=\"d0\">#{tensor.operation}</data>"
41
- arr_buf << "<data key=\"d1\">#{tensor.to_math(true, 1)}</data>"
42
- arr_buf << "<data key=\"d2\">blue</data>"
112
+ node_buf = []
113
+ node_buf << "<node id=\"#{_gml_string(tensor.name)}\">"
114
+ node_buf << "<data key=\"d0\">#{tensor.operation}</data>"
115
+ node_buf << "<data key=\"d1\">#{tensor.to_math(true, 1)}</data>"
116
+ node_buf << "<data key=\"d2\">blue</data>"
117
+
43
118
  if @last_session_context[tensor.name]
44
119
  arr_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
45
120
  end
46
- arr_buf << "</node>"
121
+ node_buf << "<data key=\"d9\">"
122
+ node_buf << "<y:ShapeNode>"
123
+ if tensor.internal?
124
+ node_buf << " <y:Fill color=\"#FFFF99\" transparent=\"false\"/>"
125
+ else
126
+ node_buf << " <y:Fill color=\"#99CC00\" transparent=\"false\"/>"
127
+ end
128
+ node_buf << " <y:NodeLabel alignment=\"center\">#{tensor.operation}</y:NodeLabel>"
129
+ node_buf << "</y:ShapeNode>"
130
+ node_buf << "</data>"
131
+ node_buf << "</node>"
132
+
133
+ if !add_to_group(groups, tensor.name, node_buf)
134
+ add_to_group(groups, "program/#{tensor.name}", node_buf)
135
+ end
47
136
 
48
137
  tensor.items.each do |item|
49
138
  next unless item
50
- next if _added[item.name]
139
+ next if added[item.name]
140
+
141
+ next to_graph_ml(item, arr_buf, added, groups) if item.is_a?(Operation)
51
142
 
52
- next to_graph_ml(item, arr_buf, added) if item.is_a?(Operation)
53
143
  added[item.name] = true
144
+ item_buf = []
54
145
  if item.is_a?(Variable)
55
- arr_buf << "<node id=\"#{_gml_string(item.name)}\">"
56
- arr_buf << "<data key=\"d0\">#{item.name}</data>"
57
- arr_buf << "<data key=\"d2\">green</data>"
146
+ item_buf << "<node id=\"#{_gml_string(item.name)}\">"
147
+ item_buf << "<data key=\"d0\">#{item.name}</data>"
148
+ item_buf << "<data key=\"d2\">green</data>"
58
149
  if @last_session_context[item.name]
59
- arr_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
150
+ item_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
60
151
  end
61
- arr_buf << "</node>"
152
+ item_buf << "<data key=\"d9\">"
153
+ item_buf << "<y:ShapeNode>"
154
+ item_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
155
+ item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
156
+ item_buf << "</y:ShapeNode>"
157
+ item_buf << "</data>"
158
+ item_buf << "</node>"
62
159
  elsif item.is_a?(Placeholder)
63
- arr_buf << "<node id=\"#{_gml_string(item.name)}\">"
64
- arr_buf << "<data key=\"d0\">#{item.name}</data>"
65
- arr_buf << "<data key=\"d2\">yellow</data>"
160
+ item_buf << "<node id=\"#{_gml_string(item.name)}\">"
161
+ item_buf << "<data key=\"d9\">"
162
+ item_buf << "<y:ShapeNode>"
163
+ item_buf << " <y:Fill color=\"#FFCC00\" transparent=\"false\"/>"
164
+ item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
165
+ item_buf << "</y:ShapeNode>"
166
+ item_buf << "</data>"
66
167
  if @last_session_context[item.name]
67
- arr_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
168
+ item_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
68
169
  end
69
- arr_buf << "</node>"
70
- else
71
- arr_buf << "<node id=\"#{_gml_string(item.name)}\">"
72
- arr_buf << "<data key=\"d0\">#{item.name}</data>"
73
- arr_buf << "<data key=\"d2\">black</data>"
74
- if @last_session_context[item.name]
75
- arr_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
170
+ item_buf << "</node>"
171
+ elsif item.is_a?(Tensor)
172
+ item_buf << "<node id=\"#{_gml_string(item.name)}\">"
173
+ item_buf << "<data key=\"d0\">#{item.name}</data>"
174
+ item_buf << "<data key=\"d2\">black</data>"
175
+ item_buf << "<data key=\"d9\">"
176
+ item_buf << "<y:ShapeNode>"
177
+
178
+ if item.internal?
179
+ item_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
+ else
181
+ item_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
+ end
183
+
184
+
185
+ item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
186
+
187
+ item_buf << "</y:ShapeNode>"
188
+ item_buf << "</data>"
189
+ item_buf << "</node>"
190
+ end
191
+
192
+ if !add_to_group(groups, item.name, item_buf)
193
+ if item.is_a?(Variable)
194
+ add_to_group(groups, "variable/#{item.name}", item_buf)
195
+ else
196
+ add_to_group(groups, "program/#{item.name}", item_buf)
76
197
  end
77
- arr_buf << "</node>"
78
198
  end
79
199
  end
80
200
 
81
- tensor.items.each do |item|
201
+ tensor.items.each_with_index do |item, index|
82
202
  next unless item
83
- arr_buf << "<edge source=\"#{_gml_string(item.name)}\" target=\"#{_gml_string(tensor.name)}\"/>"
203
+ output_edge(item, tensor, arr_buf, index)
84
204
  end
85
205
  end
86
206
 
87
207
  def _gml_string(str)
88
208
  str.gsub('/','-')
89
209
  end
210
+
211
+ def output_edge(item, tensor, arr_buf, index = 0)
212
+ target_name = tensor.is_a?(Tensor) ? tensor.name : tensor
213
+ arr_buf << "<edge source=\"#{_gml_string(item.name)}\" target=\"#{_gml_string(target_name)}\">"
214
+ arr_buf << "<data key=\"d13\">"
215
+
216
+ arr_buf << "<y:PolyLineEdge>"
217
+ arr_buf << "<y:EdgeLabel >"
218
+ if !@last_session_context.empty?
219
+ arr_buf << "<![CDATA[ #{_val(item)} ]]>"
220
+ else
221
+ if item.shape.shape.nil?
222
+ arr_buf << "<![CDATA[ #{item.data_type.to_s} ? ]]>"
223
+ else
224
+ arr_buf << "<![CDATA[ #{item.data_type.to_s} #{item.shape.shape.empty? ? 'scalar' : item.shape.shape.to_json} ]]>"
225
+ end
226
+ end
227
+ arr_buf << "</y:EdgeLabel >"
228
+ arr_buf << "<y:Arrows source=\"none\" target=\"standard\"/>"
229
+ if index == 0
230
+ arr_buf << "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
231
+ else
232
+ arr_buf << "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
233
+ end
234
+ arr_buf << "</y:PolyLineEdge>"
235
+ arr_buf << "</data>"
236
+ arr_buf << "</edge>"
237
+ end
90
238
  end
91
239
  end
@@ -1,54 +1,121 @@
1
1
  module TensorStream
2
- class Pbtext
3
- def initialize
4
- end
5
-
6
- def serialize(session, filename, tensor)
7
- end
2
+ class Pbtext < TensorStream::Serializer
3
+ include TensorStream::StringHelper
4
+ include TensorStream::OpHelper
8
5
 
9
- def get_string(graph)
6
+ def get_string(tensor_or_graph, session = nil)
7
+ graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
10
8
  @lines = []
11
9
  graph.nodes.each do |k, node|
12
10
  @lines << "node {"
13
11
  @lines << " name: #{node.name.to_json}"
14
12
  if node.is_a?(TensorStream::Operation)
15
- @lines << " op: #{node.operation.to_json}"
13
+ @lines << " op: #{camelize(node.operation.to_s).to_json}"
16
14
  node.items.each do |input|
17
15
  next unless input
18
16
  @lines << " input: #{input.name.to_json}"
19
17
  end
20
18
  # type
21
- pb_attr('T', sym_to_protobuf_type(node.data_type))
19
+ pb_attr('T', "dtype: #{sym_to_protobuf_type(node.data_type)}")
20
+ process_options(node)
22
21
  elsif node.is_a?(TensorStream::Tensor) && node.is_const
23
22
  @lines << " op: \"Const\""
24
23
  # type
25
- pb_attr('T', sym_to_protobuf_type(node.data_type))
24
+ pb_attr('T', "dtype: #{sym_to_protobuf_type(node.data_type)}")
26
25
  pb_attr('value', tensor_value(node))
26
+ elsif node.is_a?(TensorStream::Variable)
27
+ @lines << " op: \"VariableV2\""
28
+ pb_attr('T', "dtype: #{sym_to_protobuf_type(node.data_type)}")
29
+ pb_attr('shape', shape_buf(node, 'shape'))
30
+ process_options(node)
27
31
  end
28
32
  @lines << "}"
29
33
  end
30
- @lines.join("\n")
34
+ @lines << "versions {"
35
+ @lines << " producer: 26"
36
+ @lines << "}"
37
+ @lines.flatten.join("\n")
31
38
  end
32
39
 
33
40
  private
34
41
 
42
+ def process_options(node)
43
+ node.options.each do |k, v|
44
+ next if %w[name].include?(k.to_s)
45
+ @lines << " attr {"
46
+ @lines << " key: \"#{k}\""
47
+ @lines << " value {"
48
+ @lines << " }"
49
+ @lines << " }"
50
+ end
51
+ end
52
+
53
+ def pack_arr_float(float_arr)
54
+ float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
55
+ end
56
+
57
+ def pack_arr_int(int_arr)
58
+ int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
59
+ end
60
+
61
+ def shape_buf(tensor, shape_type = 'tensor_shape')
62
+ arr = []
63
+ arr << " #{shape_type} {"
64
+ tensor.shape.shape.each do |dim|
65
+ arr << " dim {"
66
+ arr << " size: #{dim}"
67
+ arr << " }"
68
+ end if tensor.shape.shape
69
+ arr << " }"
70
+ arr
71
+ end
35
72
  def tensor_value(tensor)
36
73
  arr = []
37
74
  arr << "tensor {"
38
75
  arr << " dtype: #{sym_to_protobuf_type(tensor.data_type)}"
39
- arr << " float_val: #{tensor.value}"
76
+
77
+ arr += shape_buf(tensor)
78
+
79
+ if tensor.rank > 0
80
+ if TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
81
+ packed = pack_arr_float(tensor.value)
82
+ arr << " tensor_content: \"#{packed}\""
83
+ elsif TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
84
+ packed = pack_arr_int(tensor.value)
85
+ arr << " tensor_content: \"#{packed}\""
86
+ elsif tensor.data_type == :string
87
+ tensor.value.each do |v|
88
+ arr << " string_val: #{v.to_json}"
89
+ end
90
+ else
91
+ arr << " tensor_content: #{tensor.value.flatten}"
92
+ end
93
+ else
94
+ val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
95
+ "int_val"
96
+ elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
97
+ "float_val"
98
+ elsif tensor.data_type == :string
99
+ "string_val"
100
+ else
101
+ "val"
102
+ end
103
+ arr << " #{val_type}: #{tensor.value.to_json}"
104
+ end
40
105
  arr << "}"
41
106
  arr
42
107
  end
43
108
 
44
109
  def sym_to_protobuf_type(type)
45
110
  case type
46
- when :int32
111
+ when :int32, :int
47
112
  "DT_INT32"
48
113
  when :float, :float32
49
114
  "DT_FLOAT"
115
+ when :string
116
+ "DT_STRING"
50
117
  else
51
- "DT_UNKNOWN"
118
+ "UKNOWN"
52
119
  end
53
120
  end
54
121
 
@@ -0,0 +1,13 @@
1
+ module TensorStream
2
+ class Serializer
3
+ def initialize
4
+ end
5
+
6
+ def serialize(filename, tensor, session = nil)
7
+ File.write(filename, get_string(tensor, session))
8
+ end
9
+
10
+ def get_string(tensor, session = nil)
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,12 @@
1
+ module TensorStream
2
+ module StringHelper
3
+ def camelize(string, uppercase_first_letter = true)
4
+ string = if uppercase_first_letter
5
+ string.sub(/^[a-z\d]*/) { $&.capitalize }
6
+ else
7
+ string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
8
+ end
9
+ string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::')
10
+ end
11
+ end
12
+ end
@@ -3,201 +3,243 @@ module TensorStream
3
3
  class MathGradients
4
4
  extend TensorStream::OpHelper
5
5
 
6
+ def self.tf
7
+ TensorStream
8
+ end
9
+
6
10
  def self.derivative(tensor, wrt_dx, options = {})
7
- gradient_program_name = "_grad_#{tensor.name}_#{wrt_dx.name}"
8
-
9
- return options[:graph].get_node(gradient_program_name) if options[:graph] && options[:graph].node_added?(gradient_program_name)
11
+ return i_op(:ones_like, tensor) if tensor.equal?(wrt_dx)
12
+ return i_op(:zeros_like, tensor) unless wrt_dx.consumers.include?(tensor.name)
10
13
 
11
- constant_options = { dtype: options[:dtype] }
12
- constant_options_1 = { dtype: options[:dtype] || tensor.data_type }
14
+ nodes_to_compute = wrt_dx.consumers.select do |t|
15
+ node = tensor.graph.nodes[t]
16
+ node.consumers.include?(tensor.name) || node.equal?(tensor)
17
+ end.compact + [wrt_dx.name]
13
18
 
14
- return i_op(:ones_like, wrt_dx, constant_options_1) if tensor.equal?(wrt_dx)
15
- return i_cons(0, constant_options) if options[:stop_gradients] && _include?(options[:stop_gradients], tensor)
19
+ grad = i_op(:ones_like, wrt_dx)
16
20
 
17
- if tensor.is_a?(Operation)
18
- grad = derivative(tensor.items[0], wrt_dx, options) if tensor.items[0]
19
- grad2 = derivative(tensor.items[1], wrt_dx, options) if tensor.items[1]
21
+ result = _propagate(grad, tensor, wrt_dx, nodes_to_compute, options[:stop_gradients] || [])
22
+ i_op(:truncate, result, tf.shape(wrt_dx))
23
+ end
20
24
 
21
- case tensor.operation
22
- when :zeros_like
23
- i_cons(0, constant_options)
24
- when :log1p
25
- grad * _op(:reciprocal, i_cons(1, constant_options_1) + tensor.items[0])
26
- when :max
27
- x_mask = i_op(:where, i_op(:ones_like, tensor.items[0]), i_op(:zeros_like, tensor.items[1]), pred: tensor.items[0] > tensor.items[1])
28
- y_mask = i_op(:where, i_op(:zeros_like, tensor.items[0]), i_op(:ones_like, tensor.items[1]), pred: tensor.items[0] < tensor.items[1])
29
- x_mask * grad + y_mask * grad2
30
- when :where
31
- x_mask = i_op(:where, i_op(:ones_like, tensor.items[0]), i_op(:zeros_like, tensor.items[1]), pred: tensor.options[:pred])
32
- y_mask = i_op(:where, i_op(:zeros_like, tensor.items[0]), i_op(:ones_like, tensor.items[1]), pred: tensor.options[:pred])
33
- x_mask * grad + y_mask * grad2
34
- when :cond
35
- i_op(:cond, grad, grad2, pred: tensor.options[:pred])
36
- when :identity, :print, :pad
37
- grad
38
- when :negate
39
- i_cons(-1, constant_options_1) * grad
40
- when :abs
41
- grad * i_op(:sign, _ds(tensor.items[0]))
42
- when :square
43
- i_cons(2, constant_options_1) * _ds(tensor.items[0]) * grad
44
- when :exp
45
- i_op(:exp, tensor.items[0]) * grad
46
- when :log
47
- (i_cons(1, constant_options_1) / _ds(tensor.items[0])) * grad
48
- when :tanh
49
- i_op(:mul, (i_cons(1, constant_options_1) - (i_op(:tanh, _ds(tensor.items[0]))**2)), grad, name: 'grad_tanh')
50
- when :tan
51
- (i_cons(1, constant_options_1) / (i_op(:cos, _ds(tensor.items[0]))**2)) * grad
52
- when :sin
53
- i_op(:mul, i_op(:cos, tensor.items[0]), grad, name: 'grad_sin')
54
- when :sqrt
55
- i_cons(1, constant_options_1) / (i_cons(2, constant_options_1) * i_op(:sqrt, _ds(tensor.items[0]))) * grad
56
- when :cos
57
- -i_op(:sin, tensor.items[0]) * grad
58
- when :add
59
- # rx = _op(:shape, tensor.items[0])
60
- # ry = _op(:shape, tensor.items[1])
25
+ def self._propagate(grad, tensor, stop_tensor, nodes_to_compute, stop_gradients = [])
26
+ return grad * i_op(:ones_like, stop_tensor) if stop_tensor.equal?(tensor)
27
+ return i_op(:zeros_like, stop_tensor) if stop_gradients && _include?(stop_gradients, tensor)
28
+ return i_op(:zeros_like, stop_tensor) unless tensor.is_a?(Operation)
29
+
30
+ computed_op = if _op_supports_broadcast?(tensor)
31
+ _compute_derivative(tensor, _broadcast_transform(tensor, grad)[1])
32
+ else
33
+ _compute_derivative(tensor, grad)
34
+ end
35
+
36
+ if computed_op.is_a?(Array)
37
+ partials = []
38
+ computed_op.each_with_index do |op_grad, index|
39
+ next if op_grad.nil?
40
+
41
+ if nodes_to_compute.include?(tensor.items[index].name)
42
+ partials << _propagate(op_grad, tensor.items[index], stop_tensor, nodes_to_compute, stop_gradients)
43
+ end
44
+ end
61
45
 
62
- # ones_a = _op(:ones_like, tensor.items[0])
63
- # ones_b = _op(:ones_like, tensor.items[1])
64
- # inputs = _broadcast_transform(grad * ones_a, grad2 * ones_b)
65
- # sx, sy = _broadcast_gradient_args(rx, ry)
46
+ partials.reduce(:+)
47
+ else
48
+ return tf.zeros_like(stop_tensor) if computed_op.nil?
49
+ _propagate(computed_op, tensor.items[0], stop_tensor, nodes_to_compute, stop_gradients)
50
+ end
51
+ end
66
52
 
67
- # keep_dims_x = _op(:rank, inputs[0]) == _op(:rank, tensor.items[0])
68
- # keep_dims_y = _op(:rank, inputs[1]) == _op(:rank, tensor.items[1])
53
+ def self._compute_derivative(node, grad)
54
+ node.graph.name_scope("#{node.name}_grad") do
55
+ x = node.items[0] if node.items[0]
56
+ y = node.items[1] if node.items[1]
69
57
 
70
- # add_x = _op(:reduce_sum, inputs[0], nil, axis: sy, keepdims: keep_dims_x)
71
- # add_y = _op(:reduce_sum, inputs[1], nil, axis: sx, keepdims: keep_dims_y)
72
- # _filtered_sum(add_x, add_y, wrt_dx)
73
- _grad_with_broadcast(tensor, wrt_dx, ->(a, b) { i_op(:add, a, b, name: 'grad_add') }, options)
74
- when :sub
75
- _grad_with_broadcast(tensor, wrt_dx, ->(a, b) { i_op(:sub, a, b, name: 'grad_sub') }, options)
76
- when :pow
77
- gx = _ds(tensor.items[1]) * (_ds(tensor.items[0])**(_ds(tensor.items[1]) - 1)) * grad
58
+ case node.operation
59
+ when :add
60
+ return [grad, grad] if _shapes_fully_specified_and_equal(x, y)
78
61
 
79
- log_x = i_op(:where, i_op(:log, tensor.items[0], nil, name: 'log_pow_grad'), i_op(:zeros_like, tensor.items[0]), pred: tensor.items[0] > 0)
80
- gy = _ds(tensor.items[0])**_ds(tensor.items[1]) * log_x * grad2
62
+ sx = tf.shape(x, name: 'add/shape_x')
63
+ sy = tf.shape(y, name: 'add/shape_y')
64
+ rx, ry = _broadcast_gradient_args(sx, sy)
65
+ keep_dims_x = tf.rank(x) == tf.rank(grad)
66
+ keep_dims_y = tf.rank(y) == tf.rank(grad)
81
67
 
82
- gx + gy
83
- when :div
84
- # apply the quotient rule
85
- gx = i_op(:div, grad, _ds(tensor.items[1]))
86
- gy = grad2 * i_op(:div, i_op(:div, -_ds(tensor.items[0]), _ds(tensor.items[1])), _ds(tensor.items[1]))
68
+ [tf.reduce_sum(grad, rx, name: 'add/reduce_sum_x', keepdims: keep_dims_x),
69
+ tf.reduce_sum(grad, ry, name: 'add/reduce_sum_y', keepdims: keep_dims_y)]
70
+ when :sub
71
+ return [grad, -grad] if _shapes_fully_specified_and_equal(x, y)
87
72
 
88
- _reduce_when_necessary(gx + gy, wrt_dx)
73
+ sx = tf.shape(x, name: 'sub/shape_x')
74
+ sy = tf.shape(y, name: 'sub/shape_y')
75
+ rx, ry = _broadcast_gradient_args(sx, sy)
76
+ [tf.reduce_sum(grad, rx), -tf.reduce_sum(grad, ry)]
89
77
  when :mul
90
- # apply the product rule
91
- rx = _op(:shape, tensor.items[0])
92
- ry = _op(:shape, tensor.items[1])
93
- sx, sy = _broadcast_gradient_args(rx, ry)
94
- inputs = _broadcast_transform(tensor.items[0], tensor.items[1])
95
- keep_dims_x = _op(:rank, inputs[0]) == _op(:rank, tensor.items[0])
96
- keep_dims_y = _op(:rank, inputs[1]) == _op(:rank, tensor.items[1])
97
-
98
- _filtered_sum(_op(:reduce_sum, grad * _ds(inputs[1]), nil, axis: sy, keepdims: keep_dims_x),
99
- _op(:reduce_sum, _ds(inputs[0]) * grad2, nil, axis: sx, keepdims: keep_dims_y), wrt_dx)
100
- when :reduce_mean
101
- input_size = i_op(:reduce_prod, i_op(:shape, tensor.items[0]))
102
- output_size = i_op(:reduce_prod, i_op(:shape, tensor))
103
- factor = input_size / output_size
104
-
105
- (grad / i_op(:cast, factor, data_type: grad.dtype))
106
- when :reduce_sum
107
- grad
108
- when :reciprocal
109
- -grad * (i_cons(1, constant_options_1) / _ds(tensor.items[0])**2)
110
- when :stop_gradient
111
- return i_cons(0, constant_options)
112
- when :matmul
113
- derivative_a = derivative(tensor.items[0], wrt_dx)
114
- derivative_b = derivative(tensor.items[1], wrt_dx)
115
-
116
- s0 = i_op(:shape, tensor.items[0])
117
- s1 = i_op(:shape, tensor.items[1])
78
+ sx = tf.shape(x)
79
+ sy = tf.shape(y)
80
+ rx, ry = _broadcast_gradient_args(sx, sy)
118
81
 
119
- identity_0 = i_op(:ones, [s0[0], s1[1]], nil, data_type: tensor.items[0].data_type)
120
- identity_1 = i_op(:ones, [s0[0], s1[1]], nil, data_type: tensor.items[1].data_type)
121
-
122
- matmul_da = i_op(:matmul, identity_0, tensor.items[1], transpose_b: true,
123
- pad_zeros: true,
124
- name: 'matrix_dx')
125
- matmul_db = i_op(:matmul, tensor.items[0], identity_1, transpose_a: true,
126
- pad_zeros: true,
127
- name: 'matrix_dy')
128
- # matmul_db = _op(:transpose, matmul_db, nil).first
129
-
130
- # begin_a = _op(:zeros, _op(:rank, matmul_db), nil, data_type: :int32, name: 'begin_a')
131
- # matmul_b_shape = _op(:shape, matmul_db)
132
- # end_a = [matmul_b_shape[0], 1]
82
+ [ tf.reduce_sum(tf.mul(grad, y), rx),
83
+ tf.reduce_sum(tf.mul(x, grad), ry)]
84
+ when :div
85
+ sx = i_op(:shape, x)
86
+ sy = i_op(:shape, y)
87
+ rx, ry = _broadcast_gradient_args(sx, sy)
133
88
 
134
- matmul_da = i_op(:cond, matmul_da[0], matmul_da, pred: _op(:rank, derivative_a) > 0)
89
+ [tf.reduce_sum(tf.div(grad, y), rx),
90
+ tf.reduce_sum(grad * tf.div(tf.div(-x, y), y),
91
+ ry)]
92
+ when :matmul
93
+ t_a = node.options[:transpose_a]
94
+ t_b = node.options[:transpose_b]
95
+
96
+ s0 = tf.shape(x)
97
+ s1 = tf.shape(y)
98
+
99
+ identity_0 = tf.ones([ s0[0], s1[1] ], dtype: x.data_type, name: 'matmul/identity0')
100
+ identity_1 = tf.ones([ s0[0], s1[1] ], dtype: y.data_type, name: 'matmul/identity1')
101
+
102
+ grad_a, grad_b = nil
103
+ if !t_a && !t_b
104
+ grad_a = tf.matmul(identity_0, y, transpose_b: true)
105
+ grad_b = tf.matmul(x, identity_1, transpose_a: true)
106
+ elsif !ta && tb
107
+ grad_a = tf.matmul(identity_0, y)
108
+ grad_b = tf.matmul(identity_1, x, transpose_a: true)
109
+ elsif t_a && !t_b
110
+ grad_a = tf.matmul(y, identity_0, transpose_b: true)
111
+ grad_b = tf.matmul(x, identity_1)
112
+ elsif t_a && t_b
113
+ grad_a = tf.matmul(y, identity_0, transpose_a: true, transpose_b: true)
114
+ grad_b = tf.matmul(identity_1, x, transpose_a: true, transpose_b: true)
115
+ end
116
+
117
+ grad_a = i_op(:mul, grad, grad_a, name: 'matmul/grad_a_norm_mul_da')
118
+ grad_b = i_op(:mul, grad, grad_b, name: 'matmul/grad_b_norm_mul_db')
119
+
120
+ [grad_a, grad_b]
121
+ when :sin
122
+ grad * tf.cos(x)
123
+ when :tanh
124
+ grad * i_op(:tanh_grad, x)
125
+ when :pow
126
+ z = node
127
+ sx = tf.shape(x)
128
+ sy = tf.shape(y)
129
+ rx, ry = _broadcast_gradient_args(sx, sy)
130
+ gx = tf.reshape(
131
+ tf.reduce_sum(grad * y * tf.pow(x, y - 1), rx), sx)
135
132
 
136
- # matmul_da = _op(:cond, matmul_da[0], matmul_da, pred: _op(:rank, derivative_a) > 0)
137
- norm_a = i_op(:mul, derivative_a, matmul_da, name: 'grad_a_norm_mul_da')
138
- norm_b = i_op(:mul, derivative_b, matmul_db, name: 'grad_b_norm_mul_db')
133
+ log_x = tf.where(x > 0, tf.log(x), tf.zeros_like(x))
134
+ gy = tf.reshape(tf.reduce_sum(grad * z * log_x, ry), sy)
139
135
 
140
- # norm_a = i_op(:cond, norm_a[0], norm_a, pred: i_op(:rank, matmul_da) > i_op(:rank, derivative_a))
141
- # norm_b = i_op(:cond, norm_b[0], norm_b, pred: i_op(:rank, matmul_db) > i_op(:rank, derivative_b))
142
- _filtered_sum(norm_a, norm_b, wrt_dx)
136
+ [gx, gy]
137
+ when :abs
138
+ grad * tf.sign(x)
139
+ when :log
140
+ grad * tf.reciprocal(x)
141
+ when :tanh
142
+ i_op(:tanh_grad, x) * grad
143
+ when :cos
144
+ -grad * tf.sin(x)
145
+ when :max
146
+ x_mask = tf.where(x > y, tf.ones_like(x), tf.zeros_like(y))
147
+ y_mask = tf.where(x < y, tf.zeros_like(x), tf.ones_like(y))
148
+ [x_mask * grad, y_mask * grad]
149
+ when :tan
150
+ secx = tf.reciprocal(tf.cos(x))
151
+ secx2 = tf.square(secx)
152
+ grad * secx2
153
+ when :negate
154
+ -grad
155
+ when :exp
156
+ grad * node
157
+ when :identity
158
+ grad
159
+ when :sum
160
+ _sum_grad(x, y, grad)
161
+ when :reciprocal
162
+ -grad * (tf.constant(1, dtype: x.dtype) / x**2)
163
+ when :sqrt
164
+ tf.constant(1, dtype: x.dtype) / (tf.constant(2, dtype: x.dtype) * tf.sqrt(x)) * grad
165
+ when :stop_gradient
166
+ tf.zeros_like(grad)
167
+ when :square
168
+ y = tf.constant(2.0, dtype: x.dtype)
169
+ tf.multiply(grad, tf.multiply(x, y))
170
+ when :where
171
+ x_mask = i_op(:where, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
172
+ y_mask = i_op(:where, i_op(:zeros_like, x), i_op(:ones_like, y), pred: node.options[:pred])
173
+ [x_mask * grad, y_mask * grad]
174
+ when :cond
175
+ x_cond = i_op(:cond, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
176
+ y_cond = i_op(:cond, i_op(:zeros_like, x), i_op(:ones_like, x), pred: node.options[:pred])
177
+ [x_cond * grad, y_cond * grad]
178
+ when :mean
179
+ sum_grad = _sum_grad(x, y, grad)
180
+ input_shape = tf.shape(x)
181
+ output_shape = tf.shape(node)
182
+ factor = _safe_shape_div(tf.reduce_prod(input_shape), tf.reduce_prod(output_shape))
183
+ tf.div(sum_grad, tf.cast(factor, sum_grad.data_type))
184
+ when :log1p
185
+ grad * tf.reciprocal(i_cons(1, data_type: grad.data_type) + x)
186
+ when :sigmoid
187
+ i_op(:sigmoid_grad, x, grad)
188
+ when :zeros_like
189
+ # non differentiable
190
+ nil
143
191
  else
144
- raise "no derivative implementation found for op #{tensor.operation}"
192
+ raise "no derivative op for #{node.operation}"
145
193
  end
146
- elsif tensor.is_a?(TensorStream::Variable)
147
- i_cons(0, constant_options)
148
- elsif tensor.is_a?(TensorStream::Placeholder)
149
- i_cons(0, constant_options)
150
- else
151
- i_cons(0, constant_options)
152
- end.tap do |ops|
153
- options[:graph].add_node!(gradient_program_name, ops) if options[:graph]
154
194
  end
155
195
  end
156
196
 
157
- def self._ds(tensor)
158
- return tensor unless tensor.is_a?(Operation)
197
+ def self._broadcast_gradient_args(input_a, input_b)
198
+ [_op(:broadcast_gradient_args, input_b, input_a), _op(:broadcast_gradient_args, input_a, input_b)]
199
+ end
159
200
 
160
- case tensor.operation
161
- when :reduce_sum
162
- tensor.items[0]
163
- else
164
- tensor
165
- end
201
+ def self._broadcast_transform(input_a, input_b)
202
+ _op(:broadcast_transform, input_a, input_b)
166
203
  end
167
204
 
168
- def self._grad_with_broadcast(tensor, wrt_dx, func, options)
169
- grad = derivative(tensor.items[0], wrt_dx, options)
170
- grad2 = derivative(tensor.items[1], wrt_dx, options)
171
- elements1 = i_op(:reduce_prod, i_op(:shape, tensor.items[0]), data_type: :float32)
172
- elements2 = i_op(:reduce_prod, i_op(:shape, tensor.items[1]), data_type: :float32)
173
- multiplier = elements1 / elements2
174
- _reduce_when_necessary(func.call(grad, grad2 * multiplier), wrt_dx)
205
+ def self._safe_shape_div(x, y)
206
+ x / tf.maximum(y, 1)
175
207
  end
176
208
 
177
- def self._include?(arr, obj)
178
- arr.each { |a| return true if a.equal?(obj) }
209
+ def self._sum_grad(x, y, grad)
210
+ tf.ones_like(x) * grad
211
+ end
212
+
213
+ def self._op_supports_broadcast?(node)
214
+ return true if %i[add sub div mul pow].include?(node.operation)
179
215
  false
180
216
  end
181
217
 
182
- def self._reduce_when_necessary(tensor, wrt_dx)
183
- rank = _op(:rank, tensor)
184
- dx_rank = _op(:rank, wrt_dx)
185
- reduced = _op(:reduce_sum, tensor, nil, axis: 0)
186
- _op(:cond, ->{ reduced }, tensor, pred: rank > dx_rank)
218
+ def self._min_or_max_grad(op, grad)
219
+ y = op
220
+ indicators = tf.cast(tf.equal(y, op.items[0]), grad.data_type)
221
+ num_selected = tf.reduce_sum(indicators, op.items[1])
222
+ _safe_shape_div(indicators, num_selected) * grad
187
223
  end
188
224
 
189
- def self._broadcast_gradient_args(input_a, input_b)
190
- [_op(:broadcast_gradient_args, input_a, input_b), _op(:broadcast_gradient_args, input_b, input_a)]
225
+ def self._include?(arr, obj)
226
+ arr.each { |a| return true if a.equal?(obj) }
227
+ false
191
228
  end
192
229
 
193
- def self._broadcast_transform(input_a, input_b)
194
- _op(:broadcast_transform, input_a, input_b)
230
+ def self._shapes_fully_specified_and_equal(x, y)
231
+ return false if !_shape_full_specified(x) || !_shape_full_specified(y)
232
+ return false if x.shape.shape != y.shape.shape
233
+
234
+ true
195
235
  end
196
236
 
197
- # filter out zero arrays
198
- def self._filtered_sum(input_a, input_b, wrt_dx)
199
- zero_vect = _op(:zeros_like, wrt_dx)
200
- (i_op(:cond, input_a, zero_vect, pred: i_op(:reduce_sum, input_a) != 0) + i_op(:cond, input_b, zero_vect, pred: i_op(:reduce_sum, input_b) != 0))
237
+ def self._shape_full_specified(tensor)
238
+ return false if tensor.shape.nil?
239
+ return false if tensor.shape.shape.nil?
240
+
241
+ tensor.shape.shape.each { |s| return false if s.nil? }
242
+ true
201
243
  end
202
244
  end
203
- end
245
+ end