tensor_stream 0.4.1 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (62) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +38 -17
  4. data/benchmark/benchmark.rb +16 -20
  5. data/lib/tensor_stream/control_flow.rb +3 -3
  6. data/lib/tensor_stream/debugging/debugging.rb +4 -4
  7. data/lib/tensor_stream/device.rb +5 -2
  8. data/lib/tensor_stream/evaluator/base_evaluator.rb +138 -0
  9. data/lib/tensor_stream/evaluator/buffer.rb +7 -2
  10. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/_bool_operand.cl +3 -3
  11. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/_operand.cl +0 -0
  12. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/abs.cl +0 -0
  13. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/add.cl +1 -1
  14. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/argmax.cl +0 -0
  15. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/argmin.cl +0 -0
  16. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/cast.cl +0 -0
  17. data/lib/tensor_stream/evaluator/opencl/kernels/cond.cl.erb +6 -0
  18. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/cos.cl +0 -0
  19. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/div.cl.erb +1 -1
  20. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/exp.cl +0 -0
  21. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/gemm.cl +0 -0
  22. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/log.cl +0 -0
  23. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/log1p.cl +0 -0
  24. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/max.cl +3 -3
  25. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/mul.cl +1 -1
  26. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/negate.cl +0 -0
  27. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/pow.cl +3 -3
  28. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/reciprocal.cl +0 -0
  29. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/round.cl +0 -0
  30. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sigmoid.cl +0 -0
  31. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sigmoid_grad.cl +3 -3
  32. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sign.cl +1 -1
  33. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sin.cl +0 -0
  34. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/softmax.cl +0 -0
  35. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/softmax_grad.cl +0 -0
  36. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sqrt.cl +0 -0
  37. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/square.cl +0 -0
  38. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sub.cl +1 -1
  39. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tan.cl +0 -0
  40. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tanh.cl +0 -0
  41. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tanh_grad.cl +0 -0
  42. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/where.cl +1 -1
  43. data/lib/tensor_stream/evaluator/{opencl_buffer.rb → opencl/opencl_buffer.rb} +1 -1
  44. data/lib/tensor_stream/evaluator/opencl/opencl_device.rb +5 -0
  45. data/lib/tensor_stream/evaluator/{opencl_evaluator.rb → opencl/opencl_evaluator.rb} +404 -452
  46. data/lib/tensor_stream/evaluator/{opencl_template_helper.rb → opencl/opencl_template_helper.rb} +6 -6
  47. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +21 -21
  48. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +492 -398
  49. data/lib/tensor_stream/graph.rb +21 -1
  50. data/lib/tensor_stream/graph_serializers/graphml.rb +59 -59
  51. data/lib/tensor_stream/graph_serializers/pbtext.rb +1 -1
  52. data/lib/tensor_stream/helpers/op_helper.rb +6 -2
  53. data/lib/tensor_stream/math_gradients.rb +7 -7
  54. data/lib/tensor_stream/operation.rb +100 -100
  55. data/lib/tensor_stream/session.rb +81 -8
  56. data/lib/tensor_stream/tensor.rb +7 -5
  57. data/lib/tensor_stream/utils.rb +32 -19
  58. data/lib/tensor_stream/version.rb +1 -1
  59. data/tensor_stream.gemspec +0 -1
  60. data/test_samples/raw_neural_net_sample.rb +7 -7
  61. metadata +41 -53
  62. data/lib/tensor_stream/evaluator/kernels/cond.cl.erb +0 -5
@@ -43,6 +43,19 @@ module TensorStream
43
43
  end
44
44
  end
45
45
 
46
+ ##
47
+ # Returns a context manager that specifies the default device to use.
48
+ def device(device_name)
49
+ Thread.current["ts_graph_#{object_id}"] ||= {}
50
+ Thread.current["ts_graph_#{object_id}"][:default_device] ||= []
51
+ Thread.current["ts_graph_#{object_id}"][:default_device] << device_name
52
+ begin
53
+ yield
54
+ ensure
55
+ Thread.current["ts_graph_#{object_id}"][:default_device].pop
56
+ end
57
+ end
58
+
46
59
  def self.get_default_graph
