tensor_stream 0.4.1 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +38 -17
  4. data/benchmark/benchmark.rb +16 -20
  5. data/lib/tensor_stream/control_flow.rb +3 -3
  6. data/lib/tensor_stream/debugging/debugging.rb +4 -4
  7. data/lib/tensor_stream/device.rb +5 -2
  8. data/lib/tensor_stream/evaluator/base_evaluator.rb +138 -0
  9. data/lib/tensor_stream/evaluator/buffer.rb +7 -2
  10. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/_bool_operand.cl +3 -3
  11. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/_operand.cl +0 -0
  12. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/abs.cl +0 -0
  13. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/add.cl +1 -1
  14. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/argmax.cl +0 -0
  15. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/argmin.cl +0 -0
  16. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/cast.cl +0 -0
  17. data/lib/tensor_stream/evaluator/opencl/kernels/cond.cl.erb +6 -0
  18. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/cos.cl +0 -0
  19. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/div.cl.erb +1 -1
  20. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/exp.cl +0 -0
  21. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/gemm.cl +0 -0
  22. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/log.cl +0 -0
  23. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/log1p.cl +0 -0
  24. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/max.cl +3 -3
  25. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/mul.cl +1 -1
  26. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/negate.cl +0 -0
  27. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/pow.cl +3 -3
  28. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/reciprocal.cl +0 -0
  29. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/round.cl +0 -0
  30. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sigmoid.cl +0 -0
  31. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sigmoid_grad.cl +3 -3
  32. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sign.cl +1 -1
  33. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sin.cl +0 -0
  34. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/softmax.cl +0 -0
  35. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/softmax_grad.cl +0 -0
  36. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sqrt.cl +0 -0
  37. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/square.cl +0 -0
  38. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sub.cl +1 -1
  39. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tan.cl +0 -0
  40. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tanh.cl +0 -0
  41. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tanh_grad.cl +0 -0
  42. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/where.cl +1 -1
  43. data/lib/tensor_stream/evaluator/{opencl_buffer.rb → opencl/opencl_buffer.rb} +1 -1
  44. data/lib/tensor_stream/evaluator/opencl/opencl_device.rb +5 -0
  45. data/lib/tensor_stream/evaluator/{opencl_evaluator.rb → opencl/opencl_evaluator.rb} +404 -452
  46. data/lib/tensor_stream/evaluator/{opencl_template_helper.rb → opencl/opencl_template_helper.rb} +6 -6
  47. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +21 -21
  48. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +492 -398
  49. data/lib/tensor_stream/graph.rb +21 -1
  50. data/lib/tensor_stream/graph_serializers/graphml.rb +59 -59
  51. data/lib/tensor_stream/graph_serializers/pbtext.rb +1 -1
  52. data/lib/tensor_stream/helpers/op_helper.rb +6 -2
  53. data/lib/tensor_stream/math_gradients.rb +7 -7
  54. data/lib/tensor_stream/operation.rb +100 -100
  55. data/lib/tensor_stream/session.rb +81 -8
  56. data/lib/tensor_stream/tensor.rb +7 -5
  57. data/lib/tensor_stream/utils.rb +32 -19
  58. data/lib/tensor_stream/version.rb +1 -1
  59. data/tensor_stream.gemspec +0 -1
  60. data/test_samples/raw_neural_net_sample.rb +7 -7
  61. metadata +41 -53
  62. data/lib/tensor_stream/evaluator/kernels/cond.cl.erb +0 -5
@@ -43,6 +43,19 @@ module TensorStream
43
43
  end
44
44
  end
45
45
 
46
+ ##
47
+ # Returns a context manager that specifies the default device to use.
48
+ def device(device_name)
49
+ Thread.current["ts_graph_#{object_id}"] ||= {}
50
+ Thread.current["ts_graph_#{object_id}"][:default_device] ||= []
51
+ Thread.current["ts_graph_#{object_id}"][:default_device] << device_name
52
+ begin
53
+ yield
54
+ ensure
55
+ Thread.current["ts_graph_#{object_id}"][:default_device].pop
56
+ end
57
+ end
58
+
46
59
  def self.get_default_graph
