tensor_stream 0.4.1 → 0.5.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +38 -17
- data/benchmark/benchmark.rb +16 -20
- data/lib/tensor_stream/control_flow.rb +3 -3
- data/lib/tensor_stream/debugging/debugging.rb +4 -4
- data/lib/tensor_stream/device.rb +5 -2
- data/lib/tensor_stream/evaluator/base_evaluator.rb +138 -0
- data/lib/tensor_stream/evaluator/buffer.rb +7 -2
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/_bool_operand.cl +3 -3
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/_operand.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/abs.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/add.cl +1 -1
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/argmax.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/argmin.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/cast.cl +0 -0
- data/lib/tensor_stream/evaluator/opencl/kernels/cond.cl.erb +6 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/cos.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/div.cl.erb +1 -1
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/exp.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/gemm.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/log.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/log1p.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/max.cl +3 -3
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/mul.cl +1 -1
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/negate.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/pow.cl +3 -3
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/reciprocal.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/round.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sigmoid.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sigmoid_grad.cl +3 -3
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sign.cl +1 -1
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sin.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/softmax.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/softmax_grad.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sqrt.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/square.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sub.cl +1 -1
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tan.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tanh.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tanh_grad.cl +0 -0
- data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/where.cl +1 -1
- data/lib/tensor_stream/evaluator/{opencl_buffer.rb → opencl/opencl_buffer.rb} +1 -1
- data/lib/tensor_stream/evaluator/opencl/opencl_device.rb +5 -0
- data/lib/tensor_stream/evaluator/{opencl_evaluator.rb → opencl/opencl_evaluator.rb} +404 -452
- data/lib/tensor_stream/evaluator/{opencl_template_helper.rb → opencl/opencl_template_helper.rb} +6 -6
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +21 -21
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +492 -398
- data/lib/tensor_stream/graph.rb +21 -1
- data/lib/tensor_stream/graph_serializers/graphml.rb +59 -59
- data/lib/tensor_stream/graph_serializers/pbtext.rb +1 -1
- data/lib/tensor_stream/helpers/op_helper.rb +6 -2
- data/lib/tensor_stream/math_gradients.rb +7 -7
- data/lib/tensor_stream/operation.rb +100 -100
- data/lib/tensor_stream/session.rb +81 -8
- data/lib/tensor_stream/tensor.rb +7 -5
- data/lib/tensor_stream/utils.rb +32 -19
- data/lib/tensor_stream/version.rb +1 -1
- data/tensor_stream.gemspec +0 -1
- data/test_samples/raw_neural_net_sample.rb +7 -7
- metadata +41 -53
- data/lib/tensor_stream/evaluator/kernels/cond.cl.erb +0 -5
data/lib/tensor_stream/graph.rb
CHANGED
@@ -43,6 +43,19 @@ module TensorStream
|
|
43
43
|
end
|
44
44
|
end
|
45
45
|
|
46
|
+
##
|
47
|
+
# Returns a context manager that specifies the default device to use.
|
48
|
+
def device(device_name)
|
49
|
+
Thread.current["ts_graph_#{object_id}"] ||= {}
|
50
|
+
Thread.current["ts_graph_#{object_id}"][:default_device] ||= []
|
51
|
+
Thread.current["ts_graph_#{object_id}"][:default_device] << device_name
|
52
|
+
begin
|
53
|
+
yield
|
54
|
+
ensure
|
55
|
+
Thread.current["ts_graph_#{object_id}"][:default_device].pop
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
46
59
|
def self.get_default_graph
|
47
60
|
Thread.current[:tensor_stream_current_graph] || create_default
|
48
61
|
end
|
@@ -69,6 +82,7 @@ module TensorStream
|
|
69
82
|
node.name
|
70
83
|
end
|
71
84
|
|
85
|
+
node.device = get_device_scope
|
72
86
|
@nodes[node.name] = node
|
73
87
|
@constants[node.name] = node if node.is_const
|
74
88
|
node.send(:propagate_outputs)
|
@@ -159,11 +173,17 @@ module TensorStream
|
|
159
173
|
|
160
174
|
def get_name_scope
|
161
175
|
graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
|
162
|
-
return nil if graph_thread_storage.nil?
|
176
|
+
return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil?
|
163
177
|
|
164
178
|
graph_thread_storage[:current_scope].join('/')
|
165
179
|
end
|
166
180
|
|
181
|
+
def get_device_scope
|
182
|
+
graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
|
183
|
+
return :default if graph_thread_storage.nil? || graph_thread_storage[:default_device].nil?
|
184
|
+
graph_thread_storage[:default_device].last
|
185
|
+
end
|
186
|
+
|
167
187
|
def as_graph_def
|
168
188
|
TensorStream::Pbtext.new.get_string(self)
|
169
189
|
end
|
@@ -134,73 +134,73 @@ module TensorStream
|
|
134
134
|
add_to_group(groups, "program/#{tensor.name}", node_buf)
|
135
135
|
end
|
136
136
|
|
137
|
-
tensor.
|
138
|
-
next unless
|
139
|
-
next if added[
|
140
|
-
|
141
|
-
next to_graph_ml(
|
142
|
-
|
143
|
-
added[
|
144
|
-
|
145
|
-
if
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
if @last_session_context[
|
150
|
-
|
137
|
+
tensor.inputs.each do |input|
|
138
|
+
next unless input
|
139
|
+
next if added[input.name]
|
140
|
+
|
141
|
+
next to_graph_ml(input, arr_buf, added, groups) if input.is_a?(Operation)
|
142
|
+
|
143
|
+
added[input.name] = true
|
144
|
+
input_buf = []
|
145
|
+
if input.is_a?(Variable)
|
146
|
+
input_buf << "<node id=\"#{_gml_string(input.name)}\">"
|
147
|
+
input_buf << "<data key=\"d0\">#{input.name}</data>"
|
148
|
+
input_buf << "<data key=\"d2\">green</data>"
|
149
|
+
if @last_session_context[input.name]
|
150
|
+
input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
|
151
151
|
end
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
elsif
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
if @last_session_context[
|
168
|
-
|
152
|
+
input_buf << "<data key=\"d9\">"
|
153
|
+
input_buf << "<y:ShapeNode>"
|
154
|
+
input_buf << " <y:Fill color=\"#33CCCC\" transparent=\"false\"/>"
|
155
|
+
input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
|
156
|
+
input_buf << "</y:ShapeNode>"
|
157
|
+
input_buf << "</data>"
|
158
|
+
input_buf << "</node>"
|
159
|
+
elsif input.is_a?(Placeholder)
|
160
|
+
input_buf << "<node id=\"#{_gml_string(input.name)}\">"
|
161
|
+
input_buf << "<data key=\"d9\">"
|
162
|
+
input_buf << "<y:ShapeNode>"
|
163
|
+
input_buf << " <y:Fill color=\"#FFCC00\" transparent=\"false\"/>"
|
164
|
+
input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
|
165
|
+
input_buf << "</y:ShapeNode>"
|
166
|
+
input_buf << "</data>"
|
167
|
+
if @last_session_context[input.name]
|
168
|
+
input_buf << "<data key=\"d3\">#{_val(tensor)}</data>"
|
169
169
|
end
|
170
|
-
|
171
|
-
elsif
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
if
|
179
|
-
|
170
|
+
input_buf << "</node>"
|
171
|
+
elsif input.is_a?(Tensor)
|
172
|
+
input_buf << "<node id=\"#{_gml_string(input.name)}\">"
|
173
|
+
input_buf << "<data key=\"d0\">#{input.name}</data>"
|
174
|
+
input_buf << "<data key=\"d2\">black</data>"
|
175
|
+
input_buf << "<data key=\"d9\">"
|
176
|
+
input_buf << "<y:ShapeNode>"
|
177
|
+
|
178
|
+
if input.internal?
|
179
|
+
input_buf << " <y:Fill color=\"#C0C0C0\" transparent=\"false\"/>"
|
180
180
|
else
|
181
|
-
|
181
|
+
input_buf << " <y:Fill color=\"#FFFFFF\" transparent=\"false\"/>"
|
182
182
|
end
|
183
183
|
|
184
184
|
|
185
|
-
|
185
|
+
input_buf << " <y:NodeLabel alignment=\"center\">#{input.name}</y:NodeLabel>"
|
186
186
|
|
187
|
-
|
188
|
-
|
189
|
-
|
187
|
+
input_buf << "</y:ShapeNode>"
|
188
|
+
input_buf << "</data>"
|
189
|
+
input_buf << "</node>"
|
190
190
|
end
|
191
191
|
|
192
|
-
if !add_to_group(groups,
|
193
|
-
if
|
194
|
-
add_to_group(groups, "variable/#{
|
192
|
+
if !add_to_group(groups, input.name, input_buf)
|
193
|
+
if input.is_a?(Variable)
|
194
|
+
add_to_group(groups, "variable/#{input.name}", input_buf)
|
195
195
|
else
|
196
|
-
add_to_group(groups, "program/#{
|
196
|
+
add_to_group(groups, "program/#{input.name}", input_buf)
|
197
197
|
end
|
198
198
|
end
|
199
199
|
end
|
200
200
|
|
201
|
-
tensor.
|
202
|
-
next unless
|
203
|
-
output_edge(
|
201
|
+
tensor.inputs.each_with_index do |input, index|
|
202
|
+
next unless input
|
203
|
+
output_edge(input, tensor, arr_buf, index)
|
204
204
|
end
|
205
205
|
end
|
206
206
|
|
@@ -208,20 +208,20 @@ module TensorStream
|
|
208
208
|
str.gsub('/','-')
|
209
209
|
end
|
210
210
|
|
211
|
-
def output_edge(
|
211
|
+
def output_edge(input, tensor, arr_buf, index = 0)
|
212
212
|
target_name = tensor.is_a?(Tensor) ? tensor.name : tensor
|
213
|
-
arr_buf << "<edge source=\"#{_gml_string(
|
213
|
+
arr_buf << "<edge source=\"#{_gml_string(input.name)}\" target=\"#{_gml_string(target_name)}\">"
|
214
214
|
arr_buf << "<data key=\"d13\">"
|
215
215
|
|
216
216
|
arr_buf << "<y:PolyLineEdge>"
|
217
217
|
arr_buf << "<y:EdgeLabel >"
|
218
218
|
if !@last_session_context.empty?
|
219
|
-
arr_buf << "<![CDATA[ #{_val(
|
219
|
+
arr_buf << "<![CDATA[ #{_val(input)} ]]>"
|
220
220
|
else
|
221
|
-
if
|
222
|
-
arr_buf << "<![CDATA[ #{
|
221
|
+
if input.shape.shape.nil?
|
222
|
+
arr_buf << "<![CDATA[ #{input.data_type.to_s} ? ]]>"
|
223
223
|
else
|
224
|
-
arr_buf << "<![CDATA[ #{
|
224
|
+
arr_buf << "<![CDATA[ #{input.data_type.to_s} #{input.shape.shape.empty? ? 'scalar' : input.shape.shape.to_json} ]]>"
|
225
225
|
end
|
226
226
|
end
|
227
227
|
arr_buf << "</y:EdgeLabel >"
|
@@ -11,7 +11,7 @@ module TensorStream
|
|
11
11
|
@lines << " name: #{node.name.to_json}"
|
12
12
|
if node.is_a?(TensorStream::Operation)
|
13
13
|
@lines << " op: #{camelize(node.operation.to_s).to_json}"
|
14
|
-
node.
|
14
|
+
node.inputs.each do |input|
|
15
15
|
next unless input
|
16
16
|
@lines << " input: #{input.name.to_json}"
|
17
17
|
end
|
@@ -33,8 +33,12 @@ module TensorStream
|
|
33
33
|
arr
|
34
34
|
end
|
35
35
|
|
36
|
-
def dtype_eval(rank, value)
|
37
|
-
dtype =
|
36
|
+
def dtype_eval(rank, value, data_type = nil)
|
37
|
+
dtype = if data_type.nil?
|
38
|
+
Tensor.detect_type(value[0])
|
39
|
+
else
|
40
|
+
data_type
|
41
|
+
end
|
38
42
|
|
39
43
|
rank += 1 if dtype == :array
|
40
44
|
|
@@ -38,22 +38,22 @@ module TensorStream
|
|
38
38
|
computed_op.each_with_index do |op_grad, index|
|
39
39
|
next if op_grad.nil?
|
40
40
|
|
41
|
-
if nodes_to_compute.include?(tensor.
|
42
|
-
partials << _propagate(op_grad, tensor.
|
41
|
+
if nodes_to_compute.include?(tensor.inputs[index].name)
|
42
|
+
partials << _propagate(op_grad, tensor.inputs[index], stop_tensor, nodes_to_compute, stop_gradients)
|
43
43
|
end
|
44
44
|
end
|
45
45
|
|
46
46
|
partials.reduce(:+)
|
47
47
|
else
|
48
48
|
return tf.zeros_like(stop_tensor) if computed_op.nil?
|
49
|
-
_propagate(computed_op, tensor.
|
49
|
+
_propagate(computed_op, tensor.inputs[0], stop_tensor, nodes_to_compute, stop_gradients)
|
50
50
|
end
|
51
51
|
end
|
52
52
|
|
53
53
|
def self._compute_derivative(node, grad)
|
54
54
|
node.graph.name_scope("#{node.name}_grad") do
|
55
|
-
x = node.
|
56
|
-
y = node.
|
55
|
+
x = node.inputs[0] if node.inputs[0]
|
56
|
+
y = node.inputs[1] if node.inputs[1]
|
57
57
|
|
58
58
|
case node.operation
|
59
59
|
when :add
|
@@ -221,8 +221,8 @@ module TensorStream
|
|
221
221
|
|
222
222
|
def self._min_or_max_grad(op, grad)
|
223
223
|
y = op
|
224
|
-
indicators = tf.cast(tf.equal(y, op.
|
225
|
-
num_selected = tf.reduce_sum(indicators, op.
|
224
|
+
indicators = tf.cast(tf.equal(y, op.inputs[0]), grad.data_type)
|
225
|
+
num_selected = tf.reduce_sum(indicators, op.inputs[1])
|
226
226
|
_safe_shape_div(indicators, num_selected) * grad
|
227
227
|
end
|
228
228
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# TensorStream class that defines an operation
|
3
3
|
class Operation < Tensor
|
4
|
-
attr_accessor :name, :operation, :
|
4
|
+
attr_accessor :name, :operation, :inputs, :rank, :options
|
5
5
|
attr_reader :outputs
|
6
6
|
|
7
7
|
def initialize(operation, input_a, input_b, options = {})
|
@@ -15,7 +15,7 @@ module TensorStream
|
|
15
15
|
|
16
16
|
@options = options
|
17
17
|
|
18
|
-
@
|
18
|
+
@inputs = [input_a, input_b].map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
|
19
19
|
@data_type = set_data_type(options[:data_type])
|
20
20
|
@is_const = infer_const
|
21
21
|
@shape = TensorShape.new(infer_shape)
|
@@ -30,16 +30,16 @@ module TensorStream
|
|
30
30
|
{
|
31
31
|
op: operation,
|
32
32
|
name: name,
|
33
|
-
operands: hashify_tensor(
|
33
|
+
operands: hashify_tensor(inputs)
|
34
34
|
}
|
35
35
|
end
|
36
36
|
|
37
37
|
def self.empty_matrix?(input)
|
38
38
|
if input.is_a?(Array)
|
39
|
-
input.each do |
|
40
|
-
if
|
41
|
-
return false unless empty_matrix?(
|
42
|
-
elsif
|
39
|
+
input.each do |input|
|
40
|
+
if input.is_a?(Array)
|
41
|
+
return false unless empty_matrix?(input)
|
42
|
+
elsif input != 0 || input != 0.0
|
43
43
|
return false
|
44
44
|
end
|
45
45
|
end
|
@@ -54,7 +54,7 @@ module TensorStream
|
|
54
54
|
when :random_normal, :random_uniform, :glorot_uniform, :print
|
55
55
|
false
|
56
56
|
else
|
57
|
-
non_const = @
|
57
|
+
non_const = @inputs.compact.find { |input| !input.is_const }
|
58
58
|
non_const ? false : true
|
59
59
|
end
|
60
60
|
end
|
@@ -68,23 +68,23 @@ module TensorStream
|
|
68
68
|
when :random_normal, :random_uniform, :glorot_uniform
|
69
69
|
passed_data_type || :float32
|
70
70
|
when :index
|
71
|
-
if @
|
71
|
+
if @inputs[0].is_a?(ControlFlow)
|
72
72
|
|
73
|
-
if @
|
74
|
-
@
|
73
|
+
if @inputs[1].is_const
|
74
|
+
@inputs[0].inputs[@inputs[1].value].data_type
|
75
75
|
else
|
76
76
|
:unknown
|
77
77
|
end
|
78
78
|
else
|
79
|
-
@
|
79
|
+
@inputs[0].data_type
|
80
80
|
end
|
81
81
|
else
|
82
82
|
return passed_data_type if passed_data_type
|
83
83
|
|
84
|
-
if @
|
85
|
-
@
|
86
|
-
elsif @
|
87
|
-
@
|
84
|
+
if @inputs[0]
|
85
|
+
@inputs[0].data_type
|
86
|
+
elsif @inputs[1]
|
87
|
+
@inputs[1].data_type
|
88
88
|
else
|
89
89
|
:unknown
|
90
90
|
end
|
@@ -94,119 +94,119 @@ module TensorStream
|
|
94
94
|
def to_math(name_only = false, max_depth = 99, _cur_depth = 0)
|
95
95
|
return @name if max_depth.zero?
|
96
96
|
|
97
|
-
|
98
|
-
|
97
|
+
sub_input = auto_math(inputs[0], name_only, max_depth - 1, _cur_depth + 1)
|
98
|
+
sub_input2 = auto_math(inputs[1], name_only, max_depth - 1, _cur_depth + 1) if inputs[1]
|
99
99
|
|
100
100
|
out = case operation
|
101
101
|
when :argmax
|
102
|
-
"argmax(#{
|
102
|
+
"argmax(#{sub_input},#{options[:axis]})"
|
103
103
|
when :negate
|
104
|
-
"-#{
|
104
|
+
"-#{sub_input}"
|
105
105
|
when :index
|
106
|
-
"#{
|
106
|
+
"#{sub_input}[#{sub_input2}]"
|
107
107
|
when :slice
|
108
|
-
"#{
|
108
|
+
"#{sub_input}[#{sub_input2}]"
|
109
109
|
when :assign_sub
|
110
|
-
"(#{
|
110
|
+
"(#{inputs[0] ? inputs[0].name : 'self'} -= #{auto_math(inputs[1], name_only, 1)})"
|
111
111
|
when :assign_add
|
112
|
-
"(#{
|
112
|
+
"(#{inputs[0] ? inputs[0].name : 'self'} += #{auto_math(inputs[1], name_only, 1)})"
|
113
113
|
when :assign
|
114
|
-
"(#{
|
114
|
+
"(#{inputs[0] ? inputs[0].name : 'self'} = #{auto_math(inputs[1], name_only, 1)})"
|
115
115
|
when :sin, :cos, :tanh
|
116
|
-
"#{operation}(#{
|
116
|
+
"#{operation}(#{sub_input})"
|
117
117
|
when :add
|
118
|
-
"(#{
|
118
|
+
"(#{sub_input} + #{sub_input2})"
|
119
119
|
when :sub
|
120
|
-
"(#{
|
120
|
+
"(#{sub_input} - #{sub_input2})"
|
121
121
|
when :pow
|
122
|
-
"(#{
|
122
|
+
"(#{sub_input}^#{sub_input2})"
|
123
123
|
when :div
|
124
|
-
"(#{
|
124
|
+
"(#{sub_input} / #{sub_input2})"
|
125
125
|
when :mul
|
126
|
-
if auto_math(
|
127
|
-
|
128
|
-
elsif auto_math(
|
129
|
-
|
126
|
+
if auto_math(inputs[0]) == 1
|
127
|
+
sub_input2
|
128
|
+
elsif auto_math(inputs[1]) == 1
|
129
|
+
sub_input
|
130
130
|
else
|
131
|
-
"(#{
|
131
|
+
"(#{sub_input} * #{sub_input2})"
|
132
132
|
end
|
133
133
|
when :sum
|
134
|
-
"sum(|#{
|
134
|
+
"sum(|#{sub_input}|, axis=#{sub_input2})"
|
135
135
|
when :mean
|
136
|
-
"mean(|#{
|
136
|
+
"mean(|#{sub_input}|, axis=#{sub_input2})"
|
137
137
|
when :prod
|
138
|
-
"prod(|#{
|
138
|
+
"prod(|#{sub_input}|, axis=#{sub_input2})"
|
139
139
|
when :gradients
|
140
|
-
"gradient(#{
|
140
|
+
"gradient(#{sub_input})"
|
141
141
|
when :stop_gradient
|
142
|
-
|
142
|
+
sub_input
|
143
143
|
when :matmul
|
144
|
-
"#{
|
144
|
+
"#{sub_input}.matmul(#{sub_input2})"
|
145
145
|
when :eye
|
146
|
-
"eye(#{
|
146
|
+
"eye(#{sub_input})"
|
147
147
|
when :transpose
|
148
|
-
"transpose(#{
|
148
|
+
"transpose(#{sub_input})"
|
149
149
|
when :shape
|
150
|
-
"#{
|
150
|
+
"#{sub_input}.shape"
|
151
151
|
when :exp
|
152
|
-
"e^#{
|
152
|
+
"e^#{sub_input})"
|
153
153
|
when :ones
|
154
|
-
"ones(#{
|
154
|
+
"ones(#{sub_input})"
|
155
155
|
when :ones_like
|
156
|
-
"ones_like(#{
|
156
|
+
"ones_like(#{sub_input})"
|
157
157
|
when :flow_group
|
158
|
-
"flow_group(#{
|
158
|
+
"flow_group(#{inputs.collect { |i| auto_math(i, name_only, max_depth - 1, _cur_depth) }.join(',')})"
|
159
159
|
when :zeros
|
160
|
-
"zeros(#{
|
160
|
+
"zeros(#{sub_input})"
|
161
161
|
when :reshape
|
162
|
-
"reshape(#{
|
162
|
+
"reshape(#{sub_input},#{sub_input2})"
|
163
163
|
when :rank
|
164
|
-
"#{
|
164
|
+
"#{sub_input}.rank"
|
165
165
|
when :cond
|
166
|
-
"(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)} ? #{
|
166
|
+
"(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)} ? #{sub_input} : #{sub_input2})"
|
167
167
|
when :less
|
168
|
-
"#{
|
168
|
+
"#{sub_input} < #{sub_input2}"
|
169
169
|
when :less_equal
|
170
|
-
"#{
|
170
|
+
"#{sub_input} <= #{sub_input2}"
|
171
171
|
when :greater
|
172
|
-
"#{
|
172
|
+
"#{sub_input} > #{sub_input2}"
|
173
173
|
when :greater_equal
|
174
|
-
"#{
|
174
|
+
"#{sub_input} >= #{sub_input2}"
|
175
175
|
when :square
|
176
|
-
"#{
|
176
|
+
"#{sub_input}\u00B2"
|
177
177
|
when :log
|
178
|
-
"log(#{
|
178
|
+
"log(#{sub_input})"
|
179
179
|
when :identity
|
180
|
-
"identity(#{
|
180
|
+
"identity(#{sub_input})"
|
181
181
|
when :print
|
182
|
-
"print(#{
|
182
|
+
"print(#{sub_input})"
|
183
183
|
when :pad
|
184
|
-
"pad(#{
|
184
|
+
"pad(#{sub_input},#{auto_math(options[:paddings])})"
|
185
185
|
when :equal
|
186
|
-
"#{
|
186
|
+
"#{sub_input} == #{sub_input2}"
|
187
187
|
when :not_equal
|
188
|
-
"#{
|
188
|
+
"#{sub_input} != #{sub_input2}"
|
189
189
|
when :logical_and
|
190
|
-
"#{
|
190
|
+
"#{sub_input} && #{sub_input2}"
|
191
191
|
when :sqrt
|
192
|
-
"sqrt(#{
|
192
|
+
"sqrt(#{sub_input})"
|
193
193
|
when :log1p
|
194
|
-
"log1p(#{
|
194
|
+
"log1p(#{sub_input})"
|
195
195
|
when :zeros_like
|
196
|
-
"zeros_like(#{
|
196
|
+
"zeros_like(#{sub_input})"
|
197
197
|
when :where
|
198
|
-
"where(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)}, #{
|
198
|
+
"where(#{auto_math(options[:pred], name_only, max_depth - 1, _cur_depth)}, #{sub_input}, #{sub_input2})"
|
199
199
|
when :max
|
200
|
-
"max(#{
|
200
|
+
"max(#{sub_input},#{sub_input2})"
|
201
201
|
when :cast
|
202
|
-
"cast(#{
|
202
|
+
"cast(#{sub_input}, #{data_type})"
|
203
203
|
when :broadcast_transform
|
204
|
-
"broadcast_transform(#{
|
204
|
+
"broadcast_transform(#{sub_input},#{sub_input2})"
|
205
205
|
when :broadcast_gradient_args
|
206
|
-
"broadcast_transform(#{
|
206
|
+
"broadcast_transform(#{sub_input},#{sub_input2})"
|
207
207
|
else
|
208
|
-
"#{operation}(#{
|
209
|
-
"#{operation}(#{
|
208
|
+
"#{operation}(#{sub_input})" if sub_input
|
209
|
+
"#{operation}(#{sub_input}, #{sub_input2})" if sub_input && sub_input2
|
210
210
|
end
|
211
211
|
["\n",(_cur_depth + 1).times.collect { ' ' }, out].flatten.join
|
212
212
|
end
|
@@ -224,47 +224,47 @@ module TensorStream
|
|
224
224
|
def infer_shape
|
225
225
|
case operation
|
226
226
|
when :index
|
227
|
-
|
228
|
-
return nil if
|
229
|
-
return
|
227
|
+
input_shape = inputs[0].shape.shape
|
228
|
+
return nil if input_shape.nil?
|
229
|
+
return input_shape[1, input_shape.size]
|
230
230
|
when :mean, :prod, :sum
|
231
|
-
return [] if
|
232
|
-
return nil if
|
233
|
-
|
234
|
-
return nil if
|
235
|
-
return nil if
|
231
|
+
return [] if inputs[1].nil?
|
232
|
+
return nil if inputs[0].nil?
|
233
|
+
input_shape = inputs[0].shape.shape
|
234
|
+
return nil if input_shape.nil?
|
235
|
+
return nil if inputs[1].is_a?(Tensor) && inputs[1].value.nil?
|
236
236
|
|
237
|
-
axis =
|
237
|
+
axis = inputs[1].is_a?(Tensor) ? inputs[1].value : inputs[1]
|
238
238
|
|
239
239
|
axis = [ axis ] unless axis.is_a?(Array)
|
240
|
-
return
|
240
|
+
return input_shape.each_with_index.map do |s, index|
|
241
241
|
next nil if axis.include?(index)
|
242
242
|
s
|
243
243
|
end.compact
|
244
244
|
when :reshape
|
245
|
-
new_shape =
|
245
|
+
new_shape = inputs[1] && inputs[1].value ? inputs[1].value : nil
|
246
246
|
return nil if new_shape.nil?
|
247
247
|
|
248
|
-
|
249
|
-
return new_shape if
|
248
|
+
input_shape = inputs[0].shape.shape
|
249
|
+
return new_shape if input_shape.nil?
|
250
250
|
|
251
|
-
return TensorShape.fix_inferred_elements(new_shape,
|
251
|
+
return TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
|
252
252
|
when :flow_group
|
253
253
|
return []
|
254
254
|
when :zeros, :ones
|
255
|
-
return
|
255
|
+
return inputs[0] ? inputs[0].value : options[:shape]
|
256
256
|
when :zeros_like, :ones_like
|
257
|
-
|
257
|
+
inputs[0].shape.shape
|
258
258
|
when :shape
|
259
|
-
return
|
259
|
+
return inputs[0].shape.shape ? [inputs[0].shape.shape.size] : nil
|
260
260
|
when :matmul
|
261
|
-
shape1 =
|
262
|
-
shape2 =
|
261
|
+
shape1 = inputs[0].shape.shape.nil? ? nil : inputs[0].shape.shape[0]
|
262
|
+
shape2 = inputs[1].shape.shape.nil? ? nil : inputs[1].shape.shape[1]
|
263
263
|
return [shape1, shape2]
|
264
264
|
else
|
265
|
-
return
|
266
|
-
if
|
267
|
-
return TensorShape.infer_shape(
|
265
|
+
return inputs[0].shape.shape if inputs.size == 1
|
266
|
+
if inputs.size == 2 && inputs[0] && inputs[1]
|
267
|
+
return TensorShape.infer_shape(inputs[0].shape.shape, inputs[1].shape.shape)
|
268
268
|
end
|
269
269
|
end
|
270
270
|
|
@@ -273,14 +273,14 @@ module TensorStream
|
|
273
273
|
|
274
274
|
def propagate_consumer(consumer)
|
275
275
|
super
|
276
|
-
@
|
277
|
-
|
276
|
+
@inputs.compact.each do |input|
|
277
|
+
input.send(:propagate_consumer, consumer) if input.name != name
|
278
278
|
end
|
279
279
|
end
|
280
280
|
|
281
281
|
def propagate_outputs
|
282
|
-
@
|
283
|
-
|
282
|
+
@inputs.compact.each do |input|
|
283
|
+
input.send(:setup_output, self) if input.name != self.name
|
284
284
|
end
|
285
285
|
end
|
286
286
|
|