47
60
  Thread.current[:tensor_stream_current_graph] || create_default
48
61
  end
@@ -69,6 +82,7 @@ module TensorStream
69
82
  node.name
70
83
  end
71
84
 
85
+ node.device = get_device_scope
72
86
  @nodes[node.name] = node
73
87
  @constants[node.name] = node if node.is_const
74
88
  node.send(:propagate_outputs)
@@ -159,11 +173,17 @@ module TensorStream
159
173
 
160
174
  def get_name_scope
161
175
  graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
162
- return nil if graph_thread_storage.nil?
176
+ return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil?
163
177
 
164
178
  graph_thread_storage[:current_scope].join('/')
165
179
  end
166
180
 
181
+ def get_device_scope
182
+ graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
183
+ return :default if graph_thread_storage.nil? || graph_thread_storage[:default_device].nil?
184
+ graph_thread_storage[:default_device].last
185
+ end
186
+
167
187
  def as_graph_def
168
188
  TensorStream::Pbtext.new.get_string(self)
169
189
  end
@@ -134,73 +134,73 @@ module TensorStream
134
134
  add_to_group(groups, "program/#{tensor.name}", node_buf)
135
135
  end
136
136
 
137
- tensor.items.each do |item|
138
- next unless item
139
- next if added[item.name]
140
-
141
- next to_graph_ml(item, arr_buf, added, groups) if item.is_a?(Operation)
142
-
143
- added[item.name] = true
144
- item_buf = []
145
- if item.is_a?(Variable)
146
- item_buf << "<node id=\"#{_gml_string(item.name)}\">"
147
- item_buf << "<data key=\"d0\">#{item.name}</data>"
148
- item_buf << "<data key=\"d2\">green</data>"
149
- if @last_session_context[item.name]
150
- item_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
137
+ tensor.inputs.each do |input|
138
+ next unless input
139
+ next if added[input.name]
140
+
141
+ next to_graph_ml(input, arr_buf, added, groups) if input.is_a?(Operation)
142
+
143
+ added[input.name] = true
144
+ input_buf = []
145
+ if input.is_a?(Variable)
146
+ input_buf << "<node id=\"#{_gml_string(input.name)}\">"
147
+ input_buf << "<data key=\"d0\">#{input.name}</data>"
148
+ input_buf << "<data key=\"d2\">green</data>"
149
+ if @last_session_context[input.name]
150
+ input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
151
151
  end
152
- item_buf << "<data key=\"d9\">"
153
- item_buf << "<y:ShapeNode>"
154
- item_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
155
- item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
156
- item_buf << "</y:ShapeNode>"
157
- item_buf << "</data>"
158
- item_buf << "</node>"
159
- elsif item.is_a?(Placeholder)
160
- item_buf << "<node id=\"#{_gml_string(item.name)}\">"
161
- item_buf << "<data key=\"d9\">"
162
- item_buf << "<y:ShapeNode>"
163
- item_buf << " <y:Fill color=\"#FFCC00\" transparent=\"false\"/>"
164
- item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
165
- item_buf << "</y:ShapeNode>"
166
- item_buf << "</data>"
167
- if @last_session_context[item.name]
168
- item_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
152
+ input_buf << "<data key=\"d9\">"
153
+ input_buf << "<y:ShapeNode>"
154
+ input_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
155
+ input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
156
+ input_buf << "</y:ShapeNode>"
157
+ input_buf << "</data>"
158
+ input_buf << "</node>"
159
+ elsif input.is_a?(Placeholder)
160
+ input_buf << "<node id=\"#{_gml_string(input.name)}\">"
161
+ input_buf << "<data key=\"d9\">"
162
+ input_buf << "<y:ShapeNode>"
163
+ input_buf << " <y:Fill color=\"#FFCC00\" transparent=\"false\"/>"
164
+ input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
165
+ input_buf << "</y:ShapeNode>"
166
+ input_buf << "</data>"
167
+ if @last_session_context[input.name]
168
+ input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
169
169
  end