47
60
  Thread.current[:tensor_stream_current_graph] || create_default
48
61
  end
@@ -69,6 +82,7 @@ module TensorStream
69
82
  node.name
70
83
  end
71
84
 
85
+ node.device = get_device_scope
72
86
  @nodes[node.name] = node
73
87
  @constants[node.name] = node if node.is_const
74
88
  node.send(:propagate_outputs)
@@ -159,11 +173,17 @@ module TensorStream
159
173
 
160
174
  def get_name_scope
161
175
  graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
162
- return nil if graph_thread_storage.nil?
176
+ return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil?
163
177
 
164
178
  graph_thread_storage[:current_scope].join('/')
165
179
  end
166
180
 
181
+ def get_device_scope
182
+ graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
183
+ return :default if graph_thread_storage.nil? || graph_thread_storage[:default_device].nil?
184
+ graph_thread_storage[:default_device].last
185
+ end
186
+
167
187
  def as_graph_def
168
188
  TensorStream::Pbtext.new.get_string(self)
169
189
  end
@@ -134,73 +134,73 @@ module TensorStream
134
134
  add_to_group(groups, "program/#{tensor.name}", node_buf)
135
135
  end
136
136
 
137
- tensor.items.each do |item|
138
- next unless item
139
- next if added[item.name]
140
-
141
- next to_graph_ml(item, arr_buf, added, groups) if item.is_a?(Operation)
142
-
143
- added[item.name] = true
144
- item_buf = []
145
- if item.is_a?(Variable)
146
- item_buf << "<node id=\"#{_gml_string(item.name)}\">"
147
- item_buf << "<data key=\"d0\">#{item.name}</data>"
148
- item_buf << "<data key=\"d2\">green</data>"
149
- if @last_session_context[item.name]
150
- item_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
137
+ tensor.inputs.each do |input|
138
+ next unless input
139
+ next if added[input.name]
140
+
141
+ next to_graph_ml(input, arr_buf, added, groups) if input.is_a?(Operation)
142
+
143
+ added[input.name] = true
144
+ input_buf = []
145
+ if input.is_a?(Variable)
146
+ input_buf << "<node id=\"#{_gml_string(input.name)}\">"
147
+ input_buf << "<data key=\"d0\">#{input.name}</data>"
148
+ input_buf << "<data key=\"d2\">green</data>"
149
+ if @last_session_context[input.name]
150
+ input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
151
151
  end
152
- item_buf << "<data key=\"d9\">"
153
- item_buf << "<y:ShapeNode>"
154
- item_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
155
- item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
156
- item_buf << "</y:ShapeNode>"
157
- item_buf << "</data>"
158
- item_buf << "</node>"
159
- elsif item.is_a?(Placeholder)
160
- item_buf << "<node id=\"#{_gml_string(item.name)}\">"
161
- item_buf << "<data key=\"d9\">"
162
- item_buf << "<y:ShapeNode>"
163
- item_buf << " <y:Fill color=\"#FFCC00\" transparent=\"false\"/>"
164
- item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
165
- item_buf << "</y:ShapeNode>"
166
- item_buf << "</data>"
167
- if @last_session_context[item.name]
168
- item_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
152
+ input_buf << "<data key=\"d9\">"
153
+ input_buf << "<y:ShapeNode>"
154
+ input_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
155
+ input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
156
+ input_buf << "</y:ShapeNode>"
157
+ input_buf << "</data>"
158
+ input_buf << "</node>"
159
+ elsif input.is_a?(Placeholder)
160
+ input_buf << "<node id=\"#{_gml_string(input.name)}\">"
161
+ input_buf << "<data key=\"d9\">"
162
+ input_buf << "<y:ShapeNode>"
163
+ input_buf << " <y:Fill color=\"#FFCC00\" transparent=\"false\"/>"
164
+ input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
165
+ input_buf << "</y:ShapeNode>"
166
+ input_buf << "</data>"
167
+ if @last_session_context[input.name]
168
+ input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
169
169
  end
