tensor_stream 0.1.5 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. checksums.yaml +5 -5
  2. data/CHANGELOG.md +13 -0
  3. data/README.md +34 -0
  4. data/lib/tensor_stream.rb +7 -3
  5. data/lib/tensor_stream/control_flow.rb +1 -2
  6. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +44 -3
  7. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +9 -0
  8. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +70 -36
  9. data/lib/tensor_stream/graph.rb +15 -7
  10. data/lib/tensor_stream/graph_serializers/graphml.rb +183 -35
  11. data/lib/tensor_stream/graph_serializers/pbtext.rb +81 -14
  12. data/lib/tensor_stream/graph_serializers/serializer.rb +13 -0
  13. data/lib/tensor_stream/helpers/string_helper.rb +12 -0
  14. data/lib/tensor_stream/math_gradients.rb +203 -161
  15. data/lib/tensor_stream/operation.rb +30 -16
  16. data/lib/tensor_stream/ops.rb +29 -19
  17. data/lib/tensor_stream/placeholder.rb +2 -3
  18. data/lib/tensor_stream/session.rb +7 -13
  19. data/lib/tensor_stream/tensor.rb +22 -5
  20. data/lib/tensor_stream/tensor_shape.rb +2 -0
  21. data/lib/tensor_stream/trainer.rb +6 -1
  22. data/lib/tensor_stream/variable.rb +4 -3
  23. data/lib/tensor_stream/version.rb +1 -1
  24. data/samples/gradient_sample.graphml +1255 -0
  25. data/samples/linear_regression.rb +1 -1
  26. data/samples/logistic_regression.rb +9 -2
  27. data/tensor_stream.gemspec +1 -1
  28. data/test_samples/error.graphml +120 -0
  29. data/test_samples/gradient_sample.graphml +1255 -0
  30. data/{samples → test_samples}/iris.rb +0 -0
  31. data/{samples → test_samples}/raw_neural_net_sample.rb +0 -0
  32. data/{samples → test_samples}/test.py +2 -0
  33. data/test_samples/test2.py +41 -0
  34. metadata +41 -47
@@ -1,91 +1,239 @@
1
1
  module TensorStream
2
- class Graphml
2
+ class Graphml < Serializer
3
3
  def initialize
4
4
  end
5
5
 
6
- def serialize(session, tensor, filename)
6
+ def get_string(tensor, session = nil)
7
+ tensor = TensorStream.convert_to_tensor(tensor) unless tensor.is_a?(Tensor)
7
8
  @session = session
8
- @last_session_context = session.last_session_context
9
+ @name = tensor.name
10
+ @last_session_context = session ? session.last_session_context : {}
11
+ groups = {}
9
12
 
10
13
  arr_buf = []
11
14
  arr_buf << '<?xml version="1.0" encoding="UTF-8"?>'
12
- arr_buf << '<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
15
+ arr_buf << '<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:y="http://www.yworks.com/xml/graphml"
13
16
  xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">'
14
17
  arr_buf << '<key id="d0" for="node" attr.name="label" attr.type="string"/>'
15
18
  arr_buf << '<key id="d1" for="node" attr.name="formula" attr.type="string"/>'
16
19
  arr_buf << '<key id="d2" for="node" attr.name="color" attr.type="string"/>'
17
20
  arr_buf << '<key id="d3" for="node" attr.name="value" attr.type="string"/>'
21
+ arr_buf << '<key attr.name="description" attr.type="string" for="edge" id="d12"/>'
22
+ arr_buf << '<key for="edge" id="d13" yfiles.type="edgegraphics"/>'
23
+ arr_buf << '<key for="node" id="d9" yfiles.type="nodegraphics"/>'
18
24
  arr_buf << "<graph id=\"g_#{_gml_string(tensor.name)}\" edgedefault=\"directed\">"
19
25
  arr_buf << "<node id=\"out\">"
20
26
  arr_buf << "<data key=\"d0\">out</data>"
21
27
  arr_buf << "<data key=\"d2\">red</data>"
28
+ arr_buf << "<data key=\"d9\">"
29
+ arr_buf << "<y:ShapeNode>"
30
+ arr_buf << " <y:Fill color=\"#FF0000\" transparent=\"false\"/>"
31
+ arr_buf << " <y:NodeLabel alignment=\"center\">out</y:NodeLabel>"
32
+ arr_buf << "</y:ShapeNode>"
33
+ arr_buf << "</data>"
22
34
  arr_buf << "</node>"
23
- to_graph_ml(tensor, arr_buf)
24
- arr_buf << "<edge source=\"#{_gml_string(tensor.name)}\" target=\"out\"/>"
35
+
36
+ to_graph_ml(tensor, arr_buf, {}, groups)
37
+ #dump groups
38
+ groups.each do |k, g|
39
+ arr_buf << create_group(k, k, g)
40
+ end
41
+
42
+ output_edge(tensor, "out", arr_buf)
25
43
  arr_buf << "</graph>"
26
44
  arr_buf << "</graphml>"
27
- File.write(filename, arr_buf.join("\n"))
45
+ arr_buf.flatten.join("\n")
28
46
  end
29
47
 
30
48
  private
31
49
 