170
- item_buf << "</node>"
171
- elsif item.is_a?(Tensor)
172
- item_buf << "<node id=\"#{_gml_string(item.name)}\">"
173
- item_buf << "<data key=\"d0\">#{item.name}</data>"
174
- item_buf << "<data key=\"d2\">black</data>"
175
- item_buf << "<data key=\"d9\">"
176
- item_buf << "<y:ShapeNode>"
177
-
178
- if item.internal?
179
- item_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
170
+ input_buf << "</node>"
171
+ elsif input.is_a?(Tensor)
172
+ input_buf << "<node id=\"#{_gml_string(input.name)}\">"
173
+ input_buf << "<data key=\"d0\">#{input.name}</data>"
174
+ input_buf << "<data key=\"d2\">black</data>"
175
+ input_buf << "<data key=\"d9\">"
176
+ input_buf << "<y:ShapeNode>"
177
+
178
+ if input.internal?
179
+ input_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
180
180
  else
181
- item_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
181
+ input_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
182
182
  end
183
183
 
184
184
 
185
- item_buf << " <y:NodeLabel alignment=\"center\">#{item.name}</y:NodeLabel>"
185
+ input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
186
186
 
187
- item_buf << "</y:ShapeNode>"
188
- item_buf << "</data>"
189
- item_buf << "</node>"
187
+ input_buf << "</y:ShapeNode>"
188
+ input_buf << "</data>"
189
+ input_buf << "</node>"
190
190
  end
191
191
 
192
- if !add_to_group(groups, item.name, item_buf)
193
- if item.is_a?(Variable)
194
- add_to_group(groups, "variable/#{item.name}", item_buf)
192
+ if !add_to_group(groups, input.name, input_buf)
193
+ if input.is_a?(Variable)
194
+ add_to_group(groups, "variable/#{input.name}", input_buf)
195
195
  else
196
- add_to_group(groups, "program/#{item.name}", item_buf)
196
+ add_to_group(groups, "program/#{input.name}", input_buf)
197
197
  end
198
198
  end
199
199
  end
200
200
 
201
- tensor.items.each_with_index do |item, index|
202
- next unless item
203
- output_edge(item, tensor, arr_buf, index)
201
+ tensor.inputs.each_with_index do |input, index|
202
+ next unless input
203
+ output_edge(input, tensor, arr_buf, index)
204
204
  end
205
205
  end
206
206
 
@@ -208,20 +208,20 @@ module TensorStream
208
208
  str.gsub('/','-')
209
209
  end
210
210
 
211
- def output_edge(item, tensor, arr_buf, index = 0)
211
+ def output_edge(input, tensor, arr_buf, index = 0)
212
212
  target_name = tensor.is_a?(Tensor) ? tensor.name : tensor
213
- arr_buf << "<edge source=\"#{_gml_string(item.name)}\" target=\"#{_gml_string(target_name)}\">"
213
+ arr_buf << "<edge source=\"#{_gml_string(input.name)}\" target=\"#{_gml_string(target_name)}\">"
214
214
  arr_buf << "<data key=\"d13\">"
215
215
 
216
216
  arr_buf << "<y:PolyLineEdge>"
217
217
  arr_buf << "<y:EdgeLabel >"
218
218
  if !@last_session_context.empty?
219
- arr_buf << "<![CDATA[ #{_val(item)} ]]>"
219
+ arr_buf << "<![CDATA[ #{_val(input)} ]]>"
220
220
  else
221
- if item.shape.shape.nil?
222
- arr_buf << "<![CDATA[ #{item.data_type.to_s} ? ]]>"
221
+ if input.shape.shape.nil?
222
+ arr_buf << "<![CDATA[ #{input.data_type.to_s} ? ]]>"
223
223
  else