170
- item_buf << "</node>"
171
- elsif item.is_a?(Tensor)
172
- item_buf << "<node id=\"#{_gml_string(item.name)}\">"
173
- item_buf << "<data key=\"d0\">#{item.name}</data>"
174
- item_buf << "<data key=\"d2\">black</data>"
175
- item_buf << "<data key=\"d9\">"
176
- item_buf << "<y:ShapeNode>"
177
-
178
- if item.internal?
179
- item_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
170
+ input_buf << "</node>"
171
+ elsif input.is_a?(Tensor)
172
+ input_buf << "<node id=\"#{_gml_string(input.name)}\">"
173
+ input_buf << "<data key=\"d0\">#{input.name}</data>"
174
+ input_buf << "<data key=\"d2\">black</data>"
175
+ input_buf << "<data key=\"d9\">"
176
+ input_buf << "<y:ShapeNode>"
177
+
178
+ if input.internal?
179
+ input_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
180
  else
181
- item_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
181
+ input_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
182
  end
183
183
 
184
184
 
185
- item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
185
+ input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
186
186
 
187
- item_buf << "</y:ShapeNode>"
188
- item_buf << "</data>"
189
- item_buf << "</node>"
187
+ input_buf << "</y:ShapeNode>"
188
+ input_buf << "</data>"
189
+ input_buf << "</node>"
190
190
  end
191
191
 
192
- if !add_to_group(groups, item.name, item_buf)
193
- if item.is_a?(Variable)
194
- add_to_group(groups, "variable/#{item.name}", item_buf)
192
+ if !add_to_group(groups, input.name, input_buf)
193
+ if input.is_a?(Variable)
194
+ add_to_group(groups, "variable/#{input.name}", input_buf)
195
195
  else
196
- add_to_group(groups, "program/#{item.name}", item_buf)
196
+ add_to_group(groups, "program/#{input.name}", input_buf)
197
197
  end
198
198
  end
199
199
  end
200
200
 
201
- tensor.items.each_with_index do |item, index|
202
- next unless item
203
- output_edge(item, tensor, arr_buf, index)
201
+ tensor.inputs.each_with_index do |input, index|
202
+ next unless input
203
+ output_edge(input, tensor, arr_buf, index)
204
204
  end
205
205
  end
206
206
 
@@ -208,20 +208,20 @@ module TensorStream
208
208
  str.gsub('/','-')
209
209
  end
210
210
 
211
- def output_edge(item, tensor, arr_buf, index = 0)
211
+ def output_edge(input, tensor, arr_buf, index = 0)
212
212
  target_name = tensor.is_a?(Tensor) ? tensor.name : tensor
213
- arr_buf << "<edge source=\"#{_gml_string(item.name)}\" target=\"#{_gml_string(target_name)}\">"
213
+ arr_buf << "<edge source=\"#{_gml_string(input.name)}\" target=\"#{_gml_string(target_name)}\">"
214
214
  arr_buf << "<data key=\"d13\">"
215
215
 
216
216
  arr_buf << "<y:PolyLineEdge>"
217
217
  arr_buf << "<y:EdgeLabel >"
218
218
  if !@last_session_context.empty?
219
- arr_buf << "<![CDATA[ #{_val(item)} ]]>"
219
+ arr_buf << "<![CDATA[ #{_val(input)} ]]>"
220
220
  else
221
- if item.shape.shape.nil?
222
- arr_buf << "<![CDATA[ #{item.data_type.to_s} ? ]]>"
221
+ if input.shape.shape.nil?
222
+ arr_buf << "<![CDATA[ #{input.data_type.to_s} ? ]]>"
223
223
  else