50
+ def add_to_group(groups, name, arr_buf)
51
+ name_parts = name.split('/')
52
+ return false if name_parts.size < 2
53
+
54
+ prefix = name_parts.shift
55
+
56
+ ptr = find_or_create_group(prefix, groups)
57
+
58
+ Kernel.loop do
59
+ next_group = ptr[:group]
60
+ ptr = find_or_create_group(prefix, next_group)
61
+ break if name_parts.size < 2
62
+ prefix = name_parts.shift
63
+ end
64
+
65
+ ptr[:buf] << arr_buf
66
+ true
67
+ end
68
+
69
+ def find_or_create_group(prefix, groups)
70
+ if !groups[prefix]
71
+ groups[prefix] = { buf: [], group: {} }
72
+ end
73
+
74
+ return groups[prefix]
75
+ end
76
+
77
+ def create_group(id, title, group)
78
+ arr_buf = []
79
+ arr_buf << "<node id=\"#{id}\" yfiles.foldertype=\"group\">"
80
+ arr_buf << '<data key="d9">'
81
+ arr_buf << '<y:ProxyAutoBoundsNode>'
82
+ arr_buf << '<y:Realizers active="0">'
83
+ arr_buf << '<y:GroupNode>'
84
+ arr_buf << '<y:Fill color="#CAECFF84" transparent="false"/>'
85
+ arr_buf << '<y:BorderStyle color="#666699" type="dotted" width="1.0"/>'
86
+ arr_buf << '<y:NodeLabel alignment="right" autoSizePolicy="node_width" backgroundColor="#99CCFF" borderDistance="0.0" fontFamily="Dialog" fontSize="15" fontStyle="plain" hasLineColor="false" height="21.4609375" horizontalTextPosition="center" iconTextGap="4" modelName="internal" modelPosition="t" textColor="#000000" verticalTextPosition="bottom" visible="true" width="67.18603515625" x="-8.593017578125" y="0.0">'+ title + '</y:NodeLabel>'
87
+ arr_buf << '<y:Shape type="roundrectangle"/>'
88
+ arr_buf << '</y:GroupNode>'
89
+ arr_buf << '</y:Realizers>'
90
+ arr_buf << '</y:ProxyAutoBoundsNode>'
91
+ arr_buf << '</data>'
92
+ arr_buf << '<graph edgedefault="directed" id="n105:">'
93
+ arr_buf << group[:buf]
94
+ group[:group].each do |k, g|
95
+ arr_buf << create_group(k, k, g)
96
+ end
97
+ arr_buf << '</graph>'
98
+ arr_buf << '</node>'
99
+ arr_buf
100
+ end
101
+
32
102
  def _val(tensor)
33
- JSON.pretty_generate(@last_session_context[tensor.name])
103
+ # JSON.pretty_generate(@last_session_context[tensor.name])
104
+ @last_session_context[tensor.name]
34
105
  end
35
106
 
36
- def to_graph_ml(tensor, arr_buf = [], added = {}, _id = 0)
107
+ def to_graph_ml(tensor, arr_buf = [], added = {}, groups = {}, _id = 0)
37
108
  puts tensor.name
109
+ return unless tensor.is_a?(Operation)
110
+
38
111
  added[tensor.name] = true
39
- arr_buf << "<node id=\"#{_gml_string(tensor.name)}\">"
40
- arr_buf << "<data key=\"d0\">#{tensor.operation}</data>"
41
- arr_buf << "<data key=\"d1\">#{tensor.to_math(true, 1)}</data>"
42
- arr_buf << "<data key=\"d2\">blue</data>"
112
+ node_buf = []
113
+ node_buf << "<node id=\"#{_gml_string(tensor.name)}\">"
114
+ node_buf << "<data key=\"d0\">#{tensor.operation}</data>"
115
+ node_buf << "<data key=\"d1\">#{tensor.to_math(true, 1)}</data>"
116
+ node_buf << "<data key=\"d2\">blue</data>"
117
+
43
118
  if @last_session_context[tensor.name]
44
119
  arr_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
45
120
  end
46
- arr_buf << "</node>"
121
+ node_buf << "<data key=\"d9\">"
122
+ node_buf << "<y:ShapeNode>"
123
+ if tensor.internal?
124
+ node_buf << " <y:Fill color=\"#FFFF99\" transparent=\"false\"/>"
125
+ else
126
+ node_buf << " <y:Fill color=\"#99CC00\" transparent=\"false\"/>"
127
+ end
128
+ node_buf << " <y:NodeLabel alignment=\"center\">#{tensor.operation}</y:NodeLabel>"
129
+ node_buf << "</y:ShapeNode>"
130
+ node_buf << "</data>"
131
+ node_buf << "</node>"
132
+
133
+ if !add_to_group(groups, tensor.name, node_buf)
134
+ add_to_group(groups, "program/#{tensor.name}", node_buf)
135
+ end
47
136
 
48
137
  tensor.items.each do |item|
49
138
  next unless item
50
- next if _added[item.name]
139
+ next if added[item.name]
140
+
141
+ next to_graph_ml(item, arr_buf, added, groups) if item.is_a?(Operation)
51
142
 
52
- next to_graph_ml(item, arr_buf, added) if item.is_a?(Operation)
53
143
  added[item.name] = true
144
+ item_buf = []
54
145
  if item.is_a?(Variable)
55
- arr_buf << "<node id=\"#{_gml_string(item.name)}\">"
56
- arr_buf << "<data key=\"d0\">#{item.name}</data>"
57
- arr_buf << "<data key=\"d2\">green</data>"
146
+ item_buf << "<node id=\"#{_gml_string(item.name)}\">"
147
+ item_buf << "<data key=\"d0\">#{item.name}</data>"
148
+ item_buf << "<data key=\"d2\">green</data>"
58
149
  if @last_session_context[item.name]
59
- arr_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
150
+ item_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
60
151
  end