224
- arr_buf << "<![CDATA[ #{item.data_type.to_s} #{item.shape.shape.empty? ? 'scalar' : item.shape.shape.to_json} ]]>"
224
+ arr_buf << "<![CDATA[ #{input.data_type.to_s} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
225
225
  end
226
226
  end
227
227
  arr_buf << "</y:EdgeLabel >"
@@ -11,7 +11,7 @@ module TensorStream
11
11
  @lines << " name: #{node.name.to_json}"
12
12
  if node.is_a?(TensorStream::Operation)
13
13
  @lines << " op: #{camelize(node.operation.to_s).to_json}"
14
- node.items.each do |input|
14
+ node.inputs.each do |input|
15
15
  next unless input
16
16
  @lines << " input: #{input.name.to_json}"
17
17
  end
@@ -33,8 +33,12 @@ module TensorStream
33
33
  arr
34
34
  end
35
35
 
36
- def dtype_eval(rank, value)
37
- dtype = Tensor.detect_type(value[0])
36
+ def dtype_eval(rank, value, data_type = nil)
37
+ dtype = if data_type.nil?
38
+ Tensor.detect_type(value[0])
39
+ else
40
+ data_type
41
+ end
38
42
 
39
43
  rank += 1 if dtype == :array
40
44
 
@@ -38,22 +38,22 @@ module TensorStream
38
38
  computed_op.each_with_index do |op_grad, index|
39
39
  next if op_grad.nil?
40
40
 
41
- if nodes_to_compute.include?(tensor.items[index].name)
42
- partials << _propagate(op_grad, tensor.items[index], stop_tensor, nodes_to_compute, stop_gradients)
41
+ if nodes_to_compute.include?(tensor.inputs[index].name)
42
+ partials << _propagate(op_grad, tensor.inputs[index], stop_tensor, nodes_to_compute, stop_gradients)
43
43
  end
44
44
  end
45
45
 
46
46
  partials.reduce(:+)
47
47
  else
48
48
  return tf.zeros_like(stop_tensor) if computed_op.nil?
49
- _propagate(computed_op, tensor.items[0], stop_tensor, nodes_to_compute, stop_gradients)
49
+ _propagate(computed_op, tensor.inputs[0], stop_tensor, nodes_to_compute, stop_gradients)
50
50
  end
51
51
  end
52
52
 
53
53
  def self._compute_derivative(node, grad)
54
54
  node.graph.name_scope("#{node.name}_grad") do
55
- x = node.items[0] if node.items[0]
56
- y = node.items[1] if node.items[1]
55
+ x = node.inputs[0] if node.inputs[0]
56
+ y = node.inputs[1] if node.inputs[1]
57
57
 
58
58
  case node.operation
59
59
  when :add
@@ -221,8 +221,8 @@ module TensorStream
221
221
 
222
222
  def self._min_or_max_grad(op, grad)
223
223
  y = op
224
- indicators = tf.cast(tf.equal(y, op.items[0]), grad.data_type)
225
- num_selected = tf.reduce_sum(indicators, op.items[1])
224
+ indicators = tf.cast(tf.equal(y, op.inputs[0]), grad.data_type)
225
+ num_selected = tf.reduce_sum(indicators, op.inputs[1])
226
226
  _safe_shape_div(indicators, num_selected) * grad
227
227
  end
228
228
 
@@ -1,7 +1,7 @@
1
1
  module TensorStream
2
2
  # TensorStream class that defines an operation
3
3
  class Operation < Tensor
4
- attr_accessor :name, :operation, :items, :rank, :options
4
+ attr_accessor :name, :operation, :inputs, :rank, :options
5
5
  attr_reader :outputs
6
6
 
7
7
  def initialize(operation, input_a, input_b, options = {})
@@ -15,7 +15,7 @@ module TensorStream
15
15
 
16
16
  @options = options
17
17
 