224
- arr_buf << "<![CDATA[ #{item.data_type.to_s} #{item.shape.shape.empty? ? 'scalar' : item.shape.shape.to_json} ]]>"
224
+ arr_buf << "<![CDATA[ #{input.data_type.to_s} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
225
225
  end
226
226
  end
227
227
  arr_buf << "</y:EdgeLabel >"
@@ -11,7 +11,7 @@ module TensorStream
11
11
  @lines << " name: #{node.name.to_json}"
12
12
  if node.is_a?(TensorStream::Operation)
13
13
  @lines << " op: #{camelize(node.operation.to_s).to_json}"
14
- node.items.each do |input|
14
+ node.inputs.each do |input|
15
15
  next unless input
16
16
  @lines << " input: #{input.name.to_json}"
17
17
  end
@@ -33,8 +33,12 @@ module TensorStream
33
33
  arr
34
34
  end
35
35
 
36
- def dtype_eval(rank, value)
37
- dtype = Tensor.detect_type(value[0])
36
+ def dtype_eval(rank, value, data_type = nil)
37
+ dtype = if data_type.nil?
38
+ Tensor.detect_type(value[0])
39
+ else
40
+ data_type
41
+ end
38
42
 
39
43
  rank += 1 if dtype == :array
40
44
 
@@ -38,22 +38,22 @@ module TensorStream
38
38
  computed_op.each_with_index do |op_grad, index|
39
39
  next if op_grad.nil?
40
40
 
41
- if nodes_to_compute.include?(tensor.items[index].name)
42
- partials << _propagate(op_grad, tensor.items[index], stop_tensor, nodes_to_compute, stop_gradients)
41
+ if nodes_to_compute.include?(tensor.inputs[index].name)
42
+ partials << _propagate(op_grad, tensor.inputs[index], stop_tensor, nodes_to_compute, stop_gradients)
43
43
  end
44
44
  end
45
45
 
46
46
  partials.reduce(:+)
47
47
  else
48
48
  return tf.zeros_like(stop_tensor) if computed_op.nil?
49
- _propagate(computed_op, tensor.items[0], stop_tensor, nodes_to_compute, stop_gradients)
49
+ _propagate(computed_op, tensor.inputs[0], stop_tensor, nodes_to_compute, stop_gradients)
50
50
  end
51
51
  end
52
52
 
53
53
  def self._compute_derivative(node, grad)
54
54
  node.graph.name_scope("#{node.name}_grad") do
55
- x = node.items[0] if node.items[0]
56
- y = node.items[1] if node.items[1]
55
+ x = node.inputs[0] if node.inputs[0]
56
+ y = node.inputs[1] if node.inputs[1]
57
57
 
58
58
  case node.operation
59
59
  when :add
@@ -221,8 +221,8 @@ module TensorStream
221
221
 
222
222
  def self._min_or_max_grad(op, grad)
223
223
  y = op
224
- indicators = tf.cast(tf.equal(y, op.items[0]), grad.data_type)
225
- num_selected = tf.reduce_sum(indicators, op.items[1])
224
+ indicators = tf.cast(tf.equal(y, op.inputs[0]), grad.data_type)
225
+ num_selected = tf.reduce_sum(indicators, op.inputs[1])
226
226
  _safe_shape_div(indicators, num_selected) * grad
227
227
  end
228
228
 
@@ -1,7 +1,7 @@
1
1
  module TensorStream
2
2
  # TensorStream class that defines an operation
3
3
  class Operation < Tensor
4
- attr_accessor :name, :operation, :items, :rank, :options
4
+ attr_accessor :name, :operation, :inputs, :rank, :options
5
5
  attr_reader :outputs
6
6
 
7
7
  def initialize(operation, input_a, input_b, options = {})
@@ -15,7 +15,7 @@ module TensorStream
15
15
 
16
16
  @options = options
17
17
 