61
- arr_buf << "</node>"
152
+ item_buf << "<data key=\"d9\">"
153
+ item_buf << "<y:ShapeNode>"
154
+ item_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
155
+ item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
156
+ item_buf << "</y:ShapeNode>"
157
+ item_buf << "</data>"
158
+ item_buf << "</node>"
62
159
  elsif item.is_a?(Placeholder)
63
- arr_buf << "<node id=\"#{_gml_string(item.name)}\">"
64
- arr_buf << "<data key=\"d0\">#{item.name}</data>"
65
- arr_buf << "<data key=\"d2\">yellow</data>"
160
+ item_buf << "<node id=\"#{_gml_string(item.name)}\">"
161
+ item_buf << "<data key=\"d9\">"
162
+ item_buf << "<y:ShapeNode>"
163
+ item_buf << " <y:Fill color=\"#FFCC00\" transparent=\"false\"/>"
164
+ item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
165
+ item_buf << "</y:ShapeNode>"
166
+ item_buf << "</data>"
66
167
  if @last_session_context[item.name]
67
- arr_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
168
+ item_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
68
169
  end
69
- arr_buf << "</node>"
70
- else
71
- arr_buf << "<node id=\"#{_gml_string(item.name)}\">"
72
- arr_buf << "<data key=\"d0\">#{item.name}</data>"
73
- arr_buf << "<data key=\"d2\">black</data>"
74
- if @last_session_context[item.name]
75
- arr_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
170
+ item_buf << "</node>"
171
+ elsif item.is_a?(Tensor)
172
+ item_buf << "<node id=\"#{_gml_string(item.name)}\">"
173
+ item_buf << "<data key=\"d0\">#{item.name}</data>"
174
+ item_buf << "<data key=\"d2\">black</data>"
175
+ item_buf << "<data key=\"d9\">"
176
+ item_buf << "<y:ShapeNode>"
177
+
178
+ if item.internal?
179
+ item_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
+ else
181
+ item_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
+ end
183
+
184
+
185
+ item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
186
+
187
+ item_buf << "</y:ShapeNode>"
188
+ item_buf << "</data>"
189
+ item_buf << "</node>"
190
+ end
191
+
192
+ if !add_to_group(groups, item.name, item_buf)
193
+ if item.is_a?(Variable)
194
+ add_to_group(groups, "variable/#{item.name}", item_buf)
195
+ else
196
+ add_to_group(groups, "program/#{item.name}", item_buf)
76
197
  end
77
- arr_buf << "</node>"
78
198
  end
79
199
  end
80
200
 
81
- tensor.items.each do |item|
201
+ tensor.items.each_with_index do |item, index|
82
202
  next unless item
83
- arr_buf << "<edge source=\"#{_gml_string(item.name)}\" target=\"#{_gml_string(tensor.name)}\"/>"
203
+ output_edge(item, tensor, arr_buf, index)
84
204
  end
85
205
  end
86
206
 
87
207
  def _gml_string(str)
88
208
  str.gsub('/','-')
89
209
  end
210
+
211
+ def output_edge(item, tensor, arr_buf, index = 0)
212
+ target_name = tensor.is_a?(Tensor) ? tensor.name : tensor
213
+ arr_buf << "<edge source=\"#{_gml_string(item.name)}\" target=\"#{_gml_string(target_name)}\">"
214
+ arr_buf << "<data key=\"d13\">"
215
+
216
+ arr_buf << "<y:PolyLineEdge>"
217
+ arr_buf << "<y:EdgeLabel >"
218
+ if !@last_session_context.empty?
219
+ arr_buf << "<![CDATA[ #{_val(item)} ]]>"
220
+ else
221
+ if item.shape.shape.nil?
222
+ arr_buf << "<![CDATA[ #{item.data_type.to_s} ? ]]>"
223
+ else
224
+ arr_buf << "<![CDATA[ #{item.data_type.to_s} #{item.shape.shape.empty? ? 'scalar' : item.shape.shape.to_json} ]]>"
225
+ end
226
+ end
227
+ arr_buf << "</y:EdgeLabel >"
228
+ arr_buf << "<y:Arrows source=\"none\" target=\"standard\"/>"
229
+ if index == 0
230
+ arr_buf << "<y:LineStyle color=\"#FF0000\" type=\"line\" width=\"1.0\"/>"
231
+ else
232
+ arr_buf << "<y:LineStyle color=\"#0000FF\" type=\"line\" width=\"1.0\"/>"
233
+ end
234
+ arr_buf << "</y:PolyLineEdge>"
235
+ arr_buf << "</data>"
236
+ arr_buf << "</edge>"
237
+ end
90
238
  end
91
239
  end
@@ -1,54 +1,121 @@
1
1
  module TensorStream
2
- class Pbtext
3
- def initialize
4
- end
5
-
6
- def serialize(session, filename, tensor)
7
- end
2
+ class Pbtext < TensorStream::Serializer
3
+ include TensorStream::StringHelper
4
+ include TensorStream::OpHelper
8
5
 
9
- def get_string(graph)
6
+ def get_string(tensor_or_graph, session = nil)
7
+ graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
10
8
  @lines = []
11
9
  graph.nodes.each do |k, node|
12
10
  @lines << "node {"
13
11
  @lines << " name: #{node.name.to_json}"
14
12
  if node.is_a?(TensorStream::Operation)
15
- @lines << " op: #{node.operation.to_json}"
13
+ @lines << " op: #{camelize(node.operation.to_s).to_json}"
16
14
  node.items.each do |input|