18
- @items = [input_a, input_b].map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
18
+ @inputs = [input_a, input_b].map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
19
19
  @data_type = set_data_type(options[:data_type])
20
20
  @is_const = infer_const
21
21
  @shape = TensorShape.new(infer_shape)
@@ -30,16 +30,16 @@ module TensorStream
30
30
  {
31
31
  op: operation,
32
32
  name: name,
33
- operands: hashify_tensor(items)
33
+ operands: hashify_tensor(inputs)
34
34
  }
35
35
  end
36
36
 
37
37
  def self.empty_matrix?(input)
38
38
  if input.is_a?(Array)
39
- input.each do |item|
40
- if item.is_a?(Array)
41
- return false unless empty_matrix?(item)
42
- elsif item != 0 || item != 0.0
39
+ input.each do |input|
40
+ if input.is_a?(Array)
41
+ return false unless empty_matrix?(input)
42
+ elsif input != 0 || input != 0.0
43
43
  return false
44
44
  end
45
45
  end
@@ -54,7 +54,7 @@ module TensorStream
54
54
  when :random_normal, :random_uniform, :glorot_uniform, :print
55
55
  false
56
56
  else
57
- non_const = @items.compact.find { |item| !item.is_const }
57
+ non_const = @inputs.compact.find { |input| !input.is_const }
58
58
  non_const ? false : true
59
59
  end
60
60
  end
@@ -68,23 +68,23 @@ module TensorStream
68
68
  when :random_normal, :random_uniform, :glorot_uniform
69
69
  passed_data_type || :float32
70
70
  when :index
71
- if @items[0].is_a?(ControlFlow)
71
+ if @inputs[0].is_a?(ControlFlow)
72
72
 
73
- if @items[1].is_const
74
- @items[0].items[@items[1].value].data_type
73
+ if @inputs[1].is_const
74
+ @inputs[0].inputs[@inputs[1].value].data_type
75
75
  else
76
76
  :unknown
77
77
  end
78
78
  else
79
- @items[0].data_type
79
+ @inputs[0].data_type
80
80
  end
81
81
  else
82
82
  return passed_data_type if passed_data_type
83
83
 
84
- if @items[0]
85
- @items[0].data_type
86
- elsif @items[1]
87
- @items[1].data_type
84
+ if @inputs[0]
85
+ @inputs[0].data_type
86
+ elsif @inputs[1]
87
+ @inputs[1].data_type
88
88
  else
89
89
  :unknown
90
90
  end
@@ -94,119 +94,119 @@ module TensorStream
94
94
  def to_math(name_only = false, max_depth = 99, _cur_depth = 0)
95
95
  return @name if max_depth.zero?
96
96
 
97
- sub_item = auto_math(items[0], name_only, max_depth - 1, _cur_depth + 1)
98
- sub_item2 = auto_math(items[1], name_only, max_depth - 1, _cur_depth + 1) if items[1]
97
+ sub_input = auto_math(inputs[0], name_only, max_depth - 1, _cur_depth + 1)
98
+ sub_input2 = auto_math(inputs[1], name_only, max_depth - 1, _cur_depth + 1) if inputs[1]
99
99
 
100
100
  out = case operation
101
101
  when :argmax
102
- "argmax(#{sub_item},#{options[:axis]})"
102
+ "argmax(#{sub_input},#{options[:axis]})"
103
103
  when :negate
104
- "-#{sub_item}"
104
+ "-#{sub_input}"
105
105
  when :index
106
- "#{sub_item}[#{sub_item2}]"
106
+ "#{sub_input}[#{sub_input2}]"
107
107
  when :slice
108
- "#{sub_item}[#{sub_item2}]"
108
+ "#{sub_input}[#{sub_input2}]"
109
109
  when :assign_sub
110
- "(#{items[0] ? items[0].name : 'self'} -= #{auto_math(items[1], name_only, 1)})"
110
+ "(#{inputs[0] ? inputs[0].name : 'self'} -= #{auto_math(inputs[1], name_only, 1)})"
111
111
  when :assign_add