18
- @items = [input_a, input_b].map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
18
+ @inputs = [input_a, input_b].map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
19
19
  @data_type = set_data_type(options[:data_type])
20
20
  @is_const = infer_const
21
21
  @shape = TensorShape.new(infer_shape)
@@ -30,16 +30,16 @@ module TensorStream
30
30
  {
31
31
  op: operation,
32
32
  name: name,
33
- operands: hashify_tensor(items)
33
+ operands: hashify_tensor(inputs)
34
34
  }
35
35
  end
36
36
 
37
37
  def self.empty_matrix?(input)
38
38
  if input.is_a?(Array)
39
- input.each do |item|
40
- if item.is_a?(Array)
41
- return false unless empty_matrix?(item)
42
- elsif item != 0 || item != 0.0
39
+ input.each do |input|
40
+ if input.is_a?(Array)
41
+ return false unless empty_matrix?(input)
42
+ elsif input != 0 || input != 0.0
43
43
  return false
44
44
  end
45
45
  end
@@ -54,7 +54,7 @@ module TensorStream
54
54
  when :random_normal, :random_uniform, :glorot_uniform, :print
55
55
  false
56
56
  else
57
- non_const = @items.compact.find { |item| !item.is_const }
57
+ non_const = @inputs.compact.find { |input| !input.is_const }
58
58
  non_const ? false : true
59
59
  end
60
60
  end
@@ -68,23 +68,23 @@ module TensorStream
68
68
  when :random_normal, :random_uniform, :glorot_uniform
69
69
  passed_data_type || :float32
70
70
  when :index
71
- if @items[0].is_a?(ControlFlow)
71
+ if @inputs[0].is_a?(ControlFlow)
72
72
 
73
- if @items[1].is_const
74
- @items[0].items[@items[1].value].data_type
73
+ if @inputs[1].is_const
74
+ @inputs[0].inputs[@inputs[1].value].data_type
75
75
  else
76
76
  :unknown
77
77
  end
78
78
  else
79
- @items[0].data_type
79
+ @inputs[0].data_type
80
80
  end
81
81
  else
82
82
  return passed_data_type if passed_data_type
83
83
 
84
- if @items[0]
85
- @items[0].data_type
86
- elsif @items[1]
87
- @items[1].data_type
84
+ if @inputs[0]
85
+ @inputs[0].data_type
86
+ elsif @inputs[1]
87
+ @inputs[1].data_type
88
88
  else
89
89
  :unknown
90
90
  end
@@ -94,119 +94,119 @@ module TensorStream
94
94
  def to_math(name_only = false, max_depth = 99, _cur_depth = 0)
95
95
  return @name if max_depth.zero?
96
96
 
97
- sub_item = auto_math(items[0], name_only, max_depth - 1, _cur_depth + 1)
98
- sub_item2 = auto_math(items[1], name_only, max_depth - 1, _cur_depth + 1) if items[1]
97
+ sub_input = auto_math(inputs[0], name_only, max_depth - 1, _cur_depth + 1)
98
+ sub_input2 = auto_math(inputs[1], name_only, max_depth - 1, _cur_depth + 1) if inputs[1]
99
99
 
100
100
  out = case operation
101
101
  when :argmax
102
- "argmax(#{sub_item},#{options[:axis]})"
102
+ "argmax(#{sub_input},#{options[:axis]})"
103
103
  when :negate
104
- "-#{sub_item}"
104
+ "-#{sub_input}"
105
105
  when :index
106
- "#{sub_item}[#{sub_item2}]"
106
+ "#{sub_input}[#{sub_input2}]"
107
107
  when :slice
108
- "#{sub_item}[#{sub_item2}]"
108
+ "#{sub_input}[#{sub_input2}]"
109
109
  when :assign_sub
110
- "(#{items[0] ? items[0].name : 'self'} -= #{auto_math(items[1], name_only, 1)})"
110
+ "(#{inputs[0] ? inputs[0].name : 'self'} -= #{auto_math(inputs[1], name_only, 1)})"
111
111
  when :assign_add
