tensor_stream 0.4.1 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +38 -17
- data/benchmark/benchmark.rb +16 -20
- data/lib/tensor_stream/control_flow.rb +3 -3
- data/lib/tensor_stream/debugging/debugging.rb +4 -4
- data/lib/tensor_stream/device.rb +5 -2
- data/lib/tensor_stream/evaluator/base_evaluator.rb +138 -0
- data/lib/tensor_stream/evaluator/buffer.rb +7 -2
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/_bool_operand.cl +3 -3
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/_operand.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/abs.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/add.cl +1 -1
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/argmax.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/argmin.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/cast.cl +0 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/cond.cl.erb +6 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/cos.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/div.cl.erb +1 -1
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/exp.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/gemm.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/log.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/log1p.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/max.cl +3 -3
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/mul.cl +1 -1
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/negate.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/pow.cl +3 -3
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/reciprocal.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/round.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sigmoid.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sigmoid_grad.cl +3 -3
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sign.cl +1 -1
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sin.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/softmax.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/softmax_grad.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sqrt.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/square.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sub.cl +1 -1
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tan.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tanh.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tanh_grad.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/where.cl +1 -1
- data/lib/tensor_stream/evaluator/{opencl_buffer.rb → opencl/opencl_buffer.rb} +1 -1
- data/lib/tensor_stream/evaluator/opencl/opencl_device.rb +5 -0
- data/lib/tensor_stream/evaluator/{opencl_evaluator.rb → opencl/opencl_evaluator.rb} +404 -452
- data/lib/tensor_stream/evaluator/{opencl_template_helper.rb → opencl/opencl_template_helper.rb} +6 -6
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +21 -21
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +492 -398
- data/lib/tensor_stream/graph.rb +21 -1
- data/lib/tensor_stream/graph_serializers/graphml.rb +59 -59
- data/lib/tensor_stream/graph_serializers/pbtext.rb +1 -1
- data/lib/tensor_stream/helpers/op_helper.rb +6 -2
- data/lib/tensor_stream/math_gradients.rb +7 -7
- data/lib/tensor_stream/operation.rb +100 -100
- data/lib/tensor_stream/session.rb +81 -8
- data/lib/tensor_stream/tensor.rb +7 -5
- data/lib/tensor_stream/utils.rb +32 -19
- data/lib/tensor_stream/version.rb +1 -1
- data/tensor_stream.gemspec +0 -1
- data/test_samples/raw_neural_net_sample.rb +7 -7
- metadata +41 -53
- data/lib/tensor_stream/evaluator/kernels/cond.cl.erb +0 -5
data/lib/tensor_stream/graph.rb
CHANGED
@@ -43,6 +43,19 @@ module TensorStream
|
|
43
43
|
end
|
44
44
|
end
|
45
45
|
|
46
|
+
##
|
47
|
+
# Returns a context manager that specifies the default device to use.
|
48
|
+
def device(device_name)
|
49
|
+
Thread.current["ts_graph_#{object_id}"] ||= {}
|
50
|
+
Thread.current["ts_graph_#{object_id}"][:default_device] ||= []
|
51
|
+
Thread.current["ts_graph_#{object_id}"][:default_device] << device_name
|
52
|
+
begin
|
53
|
+
yield
|
54
|
+
ensure
|
55
|
+
Thread.current["ts_graph_#{object_id}"][:default_device].pop
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
46
59
|
def self.get_default_graph
|
47
60
|
Thread.current[:tensor_stream_current_graph] || create_default
|
48
61
|
end
|
@@ -69,6 +82,7 @@ module TensorStream
|
|
69
82
|
node.name
|
70
83
|
end
|
71
84
|
|
85
|
+
node.device = get_device_scope
|
72
86
|
@nodes[node.name] = node
|
73
87
|
@constants[node.name] = node if node.is_const
|
74
88
|
node.send(:propagate_outputs)
|
@@ -159,11 +173,17 @@ module TensorStream
|
|
159
173
|
|
160
174
|
def get_name_scope
|
161
175
|
graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
|
162
|
-
return nil if graph_thread_storage.nil?
|
176
|
+
return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil?
|
163
177
|
|
164
178
|
graph_thread_storage[:current_scope].join('/')
|
165
179
|
end
|
166
180
|
|
181
|
+
def get_device_scope
|
182
|
+
graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
|
183
|
+
return :default if graph_thread_storage.nil? || graph_thread_storage[:default_device].nil?
|
184
|
+
graph_thread_storage[:default_device].last
|
185
|
+
end
|
186
|
+
|
167
187
|
def as_graph_def
|
168
188
|
TensorStream::Pbtext.new.get_string(self)
|
169
189
|
end
|
@@ -134,73 +134,73 @@ module TensorStream
|
|
134
134
|
add_to_group(groups, "program/#{tensor.name}", node_buf)
|
135
135
|
end
|
136
136
|
|
137
|
-
tensor.
|
138
|
-
next unless
|
139
|
-
next if added[
|
140
|
-
|
141
|
-
next to_graph_ml(
|
142
|
-
|
143
|
-
added[
|
144
|
-
|
145
|
-
if
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
if @last_session_context[
|
150
|
-
|
137
|
+
tensor.inputs.each do |input|
|
138
|
+
next unless input
|
139
|
+
next if added[input.name]
|
140
|
+
|
141
|
+
next to_graph_ml(input, arr_buf, added, groups) if input.is_a?(Operation)
|
142
|
+
|
143
|
+
added[input.name] = true
|
144
|
+
input_buf = []
|
145
|
+
if input.is_a?(Variable)
|
146
|
+
input_buf << "<node id=\"#{_gml_string(input.name)}\">"
|
147
|
+
input_buf << "<data key=\"d0\">#{input.name}</data>"
|
148
|
+
input_buf << "<data key=\"d2\">green</data>"
|
149
|
+
if @last_session_context[input.name]
|
150
|
+
input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
|
151
151
|
end
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
elsif
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
if @last_session_context[
|
168
|
-
|
152
|
+
input_buf << "<data key=\"d9\">"
|
153
|
+
input_buf << "<y:ShapeNode>"
|
154
|
+
input_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
|
155
|
+
input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
|
156
|
+
input_buf << "</y:ShapeNode>"
|
157
|
+
input_buf << "</data>"
|
158
|
+
input_buf << "</node>"
|
159
|
+
elsif input.is_a?(Placeholder)
|
160
|
+
input_buf << "<node id=\"#{_gml_string(input.name)}\">"
|
161
|
+
input_buf << "<data key=\"d9\">"
|
162
|
+
input_buf << "<y:ShapeNode>"
|
163
|
+
input_buf << " <y:Fill color=\"#FFCC00\" transparent=\"false\"/>"
|
164
|
+
input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
|
165
|
+
input_buf << "</y:ShapeNode>"
|
166
|
+
input_buf << "</data>"
|
167
|
+
if @last_session_context[input.name]
|
168
|
+
input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
|
169
169
|
end
|
170
|
-
|
171
|
-
elsif
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
if
|
179
|
-
|
170
|
+
input_buf << "</node>"
|
171
|
+
elsif input.is_a?(Tensor)
|
172
|
+
input_buf << "<node id=\"#{_gml_string(input.name)}\">"
|
173
|
+
input_buf << "<data key=\"d0\">#{input.name}</data>"
|
174
|
+
input_buf << "<data key=\"d2\">black</data>"
|
175
|
+
input_buf << "<data key=\"d9\">"
|
176
|
+
input_buf << "<y:ShapeNode>"
|
177
|
+
|
178
|
+
if input.internal?
|
179
|
+
input_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
|
180
180
|
else
|
181
|
-
|
181
|
+
input_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
|
182
182
|
end
|
183
183
|
|
184
184
|
|
185
|
-
|
185
|
+
input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
|
186
186
|
|
187
|
-
|
188
|
-
|
189
|
-
|
187
|
+
input_buf << "</y:ShapeNode>"
|
188
|
+
input_buf << "</data>"
|
189
|
+
input_buf << "</node>"
|
190
190
|
end
|
191
191
|
|
192
|
-
if !add_to_group(groups,
|
193
|
-
if
|
194
|
-
add_to_group(groups, "variable/#{
|
192
|
+
if !add_to_group(groups, input.name, input_buf)
|
193
|
+
if input.is_a?(Variable)
|
194
|
+
add_to_group(groups, "variable/#{input.name}", input_buf)
|
195
195
|
else
|
196
|
-
add_to_group(groups, "program/#{
|
196
|
+
add_to_group(groups, "program/#{input.name}", input_buf)
|
197
197
|
end
|
198
198
|
end
|
199
199
|
end
|
200
200
|
|
201
|
-
tensor.
|
202
|
-
next unless
|
203
|
-
output_edge(
|
201
|
+
tensor.inputs.each_with_index do |input, index|
|
202
|
+
next unless input
|
203
|
+
output_edge(input, tensor, arr_buf, index)
|
204
204
|
end
|
205
205
|
end
|
206
206
|
|
@@ -208,20 +208,20 @@ module TensorStream
|
|
208
208
|
str.gsub('/','-')
|
209
209
|
end
|
210
210
|
|
211
|
-
def output_edge(
|
211
|
+
def output_edge(input, tensor, arr_buf, index = 0)
|
212
212
|
target_name = tensor.is_a?(Tensor) ? tensor.name : tensor
|
213
|
-
arr_buf << "<edge source=\"#{_gml_string(
|
213
|
+
arr_buf << "<edge source=\"#{_gml_string(input.name)}\" target=\"#{_gml_string(target_name)}\">"
|
214
214
|
arr_buf << "<data key=\"d13\">"
|
215
215
|
|
216
216
|
arr_buf << "<y:PolyLineEdge>"
|
217
217
|
arr_buf << "<y:EdgeLabel >"
|
218
218
|
if !@last_session_context.empty?
|
219
|
-
arr_buf << "<![CDATA[ #{_val(
|
219
|
+
arr_buf << "<![CDATA[ #{_val(input)} ]]>"
|
220
220
|
else
|
221
|
-
if
|
222
|
-
arr_buf << "<![CDATA[ #{
|
221
|
+
if input.shape.shape.nil?
|
222
|
+
arr_buf << "<![CDATA[ #{input.data_type.to_s} ? ]]>"
|
223
223
|
else
|
224
|
-
arr_buf << "<![CDATA[ #{
|
224
|
+
arr_buf << "<![CDATA[ #{input.data_type.to_s} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
|
225
225
|
end
|
226
226
|
end
|
227
227
|
arr_buf << "</y:EdgeLabel >"
|
@@ -11,7 +11,7 @@ module TensorStream
|
|
11
11
|
@lines << " name: #{node.name.to_json}"
|
12
12
|
if node.is_a?(TensorStream::Operation)
|
13
13
|
@lines << " op: #{camelize(node.operation.to_s).to_json}"
|
14
|
-
node.
|
14
|
+
node.inputs.each do |input|
|
15
15
|
next unless input
|
16
16
|
@lines << " input: #{input.name.to_json}"
|
17
17
|
end
|
@@ -33,8 +33,12 @@ module TensorStream
|
|
33
33
|
arr
|
34
34
|
end
|
35
35
|
|
36
|
-
def dtype_eval(rank, value)
|
37
|
-
dtype =
|
36
|
+
def dtype_eval(rank, value, data_type = nil)
|
37
|
+
dtype = if data_type.nil?
|
38
|
+
Tensor.detect_type(value[0])
|
39
|
+
else
|
40
|
+
data_type
|
41
|
+
end
|
38
42
|
|
39
43
|
rank += 1 if dtype == :array
|
40
44
|
|
@@ -38,22 +38,22 @@ module TensorStream
|
|
38
38
|
computed_op.each_with_index do |op_grad, index|
|
39
39
|
next if op_grad.nil?
|
40
40
|
|
41
|
-
if nodes_to_compute.include?(tensor.
|
42
|
-
partials << _propagate(op_grad, tensor.
|
41
|
+
if nodes_to_compute.include?(tensor.inputs[index].name)
|
42
|
+
partials << _propagate(op_grad, tensor.inputs[index], stop_tensor, nodes_to_compute, stop_gradients)
|
43
43
|
end
|
44
44
|
end
|
45
45
|
|
46
46
|
partials.reduce(:+)
|
47
47
|
else
|
48
48
|
return tf.zeros_like(stop_tensor) if computed_op.nil?
|
49
|
-
_propagate(computed_op, tensor.
|
49
|
+
_propagate(computed_op, tensor.inputs[0], stop_tensor, nodes_to_compute, stop_gradients)
|
50
50
|
end
|
51
51
|
end
|
52
52
|
|
53
53
|
def self._compute_derivative(node, grad)
|
54
54
|
node.graph.name_scope("#{node.name}_grad") do
|
55
|
-
x = node.
|
56
|
-
y = node.
|
55
|
+
x = node.inputs[0] if node.inputs[0]
|
56
|
+
y = node.inputs[1] if node.inputs[1]
|
57
57
|
|
58
58
|
case node.operation
|
59
59
|
when :add
|
@@ -221,8 +221,8 @@ module TensorStream
|
|
221
221
|
|
222
222
|
def self._min_or_max_grad(op, grad)
|
223
223
|
y = op
|
224
|
-
indicators = tf.cast(tf.equal(y, op.
|
225
|
-
num_selected = tf.reduce_sum(indicators, op.
|
224
|
+
indicators = tf.cast(tf.equal(y, op.inputs[0]), grad.data_type)
|
225
|
+
num_selected = tf.reduce_sum(indicators, op.inputs[1])
|
226
226
|
_safe_shape_div(indicators, num_selected) * grad
|
227
227
|
end
|
228
228
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# TensorStream class that defines an operation
|
3
3
|
class Operation < Tensor
|
4
|
-
attr_accessor :name, :operation, :
|
4
|
+
attr_accessor :name, :operation, :inputs, :rank, :options
|
5
5
|
attr_reader :outputs
|
6
6
|
|
7
7
|
def initialize(operation, input_a, input_b, options = {})
|
@@ -15,7 +15,7 @@ module TensorStream
|
|
15
15
|
|
16
16
|
@options = options
|
17
17
|
|
18
|
-
@
|
18
|
+
@inputs = [input_a, input_b].map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
|
19
19
|
@data_type = set_data_type(options[:data_type])
|
20
20
|
@is_const = infer_const
|
21
21
|
@shape = TensorShape.new(infer_shape)
|
@@ -30,16 +30,16 @@ module TensorStream
|
|
30
30
|
{
|
31
31
|
op: operation,
|
32
32
|
name: name,
|
33
|
-
operands: hashify_tensor(
|
33
|
+
operands: hashify_tensor(inputs)
|
34
34
|
}
|
35
35
|
end
|
36
36
|
|
37
37
|
def self.empty_matrix?(input)
|
38
38
|
if input.is_a?(Array)
|
39
|
-
input.each do |
|
40
|
-
if
|
41
|
-
return false unless empty_matrix?(
|
42
|
-
elsif
|
39
|
+
input.each do |input|
|
40
|
+
if input.is_a?(Array)
|
41
|
+
return false unless empty_matrix?(input)
|
42
|
+
elsif input != 0 || input != 0.0
|
43
43
|
return false
|
44
44
|
end
|
45
45
|
end
|
@@ -54,7 +54,7 @@ module TensorStream
|
|
54
54
|
when :random_normal, :random_uniform, :glorot_uniform, :print
|
55
55
|
false
|
56
56
|
else
|
57
|
-
non_const = @
|
57
|
+
non_const = @inputs.compact.find { |input| !input.is_const }
|
58
58
|
non_const ? false : true
|
59
59
|
end
|
60
60
|
end
|
@@ -68,23 +68,23 @@ module TensorStream
|
|
68
68
|
when :random_normal, :random_uniform, :glorot_uniform
|
69
69
|
passed_data_type || :float32
|
70
70
|
when :index
|
71
|
-
if @
|
71
|
+
if @inputs[0].is_a?(ControlFlow)
|
72
72
|
|
73
|
-
if @
|
74
|
-
@
|
73
|
+
if @inputs[1].is_const
|
74
|
+
@inputs[0].inputs[@inputs[1].value].data_type
|
75
75
|
else
|
76
76
|
:unknown
|
77
77
|
end
|
78
78
|
else
|
79
|
-
@
|
79
|
+
@inputs[0].data_type
|
80
80
|
end
|
81
81
|
else
|
82
82
|
return passed_data_type if passed_data_type
|
83
83
|
|
84
|
-
if @
|
85
|
-
@
|
86
|
-
elsif @
|
87
|
-
@
|
84
|
+
if @inputs[0]
|
85
|
+
@inputs[0].data_type
|
86
|
+
elsif @inputs[1]
|
87
|
+
@inputs[1].data_type
|
88
88
|
else
|
89
89
|
:unknown
|
90
90
|
end
|
@@ -94,119 +94,119 @@ module TensorStream
|
|
94
94
|
def to_math(name_only = false, max_depth = 99, _cur_depth = 0)
|
95
95
|
return @name if max_depth.zero?
|
96
96
|
|
97
|
-
|
98
|
-
|
97
|
+
sub_input = auto_math(inputs[0], name_only, max_depth - 1, _cur_depth + 1)
|
98
|
+
sub_input2 = auto_math(inputs[1], name_only, max_depth - 1, _cur_depth + 1) if inputs[1]
|
99
99
|
|
100
100
|
out = case operation
|
101
101
|
when :argmax
|
102
|
-
"argmax(#{
|
102
|
+
"argmax(#{sub_input},#{options[:axis]})"
|
103
103
|
when :negate
|
104
|
-
"-#{
|
104
|
+
"-#{sub_input}"
|
105
105
|
when :index
|
106
|
-
"#{
|
106
|
+
"#{sub_input}[#{sub_input2}]"
|
107
107
|
when :slice
|
108
|
-
"#{
|
108
|
+
"#{sub_input}[#{sub_input2}]"
|
109
109
|
when :assign_sub
|
110
|
-
"(#{
|
110
|
+
"(#{inputs[0] ? inputs[0].name : 'self'} -= #{auto_math(inputs[1], name_only, 1)})"
|
111
111
|
when :assign_add
|
112
|
-
"(#{
|
112
|
+
"(#{inputs[0] ? inputs[0].name : 'self'} += #{auto_math(inputs[1], name_only, 1)})"
|
113
113
|
when :assign
|
114
|
-
"(#{
|
114
|
+
"(#{inputs[0] ? inputs[0].name : 'self'} = #{auto_math(inputs[1], name_only, 1)})"
|
115
115
|
when :sin, :cos, :tanh
|
116
|
-
"#{operation}(#{
|
116
|
+
"#{operation}(#{sub_input})"
|
117
117
|
when :add
|
118
|
-
"(#{
|
118
|
+
"(#{sub_input} + #{sub_input2})"
|
119
119
|
when :sub
|
120
|
-
"(#{
|
120
|
+
"(#{sub_input} - #{sub_input2})"
|
121
121
|
when :pow
|
122
|
-
"(#{
|
122
|
+
"(#{sub_input}^#{sub_input2})"
|
123
123
|
when :div
|
124
|
-
"(#{
|
124
|
+
"(#{sub_input} / #{sub_input2})"
|
125
125
|
when :mul
|
126
|
-
if auto_math(
|
127
|
-
|
128
|
-
elsif auto_math(
|
129
|
-
|
126
|
+
if auto_math(inputs[0]) == 1
|
127
|
+
sub_input2
|
128
|
+
elsif auto_math(inputs[1]) == 1
|
129
|
+
sub_input
|
130
130
|
else
|
131
|
-
"(#{
|
131
|
+
"(#{sub_input} * #{sub_input2})"
|
132
132
|
end
|
133
133
|
when :sum
|
134
|
-
"sum(|#{
|
134
|
+
"sum(|#{sub_input}|, axis=#{sub_input2})"
|
135
135
|
when :mean
|
136
|
-
"mean(|#{
|
136
|
+
"mean(|#{sub_input}|, axis=#{sub_input2})"
|
137
137
|
when :prod
|
138
|
-
"prod(|#{
|
138
|
+
"prod(|#{sub_input}|, axis=#{sub_input2})"
|
139
139
|
when :gradients
|
140
|
-
"gradient(#{
|
140
|
+
"gradient(#{sub_input})"
|
141
141
|
when :stop_gradient
|
142
|
-
|
142
|
+
sub_input
|
143
143
|
when :matmul
|
144
|
-
"#{
|
144
|
+
"#{sub_input}.matmul(#{sub_input2})"
|
145
145
|
when :eye
|
146
|
-
"eye(#{
|
146
|
+
"eye(#{sub_input})"
|
147
147
|
when :transpose
|
148
|
-
"transpose(#{
|
148
|
+
"transpose(#{sub_input})"
|
149
149
|
when :shape
|
150
|
-
"#{
|
150
|
+
"#{sub_input}.shape"
|
151
151
|
when :exp
|
152
|
-
"e^#{
|
152
|
+
"e^#{sub_input})"
|
153
153
|
when :ones
|
154
|
-
"ones(#{
|
154
|
+
"ones(#{sub_input})"
|
155
155
|
when :ones_like
|
156
|
-
"ones_like(#{
|
156
|
+
"ones_like(#{sub_input})"
|
157
157
|
when :flow_group
|
158
|
-
"flow_group(#{
|
158
|
+
"flow_group(#{inputs.collect { |i| auto_math(i, name_only, max_depth - 1, _cur_depth) }.join(',')})"
|
159
159
|
when :zeros
|
160
|
-
"zeros(#{
|
160
|
+
"zeros(#{sub_input})"
|
161
161
|
when :reshape
|
162
|
-
"reshape(#{
|
162
|
+
"reshape(#{sub_input},#{sub_input2})"
|
163
163
|
when :rank
|
164
|
-
"#{
|
164
|
+
"#{sub_input}.rank"
|
165
165
|
when :cond
|
166
|
-
"(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)} ? #{
|
166
|
+
"(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)} ? #{sub_input} : #{sub_input2})"
|
167
167
|
when :less
|
168
|
-
"#{
|
168
|
+
"#{sub_input} < #{sub_input2}"
|
169
169
|
when :less_equal
|
170
|
-
"#{
|
170
|
+
"#{sub_input} <= #{sub_input2}"
|
171
171
|
when :greater
|
172
|
-
"#{
|
172
|
+
"#{sub_input} > #{sub_input2}"
|
173
173
|
when :greater_equal
|
174
|
-
"#{
|
174
|
+
"#{sub_input} >= #{sub_input2}"
|
175
175
|
when :square
|
176
|
-
"#{
|
176
|
+
"#{sub_input}\u00B2"
|
177
177
|
when :log
|
178
|
-
"log(#{
|
178
|
+
"log(#{sub_input})"
|
179
179
|
when :identity
|
180
|
-
"identity(#{
|
180
|
+
"identity(#{sub_input})"
|
181
181
|
when :print
|
182
|
-
"print(#{
|
182
|
+
"print(#{sub_input})"
|
183
183
|
when :pad
|
184
|
-
"pad(#{
|
184
|
+
"pad(#{sub_input},#{auto_math(options[:paddings])})"
|
185
185
|
when :equal
|
186
|
-
"#{
|
186
|
+
"#{sub_input} == #{sub_input2}"
|
187
187
|
when :not_equal
|
188
|
-
"#{
|
188
|
+
"#{sub_input} != #{sub_input2}"
|
189
189
|
when :logical_and
|
190
|
-
"#{
|
190
|
+
"#{sub_input} && #{sub_input2}"
|
191
191
|
when :sqrt
|
192
|
-
"sqrt(#{
|
192
|
+
"sqrt(#{sub_input})"
|
193
193
|
when :log1p
|
194
|
-
"log1p(#{
|
194
|
+
"log1p(#{sub_input})"
|
195
195
|
when :zeros_like
|
196
|
-
"zeros_like(#{
|
196
|
+
"zeros_like(#{sub_input})"
|
197
197
|
when :where
|
198
|
-
"where(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)}, #{
|
198
|
+
"where(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)}, #{sub_input}, #{sub_input2})"
|
199
199
|
when :max
|
200
|
-
"max(#{
|
200
|
+
"max(#{sub_input},#{sub_input2})"
|
201
201
|
when :cast
|
202
|
-
"cast(#{
|
202
|
+
"cast(#{sub_input}, #{data_type})"
|
203
203
|
when :broadcast_transform
|
204
|
-
"broadcast_transform(#{
|
204
|
+
"broadcast_transform(#{sub_input},#{sub_input2})"
|
205
205
|
when :broadcast_gradient_args
|
206
|
-
"broadcast_transform(#{
|
206
|
+
"broadcast_transform(#{sub_input},#{sub_input2})"
|
207
207
|
else
|
208
|
-
"#{operation}(#{
|
209
|
-
"#{operation}(#{
|
208
|
+
"#{operation}(#{sub_input})" if sub_input
|
209
|
+
"#{operation}(#{sub_input}, #{sub_input2})" if sub_input && sub_input2
|
210
210
|
end
|
211
211
|
["\n",(_cur_depth + 1).times.collect { ' ' }, out].flatten.join
|
212
212
|
end
|
@@ -224,47 +224,47 @@ module TensorStream
|
|
224
224
|
def infer_shape
|
225
225
|
case operation
|
226
226
|
when :index
|
227
|
-
|
228
|
-
return nil if
|
229
|
-
return
|
227
|
+
input_shape = inputs[0].shape.shape
|
228
|
+
return nil if input_shape.nil?
|
229
|
+
return input_shape[1, input_shape.size]
|
230
230
|
when :mean, :prod, :sum
|
231
|
-
return [] if
|
232
|
-
return nil if
|
233
|
-
|
234
|
-
return nil if
|
235
|
-
return nil if
|
231
|
+
return [] if inputs[1].nil?
|
232
|
+
return nil if inputs[0].nil?
|
233
|
+
input_shape = inputs[0].shape.shape
|
234
|
+
return nil if input_shape.nil?
|
235
|
+
return nil if inputs[1].is_a?(Tensor) && inputs[1].value.nil?
|
236
236
|
|
237
|
-
axis =
|
237
|
+
axis = inputs[1].is_a?(Tensor) ? inputs[1].value : inputs[1]
|
238
238
|
|
239
239
|
axis = [ axis ] unless axis.is_a?(Array)
|
240
|
-
return
|
240
|
+
return input_shape.each_with_index.map do |s, index|
|
241
241
|
next nil if axis.include?(index)
|
242
242
|
s
|
243
243
|
end.compact
|
244
244
|
when :reshape
|
245
|
-
new_shape =
|
245
|
+
new_shape = inputs[1] && inputs[1].value ? inputs[1].value : nil
|
246
246
|
return nil if new_shape.nil?
|
247
247
|
|
248
|
-
|
249
|
-
return new_shape if
|
248
|
+
input_shape = inputs[0].shape.shape
|
249
|
+
return new_shape if input_shape.nil?
|
250
250
|
|
251
|
-
return TensorShape.fix_inferred_elements(new_shape,
|
251
|
+
return TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
|
252
252
|
when :flow_group
|
253
253
|
return []
|
254
254
|
when :zeros, :ones
|
255
|
-
return
|
255
|
+
return inputs[0] ? inputs[0].value : options[:shape]
|
256
256
|
when :zeros_like, :ones_like
|
257
|
-
|
257
|
+
inputs[0].shape.shape
|
258
258
|
when :shape
|
259
|
-
return
|
259
|
+
return inputs[0].shape.shape ? [inputs[0].shape.shape.size] : nil
|
260
260
|
when :matmul
|
261
|
-
shape1 =
|
262
|
-
shape2 =
|
261
|
+
shape1 = inputs[0].shape.shape.nil? ? nil : inputs[0].shape.shape[0]
|
262
|
+
shape2 = inputs[1].shape.shape.nil? ? nil : inputs[1].shape.shape[1]
|
263
263
|
return [shape1, shape2]
|
264
264
|
else
|
265
|
-
return
|
266
|
-
if
|
267
|
-
return TensorShape.infer_shape(
|
265
|
+
return inputs[0].shape.shape if inputs.size == 1
|
266
|
+
if inputs.size == 2 && inputs[0] && inputs[1]
|
267
|
+
return TensorShape.infer_shape(inputs[0].shape.shape, inputs[1].shape.shape)
|
268
268
|
end
|
269
269
|
end
|
270
270
|
|
@@ -273,14 +273,14 @@ module TensorStream
|
|
273
273
|
|
274
274
|
def propagate_consumer(consumer)
|
275
275
|
super
|
276
|
-
@
|
277
|
-
|
276
|
+
@inputs.compact.each do |input|
|
277
|
+
input.send(:propagate_consumer, consumer) if input.name != name
|
278
278
|
end
|
279
279
|
end
|
280
280
|
|
281
281
|
def propagate_outputs
|
282
|
-
@
|
283
|
-
|
282
|
+
@inputs.compact.each do |input|
|
283
|
+
input.send(:setup_output, self) if input.name != self.name
|
284
284
|
end
|
285
285
|
end
|
286
286
|
|