112
- "(#{items[0] ? items[0].name : 'self'} += #{auto_math(items[1], name_only, 1)})"
112
+ "(#{inputs[0] ? inputs[0].name : 'self'} += #{auto_math(inputs[1], name_only, 1)})"
113
113
  when :assign
114
- "(#{items[0] ? items[0].name : 'self'} = #{auto_math(items[1], name_only, 1)})"
114
+ "(#{inputs[0] ? inputs[0].name : 'self'} = #{auto_math(inputs[1], name_only, 1)})"
115
115
  when :sin, :cos, :tanh
116
- "#{operation}(#{sub_item})"
116
+ "#{operation}(#{sub_input})"
117
117
  when :add
118
- "(#{sub_item} + #{sub_item2})"
118
+ "(#{sub_input} + #{sub_input2})"
119
119
  when :sub
120
- "(#{sub_item} - #{sub_item2})"
120
+ "(#{sub_input} - #{sub_input2})"
121
121
  when :pow
122
- "(#{sub_item}^#{sub_item2})"
122
+ "(#{sub_input}^#{sub_input2})"
123
123
  when :div
124
- "(#{sub_item} / #{sub_item2})"
124
+ "(#{sub_input} / #{sub_input2})"
125
125
  when :mul
126
- if auto_math(items[0]) == 1
127
- sub_item2
128
- elsif auto_math(items[1]) == 1
129
- sub_item
126
+ if auto_math(inputs[0]) == 1
127
+ sub_input2
128
+ elsif auto_math(inputs[1]) == 1
129
+ sub_input
130
130
  else
131
- "(#{sub_item} * #{sub_item2})"
131
+ "(#{sub_input} * #{sub_input2})"
132
132
  end
133
133
  when :sum
134
- "sum(|#{sub_item}|, axis=#{sub_item2})"
134
+ "sum(|#{sub_input}|, axis=#{sub_input2})"
135
135
  when :mean
136
- "mean(|#{sub_item}|, axis=#{sub_item2})"
136
+ "mean(|#{sub_input}|, axis=#{sub_input2})"
137
137
  when :prod
138
- "prod(|#{sub_item}|, axis=#{sub_item2})"
138
+ "prod(|#{sub_input}|, axis=#{sub_input2})"
139
139
  when :gradients
140
- "gradient(#{sub_item})"
140
+ "gradient(#{sub_input})"
141
141
  when :stop_gradient
142
- sub_item
142
+ sub_input
143
143
  when :matmul
144
- "#{sub_item}.matmul(#{sub_item2})"
144
+ "#{sub_input}.matmul(#{sub_input2})"
145
145
  when :eye
146
- "eye(#{sub_item})"
146
+ "eye(#{sub_input})"
147
147
  when :transpose
148
- "transpose(#{sub_item})"
148
+ "transpose(#{sub_input})"
149
149
  when :shape
150
- "#{sub_item}.shape"
150
+ "#{sub_input}.shape"
151
151
  when :exp
152
- "e^#{sub_item})"
152
+ "e^#{sub_input})"
153
153
  when :ones
154
- "ones(#{sub_item})"
154
+ "ones(#{sub_input})"
155
155
  when :ones_like
156
- "ones_like(#{sub_item})"
156
+ "ones_like(#{sub_input})"
157
157
  when :flow_group
158
- "flow_group(#{items.collect { |i| auto_math(i, name_only, max_depth - 1, _cur_depth) }.join(',')})"
158
+ "flow_group(#{inputs.collect { |i| auto_math(i, name_only, max_depth - 1, _cur_depth) }.join(',')})"
159
159
  when :zeros
160
- "zeros(#{sub_item})"
160
+ "zeros(#{sub_input})"
161
161
  when :reshape
162
- "reshape(#{sub_item},#{sub_item2})"
162
+ "reshape(#{sub_input},#{sub_input2})"
163
163
  when :rank
