tensor_stream 0.9.2 → 0.9.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/tensor_stream/evaluator/base_evaluator.rb +3 -0
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +25 -0
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +24 -24
- data/lib/tensor_stream/evaluator/ruby/check_ops.rb +8 -0
- data/lib/tensor_stream/evaluator/ruby/images_ops.rb +16 -18
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +20 -4
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +9 -5
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +4 -4
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +16 -61
- data/lib/tensor_stream/graph_builder.rb +1 -0
- data/lib/tensor_stream/graph_serializers/graphml.rb +1 -1
- data/lib/tensor_stream/graph_serializers/pbtext.rb +1 -0
- data/lib/tensor_stream/helpers/infer_shape.rb +182 -0
- data/lib/tensor_stream/helpers/op_helper.rb +2 -2
- data/lib/tensor_stream/images.rb +1 -1
- data/lib/tensor_stream/math_gradients.rb +1 -1
- data/lib/tensor_stream/monkey_patches/array.rb +15 -0
- data/lib/tensor_stream/monkey_patches/float.rb +3 -0
- data/lib/tensor_stream/monkey_patches/integer.rb +3 -0
- data/lib/tensor_stream/monkey_patches/patch.rb +70 -0
- data/lib/tensor_stream/nn/nn_ops.rb +43 -9
- data/lib/tensor_stream/operation.rb +2 -153
- data/lib/tensor_stream/ops.rb +71 -56
- data/lib/tensor_stream/profile/report_tool.rb +3 -3
- data/lib/tensor_stream/tensor_shape.rb +9 -6
- data/lib/tensor_stream/train/adadelta_optimizer.rb +1 -1
- data/lib/tensor_stream/train/adagrad_optimizer.rb +1 -1
- data/lib/tensor_stream/train/adam_optimizer.rb +2 -2
- data/lib/tensor_stream/train/learning_rate_decay.rb +29 -0
- data/lib/tensor_stream/train/optimizer.rb +7 -6
- data/lib/tensor_stream/train/saver.rb +1 -0
- data/lib/tensor_stream/train/slot_creator.rb +2 -2
- data/lib/tensor_stream/trainer.rb +3 -0
- data/lib/tensor_stream/utils.rb +2 -2
- data/lib/tensor_stream/version.rb +1 -1
- data/lib/tensor_stream.rb +5 -1
- data/samples/rnn.rb +108 -0
- metadata +8 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 51cb6686663dece94714073ff12b13ad7e57de1aadbee44506897c69a2d5fd67
|
4
|
+
data.tar.gz: 47152a908b2cc966cba6721da8f0c3170e3198a70f416fd224caee7d35a1fca2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7fff67042fd35c651409e04e890feafab25e618d88d4c3f69c97b39e839a943c64a4e05be388a4464e32328a777f6ee56dc5ede0a6c69c977cc94bcb7cd38bad
|
7
|
+
data.tar.gz: a30230994e81062626eaae5ca7a13ece36d2a5ea6210cfe1e3e465b5e5e24adad525547223f27e15991f74728d53d8b02e7cac0710fe16bf3c602402f5b65b5c
|
@@ -140,6 +140,8 @@ module TensorStream
|
|
140
140
|
@context[:profile][:operations][tensor.name] = { op: tensor.operation,
|
141
141
|
step: @context[:profile][:step],
|
142
142
|
eval_time: end_time - start_time,
|
143
|
+
shape: tensor.shape ? tensor.shape.shape : nil,
|
144
|
+
data_type: tensor.data_type,
|
143
145
|
tensor: tensor }
|
144
146
|
end
|
145
147
|
end
|
@@ -166,6 +168,7 @@ module TensorStream
|
|
166
168
|
def global_eval(tensor, input, execution_context, op_options = {})
|
167
169
|
return nil unless input
|
168
170
|
return input unless input.is_a?(Tensor)
|
171
|
+
|
169
172
|
# puts "global eval #{tensor.name}"
|
170
173
|
@context[:_cache][:placement][input.name] = @session.assign_evaluator(input) if @context[:_cache][:placement][input.name].nil?
|
171
174
|
if !on_same_device?(input) # tensor is on another device or evaluator
|
@@ -310,5 +310,30 @@ module TensorStream
|
|
310
310
|
|
311
311
|
reduce_axis(0, axis, val, keep_dims, func)
|
312
312
|
end
|
313
|
+
|
314
|
+
def arr_pad(arr, paddings, data_type = :float32, rank = 0)
|
315
|
+
raise "padding #{paddings[rank]} needs to have to elements [before, after]" if paddings[rank].size != 2
|
316
|
+
|
317
|
+
before = paddings[rank][0]
|
318
|
+
after = paddings[rank][1]
|
319
|
+
pad_value = fp_type?(data_type) ? 0.0 : 0
|
320
|
+
if arr[0].is_a?(Array)
|
321
|
+
next_dim_elem = arr.collect { |a| arr_pad(a, paddings, data_type, rank + 1) }
|
322
|
+
padding = deep_dup_array(next_dim_elem[0], pad_value)
|
323
|
+
Array.new(before) { padding } + next_dim_elem + Array.new(after) { padding }
|
324
|
+
else
|
325
|
+
Array.new(before) { pad_value } + arr + Array.new(after) { pad_value }
|
326
|
+
end
|
327
|
+
end
|
328
|
+
|
329
|
+
def deep_dup_array(arr, value = nil)
|
330
|
+
if arr.is_a?(Array)
|
331
|
+
arr.dup.collect do |a|
|
332
|
+
deep_dup_array(a, value)
|
333
|
+
end
|
334
|
+
else
|
335
|
+
value.nil? ? arr : value
|
336
|
+
end
|
337
|
+
end
|
313
338
|
end
|
314
339
|
end
|
@@ -132,9 +132,9 @@ module TensorStream
|
|
132
132
|
axis = !tensor.options[:axis].is_a?(Array) ? [tensor.options[:axis]] : tensor.options[:axis]
|
133
133
|
|
134
134
|
if !axis.empty?
|
135
|
-
axis.each do |
|
136
|
-
if shape[
|
137
|
-
shape[
|
135
|
+
axis.each do |x|
|
136
|
+
if shape[x] == 1
|
137
|
+
shape[x] = nil
|
138
138
|
else
|
139
139
|
raise TensorStream::ValueError, "unable to squeeze dimension that does not have a size of 1"
|
140
140
|
end
|
@@ -201,6 +201,7 @@ module TensorStream
|
|
201
201
|
out = []
|
202
202
|
input.each_with_index do |x, index|
|
203
203
|
next if remove.include?(x)
|
204
|
+
|
204
205
|
out << x
|
205
206
|
idx << index
|
206
207
|
end
|
@@ -225,6 +226,7 @@ module TensorStream
|
|
225
226
|
break if start == limit
|
226
227
|
break if (start < limit) && (cur_step >= limit)
|
227
228
|
break if (start > limit) && (cur_step <= limit)
|
229
|
+
|
228
230
|
r << cur_step
|
229
231
|
cur_step += delta
|
230
232
|
end
|
@@ -281,34 +283,34 @@ module TensorStream
|
|
281
283
|
get_rank(inputs[0])
|
282
284
|
end
|
283
285
|
|
284
|
-
register_op :split do |
|
286
|
+
register_op :split do |_context, tensor, inputs|
|
285
287
|
value, num_split, axis = inputs
|
286
288
|
|
287
289
|
value_shape = shape_eval(value)
|
288
290
|
res = if num_split.is_a?(Array)
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
291
|
+
begin_index = 0
|
292
|
+
num_split.collect do |num|
|
293
|
+
end_index = begin_index + num
|
294
|
+
arr = split_tensor(value, begin_index, end_index, axis)
|
295
|
+
begin_index = end_index
|
296
|
+
arr
|
297
|
+
end
|
298
|
+
else
|
299
|
+
raise TensorStream::ValueError, "#{num_split} does not divide #{value_shape[axis]} evenly" if value_shape[axis] % num_split != 0
|
300
|
+
|
301
|
+
piece_sizes = value_shape[axis] / num_split
|
302
|
+
Array.new(num_split) do |num|
|
303
|
+
begin_index = num * piece_sizes
|
304
|
+
end_index = begin_index + piece_sizes
|
305
|
+
split_tensor(value, begin_index, end_index, axis)
|
306
|
+
end
|
307
|
+
end
|
305
308
|
TensorStream::Evaluator::OutputGroup.new(res, res.map { tensor.inputs[0].data_type })
|
306
309
|
end
|
307
310
|
|
308
311
|
register_op :reshape do |_context, _tensor, inputs|
|
309
312
|
arr, new_shape = inputs
|
310
313
|
arr = [arr] unless arr.is_a?(Array)
|
311
|
-
|
312
314
|
flat_arr = arr.flatten
|
313
315
|
if new_shape.size.zero? && flat_arr.size == 1
|
314
316
|
flat_arr[0]
|
@@ -318,9 +320,7 @@ module TensorStream
|
|
318
320
|
end
|
319
321
|
|
320
322
|
register_op :pad do |context, tensor, inputs|
|
321
|
-
|
322
|
-
|
323
|
-
arr_pad(inputs[0], p, tensor.data_type)
|
323
|
+
arr_pad(inputs[0], inputs[1], tensor.data_type)
|
324
324
|
end
|
325
325
|
|
326
326
|
register_op :tile do |_context, _tensor, inputs|
|
@@ -2,7 +2,15 @@ module TensorStream
|
|
2
2
|
module CheckOps
|
3
3
|
def CheckOps.included(klass)
|
4
4
|
klass.class_eval do
|
5
|
+
register_op :assert_equal do |context, tensor, inputs|
|
6
|
+
result = call_vector_op(tensor, :equal, inputs[0], inputs[1], context, ->(t, u) { t == u })
|
5
7
|
|
8
|
+
result = result.is_a?(Array) ? result.flatten.uniq : [result]
|
9
|
+
prefix = tensor.options[:message] || ""
|
10
|
+
raise TensorStream::InvalidArgumentError, "#{prefix} #{tensor.inputs[0].name} != #{tensor.inputs[1].name}" if result != [true]
|
11
|
+
|
12
|
+
nil
|
13
|
+
end
|
6
14
|
end
|
7
15
|
end
|
8
16
|
end
|
@@ -14,30 +14,28 @@ module TensorStream
|
|
14
14
|
image.grayscale! if channels == 1
|
15
15
|
image_data = image.pixels.collect do |pixel|
|
16
16
|
color_values = if channels == 4
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
if fp_type?(tensor.data_type)
|
32
|
-
color_values.map! { |v| v.to_f }
|
33
|
-
end
|
17
|
+
[ChunkyPNG::Color.r(pixel),
|
18
|
+
ChunkyPNG::Color.g(pixel),
|
19
|
+
ChunkyPNG::Color.b(pixel),
|
20
|
+
ChunkyPNG::Color.a(pixel)]
|
21
|
+
elsif channels == 3
|
22
|
+
[ChunkyPNG::Color.r(pixel),
|
23
|
+
ChunkyPNG::Color.g(pixel),
|
24
|
+
ChunkyPNG::Color.b(pixel)]
|
25
|
+
elsif channels == 1
|
26
|
+
[ChunkyPNG::Color.r(pixel)]
|
27
|
+
else
|
28
|
+
raise "Invalid channel value #{channels}"
|
29
|
+
end
|
30
|
+
|
31
|
+
color_values.map!(&:to_f) if fp_type?(tensor.data_type)
|
34
32
|
|
35
33
|
color_values
|
36
34
|
end
|
37
35
|
TensorShape.reshape(image_data, [image.height, image.width, channels])
|
38
36
|
end
|
39
37
|
|
40
|
-
register_op :encode_png do |_context,
|
38
|
+
register_op :encode_png do |_context, _tensor, inputs|
|
41
39
|
image_data = inputs[0]
|
42
40
|
height, width, channels = shape_eval(image_data)
|
43
41
|
|
@@ -140,17 +140,33 @@ module TensorStream
|
|
140
140
|
end
|
141
141
|
|
142
142
|
register_op(%i[argmax arg_max]) do |_context, tensor, inputs|
|
143
|
-
axis =
|
143
|
+
axis = inputs[1] || 0
|
144
144
|
rank = get_rank(inputs[0])
|
145
145
|
raise TensorStream::InvalidArgumentError, "Expected dimension in the range [#{-rank},#{rank}) but got #{axis}" if axis < -rank || axis >= rank
|
146
|
-
|
146
|
+
|
147
|
+
new_shape = shape_eval(inputs[0])
|
148
|
+
ns = new_shape.each_with_index.collect do |shape, index|
|
149
|
+
next nil if index == axis
|
150
|
+
|
151
|
+
shape
|
152
|
+
end.compact
|
153
|
+
|
154
|
+
Tensor.cast_dtype(TensorShape.reshape(get_op_with_axis(inputs[0], axis, 0, :max), ns), tensor.data_type)
|
147
155
|
end
|
148
156
|
|
149
157
|
register_op(%i[argmin arg_min]) do |_context, tensor, inputs|
|
150
|
-
axis =
|
158
|
+
axis = inputs[1] || 0
|
151
159
|
rank = get_rank(inputs[0])
|
152
160
|
raise TensorStream::InvalidArgumentError, "Expected dimension in the range [#{-rank},#{rank}) but got #{axis}" if axis < -rank || axis >= rank
|
153
|
-
|
161
|
+
|
162
|
+
new_shape = shape_eval(inputs[0])
|
163
|
+
ns = new_shape.each_with_index.collect do |shape, index|
|
164
|
+
next nil if index == axis
|
165
|
+
|
166
|
+
shape
|
167
|
+
end.compact
|
168
|
+
|
169
|
+
Tensor.cast_dtype(TensorShape.reshape(get_op_with_axis(inputs[0], axis, 0, :min), ns), tensor.data_type)
|
154
170
|
end
|
155
171
|
|
156
172
|
register_op :cumprod do |context, tensor, inputs|
|
@@ -16,11 +16,11 @@ module TensorStream
|
|
16
16
|
assign = tensor.inputs[0] || tensor
|
17
17
|
assign_acc = tensor.inputs[1]
|
18
18
|
assign_acc.value = multi_array_op(->(t, u) { t * momentum + u }, momentum_var, grad)
|
19
|
-
if tensor.options[:use_nesterov]
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
19
|
+
assign.value = if tensor.options[:use_nesterov]
|
20
|
+
multi_array_op(->(v, g, acc) { v - (g * learning_rate + acc * momentum * learning_rate) }, target_var, grad, momentum_var)
|
21
|
+
else
|
22
|
+
multi_array_op(->(v, acc) { v - acc * learning_rate }, target_var, momentum_var)
|
23
|
+
end
|
24
24
|
assign.value
|
25
25
|
end
|
26
26
|
|
@@ -207,6 +207,10 @@ module TensorStream
|
|
207
207
|
TensorShape.reshape(arr, input_shape)
|
208
208
|
end
|
209
209
|
end
|
210
|
+
|
211
|
+
register_op :relu6 do |context, tensor, inputs|
|
212
|
+
call_vector_op(tensor, :relu6, inputs[0], inputs[1], context, ->(t, u) { [[t, 0].max, 6].min })
|
213
|
+
end
|
210
214
|
end
|
211
215
|
end
|
212
216
|
end
|
@@ -25,24 +25,24 @@ module TensorStream
|
|
25
25
|
generate_vector(shape, generator: generator)
|
26
26
|
end
|
27
27
|
|
28
|
-
register_op :random_uniform, no_eval: true do |_context, tensor,
|
28
|
+
register_op :random_uniform, no_eval: true do |_context, tensor, inputs|
|
29
29
|
maxval = tensor.options.fetch(:maxval, 1)
|
30
30
|
minval = tensor.options.fetch(:minval, 0)
|
31
31
|
seed = tensor.options[:seed]
|
32
32
|
|
33
33
|
random = _get_randomizer(tensor, seed)
|
34
34
|
generator = -> { random.rand * (maxval - minval) + minval }
|
35
|
-
shape =
|
35
|
+
shape = inputs[0] || tensor.shape.shape
|
36
36
|
generate_vector(shape, generator: generator)
|
37
37
|
end
|
38
38
|
|
39
|
-
register_op :random_standard_normal, no_eval: true do |_context, tensor,
|
39
|
+
register_op :random_standard_normal, no_eval: true do |_context, tensor, inputs|
|
40
40
|
seed = tensor.options[:seed]
|
41
41
|
random = _get_randomizer(tensor, seed)
|
42
42
|
r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev), -> { random.rand })
|
43
43
|
random = _get_randomizer(tensor, seed)
|
44
44
|
generator = -> { r.rand }
|
45
|
-
shape =
|
45
|
+
shape = inputs[0] || tensor.shape.shape
|
46
46
|
generate_vector(shape, generator: generator)
|
47
47
|
end
|
48
48
|
end
|
@@ -289,7 +289,7 @@ module TensorStream
|
|
289
289
|
# assertions to make sure inferred shapes == actual evaluated shapes
|
290
290
|
if tensor.shape.known? && (result.is_a?(Array) || result.is_a?(Float) || result.is_a?(Integer))
|
291
291
|
if shape_eval(result) != tensor.shape.shape
|
292
|
-
raise "assert error #{tensor.name} #{shape_eval(result)} != #{tensor.shape.shape}"
|
292
|
+
# raise "assert error #{tensor.name} #{shape_eval(result)} != #{tensor.shape.shape}"
|
293
293
|
end
|
294
294
|
end
|
295
295
|
|
@@ -313,21 +313,20 @@ module TensorStream
|
|
313
313
|
@context[tensor.name] = result
|
314
314
|
end
|
315
315
|
rescue EvaluatorExcecutionException => e
|
316
|
-
raise e, "error #{e.message} while evaluating #{tensor.name}
|
316
|
+
raise e, "error #{e.message} while evaluating #{tensor.name} defined at #{tensor.source}"
|
317
317
|
rescue TensorStreamError => e
|
318
|
-
raise e, "error #{e.message} while evaluating #{tensor.name}
|
318
|
+
raise e, "error #{e.message} while evaluating #{tensor.name} defined at #{tensor.source}"
|
319
319
|
rescue StandardError => e
|
320
|
-
|
321
|
-
|
322
|
-
b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
|
320
|
+
# a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
|
321
|
+
# b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
|
323
322
|
puts e.message
|
324
323
|
puts e.backtrace.join("\n")
|
325
324
|
# shape_a = a.shape.shape if a
|
326
325
|
# shape_b = b.shape.shape if b
|
327
326
|
# dtype_a = a.data_type if a
|
328
327
|
# dtype_b = b.data_type if b
|
329
|
-
a = complete_eval(a, child_context)
|
330
|
-
b = complete_eval(b, child_context)
|
328
|
+
# a = complete_eval(a, child_context)
|
329
|
+
# b = complete_eval(b, child_context)
|
331
330
|
# puts "name: #{tensor.given_name}"
|
332
331
|
# # puts "op: #{tensor.to_math(true, 1)}"
|
333
332
|
# puts "A #{shape_a} #{dtype_a}: #{a}" if a
|
@@ -363,36 +362,17 @@ module TensorStream
|
|
363
362
|
|
364
363
|
private
|
365
364
|
|
366
|
-
def get_op_with_axis(a, target_axis, current_axis,
|
367
|
-
|
368
|
-
|
369
|
-
(0...a[0].size).each.collect do |column_index|
|
370
|
-
max = nil
|
371
|
-
max_index = 0
|
372
|
-
a.each_with_index do |row, row_index|
|
373
|
-
if max.nil? || op.call(row[column_index], max)
|
374
|
-
max = row[column_index]
|
375
|
-
max_index = row_index
|
376
|
-
end
|
377
|
-
end
|
365
|
+
def get_op_with_axis(a, target_axis, current_axis, op)
|
366
|
+
rank = get_rank(a)
|
367
|
+
return a.index(a.send(:"#{op}")) if rank == 1
|
378
368
|
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
a.each_with_index do |x, index|
|
385
|
-
if max.nil? || op.call(x, max)
|
386
|
-
max = x
|
387
|
-
max_index = index
|
388
|
-
end
|
389
|
-
end
|
390
|
-
Tensor.cast_dtype(max_index, output_type)
|
391
|
-
end
|
369
|
+
if current_axis == target_axis
|
370
|
+
compare_items = a.collect(&:flatten).transpose
|
371
|
+
compare_items.map { |item| item.index(item.send(:"#{op}")) }
|
372
|
+
elsif a[0].is_a?(Array)
|
373
|
+
a.map { |item| get_op_with_axis(item, target_axis, current_axis + 1, op) }
|
392
374
|
else
|
393
|
-
a.
|
394
|
-
get_op_with_axis(row, target_axis, current_axis + 1, output_type, op)
|
395
|
-
end
|
375
|
+
return a.index(a.send(:"#{op}"))
|
396
376
|
end
|
397
377
|
end
|
398
378
|
|
@@ -403,31 +383,6 @@ module TensorStream
|
|
403
383
|
reduce(val, axis, keep_dims, func)
|
404
384
|
end
|
405
385
|
|
406
|
-
def arr_pad(arr, paddings, data_type = :float32, rank = 0)
|
407
|
-
raise "padding #{paddings[rank]} needs to have to elements [before, after]" if paddings[rank].size != 2
|
408
|
-
|
409
|
-
before = paddings[rank][0]
|
410
|
-
after = paddings[rank][1]
|
411
|
-
pad_value = fp_type?(data_type) ? 0.0 : 0
|
412
|
-
if arr[0].is_a?(Array)
|
413
|
-
next_dim_elem = arr.collect { |a| arr_pad(a, paddings, data_type, rank + 1) }
|
414
|
-
padding = deep_dup_array(next_dim_elem[0], pad_value)
|
415
|
-
Array.new(before) { padding } + next_dim_elem + Array.new(after) { padding }
|
416
|
-
else
|
417
|
-
Array.new(before) { pad_value } + arr + Array.new(after) { pad_value }
|
418
|
-
end
|
419
|
-
end
|
420
|
-
|
421
|
-
def deep_dup_array(arr, value = nil)
|
422
|
-
if arr.is_a?(Array)
|
423
|
-
arr.dup.collect do |a|
|
424
|
-
deep_dup_array(a, value)
|
425
|
-
end
|
426
|
-
else
|
427
|
-
value.nil? ? arr : value
|
428
|
-
end
|
429
|
-
end
|
430
|
-
|
431
386
|
def call_op(op, a, child_context, func)
|
432
387
|
a = complete_eval(a, child_context)
|
433
388
|
process_function_op(a, func)
|
@@ -0,0 +1,182 @@
|
|
1
|
+
require 'tensor_stream/evaluator/operation_helpers/array_ops_helper'
|
2
|
+
module TensorStream
|
3
|
+
##
|
4
|
+
# Convenience class for guessing the shape of a tensor
|
5
|
+
#
|
6
|
+
class InferShape
|
7
|
+
extend TensorStream::ArrayOpsHelper
|
8
|
+
extend TensorStream::OpHelper
|
9
|
+
|
10
|
+
def self.infer_shape(tensor)
|
11
|
+
case tensor.operation
|
12
|
+
when :assign
|
13
|
+
possible_shape = if tensor.inputs[0] && tensor.inputs[0].shape.shape
|
14
|
+
tensor.inputs[0].shape.shape
|
15
|
+
else
|
16
|
+
tensor.inputs[1].shape.shape
|
17
|
+
end
|
18
|
+
|
19
|
+
possible_shape
|
20
|
+
when :index
|
21
|
+
return nil unless tensor.inputs[0].is_a?(Tensor)
|
22
|
+
return nil unless tensor.inputs[0].const_value
|
23
|
+
|
24
|
+
input_shape = tensor.inputs[0].shape
|
25
|
+
return nil unless input_shape.known?
|
26
|
+
|
27
|
+
s = input_shape.shape.dup
|
28
|
+
s.shift
|
29
|
+
s
|
30
|
+
when :arg_min, :argmax, :argmin
|
31
|
+
return nil unless tensor.inputs[0].shape.known?
|
32
|
+
return nil if tensor.inputs[1] && tensor.inputs[1].value.nil?
|
33
|
+
|
34
|
+
axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].value
|
35
|
+
new_shape = tensor.inputs[0].shape.shape
|
36
|
+
new_shape.each_with_index.collect do |shape, index|
|
37
|
+
next nil if index == axis
|
38
|
+
|
39
|
+
shape
|
40
|
+
end.compact
|
41
|
+
when :mean, :prod, :sum, :arg_max
|
42
|
+
return [] if tensor.inputs[1].nil?
|
43
|
+
return nil if tensor.inputs[0].nil?
|
44
|
+
return nil unless tensor.inputs[0].shape.known?
|
45
|
+
|
46
|
+
input_shape = tensor.inputs[0].shape.shape
|
47
|
+
rank = input_shape.size
|
48
|
+
|
49
|
+
axis = tensor.inputs[1].const_value
|
50
|
+
return nil if axis.nil?
|
51
|
+
|
52
|
+
axis = [axis] unless axis.is_a?(Array)
|
53
|
+
axis = axis.map { |a| a < 0 ? rank - a.abs : a }
|
54
|
+
|
55
|
+
input_shape.each_with_index.map do |item, index|
|
56
|
+
if axis.include?(index)
|
57
|
+
next 1 if tensor.options[:keepdims]
|
58
|
+
|
59
|
+
next nil
|
60
|
+
end
|
61
|
+
item
|
62
|
+
end.compact
|
63
|
+
when :reshape
|
64
|
+
new_shape = tensor.inputs[1] && tensor.inputs[1].value ? tensor.inputs[1].value : nil
|
65
|
+
return nil if new_shape.nil?
|
66
|
+
return nil if tensor.inputs[0].shape.nil?
|
67
|
+
|
68
|
+
input_shape = tensor.inputs[0].shape.shape
|
69
|
+
return new_shape if input_shape.nil?
|
70
|
+
return nil if input_shape.include?(nil)
|
71
|
+
TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
|
72
|
+
when :flow_group
|
73
|
+
[]
|
74
|
+
when :zeros, :ones, :fill, :random_standard_normal, :random_uniform
|
75
|
+
a_shape = tensor.inputs[0] ? tensor.inputs[0].const_value : tensor.options[:shape]
|
76
|
+
return nil if a_shape.nil?
|
77
|
+
a_shape.is_a?(Array) ? a_shape : [a_shape]
|
78
|
+
when :zeros_like, :ones_like
|
79
|
+
tensor.inputs[0].shape.shape
|
80
|
+
when :shape
|
81
|
+
tensor.inputs[0].shape.shape ? [tensor.inputs[0].shape.shape.size] : nil
|
82
|
+
when :pad
|
83
|
+
return nil unless tensor.inputs[0].shape.known?
|
84
|
+
return nil unless tensor.inputs[1].value
|
85
|
+
|
86
|
+
size = tensor.inputs[0].shape.shape.reduce(:*) || 1
|
87
|
+
dummy_tensor_for_shape = TensorShape.reshape(Array.new(size), tensor.inputs[0].shape)
|
88
|
+
shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].value))
|
89
|
+
when :mat_mul
|
90
|
+
return nil if tensor.inputs[0].shape.shape.nil? || tensor.inputs[1].shape.shape.nil?
|
91
|
+
return [] if tensor.inputs[0].shape.shape.empty? || tensor.inputs[1].shape.shape.empty?
|
92
|
+
return nil if tensor.inputs[0].shape.shape.size != 2 || tensor.inputs[1].shape.shape.size != 2
|
93
|
+
|
94
|
+
shape1, m = if tensor.options[:transpose_a]
|
95
|
+
[tensor.inputs[0].shape.shape[0], tensor.inputs[0].shape.shape[1]]
|
96
|
+
else
|
97
|
+
[tensor.inputs[0].shape.shape[1], tensor.inputs[0].shape.shape[0]]
|
98
|
+
end
|
99
|
+
|
100
|
+
shape2, n = if tensor.options[:transpose_b]
|
101
|
+
[tensor.inputs[1].shape.shape[1], tensor.inputs[1].shape.shape[0]]
|
102
|
+
else
|
103
|
+
[tensor.inputs[1].shape.shape[0], tensor.inputs[1].shape.shape[1]]
|
104
|
+
end
|
105
|
+
|
106
|
+
return nil if shape1.nil? || shape2.nil? || shape1 < 0 || shape2 < 0
|
107
|
+
|
108
|
+
raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{shape1} != #{shape2}) #{tensor.inputs[0].shape.shape} vs #{tensor.inputs[1].shape.shape}" if shape1 != shape2
|
109
|
+
|
110
|
+
[m, n]
|
111
|
+
when :transpose
|
112
|
+
return nil unless shape_full_specified(tensor.inputs[0])
|
113
|
+
return nil if tensor.inputs[1].is_a?(Tensor)
|
114
|
+
|
115
|
+
rank = tensor.inputs[0].shape.shape.size
|
116
|
+
perm = tensor.inputs[1] || (0...rank).to_a.reverse
|
117
|
+
perm.map { |p| tensor.inputs[0].shape.shape[p] }
|
118
|
+
when :stack
|
119
|
+
return nil unless shape_full_specified(tensor.inputs[0])
|
120
|
+
|
121
|
+
axis = tensor.options[:axis] || 0
|
122
|
+
new_shape = [tensor.inputs.size]
|
123
|
+
tensor.inputs[0].shape.shape.inject(new_shape) { |ns, i| ns << i }
|
124
|
+
rank = tensor.inputs[0].shape.shape.size + 1
|
125
|
+
axis = rank + axis if axis < 0
|
126
|
+
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
127
|
+
rotated_shape.rotate! + new_shape
|
128
|
+
when :concat
|
129
|
+
return nil if tensor.inputs[0].value.nil?
|
130
|
+
|
131
|
+
axis = tensor.inputs[0].value # get axis
|
132
|
+
|
133
|
+
axis_size = 0
|
134
|
+
|
135
|
+
tensor.inputs[1..tensor.inputs.size].each do |input_item|
|
136
|
+
return nil if input_item.shape.shape.nil?
|
137
|
+
return nil if input_item.shape.shape[axis].nil?
|
138
|
+
|
139
|
+
axis_size += input_item.shape.shape[axis]
|
140
|
+
end
|
141
|
+
|
142
|
+
new_shape = tensor.inputs[1].shape.shape.dup
|
143
|
+
new_shape[axis] = axis_size
|
144
|
+
new_shape
|
145
|
+
when :slice, :squeeze
|
146
|
+
nil
|
147
|
+
when :tile
|
148
|
+
nil
|
149
|
+
when :expand_dims
|
150
|
+
nil
|
151
|
+
when :broadcast_gradient_args
|
152
|
+
nil
|
153
|
+
when :no_op
|
154
|
+
nil
|
155
|
+
when :softmax_cross_entropy_with_logits_v2, :sparse_softmax_cross_entropy_with_logits
|
156
|
+
nil
|
157
|
+
when :decode_png, :flow_dynamic_stitch, :dynamic_stitch, :gather
|
158
|
+
nil
|
159
|
+
when :eye
|
160
|
+
return [tensor.inputs[0].const_value, tensor.inputs[1].const_value] if tensor.inputs[0].const_value && tensor.inputs[1].const_value
|
161
|
+
|
162
|
+
nil
|
163
|
+
when :size
|
164
|
+
[]
|
165
|
+
when :unstack
|
166
|
+
return nil unless tensor.inputs[0].shape.known?
|
167
|
+
|
168
|
+
new_shape = tensor.inputs[0].shape.shape.dup
|
169
|
+
rank = new_shape.size - 1
|
170
|
+
axis = tensor.options[:axis] || 0
|
171
|
+
axis = rank + axis if axis < 0
|
172
|
+
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
173
|
+
rotated_shape.rotate!(-1) + new_shape
|
174
|
+
else
|
175
|
+
return nil if tensor.inputs[0].nil?
|
176
|
+
return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
|
177
|
+
|
178
|
+
TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1]
|
179
|
+
end
|
180
|
+
end
|
181
|
+
end
|
182
|
+
end
|
@@ -56,8 +56,8 @@ module TensorStream
|
|
56
56
|
|
57
57
|
def format_source(trace)
|
58
58
|
grad_source = trace.select { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }.first
|
59
|
-
source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
|
60
|
-
[grad_source,
|
59
|
+
# source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
|
60
|
+
[grad_source, trace].compact.join("\n")
|
61
61
|
end
|
62
62
|
|
63
63
|
def shapes_fully_specified_and_equal(x, y)
|
data/lib/tensor_stream/images.rb
CHANGED
@@ -8,7 +8,7 @@ module TensorStream
|
|
8
8
|
end
|
9
9
|
|
10
10
|
def self.encode_png(contents, compression: -1, name: nil)
|
11
|
-
check_allowed_types(contents, [
|
11
|
+
check_allowed_types(contents, %i[uint8 uint16])
|
12
12
|
contents = convert_to_tensor(contents, dtype: :uint16)
|
13
13
|
_op(:encode_png, contents, compression: compression, name: name)
|
14
14
|
end
|