112
- "(#{items[0] ? items[0].name : 'self'} += #{auto_math(items[1], name_only, 1)})"
112
+ "(#{inputs[0] ? inputs[0].name : 'self'} += #{auto_math(inputs[1], name_only, 1)})"
113
113
  when :assign
114
- "(#{items[0] ? items[0].name : 'self'} = #{auto_math(items[1], name_only, 1)})"
114
+ "(#{inputs[0] ? inputs[0].name : 'self'} = #{auto_math(inputs[1], name_only, 1)})"
115
115
  when :sin, :cos, :tanh
116
- "#{operation}(#{sub_item})"
116
+ "#{operation}(#{sub_input})"
117
117
  when :add
118
- "(#{sub_item} + #{sub_item2})"
118
+ "(#{sub_input} + #{sub_input2})"
119
119
  when :sub
120
- "(#{sub_item} - #{sub_item2})"
120
+ "(#{sub_input} - #{sub_input2})"
121
121
  when :pow
122
- "(#{sub_item}^#{sub_item2})"
122
+ "(#{sub_input}^#{sub_input2})"
123
123
  when :div
124
- "(#{sub_item} / #{sub_item2})"
124
+ "(#{sub_input} / #{sub_input2})"
125
125
  when :mul
126
- if auto_math(items[0]) == 1
127
- sub_item2
128
- elsif auto_math(items[1]) == 1
129
- sub_item
126
+ if auto_math(inputs[0]) == 1
127
+ sub_input2
128
+ elsif auto_math(inputs[1]) == 1
129
+ sub_input
130
130
  else
131
- "(#{sub_item} * #{sub_item2})"
131
+ "(#{sub_input} * #{sub_input2})"
132
132
  end
133
133
  when :sum
134
- "sum(|#{sub_item}|, axis=#{sub_item2})"
134
+ "sum(|#{sub_input}|, axis=#{sub_input2})"
135
135
  when :mean
136
- "mean(|#{sub_item}|, axis=#{sub_item2})"
136
+ "mean(|#{sub_input}|, axis=#{sub_input2})"
137
137
  when :prod
138
- "prod(|#{sub_item}|, axis=#{sub_item2})"
138
+ "prod(|#{sub_input}|, axis=#{sub_input2})"
139
139
  when :gradients
140
- "gradient(#{sub_item})"
140
+ "gradient(#{sub_input})"
141
141
  when :stop_gradient
142
- sub_item
142
+ sub_input
143
143
  when :matmul
144
- "#{sub_item}.matmul(#{sub_item2})"
144
+ "#{sub_input}.matmul(#{sub_input2})"
145
145
  when :eye
146
- "eye(#{sub_item})"
146
+ "eye(#{sub_input})"
147
147
  when :transpose
148
- "transpose(#{sub_item})"
148
+ "transpose(#{sub_input})"
149
149
  when :shape
150
- "#{sub_item}.shape"
150
+ "#{sub_input}.shape"
151
151
  when :exp
152
- "e^#{sub_item})"
152
+ "e^#{sub_input})"
153
153
  when :ones
154
- "ones(#{sub_item})"
154
+ "ones(#{sub_input})"
155
155
  when :ones_like
156
- "ones_like(#{sub_item})"
156
+ "ones_like(#{sub_input})"
157
157
  when :flow_group
158
- "flow_group(#{items.collect { |i| auto_math(i, name_only, max_depth - 1, _cur_depth) }.join(',')})"
158
+ "flow_group(#{inputs.collect { |i| auto_math(i, name_only, max_depth - 1, _cur_depth) }.join(',')})"
159
159
  when :zeros
160
- "zeros(#{sub_item})"
160
+ "zeros(#{sub_input})"
161
161
  when :reshape
162
- "reshape(#{sub_item},#{sub_item2})"
162
+ "reshape(#{sub_input},#{sub_input2})"
163
163
  when :rank