164
- "#{sub_item}.rank"
164
+ "#{sub_input}.rank"
165
165
  when :cond
166
- "(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)} ? #{sub_item} : #{sub_item2})"
166
+ "(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)} ? #{sub_input} : #{sub_input2})"
167
167
  when :less
168
- "#{sub_item} < #{sub_item2}"
168
+ "#{sub_input} < #{sub_input2}"
169
169
  when :less_equal
170
- "#{sub_item} <= #{sub_item2}"
170
+ "#{sub_input} <= #{sub_input2}"
171
171
  when :greater
172
- "#{sub_item} > #{sub_item2}"
172
+ "#{sub_input} > #{sub_input2}"
173
173
  when :greater_equal
174
- "#{sub_item} >= #{sub_item2}"
174
+ "#{sub_input} >= #{sub_input2}"
175
175
  when :square
176
- "#{sub_item}\u00B2"
176
+ "#{sub_input}\u00B2"
177
177
  when :log
178
- "log(#{sub_item})"
178
+ "log(#{sub_input})"
179
179
  when :identity
180
- "identity(#{sub_item})"
180
+ "identity(#{sub_input})"
181
181
  when :print
182
- "print(#{sub_item})"
182
+ "print(#{sub_input})"
183
183
  when :pad
184
- "pad(#{sub_item},#{auto_math(options[:paddings])})"
184
+ "pad(#{sub_input},#{auto_math(options[:paddings])})"
185
185
  when :equal
186
- "#{sub_item} == #{sub_item2}"
186
+ "#{sub_input} == #{sub_input2}"
187
187
  when :not_equal
188
- "#{sub_item} != #{sub_item2}"
188
+ "#{sub_input} != #{sub_input2}"
189
189
  when :logical_and
190
- "#{sub_item} && #{sub_item2}"
190
+ "#{sub_input} && #{sub_input2}"
191
191
  when :sqrt
192
- "sqrt(#{sub_item})"
192
+ "sqrt(#{sub_input})"
193
193
  when :log1p
194
- "log1p(#{sub_item})"
194
+ "log1p(#{sub_input})"
195
195
  when :zeros_like
196
- "zeros_like(#{sub_item})"
196
+ "zeros_like(#{sub_input})"
197
197
  when :where
198
- "where(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)}, #{sub_item}, #{sub_item2})"
198
+ "where(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)}, #{sub_input}, #{sub_input2})"
199
199
  when :max
200
- "max(#{sub_item},#{sub_item2})"
200
+ "max(#{sub_input},#{sub_input2})"
201
201
  when :cast
202
- "cast(#{sub_item}, #{data_type})"
202
+ "cast(#{sub_input}, #{data_type})"
203
203
  when :broadcast_transform
204
- "broadcast_transform(#{sub_item},#{sub_item2})"
204
+ "broadcast_transform(#{sub_input},#{sub_input2})"
205
205
  when :broadcast_gradient_args
206
- "broadcast_transform(#{sub_item},#{sub_item2})"
206
+ "broadcast_transform(#{sub_input},#{sub_input2})"
207
207
  else
208
- "#{operation}(#{sub_item})" if sub_item
209
- "#{operation}(#{sub_item}, #{sub_item2})" if sub_item && sub_item2
208
+ "#{operation}(#{sub_input})" if sub_input
209
+ "#{operation}(#{sub_input}, #{sub_input2})" if sub_input && sub_input2
210
210
  end
211
211
  ["\n",(_cur_depth + 1).times.collect { ' ' }, out].flatten.join
212
212
  end
@@ -224,47 +224,47 @@ module TensorStream
224
224
  def infer_shape
225
225
  case operation
226
226
  when :index
227
- item_shape = items[0].shape.shape
228
- return nil if item_shape.nil?
229
- return item_shape[1, item_shape.size]
227
+ input_shape = inputs[0].shape.shape
228
+ return nil if input_shape.nil?
229
+ return input_shape[1, input_shape.size]
230
230
  when :mean, :prod, :sum