17
15
  next unless input
18
16
  @lines << " input: #{input.name.to_json}"
19
17
  end
20
18
  # type
21
- pb_attr('T', sym_to_protobuf_type(node.data_type))
19
+ pb_attr('T', "dtype: #{sym_to_protobuf_type(node.data_type)}")
20
+ process_options(node)
22
21
  elsif node.is_a?(TensorStream::Tensor) && node.is_const
23
22
  @lines << " op: \"Const\""
24
23
  # type
25
- pb_attr('T', sym_to_protobuf_type(node.data_type))
24
+ pb_attr('T', "dtype: #{sym_to_protobuf_type(node.data_type)}")
26
25
  pb_attr('value', tensor_value(node))
26
+ elsif node.is_a?(TensorStream::Variable)
27
+ @lines << " op: \"VariableV2\""
28
+ pb_attr('T', "dtype: #{sym_to_protobuf_type(node.data_type)}")
29
+ pb_attr('shape', shape_buf(node, 'shape'))
30
+ process_options(node)
27
31
  end
28
32
  @lines << "}"
29
33
  end
30
- @lines.join("\n")
34
+ @lines << "versions {"
35
+ @lines << " producer: 26"
36
+ @lines << "}"
37
+ @lines.flatten.join("\n")
31
38
  end
32
39
 
33
40
  private
34
41
 