164
- "#{sub_item}.rank"
164
+ "#{sub_input}.rank"
165
165
  when :cond
166
- "(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)} ? #{sub_item} : #{sub_item2})"
166
+ "(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)} ? #{sub_input} : #{sub_input2})"
167
167
  when :less
168
- "#{sub_item} < #{sub_item2}"
168
+ "#{sub_input} < #{sub_input2}"
169
169
  when :less_equal
170
- "#{sub_item} <= #{sub_item2}"
170
+ "#{sub_input} <= #{sub_input2}"
171
171
  when :greater
172
- "#{sub_item} > #{sub_item2}"
172
+ "#{sub_input} > #{sub_input2}"
173
173
  when :greater_equal
174
- "#{sub_item} >= #{sub_item2}"
174
+ "#{sub_input} >= #{sub_input2}"
175
175
  when :square
176
- "#{sub_item}\u00B2"
176
+ "#{sub_input}\u00B2"
177
177
  when :log
178
- "log(#{sub_item})"
178
+ "log(#{sub_input})"
179
179
  when :identity
180
- "identity(#{sub_item})"
180
+ "identity(#{sub_input})"
181
181
  when :print
182
- "print(#{sub_item})"
182
+ "print(#{sub_input})"
183
183
  when :pad
184
- "pad(#{sub_item},#{auto_math(options[:paddings])})"
184
+ "pad(#{sub_input},#{auto_math(options[:paddings])})"
185
185
  when :equal
186
- "#{sub_item} == #{sub_item2}"
186
+ "#{sub_input} == #{sub_input2}"
187
187
  when :not_equal
188
- "#{sub_item} != #{sub_item2}"
188
+ "#{sub_input} != #{sub_input2}"
189
189
  when :logical_and
190
- "#{sub_item} && #{sub_item2}"
190
+ "#{sub_input} && #{sub_input2}"
191
191
  when :sqrt
192
- "sqrt(#{sub_item})"
192
+ "sqrt(#{sub_input})"
193
193
  when :log1p
194
- "log1p(#{sub_item})"
194
+ "log1p(#{sub_input})"
195
195
  when :zeros_like
196
- "zeros_like(#{sub_item})"
196
+ "zeros_like(#{sub_input})"
197
197
  when :where
198
- "where(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)}, #{sub_item}, #{sub_item2})"
198
+ "where(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)}, #{sub_input}, #{sub_input2})"
199
199
  when :max
200
- "max(#{sub_item},#{sub_item2})"
200
+ "max(#{sub_input},#{sub_input2})"
201
201
  when :cast
202
- "cast(#{sub_item}, #{data_type})"
202
+ "cast(#{sub_input}, #{data_type})"
203
203
  when :broadcast_transform
204
- "broadcast_transform(#{sub_item},#{sub_item2})"
204
+ "broadcast_transform(#{sub_input},#{sub_input2})"
205
205
  when :broadcast_gradient_args
206
- "broadcast_transform(#{sub_item},#{sub_item2})"
206
+ "broadcast_transform(#{sub_input},#{sub_input2})"
207
207
  else
208
- "#{operation}(#{sub_item})" if sub_item
209
- "#{operation}(#{sub_item}, #{sub_item2})" if sub_item && sub_item2
208
+ "#{operation}(#{sub_input})" if sub_input
209
+ "#{operation}(#{sub_input}, #{sub_input2})" if sub_input && sub_input2
210
210
  end
211
211
  ["\n",(_cur_depth + 1).times.collect { ' ' }, out].flatten.join
212
212
  end
@@ -224,47 +224,47 @@ module TensorStream
224
224
  def infer_shape
225
225
  case operation
226
226
  when :index
227
- item_shape = items[0].shape.shape
228
- return nil if item_shape.nil?
229
- return item_shape[1, item_shape.size]
227
+ input_shape = inputs[0].shape.shape
228
+ return nil if input_shape.nil?
229
+ return input_shape[1, input_shape.size]
230
230
  when :mean, :prod, :sum