231
- return [] if items[1].nil?
232
- return nil if items[0].nil?
233
- item_shape = items[0].shape.shape
234
- return nil if item_shape.nil?
235
- return nil if items[1].is_a?(Tensor) && items[1].value.nil?
231
+ return [] if inputs[1].nil?
232
+ return nil if inputs[0].nil?
233
+ input_shape = inputs[0].shape.shape
234
+ return nil if input_shape.nil?
235
+ return nil if inputs[1].is_a?(Tensor) && inputs[1].value.nil?
236
236
 
237
- axis = items[1].is_a?(Tensor) ? items[1].value : items[1]
237
+ axis = inputs[1].is_a?(Tensor) ? inputs[1].value : inputs[1]
238
238
 
239
239
  axis = [ axis ] unless axis.is_a?(Array)
240
- return item_shape.each_with_index.map do |s, index|
240
+ return input_shape.each_with_index.map do |s, index|
241
241
  next nil if axis.include?(index)
242
242
  s
243
243
  end.compact
244
244
  when :reshape
245
- new_shape = items[1] && items[1].value ? items[1].value : nil
245
+ new_shape = inputs[1] && inputs[1].value ? inputs[1].value : nil
246
246
  return nil if new_shape.nil?
247
247
 
248
- item_shape = items[0].shape.shape
249
- return new_shape if item_shape.nil?
248
+ input_shape = inputs[0].shape.shape
249
+ return new_shape if input_shape.nil?
250
250
 
251
- return TensorShape.fix_inferred_elements(new_shape, item_shape.reduce(:*))
251
+ return TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
252
252
  when :flow_group
253
253
  return []
254
254
  when :zeros, :ones
255
- return items[0] ? items[0].value : options[:shape]
255
+ return inputs[0] ? inputs[0].value : options[:shape]
256
256
  when :zeros_like, :ones_like
257
- items[0].shape.shape
257
+ inputs[0].shape.shape
258
258
  when :shape
259
- return items[0].shape.shape ? [items[0].shape.shape.size] : nil
259
+ return inputs[0].shape.shape ? [inputs[0].shape.shape.size] : nil
260
260
  when :matmul
261
- shape1 = items[0].shape.shape.nil? ? nil : items[0].shape.shape[0]
262
- shape2 = items[1].shape.shape.nil? ? nil : items[1].shape.shape[1]
261
+ shape1 = inputs[0].shape.shape.nil? ? nil : inputs[0].shape.shape[0]
262
+ shape2 = inputs[1].shape.shape.nil? ? nil : inputs[1].shape.shape[1]
263
263
  return [shape1, shape2]
264
264
  else
265
- return items[0].shape.shape if items.size == 1
266
- if items.size == 2 && items[0] && items[1]
267
- return TensorShape.infer_shape(items[0].shape.shape, items[1].shape.shape)
265
+ return inputs[0].shape.shape if inputs.size == 1
266
+ if inputs.size == 2 && inputs[0] && inputs[1]
267
+ return TensorShape.infer_shape(inputs[0].shape.shape, inputs[1].shape.shape)
268
268
  end
269
269
  end
270
270
 
@@ -273,14 +273,14 @@ module TensorStream
273
273
 
274
274
  def propagate_consumer(consumer)
275
275
  super
276
- @items.compact.each do |item|
277
- item.send(:propagate_consumer, consumer) if item.name != name
276
+ @inputs.compact.each do |input|
277
+ input.send(:propagate_consumer, consumer) if input.name != name
278
278
  end
279
279
  end
280
280
 
281
281
  def propagate_outputs
282
- @items.compact.each do |item|
283
- item.send(:setup_output, self) if item.name != self.name
282
+ @inputs.compact.each do |input|
283
+ input.send(:setup_output, self) if input.name != self.name
284
284
  end
285
285
  end
286
286