42
+ def process_options(node)
43
+ node.options.each do |k, v|
44
+ next if %w[name].include?(k.to_s)
45
+ @lines << " attr {"
46
+ @lines << " key: \"#{k}\""
47
+ @lines << " value {"
48
+ @lines << " }"
49
+ @lines << " }"
50
+ end
51
+ end
52
+
53
+ def pack_arr_float(float_arr)
54
+ float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
55
+ end
56
+
57
+ def pack_arr_int(int_arr)
58
+ int_arr.flatten.pack('l*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
59
+ end
60
+
61
+ def shape_buf(tensor, shape_type = 'tensor_shape')
62
+ arr = []
63
+ arr << " #{shape_type} {"
64
+ tensor.shape.shape.each do |dim|
65
+ arr << " dim {"
66
+ arr << " size: #{dim}"
67
+ arr << " }"
68
+ end if tensor.shape.shape
69
+ arr << " }"
70
+ arr
71
+ end
35
72
  def tensor_value(tensor)
36
73
  arr = []
37
74
  arr << "tensor {"
38
75
  arr << " dtype: #{sym_to_protobuf_type(tensor.data_type)}"
39
- arr << " float_val: #{tensor.value}"
76
+
77
+ arr += shape_buf(tensor)
78
+
79
+ if tensor.rank > 0
80
+ if TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
81
+ packed = pack_arr_float(tensor.value)
82
+ arr << " tensor_content: \"#{packed}\""
83
+ elsif TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
84
+ packed = pack_arr_int(tensor.value)
85
+ arr << " tensor_content: \"#{packed}\""
86
+ elsif tensor.data_type == :string
87
+ tensor.value.each do |v|
88
+ arr << " string_val: #{v.to_json}"
89
+ end
90
+ else
91
+ arr << " tensor_content: #{tensor.value.flatten}"
92
+ end
93
+ else
94
+ val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
95
+ "int_val"
96
+ elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
97
+ "float_val"
98
+ elsif tensor.data_type == :string
99
+ "string_val"
100
+ else
101
+ "val"
102
+ end
103
+ arr << " #{val_type}: #{tensor.value.to_json}"
104
+ end
40
105
  arr << "}"
41
106
  arr
42
107
  end
43
108
 
44
109
  def sym_to_protobuf_type(type)
45
110
  case type
46
- when :int32
111
+ when :int32, :int
47
112
  "DT_INT32"
48
113
  when :float, :float32
49
114
  "DT_FLOAT"
115
+ when :string
116
+ "DT_STRING"
50
117
  else
51
- "DT_UNKNOWN"
118
+ "UKNOWN"
52
119
  end
53
120
  end
54
121
 
@@ -0,0 +1,13 @@
1
+ module TensorStream
2
+ class Serializer
3
+ def initialize
4
+ end
5
+
6
+ def serialize(filename, tensor, session = nil)
7
+ File.write(filename, get_string(tensor, session))
8
+ end
9
+
10
+ def get_string(tensor, session = nil)
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,12 @@
1
+ module TensorStream
2
+ module StringHelper
3
+ def camelize(string, uppercase_first_letter = true)
4
+ string = if uppercase_first_letter
5
+ string.sub(/^[a-z\d]*/) { $&.capitalize }
6
+ else
7
+ string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
8
+ end
9
+ string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::')
10
+ end
11
+ end
12
+ end
@@ -3,201 +3,243 @@ module TensorStream
3
3
  class MathGradients
4
4
  extend TensorStream::OpHelper
5
5
 
6
+ def self.tf
7
+ TensorStream
8
+ end
9
+
6
10
  def self.derivative(tensor, wrt_dx, options = {})
7
- gradient_program_name = "_grad_#{tensor.name}_#{wrt_dx.name}"
8
-
9
- return options[:graph].get_node(gradient_program_name) if options[:graph] && options[:graph].node_added?(gradient_program_name)
11
+ return i_op(:ones_like, tensor) if tensor.equal?(wrt_dx)
12
+ return i_op(:zeros_like, tensor) unless wrt_dx.consumers.include?(tensor.name)
10
13
 
11
- constant_options = { dtype: options[:dtype] }
12
- constant_options_1 = { dtype: options[:dtype] || tensor.data_type }
14
+ nodes_to_compute = wrt_dx.consumers.select do |t|
15
+ node = tensor.graph.nodes[t]
16
+ node.consumers.include?(tensor.name) || node.equal?(tensor)
17
+ end.compact + [wrt_dx.name]
13
18
 
14
- return i_op(:ones_like, wrt_dx, constant_options_1) if tensor.equal?(wrt_dx)
15
- return i_cons(0, constant_options) if options[:stop_gradients] && _include?(options[:stop_gradients], tensor)
19
+ grad = i_op(:ones_like, wrt_dx)
16
20
 
17
- if tensor.is_a?(Operation)
18
- grad = derivative(tensor.items[0], wrt_dx, options) if tensor.items[0]
19
- grad2 = derivative(tensor.items[1], wrt_dx, options) if tensor.items[1]
21
+ result = _propagate(grad, tensor, wrt_dx, nodes_to_compute, options[:stop_gradients] || [])
22
+ i_op(:truncate, result, tf.shape(wrt_dx))
23
+ end
20
24
 
21
- case tensor.operation
22
- when :zeros_like
23
- i_cons(0, constant_options)
24
- when :log1p
25
- grad * _op(:reciprocal, i_cons(1, constant_options_1) + tensor.items[0])
26
- when :max
27
- x_mask = i_op(:where, i_op(:ones_like, tensor.items[0]), i_op(:zeros_like, tensor.items[1]), pred: tensor.items[0] > tensor.items[1])
28
- y_mask = i_op(:where, i_op(:zeros_like, tensor.items[0]), i_op(:ones_like, tensor.items[1]), pred: tensor.items[0] < tensor.items[1])
29
- x_mask * grad + y_mask * grad2
30
- when :where
31
- x_mask = i_op(:where, i_op(:ones_like, tensor.items[0]), i_op(:zeros_like, tensor.items[1]), pred: tensor.options[:pred])
32
- y_mask = i_op(:where, i_op(:zeros_like, tensor.items[0]), i_op(:ones_like, tensor.items[1]), pred: tensor.options[:pred])
33
- x_mask * grad + y_mask * grad2
34
- when :cond
35
- i_op(:cond, grad, grad2, pred: tensor.options[:pred])
36
- when :identity, :print, :pad
37
- grad
38
- when :negate
39
- i_cons(-1, constant_options_1) * grad
40
- when :abs
41
- grad * i_op(:sign, _ds(tensor.items[0]))
42
- when :square
43
- i_cons(2, constant_options_1) * _ds(tensor.items[0]) * grad
44
- when :exp
45
- i_op(:exp, tensor.items[0]) * grad
46
- when :log
47
- (i_cons(1, constant_options_1) / _ds(tensor.items[0])) * grad
48
- when :tanh
49
- i_op(:mul, (i_cons(1, constant_options_1) - (i_op(:tanh, _ds(tensor.items[0]))**2)), grad, name: 'grad_tanh')
50
- when :tan
51
- (i_cons(1, constant_options_1) / (i_op(:cos, _ds(tensor.items[0]))**2)) * grad
52
- when :sin
53
- i_op(:mul, i_op(:cos, tensor.items[0]), grad, name: 'grad_sin')
54
- when :sqrt
55
- i_cons(1, constant_options_1) / (i_cons(2, constant_options_1) * i_op(:sqrt, _ds(tensor.items[0]))) * grad
56
- when :cos
57
- -i_op(:sin, tensor.items[0]) * grad
58
- when :add
59
- # rx = _op(:shape, tensor.items[0])
60
- # ry = _op(:shape, tensor.items[1])
25
+ def self._propagate(grad, tensor, stop_tensor, nodes_to_compute, stop_gradients = [])
26
+ return grad * i_op(:ones_like, stop_tensor) if stop_tensor.equal?(tensor)
27
+ return i_op(:zeros_like, stop_tensor) if stop_gradients && _include?(stop_gradients, tensor)
28
+ return i_op(:zeros_like, stop_tensor) unless tensor.is_a?(Operation)
29
+
30
+ computed_op = if _op_supports_broadcast?(tensor)
31
+ _compute_derivative(tensor, _broadcast_transform(tensor, grad)[1])
32
+ else
33
+ _compute_derivative(tensor, grad)
34
+ end
35
+
36
+ if computed_op.is_a?(Array)
37
+ partials = []
38
+ computed_op.each_with_index do |op_grad, index|
39
+ next if op_grad.nil?
40
+
41
+ if nodes_to_compute.include?(tensor.items[index].name)
42
+ partials << _propagate(op_grad, tensor.items[index], stop_tensor, nodes_to_compute, stop_gradients)
43
+ end
44
+ end
61
45
 
62
- # ones_a = _op(:ones_like, tensor.items[0])
63
- # ones_b = _op(:ones_like, tensor.items[1])
64
- # inputs = _broadcast_transform(grad * ones_a, grad2 * ones_b)
65
- # sx, sy = _broadcast_gradient_args(rx, ry)
46
+ partials.reduce(:+)
47
+ else
48
+ return tf.zeros_like(stop_tensor) if computed_op.nil?
49
+ _propagate(computed_op, tensor.items[0], stop_tensor, nodes_to_compute, stop_gradients)
50
+ end
51
+ end
66
52
 
67
- # keep_dims_x = _op(:rank, inputs[0]) == _op(:rank, tensor.items[0])
68
- # keep_dims_y = _op(:rank, inputs[1]) == _op(:rank, tensor.items[1])
53
+ def self._compute_derivative(node, grad)
54
+ node.graph.name_scope("#{node.name}_grad") do
55
+ x = node.items[0] if node.items[0]
56
+ y = node.items[1] if node.items[1]
69
57
 
70
- # add_x = _op(:reduce_sum, inputs[0], nil, axis: sy, keepdims: keep_dims_x)
71
- # add_y = _op(:reduce_sum, inputs[1], nil, axis: sx, keepdims: keep_dims_y)
72
- # _filtered_sum(add_x, add_y, wrt_dx)
73
- _grad_with_broadcast(tensor, wrt_dx, ->(a, b) { i_op(:add, a, b, name: 'grad_add') }, options)
74
- when :sub
75
- _grad_with_broadcast(tensor, wrt_dx, ->(a, b) { i_op(:sub, a, b, name: 'grad_sub') }, options)
76
- when :pow
77
- gx = _ds(tensor.items[1]) * (_ds(tensor.items[0])**(_ds(tensor.items[1]) - 1)) * grad
58
+ case node.operation
59
+ when :add
60
+ return [grad, grad] if _shapes_fully_specified_and_equal(x, y)
78
61
 
79
- log_x = i_op(:where, i_op(:log, tensor.items[0], nil, name: 'log_pow_grad'), i_op(:zeros_like, tensor.items[0]), pred: tensor.items[0] > 0)
80
- gy = _ds(tensor.items[0])**_ds(tensor.items[1]) * log_x * grad2
62
+ sx = tf.shape(x, name: 'add/shape_x')
63
+ sy = tf.shape(y, name: 'add/shape_y')
64
+ rx, ry = _broadcast_gradient_args(sx, sy)
65
+ keep_dims_x = tf.rank(x) == tf.rank(grad)
66
+ keep_dims_y = tf.rank(y) == tf.rank(grad)
81
67
 
82
- gx + gy
83
- when :div
84
- # apply the quotient rule
85
- gx = i_op(:div, grad, _ds(tensor.items[1]))
86
- gy = grad2 * i_op(:div, i_op(:div, -_ds(tensor.items[0]), _ds(tensor.items[1])), _ds(tensor.items[1]))
68
+ [tf.reduce_sum(grad, rx, name: 'add/reduce_sum_x', keepdims: keep_dims_x),
69
+ tf.reduce_sum(grad, ry, name: 'add/reduce_sum_y', keepdims: keep_dims_y)]
70
+ when :sub
71
+ return [grad, -grad] if _shapes_fully_specified_and_equal(x, y)
87
72
 
88
- _reduce_when_necessary(gx + gy, wrt_dx)
73
+ sx = tf.shape(x, name: 'sub/shape_x')
74
+ sy = tf.shape(y, name: 'sub/shape_y')
75
+ rx, ry = _broadcast_gradient_args(sx, sy)
76
+ [tf.reduce_sum(grad, rx), -tf.reduce_sum(grad, ry)]
89
77
  when :mul
90
- # apply the product rule
91
- rx = _op(:shape, tensor.items[0])
92
- ry = _op(:shape, tensor.items[1])
93
- sx, sy = _broadcast_gradient_args(rx, ry)
94
- inputs = _broadcast_transform(tensor.items[0], tensor.items[1])
95
- keep_dims_x = _op(:rank, inputs[0]) == _op(:rank, tensor.items[0])
96
- keep_dims_y = _op(:rank, inputs[1]) == _op(:rank, tensor.items[1])
97
-
98
- _filtered_sum(_op(:reduce_sum, grad * _ds(inputs[1]), nil, axis: sy, keepdims: keep_dims_x),
99
- _op(:reduce_sum, _ds(inputs[0]) * grad2, nil, axis: sx, keepdims: keep_dims_y), wrt_dx)
100
- when :reduce_mean
101
- input_size = i_op(:reduce_prod, i_op(:shape, tensor.items[0]))
102
- output_size = i_op(:reduce_prod, i_op(:shape, tensor))
103
- factor = input_size / output_size
104
-
105
- (grad / i_op(:cast, factor, data_type: grad.dtype))
106
- when :reduce_sum
107
- grad
108
- when :reciprocal
109
- -grad * (i_cons(1, constant_options_1) / _ds(tensor.items[0])**2)
110
- when :stop_gradient
111
- return i_cons(0, constant_options)
112
- when :matmul
113
- derivative_a = derivative(tensor.items[0], wrt_dx)
114
- derivative_b = derivative(tensor.items[1], wrt_dx)
115
-
116
- s0 = i_op(:shape, tensor.items[0])
117
- s1 = i_op(:shape, tensor.items[1])
78
+ sx = tf.shape(x)
79
+ sy = tf.shape(y)
80
+ rx, ry = _broadcast_gradient_args(sx, sy)
118
81
 
119
- identity_0 = i_op(:ones, [s0[0], s1[1]], nil, data_type: tensor.items[0].data_type)
120
- identity_1 = i_op(:ones, [s0[0], s1[1]], nil, data_type: tensor.items[1].data_type)
121
-
122
- matmul_da = i_op(:matmul, identity_0, tensor.items[1], transpose_b: true,
123
- pad_zeros: true,
124
- name: 'matrix_dx')
125
- matmul_db = i_op(:matmul, tensor.items[0], identity_1, transpose_a: true,
126
- pad_zeros: true,
127
- name: 'matrix_dy')
128
- # matmul_db = _op(:transpose, matmul_db, nil).first
129
-
130
- # begin_a = _op(:zeros, _op(:rank, matmul_db), nil, data_type: :int32, name: 'begin_a')
131
- # matmul_b_shape = _op(:shape, matmul_db)
132
- # end_a = [matmul_b_shape[0], 1]
82
+ [ tf.reduce_sum(tf.mul(grad, y), rx),
83
+ tf.reduce_sum(tf.mul(x, grad), ry)]
84
+ when :div
85
+ sx = i_op(:shape, x)
86
+ sy = i_op(:shape, y)
87
+ rx, ry = _broadcast_gradient_args(sx, sy)
133
88
 
134
- matmul_da = i_op(:cond, matmul_da[0], matmul_da, pred: _op(:rank, derivative_a) > 0)
89
+ [tf.reduce_sum(tf.div(grad, y), rx),
90
+ tf.reduce_sum(grad * tf.div(tf.div(-x, y), y),
91
+ ry)]
92
+ when :matmul
93
+ t_a = node.options[:transpose_a]
94
+ t_b = node.options[:transpose_b]
95
+
96
+ s0 = tf.shape(x)
97
+ s1 = tf.shape(y)
98
+
99
+ identity_0 = tf.ones([ s0[0], s1[1] ], dtype: x.data_type, name: 'matmul/identity0')
100
+ identity_1 = tf.ones([ s0[0], s1[1] ], dtype: y.data_type, name: 'matmul/identity1')
101
+
102
+ grad_a, grad_b = nil
103
+ if !t_a && !t_b
104
+ grad_a = tf.matmul(identity_0, y, transpose_b: true)
105
+ grad_b = tf.matmul(x, identity_1, transpose_a: true)
106
+ elsif !ta && tb
107
+ grad_a = tf.matmul(identity_0, y)
108
+ grad_b = tf.matmul(identity_1, x, transpose_a: true)
109
+ elsif t_a && !t_b
110
+ grad_a = tf.matmul(y, identity_0, transpose_b: true)
111
+ grad_b = tf.matmul(x, identity_1)
112
+ elsif t_a && t_b
113
+ grad_a = tf.matmul(y, identity_0, transpose_a: true, transpose_b: true)
114
+ grad_b = tf.matmul(identity_1, x, transpose_a: true, transpose_b: true)
115
+ end
116
+
117
+ grad_a = i_op(:mul, grad, grad_a, name: 'matmul/grad_a_norm_mul_da')
118
+ grad_b = i_op(:mul, grad, grad_b, name: 'matmul/grad_b_norm_mul_db')
119
+
120
+ [grad_a, grad_b]
121
+ when :sin
122
+ grad * tf.cos(x)
123
+ when :tanh
124
+ grad * i_op(:tanh_grad, x)
125
+ when :pow
126
+ z = node
127
+ sx = tf.shape(x)
128
+ sy = tf.shape(y)
129
+ rx, ry = _broadcast_gradient_args(sx, sy)
130
+ gx = tf.reshape(
131
+ tf.reduce_sum(grad * y * tf.pow(x, y - 1), rx), sx)
135
132
 
136
- # matmul_da = _op(:cond, matmul_da[0], matmul_da, pred: _op(:rank, derivative_a) > 0)
137
- norm_a = i_op(:mul, derivative_a, matmul_da, name: 'grad_a_norm_mul_da')
138
- norm_b = i_op(:mul, derivative_b, matmul_db, name: 'grad_b_norm_mul_db')
133
+ log_x = tf.where(x > 0, tf.log(x), tf.zeros_like(x))
134
+ gy = tf.reshape(tf.reduce_sum(grad * z * log_x, ry), sy)
139
135
 
140
- # norm_a = i_op(:cond, norm_a[0], norm_a, pred: i_op(:rank, matmul_da) > i_op(:rank, derivative_a))
141
- # norm_b = i_op(:cond, norm_b[0], norm_b, pred: i_op(:rank, matmul_db) > i_op(:rank, derivative_b))
142
- _filtered_sum(norm_a, norm_b, wrt_dx)
136
+ [gx, gy]
137
+ when :abs
138
+ grad * tf.sign(x)
139
+ when :log
140
+ grad * tf.reciprocal(x)
141
+ when :tanh
142
+ i_op(:tanh_grad, x) * grad
143
+ when :cos
144
+ -grad * tf.sin(x)
145
+ when :max
146
+ x_mask = tf.where(x > y, tf.ones_like(x), tf.zeros_like(y))
147
+ y_mask = tf.where(x < y, tf.zeros_like(x), tf.ones_like(y))
148
+ [x_mask * grad, y_mask * grad]
149
+ when :tan
150
+ secx = tf.reciprocal(tf.cos(x))
151
+ secx2 = tf.square(secx)
152
+ grad * secx2
153
+ when :negate
154
+ -grad
155
+ when :exp
156
+ grad * node
157
+ when :identity
158
+ grad
159
+ when :sum
160
+ _sum_grad(x, y, grad)
161
+ when :reciprocal
162
+ -grad * (tf.constant(1, dtype: x.dtype) / x**2)
163
+ when :sqrt
164
+ tf.constant(1, dtype: x.dtype) / (tf.constant(2, dtype: x.dtype) * tf.sqrt(x)) * grad
165
+ when :stop_gradient
166
+ tf.zeros_like(grad)
167
+ when :square
168
+ y = tf.constant(2.0, dtype: x.dtype)
169
+ tf.multiply(grad, tf.multiply(x, y))
170
+ when :where
171
+ x_mask = i_op(:where, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
172
+ y_mask = i_op(:where, i_op(:zeros_like, x), i_op(:ones_like, y), pred: node.options[:pred])
173
+ [x_mask * grad, y_mask * grad]
174
+ when :cond
175
+ x_cond = i_op(:cond, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
176
+ y_cond = i_op(:cond, i_op(:zeros_like, x), i_op(:ones_like, x), pred: node.options[:pred])
177
+ [x_cond * grad, y_cond * grad]
178
+ when :mean
179
+ sum_grad = _sum_grad(x, y, grad)
180
+ input_shape = tf.shape(x)
181
+ output_shape = tf.shape(node)
182
+ factor = _safe_shape_div(tf.reduce_prod(input_shape), tf.reduce_prod(output_shape))
183
+ tf.div(sum_grad, tf.cast(factor, sum_grad.data_type))
184
+ when :log1p
185
+ grad * tf.reciprocal(i_cons(1, data_type: grad.data_type) + x)
186
+ when :sigmoid
187
+ i_op(:sigmoid_grad, x, grad)
188
+ when :zeros_like
189
+ # non differentiable
190
+ nil
143
191
  else
144
- raise "no derivative implementation found for op #{tensor.operation}"
192
+ raise "no derivative op for #{node.operation}"
145
193
  end
146
- elsif tensor.is_a?(TensorStream::Variable)
147
- i_cons(0, constant_options)
148
- elsif tensor.is_a?(TensorStream::Placeholder)
149
- i_cons(0, constant_options)
150
- else
151
- i_cons(0, constant_options)
152
- end.tap do |ops|
153
- options[:graph].add_node!(gradient_program_name, ops) if options[:graph]
154
194
  end
155
195
  end
156
196
 
157
- def self._ds(tensor)
158
- return tensor unless tensor.is_a?(Operation)
197
+ def self._broadcast_gradient_args(input_a, input_b)
198
+ [_op(:broadcast_gradient_args, input_b, input_a), _op(:broadcast_gradient_args, input_a, input_b)]
199
+ end
159
200
 
160
- case tensor.operation
161
- when :reduce_sum
162
- tensor.items[0]
163
- else
164
- tensor
165
- end
201
+ def self._broadcast_transform(input_a, input_b)
202
+ _op(:broadcast_transform, input_a, input_b)
166
203
  end
167
204
 
168
- def self._grad_with_broadcast(tensor, wrt_dx, func, options)
169
- grad = derivative(tensor.items[0], wrt_dx, options)
170
- grad2 = derivative(tensor.items[1], wrt_dx, options)
171
- elements1 = i_op(:reduce_prod, i_op(:shape, tensor.items[0]), data_type: :float32)
172
- elements2 = i_op(:reduce_prod, i_op(:shape, tensor.items[1]), data_type: :float32)
173
- multiplier = elements1 / elements2
174
- _reduce_when_necessary(func.call(grad, grad2 * multiplier), wrt_dx)
205
+ def self._safe_shape_div(x, y)
206
+ x / tf.maximum(y, 1)
175
207
  end
176
208
 
177
- def self._include?(arr, obj)
178
- arr.each { |a| return true if a.equal?(obj) }
209
+ def self._sum_grad(x, y, grad)
210
+ tf.ones_like(x) * grad
211
+ end
212
+
213
+ def self._op_supports_broadcast?(node)
214
+ return true if %i[add sub div mul pow].include?(node.operation)
179
215
  false
180
216
  end
181
217
 
182
- def self._reduce_when_necessary(tensor, wrt_dx)
183
- rank = _op(:rank, tensor)
184
- dx_rank = _op(:rank, wrt_dx)
185
- reduced = _op(:reduce_sum, tensor, nil, axis: 0)
186
- _op(:cond, ->{ reduced }, tensor, pred: rank > dx_rank)
218
+ def self._min_or_max_grad(op, grad)
219
+ y = op
220
+ indicators = tf.cast(tf.equal(y, op.items[0]), grad.data_type)
221
+ num_selected = tf.reduce_sum(indicators, op.items[1])
222
+ _safe_shape_div(indicators, num_selected) * grad
187
223
  end
188
224
 
189
- def self._broadcast_gradient_args(input_a, input_b)
190
- [_op(:broadcast_gradient_args, input_a, input_b), _op(:broadcast_gradient_args, input_b, input_a)]
225
+ def self._include?(arr, obj)
226
+ arr.each { |a| return true if a.equal?(obj) }
227
+ false
191
228
  end
192
229
 
193
- def self._broadcast_transform(input_a, input_b)
194
- _op(:broadcast_transform, input_a, input_b)
230
+ def self._shapes_fully_specified_and_equal(x, y)
231
+ return false if !_shape_full_specified(x) || !_shape_full_specified(y)
232
+ return false if x.shape.shape != y.shape.shape
233
+
234
+ true
195
235
  end
196
236
 
197
- # filter out zero arrays
198
- def self._filtered_sum(input_a, input_b, wrt_dx)
199
- zero_vect = _op(:zeros_like, wrt_dx)
200
- (i_op(:cond, input_a, zero_vect, pred: i_op(:reduce_sum, input_a) != 0) + i_op(:cond, input_b, zero_vect, pred: i_op(:reduce_sum, input_b) != 0))
237
+ def self._shape_full_specified(tensor)
238
+ return false if tensor.shape.nil?
239
+ return false if tensor.shape.shape.nil?
240
+
241
+ tensor.shape.shape.each { |s| return false if s.nil? }
242
+ true
201
243
  end
202
244
  end
203
- end
245
+ end