231
- return [] if items[1].nil?
232
- return nil if items[0].nil?
233
- item_shape = items[0].shape.shape
234
- return nil if item_shape.nil?
235
- return nil if items[1].is_a?(Tensor) && items[1].value.nil?
231
+ return [] if inputs[1].nil?
232
+ return nil if inputs[0].nil?
233
+ input_shape = inputs[0].shape.shape
234
+ return nil if input_shape.nil?
235
+ return nil if inputs[1].is_a?(Tensor) && inputs[1].value.nil?
236
236
 
237
- axis = items[1].is_a?(Tensor) ? items[1].value : items[1]
237
+ axis = inputs[1].is_a?(Tensor) ? inputs[1].value : inputs[1]
238
238
 
239
239
  axis = [ axis ] unless axis.is_a?(Array)
240
- return item_shape.each_with_index.map do |s, index|
240
+ return input_shape.each_with_index.map do |s, index|
241
241
  next nil if axis.include?(index)
242
242
  s
243
243
  end.compact
244
244
  when :reshape
245
- new_shape = items[1] && items[1].value ? items[1].value : nil
245
+ new_shape = inputs[1] && inputs[1].value ? inputs[1].value : nil
246
246
  return nil if new_shape.nil?
247
247
 
248
- item_shape = items[0].shape.shape
249
- return new_shape if item_shape.nil?
248
+ input_shape = inputs[0].shape.shape
249
+ return new_shape if input_shape.nil?
250
250
 
251
- return TensorShape.fix_inferred_elements(new_shape, item_shape.reduce(:*))
251
+ return TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
252
252
  when :flow_group
253
253
  return []
254
254
  when :zeros, :ones
255
- return items[0] ? items[0].value : options[:shape]
255
+ return inputs[0] ? inputs[0].value : options[:shape]
256
256
  when :zeros_like, :ones_like
257
- items[0].shape.shape
257
+ inputs[0].shape.shape
258
258
  when :shape
259
- return items[0].shape.shape ? [items[0].shape.shape.size] : nil
259
+ return inputs[0].shape.shape ? [inputs[0].shape.shape.size] : nil
260
260
  when :matmul
261
- shape1 = items[0].shape.shape.nil? ? nil : items[0].shape.shape[0]
262
- shape2 = items[1].shape.shape.nil? ? nil : items[1].shape.shape[1]
261
+ shape1 = inputs[0].shape.shape.nil? ? nil : inputs[0].shape.shape[0]
262
+ shape2 = inputs[1].shape.shape.nil? ? nil : inputs[1].shape.shape[1]
263
263
  return [shape1, shape2]
264
264
  else
265
- return items[0].shape.shape if items.size == 1
266
- if items.size == 2 && items[0] && items[1]
267
- return TensorShape.infer_shape(items[0].shape.shape, items[1].shape.shape)
265
+ return inputs[0].shape.shape if inputs.size == 1
266
+ if inputs.size == 2 && inputs[0] && inputs[1]
267
+ return TensorShape.infer_shape(inputs[0].shape.shape, inputs[1].shape.shape)
268
268
  end
269
269
  end
270
270
 
@@ -273,14 +273,14 @@ module TensorStream
273
273
 
274
274
  def propagate_consumer(consumer)
275
275
  super
276
- @items.compact.each do |item|
277
- item.send(:propagate_consumer, consumer) if item.name != name
276
+ @inputs.compact.each do |input|
277
+ input.send(:propagate_consumer, consumer) if input.name != name
278
278
  end
279
279
  end
280
280
 
281
281
  def propagate_outputs
282
- @items.compact.each do |item|
283
- item.send(:setup_output, self) if item.name != self.name
282
+ @inputs.compact.each do |input|
283
+ input.send(:setup_output, self) if input.name != self.name
284
284
  end
285
285
  end
286
286