tensor_stream 0.9.2 → 0.9.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/tensor_stream/evaluator/base_evaluator.rb +3 -0
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +25 -0
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +24 -24
- data/lib/tensor_stream/evaluator/ruby/check_ops.rb +8 -0
- data/lib/tensor_stream/evaluator/ruby/images_ops.rb +16 -18
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +20 -4
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +9 -5
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +4 -4
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +16 -61
- data/lib/tensor_stream/graph_builder.rb +1 -0
- data/lib/tensor_stream/graph_serializers/graphml.rb +1 -1
- data/lib/tensor_stream/graph_serializers/pbtext.rb +1 -0
- data/lib/tensor_stream/helpers/infer_shape.rb +182 -0
- data/lib/tensor_stream/helpers/op_helper.rb +2 -2
- data/lib/tensor_stream/images.rb +1 -1
- data/lib/tensor_stream/math_gradients.rb +1 -1
- data/lib/tensor_stream/monkey_patches/array.rb +15 -0
- data/lib/tensor_stream/monkey_patches/float.rb +3 -0
- data/lib/tensor_stream/monkey_patches/integer.rb +3 -0
- data/lib/tensor_stream/monkey_patches/patch.rb +70 -0
- data/lib/tensor_stream/nn/nn_ops.rb +43 -9
- data/lib/tensor_stream/operation.rb +2 -153
- data/lib/tensor_stream/ops.rb +71 -56
- data/lib/tensor_stream/profile/report_tool.rb +3 -3
- data/lib/tensor_stream/tensor_shape.rb +9 -6
- data/lib/tensor_stream/train/adadelta_optimizer.rb +1 -1
- data/lib/tensor_stream/train/adagrad_optimizer.rb +1 -1
- data/lib/tensor_stream/train/adam_optimizer.rb +2 -2
- data/lib/tensor_stream/train/learning_rate_decay.rb +29 -0
- data/lib/tensor_stream/train/optimizer.rb +7 -6
- data/lib/tensor_stream/train/saver.rb +1 -0
- data/lib/tensor_stream/train/slot_creator.rb +2 -2
- data/lib/tensor_stream/trainer.rb +3 -0
- data/lib/tensor_stream/utils.rb +2 -2
- data/lib/tensor_stream/version.rb +1 -1
- data/lib/tensor_stream.rb +5 -1
- data/samples/rnn.rb +108 -0
- metadata +8 -2
@@ -0,0 +1,15 @@
|
|
1
|
+
class Array
|
2
|
+
include TensorStream::MonkeyPatch
|
3
|
+
|
4
|
+
def /(other)
|
5
|
+
TensorStream.convert_to_tensor(self) * other
|
6
|
+
end
|
7
|
+
|
8
|
+
def %(other)
|
9
|
+
TensorStream.convert_to_tensor(self) % other
|
10
|
+
end
|
11
|
+
|
12
|
+
def **(other)
|
13
|
+
TensorStream.convert_to_tensor(self)**other
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,70 @@
|
|
1
|
+
require 'pry-byebug'
|
2
|
+
module TensorStream
|
3
|
+
# various monkey patches to FixNum types
|
4
|
+
module MonkeyPatch
|
5
|
+
def self.included(klass)
|
6
|
+
ops = if klass == Array
|
7
|
+
{:+ => 'add', :- => 'sub', :* => 'mul'}
|
8
|
+
else
|
9
|
+
{:+ => 'add', :- => 'sub', :/ => 'div', :% => 'mod', :* => 'mul', :** => 'pow' }
|
10
|
+
end
|
11
|
+
|
12
|
+
ops.each do |m, name|
|
13
|
+
klass.send(:alias_method, :"_tensor_stream_#{name}_orig", m)
|
14
|
+
klass.send(:remove_method, m)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
|
18
|
+
def t(name = nil)
|
19
|
+
TensorStream.convert_to_tensor(self, name: name)
|
20
|
+
end
|
21
|
+
|
22
|
+
def +(other)
|
23
|
+
if other.is_a?(TensorStream::Tensor)
|
24
|
+
TensorStream.convert_to_tensor(self) + other
|
25
|
+
else
|
26
|
+
_tensor_stream_add_orig(other)
|
27
|
+
end
|
28
|
+
end
|
29
|
+
|
30
|
+
def -(other)
|
31
|
+
if other.is_a?(TensorStream::Tensor)
|
32
|
+
TensorStream.convert_to_tensor(self) - other
|
33
|
+
else
|
34
|
+
_tensor_stream_sub_orig(other)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def *(other)
|
39
|
+
if other.is_a?(TensorStream::Tensor)
|
40
|
+
TensorStream.convert_to_tensor(self) * other
|
41
|
+
else
|
42
|
+
_tensor_stream_mul_orig(other)
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
def /(other)
|
47
|
+
if other.is_a?(TensorStream::Tensor)
|
48
|
+
TensorStream.convert_to_tensor(self) * other
|
49
|
+
else
|
50
|
+
_tensor_stream_div_orig(other)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
|
54
|
+
def %(other)
|
55
|
+
if other.is_a?(TensorStream::Tensor)
|
56
|
+
TensorStream.convert_to_tensor(self) % other
|
57
|
+
else
|
58
|
+
_tensor_stream_mod_orig(other)
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
def **(other)
|
63
|
+
if other.is_a?(TensorStream::Tensor)
|
64
|
+
TensorStream.convert_to_tensor(self)**other
|
65
|
+
else
|
66
|
+
_tensor_stream_pow_orig(other)
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
@@ -11,6 +11,42 @@ module TensorStream
|
|
11
11
|
TensorStream.max(features, 0, name: "relu_#{name}")
|
12
12
|
end
|
13
13
|
|
14
|
+
def self.relu6(features, name: nil)
|
15
|
+
TensorStream.name_scope(name, "Relu6", values: [features]) do
|
16
|
+
features = TensorStream.convert_to_tensor(features, name: "features")
|
17
|
+
_op(:relu6, features, name: name)
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
##
|
22
|
+
# Computes dropout.
|
23
|
+
#
|
24
|
+
# With probability keep_prob, outputs the input element scaled up by 1 / keep_prob, otherwise outputs 0. The scaling is so that the expected sum is unchanged.
|
25
|
+
def self.dropout(x, keep_prob, noise_shape: nil, seed: nil, name: nil)
|
26
|
+
TensorStream.name_scope(name, "dropout", values: [x]) do
|
27
|
+
x = TensorStream.convert_to_tensor(x, name: "x")
|
28
|
+
raise TensorStream::ValueError, "x has to be a floating point tensor since it's going to be scaled. Got a #{x.data_type} tensor instead." unless fp_type?(x.data_type)
|
29
|
+
raise TensorStream::ValueError, "keep_prob must be a scalar tensor or a float in the range (0, 1], got #{keep_prob}" unless (0 < keep_prob && keep_prob <= 1)
|
30
|
+
|
31
|
+
return x if keep_prob.is_a?(Float) && keep_prob.to_f == 1.0
|
32
|
+
|
33
|
+
keep_prob = TensorStream.convert_to_tensor(keep_prob, dtype: x.dtype, name: "keep_prob")
|
34
|
+
return x if keep_prob.value == 1.0
|
35
|
+
|
36
|
+
noise_shape = if noise_shape.nil?
|
37
|
+
TensorStream.shape(x)
|
38
|
+
else
|
39
|
+
noise_shape
|
40
|
+
end
|
41
|
+
|
42
|
+
random_tensor = keep_prob
|
43
|
+
random_tensor += TensorStream.random_uniform(noise_shape, seed: seed, dtype: x.dtype)
|
44
|
+
|
45
|
+
binary_tensor = TensorStream.floor(random_tensor)
|
46
|
+
TensorStream.div(x, keep_prob) * binary_tensor
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
14
50
|
def self.sigmoid(input, name: nil)
|
15
51
|
TensorStream.sigmoid(input, name: name)
|
16
52
|
end
|
@@ -21,10 +57,10 @@ module TensorStream
|
|
21
57
|
|
22
58
|
def self.softmax_cross_entropy_with_logits_v2(labels: nil, logits: nil, name: nil)
|
23
59
|
TensorStream.name_scope(name, default: 'softmax_cross_entropy_with_logits', values: [logits, labels]) do
|
24
|
-
|
25
|
-
logits =
|
26
|
-
labels =
|
27
|
-
labels =
|
60
|
+
ts = TensorStream
|
61
|
+
logits = ts.convert_to_tensor(logits, name: 'logits')
|
62
|
+
labels = ts.convert_to_tensor(labels, name: 'labels')
|
63
|
+
labels = ts.cast(labels, logits.dtype)
|
28
64
|
|
29
65
|
output = _op(:softmax_cross_entropy_with_logits_v2, logits, labels)
|
30
66
|
output[0]
|
@@ -45,7 +81,6 @@ module TensorStream
|
|
45
81
|
raise TensorStream::ValueError, "Logits cannot be scalars - received shape #{logits.shape.shape}." if logits.shape.known? && logits.shape.scalar?
|
46
82
|
raise TensorStream::ValueError, "Rank mismatch: Rank of labels (received #{labels_static_shape.ndims}) " +
|
47
83
|
"should equal rank of logits minus 1 (received #{logits.shape.ndims})." if logits.shape.known? && (labels_static_shape.known? && labels_static_shape.ndims != logits.shape.ndims - 1)
|
48
|
-
|
49
84
|
if logits.shape.ndims == 2
|
50
85
|
cost = _op(:sparse_softmax_cross_entropy_with_logits,
|
51
86
|
precise_logits, labels, name: name)
|
@@ -57,14 +92,13 @@ module TensorStream
|
|
57
92
|
end
|
58
93
|
|
59
94
|
shape_checks = []
|
60
|
-
|
61
|
-
|
62
|
-
end
|
95
|
+
|
96
|
+
shape_checks << tf.assert_equal(tf.rank(labels), tf.rank(logits) - 1) unless static_shapes_fully_defined
|
63
97
|
|
64
98
|
tf.control_dependencies(shape_checks) do
|
65
99
|
num_classes = tf.shape(logits)[tf.rank(logits) - 1]
|
66
100
|
precise_logits = tf.reshape(precise_logits, [-1, num_classes])
|
67
|
-
labels =
|
101
|
+
labels = tf.reshape(labels, [-1])
|
68
102
|
cost = _op(:sparse_softmax_cross_entropy_with_logits, precise_logits, labels, name: name)
|
69
103
|
cost = tf.reshape(cost[0], labels_shape)
|
70
104
|
|
@@ -1,3 +1,4 @@
|
|
1
|
+
require 'tensor_stream/helpers/infer_shape'
|
1
2
|
module TensorStream
|
2
3
|
# TensorStream class that defines an operation
|
3
4
|
class Operation < Tensor
|
@@ -26,7 +27,7 @@ module TensorStream
|
|
26
27
|
@inputs = inputs.map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
|
27
28
|
@data_type = set_data_type(options[:data_type])
|
28
29
|
@is_const = infer_const
|
29
|
-
@shape = TensorShape.new(infer_shape)
|
30
|
+
@shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
|
30
31
|
@graph.add_node(self)
|
31
32
|
end
|
32
33
|
|
@@ -219,158 +220,6 @@ module TensorStream
|
|
219
220
|
|
220
221
|
private
|
221
222
|
|
222
|
-
def infer_shape
|
223
|
-
case operation
|
224
|
-
when :assign
|
225
|
-
possible_shape = if inputs[0] && inputs[0].shape.shape
|
226
|
-
inputs[0].shape.shape
|
227
|
-
else
|
228
|
-
inputs[1].shape.shape
|
229
|
-
end
|
230
|
-
|
231
|
-
possible_shape
|
232
|
-
when :index
|
233
|
-
return nil unless inputs[0].is_a?(Tensor)
|
234
|
-
return nil unless inputs[0].const_value
|
235
|
-
|
236
|
-
input_shape = inputs[0].shape
|
237
|
-
return nil unless input_shape.known?
|
238
|
-
|
239
|
-
s = input_shape.shape.dup
|
240
|
-
s.shift
|
241
|
-
s
|
242
|
-
when :mean, :prod, :sum
|
243
|
-
return [] if inputs[1].nil?
|
244
|
-
return nil if inputs[0].nil?
|
245
|
-
return nil unless inputs[0].shape.known?
|
246
|
-
|
247
|
-
input_shape = inputs[0].shape.shape
|
248
|
-
rank = input_shape.size
|
249
|
-
|
250
|
-
axis = inputs[1].const_value
|
251
|
-
return nil if axis.nil?
|
252
|
-
|
253
|
-
axis = [axis] unless axis.is_a?(Array)
|
254
|
-
axis = axis.map { |a| a < 0 ? rank - a.abs : a }
|
255
|
-
|
256
|
-
input_shape.each_with_index.map do |item, index|
|
257
|
-
if axis.include?(index)
|
258
|
-
next 1 if options[:keepdims]
|
259
|
-
next nil
|
260
|
-
end
|
261
|
-
item
|
262
|
-
end.compact
|
263
|
-
when :reshape
|
264
|
-
new_shape = inputs[1] && inputs[1].value ? inputs[1].value : nil
|
265
|
-
return nil if new_shape.nil?
|
266
|
-
return nil if inputs[0].shape.nil?
|
267
|
-
|
268
|
-
input_shape = inputs[0].shape.shape
|
269
|
-
return new_shape if input_shape.nil?
|
270
|
-
return nil if input_shape.include?(nil)
|
271
|
-
TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
|
272
|
-
when :flow_group
|
273
|
-
[]
|
274
|
-
when :zeros, :ones, :fill, :random_standard_normal, :random_uniform
|
275
|
-
a_shape = inputs[0] ? inputs[0].const_value : options[:shape]
|
276
|
-
return nil if a_shape.nil?
|
277
|
-
a_shape.is_a?(Array) ? a_shape : [a_shape]
|
278
|
-
when :zeros_like, :ones_like
|
279
|
-
inputs[0].shape.shape
|
280
|
-
when :shape
|
281
|
-
inputs[0].shape.shape ? [inputs[0].shape.shape.size] : nil
|
282
|
-
when :mat_mul
|
283
|
-
return nil if inputs[0].shape.shape.nil? || inputs[1].shape.shape.nil?
|
284
|
-
return [] if inputs[0].shape.shape.empty? || inputs[1].shape.shape.empty?
|
285
|
-
return nil if inputs[0].shape.shape.size != 2 || inputs[1].shape.shape.size != 2
|
286
|
-
|
287
|
-
shape1, m = if options[:transpose_a]
|
288
|
-
[inputs[0].shape.shape[0], inputs[0].shape.shape[1]]
|
289
|
-
else
|
290
|
-
[inputs[0].shape.shape[1], inputs[0].shape.shape[0]]
|
291
|
-
end
|
292
|
-
|
293
|
-
shape2, n = if options[:transpose_b]
|
294
|
-
[inputs[1].shape.shape[1], inputs[1].shape.shape[0]]
|
295
|
-
else
|
296
|
-
[inputs[1].shape.shape[0], inputs[1].shape.shape[1]]
|
297
|
-
end
|
298
|
-
|
299
|
-
return nil if shape1.nil? || shape2.nil? || shape1 < 0 || shape2 < 0
|
300
|
-
|
301
|
-
raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{shape1} != #{shape2}) #{inputs[0].shape.shape} vs #{inputs[1].shape.shape}" if shape1 != shape2
|
302
|
-
|
303
|
-
[m, n]
|
304
|
-
when :transpose
|
305
|
-
return nil unless shape_full_specified(inputs[0])
|
306
|
-
return nil if inputs[1].is_a?(Tensor)
|
307
|
-
|
308
|
-
rank = inputs[0].shape.shape.size
|
309
|
-
perm = inputs[1] || (0...rank).to_a.reverse
|
310
|
-
perm.map { |p| inputs[0].shape.shape[p] }
|
311
|
-
when :stack
|
312
|
-
return nil unless shape_full_specified(inputs[0])
|
313
|
-
|
314
|
-
axis = options[:axis] || 0
|
315
|
-
new_shape = [inputs.size]
|
316
|
-
inputs[0].shape.shape.inject(new_shape) { |ns, s| ns << s }
|
317
|
-
rank = inputs[0].shape.shape.size + 1
|
318
|
-
axis = rank + axis if axis < 0
|
319
|
-
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
320
|
-
rotated_shape.rotate! + new_shape
|
321
|
-
when :concat
|
322
|
-
return nil if inputs[0].value.nil?
|
323
|
-
|
324
|
-
axis = inputs[0].value # get axis
|
325
|
-
|
326
|
-
axis_size = 0
|
327
|
-
|
328
|
-
inputs[1..inputs.size].each do |input_item|
|
329
|
-
return nil if input_item.shape.shape.nil?
|
330
|
-
return nil if input_item.shape.shape[axis].nil?
|
331
|
-
|
332
|
-
axis_size += input_item.shape.shape[axis]
|
333
|
-
end
|
334
|
-
|
335
|
-
new_shape = inputs[1].shape.shape.dup
|
336
|
-
new_shape[axis] = axis_size
|
337
|
-
new_shape
|
338
|
-
when :slice, :squeeze
|
339
|
-
nil
|
340
|
-
when :tile
|
341
|
-
nil
|
342
|
-
when :expand_dims
|
343
|
-
nil
|
344
|
-
when :broadcast_gradient_args
|
345
|
-
nil
|
346
|
-
when :no_op
|
347
|
-
nil
|
348
|
-
when :softmax_cross_entropy_with_logits_v2, :sparse_softmax_cross_entropy_with_logits
|
349
|
-
nil
|
350
|
-
when :decode_png, :flow_dynamic_stitch, :dynamic_stitch, :gather
|
351
|
-
nil
|
352
|
-
when :eye
|
353
|
-
return [inputs[0].const_value, inputs[1].const_value] if inputs[0].const_value && inputs[1].const_value
|
354
|
-
|
355
|
-
nil
|
356
|
-
when :size
|
357
|
-
[]
|
358
|
-
when :unstack
|
359
|
-
return nil unless inputs[0].shape.known?
|
360
|
-
|
361
|
-
new_shape = inputs[0].shape.shape.dup
|
362
|
-
rank = new_shape.size - 1
|
363
|
-
axis = options[:axis] || 0
|
364
|
-
axis = rank + axis if axis < 0
|
365
|
-
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
366
|
-
rotated_shape.rotate!(-1) + new_shape
|
367
|
-
else
|
368
|
-
return nil if inputs[0].nil?
|
369
|
-
return inputs[0].shape.shape if inputs.size == 1
|
370
|
-
TensorShape.infer_shape(inputs[0].shape.shape, inputs[1].shape.shape) if inputs.size == 2 && inputs[0] && inputs[1] && inputs[0].shape.known? && inputs[1].shape.known?
|
371
|
-
end
|
372
|
-
end
|
373
|
-
|
374
223
|
def propagate_consumer(consumer)
|
375
224
|
super
|
376
225
|
@inputs.compact.each do |input|
|
data/lib/tensor_stream/ops.rb
CHANGED
@@ -19,7 +19,7 @@ module TensorStream
|
|
19
19
|
# +axis+ Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0
|
20
20
|
# +output_type+ Output data type defaults to int32
|
21
21
|
def argmax(input, axis = nil, name: nil, dimension: nil, output_type: :int32)
|
22
|
-
_op(:argmax, input,
|
22
|
+
_op(:argmax, input, axis, name: name, dimension: dimension, data_type: output_type)
|
23
23
|
end
|
24
24
|
|
25
25
|
##
|
@@ -31,7 +31,21 @@ module TensorStream
|
|
31
31
|
# +axis+ Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0
|
32
32
|
# +output_type+ Output data type defaults to int32
|
33
33
|
def argmin(input, axis = nil, name: nil, dimension: nil, output_type: :int32)
|
34
|
-
_op(:argmin, input,
|
34
|
+
_op(:argmin, input, axis, name: name, dimension: dimension, data_type: output_type)
|
35
|
+
end
|
36
|
+
|
37
|
+
##
|
38
|
+
# Assert the condition x == y holds element-wise.
|
39
|
+
#
|
40
|
+
# Argmuments
|
41
|
+
#
|
42
|
+
# +x+ Numeric Tensor.
|
43
|
+
# +y+ Numeric Tensor, same dtype as and broadcastable to x.
|
44
|
+
#
|
45
|
+
# Returns
|
46
|
+
# Op that raises InvalidArgumentError if x == y is false
|
47
|
+
def assert_equal(x, y, data: nil, summarize: nil, message: nil, name: nil)
|
48
|
+
_op(:assert_equal, x, y, data: data, summarize: summarize, message: message, name: name)
|
35
49
|
end
|
36
50
|
|
37
51
|
##
|
@@ -67,15 +81,15 @@ module TensorStream
|
|
67
81
|
##
|
68
82
|
# Outputs random values from a uniform distribution.
|
69
83
|
def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
|
70
|
-
options = {
|
71
|
-
_op(:random_uniform,
|
84
|
+
options = { dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name }
|
85
|
+
_op(:random_uniform, shape, nil, options)
|
72
86
|
end
|
73
87
|
|
74
88
|
##
|
75
89
|
# Outputs random values from a normal distribution.
|
76
90
|
def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
|
77
|
-
options = {
|
78
|
-
_op(:random_standard_normal,
|
91
|
+
options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
|
92
|
+
_op(:random_standard_normal, shape, nil, options)
|
79
93
|
end
|
80
94
|
|
81
95
|
##
|
@@ -83,7 +97,7 @@ module TensorStream
|
|
83
97
|
#
|
84
98
|
# When executed in a graph, this op outputs its input tensor as-is.
|
85
99
|
def stop_gradient(tensor, options = {})
|
86
|
-
_op(:stop_gradient, tensor,
|
100
|
+
_op(:stop_gradient, tensor, options)
|
87
101
|
end
|
88
102
|
|
89
103
|
##
|
@@ -190,13 +204,13 @@ module TensorStream
|
|
190
204
|
##
|
191
205
|
# Creates a tensor with all elements set to zero
|
192
206
|
def zeros(shape, dtype: :float32, name: nil)
|
193
|
-
_op(:zeros, shape,
|
207
|
+
_op(:zeros, shape, data_type: dtype, name: name)
|
194
208
|
end
|
195
209
|
|
196
210
|
##
|
197
211
|
# Creates a tensor with all elements set to 1.
|
198
212
|
def ones(shape, dtype: :float32, name: nil)
|
199
|
-
_op(:ones, shape,
|
213
|
+
_op(:ones, shape, data_type: dtype, name: name)
|
200
214
|
end
|
201
215
|
|
202
216
|
##
|
@@ -302,37 +316,38 @@ module TensorStream
|
|
302
316
|
res = _op(:split, value, num_or_size_splits, axis, name: name)
|
303
317
|
|
304
318
|
pieces = if value.shape.known? && num_or_size_splits.is_const && num_or_size_splits.value && axis.is_const
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
elsif num_or_size_splits.shape.ndims == 1
|
316
|
-
raise TensorStream::ValueError, "Sum of splits do not match total dimen in axis #{value.shape.shape[axis.value]} != #{num_or_size_splits.value.reduce(:+)}" if value.shape.shape[axis.value] != num_or_size_splits.value.reduce(:+)
|
317
|
-
num_or_size_splits.value.collect do |v|
|
318
|
-
new_shape = value.shape.shape.dup
|
319
|
-
new_shape[axis.value] = v
|
320
|
-
new_shape
|
321
|
-
end
|
322
|
-
else
|
323
|
-
raise TensorStream::ValueError, "Scalar or 1D Tensor expected for num_or_size_splits"
|
319
|
+
if num_or_size_splits.shape.scalar?
|
320
|
+
raise TensorStream::ValueError, "num_or_size_splits must divide dimension #{value.shape.shape[axis.value]} evenly" unless (value.shape.shape[axis.value] % num_or_size_splits.value).zero?
|
321
|
+
|
322
|
+
div = num_or_size_splits.value
|
323
|
+
n = value.shape.shape[axis.value] / div
|
324
|
+
|
325
|
+
Array.new(div) do
|
326
|
+
new_shape = value.shape.shape.dup
|
327
|
+
new_shape[axis.value] = n
|
328
|
+
new_shape
|
324
329
|
end
|
325
|
-
|
326
|
-
raise TensorStream::ValueError, "
|
330
|
+
elsif num_or_size_splits.shape.ndims == 1
|
331
|
+
raise TensorStream::ValueError, "Sum of splits do not match total dimen in axis #{value.shape.shape[axis.value]} != #{num_or_size_splits.value.reduce(:+)}" if value.shape.shape[axis.value] != num_or_size_splits.value.reduce(:+)
|
327
332
|
|
328
|
-
|
333
|
+
num_or_size_splits.value.collect do |v|
|
334
|
+
new_shape = value.shape.shape.dup
|
335
|
+
new_shape[axis.value] = v
|
336
|
+
new_shape
|
337
|
+
end
|
338
|
+
else
|
339
|
+
raise TensorStream::ValueError, "Scalar or 1D Tensor expected for num_or_size_splits"
|
329
340
|
end
|
341
|
+
else
|
342
|
+
raise TensorStream::ValueError, "Cannot automatically determine num, please specify num: in options" if num.nil?
|
343
|
+
|
344
|
+
Array.new(num) { nil }
|
345
|
+
end
|
330
346
|
|
331
347
|
pieces.collect.with_index do |shape, i|
|
332
348
|
op = index(res, i, name: "split/index:#{i}")
|
333
|
-
if shape
|
334
|
-
|
335
|
-
end
|
349
|
+
op.shape = TensorShape.new(shape) if shape
|
350
|
+
|
336
351
|
op
|
337
352
|
end
|
338
353
|
end
|
@@ -354,20 +369,20 @@ module TensorStream
|
|
354
369
|
##
|
355
370
|
# Computes square of x element-wise.
|
356
371
|
def square(tensor, name: nil)
|
357
|
-
_op(:square, tensor,
|
372
|
+
_op(:square, tensor, name: name)
|
358
373
|
end
|
359
374
|
|
360
375
|
##
|
361
376
|
# Rounds the values of a tensor to the nearest integer, element-wise
|
362
377
|
def round(tensor, name: nil)
|
363
378
|
check_allowed_types(tensor, FLOATING_POINT_TYPES)
|
364
|
-
_op(:round, tensor,
|
379
|
+
_op(:round, tensor, name: name)
|
365
380
|
end
|
366
381
|
|
367
382
|
##
|
368
383
|
# Computes the reciprocal of x element-wise.
|
369
384
|
def reciprocal(tensor, name: nil)
|
370
|
-
_op(:reciprocal, tensor,
|
385
|
+
_op(:reciprocal, tensor, name: name)
|
371
386
|
end
|
372
387
|
|
373
388
|
##
|
@@ -495,7 +510,7 @@ module TensorStream
|
|
495
510
|
input = convert_to_tensor(input)
|
496
511
|
return input if input.data_type == dtype
|
497
512
|
|
498
|
-
_op(:cast, input,
|
513
|
+
_op(:cast, input, data_type: dtype, name: name)
|
499
514
|
end
|
500
515
|
|
501
516
|
##
|
@@ -509,7 +524,7 @@ module TensorStream
|
|
509
524
|
##
|
510
525
|
# Computes numerical negative value element-wise.
|
511
526
|
def negate(input, name: nil)
|
512
|
-
_op(:negate, input,
|
527
|
+
_op(:negate, input, name: name)
|
513
528
|
end
|
514
529
|
|
515
530
|
##
|
@@ -539,7 +554,7 @@ module TensorStream
|
|
539
554
|
# of the same type and shape as tensor with all elements set to zero.
|
540
555
|
# Optionally, you can use dtype to specify a new type for the returned tensor.
|
541
556
|
def zeros_like(tensor, dtype: nil, name: nil)
|
542
|
-
_op(:zeros_like, tensor,
|
557
|
+
_op(:zeros_like, tensor, data_type: dtype, name: name)
|
543
558
|
end
|
544
559
|
|
545
560
|
##
|
@@ -548,13 +563,13 @@ module TensorStream
|
|
548
563
|
# tensor of the same type and shape as tensor with all elements set to 1.
|
549
564
|
# Optionally, you can specify a new type (dtype) for the returned tensor.
|
550
565
|
def ones_like(tensor, dtype: nil, name: nil)
|
551
|
-
_op(:ones_like, tensor,
|
566
|
+
_op(:ones_like, tensor, data_type: dtype, name: name)
|
552
567
|
end
|
553
568
|
|
554
569
|
##
|
555
570
|
# Return a tensor with the same shape and contents as input.
|
556
571
|
def identity(input, name: nil)
|
557
|
-
_op(:identity, input,
|
572
|
+
_op(:identity, input, name: name)
|
558
573
|
end
|
559
574
|
|
560
575
|
##
|
@@ -591,7 +606,7 @@ module TensorStream
|
|
591
606
|
##
|
592
607
|
# Computes the absolute value of a tensor.
|
593
608
|
def abs(input, name: nil)
|
594
|
-
_op(:abs, input,
|
609
|
+
_op(:abs, input, name: name)
|
595
610
|
end
|
596
611
|
|
597
612
|
##
|
@@ -599,63 +614,63 @@ module TensorStream
|
|
599
614
|
# y = sign(x) = -1 if x < 0; 0 if x == 0 or tf.is_nan(x); 1 if x > 0.
|
600
615
|
# Zero is returned for NaN inputs.
|
601
616
|
def sign(input, name: nil)
|
602
|
-
_op(:sign, input,
|
617
|
+
_op(:sign, input, name: name)
|
603
618
|
end
|
604
619
|
|
605
620
|
##
|
606
621
|
# Computes sin of input element-wise.
|
607
622
|
def sin(input, name: nil)
|
608
623
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
609
|
-
_op(:sin, input,
|
624
|
+
_op(:sin, input, name: name)
|
610
625
|
end
|
611
626
|
|
612
627
|
##
|
613
628
|
# Computes cos of input element-wise.
|
614
629
|
def cos(input, name: nil)
|
615
630
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
616
|
-
_op(:cos, input,
|
631
|
+
_op(:cos, input, name: name)
|
617
632
|
end
|
618
633
|
|
619
634
|
##
|
620
635
|
# Computes tan of input element-wise.
|
621
636
|
def tan(input, name: nil)
|
622
637
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
623
|
-
_op(:tan, input,
|
638
|
+
_op(:tan, input, name: name)
|
624
639
|
end
|
625
640
|
|
626
641
|
##
|
627
642
|
# Computes tanh of input element-wise.
|
628
643
|
def tanh(input, name: nil)
|
629
644
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
630
|
-
_op(:tanh, input,
|
645
|
+
_op(:tanh, input, name: name)
|
631
646
|
end
|
632
647
|
|
633
648
|
##
|
634
649
|
# Computes sqrt of input element-wise.
|
635
650
|
def sqrt(input, name: nil)
|
636
651
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
637
|
-
_op(:sqrt, input,
|
652
|
+
_op(:sqrt, input, name: name)
|
638
653
|
end
|
639
654
|
|
640
655
|
##
|
641
656
|
# Computes natural logarithm of x element-wise.
|
642
657
|
def log(input, name: nil)
|
643
658
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
644
|
-
_op(:log, input,
|
659
|
+
_op(:log, input, name: name)
|
645
660
|
end
|
646
661
|
|
647
662
|
##
|
648
663
|
# Computes natural logarithm of (1 + x) element-wise.
|
649
664
|
def log1p(input, name: nil)
|
650
665
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
651
|
-
_op(:log1p, input,
|
666
|
+
_op(:log1p, input, name: name)
|
652
667
|
end
|
653
668
|
|
654
669
|
##
|
655
670
|
# Computes exponential of x element-wise.
|
656
671
|
def exp(input, name: nil)
|
657
672
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
658
|
-
_op(:exp, input,
|
673
|
+
_op(:exp, input, name: name)
|
659
674
|
end
|
660
675
|
|
661
676
|
##
|
@@ -675,7 +690,7 @@ module TensorStream
|
|
675
690
|
# Computes sigmoid of x element-wise.
|
676
691
|
def sigmoid(input, name: nil)
|
677
692
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
678
|
-
_op(:sigmoid, input,
|
693
|
+
_op(:sigmoid, input, name: name)
|
679
694
|
end
|
680
695
|
|
681
696
|
##
|
@@ -698,14 +713,14 @@ module TensorStream
|
|
698
713
|
# Pads a tensor.
|
699
714
|
# This operation pads a tensor according to the paddings you specify.
|
700
715
|
def pad(tensor, paddings, mode: 'CONSTANT', name: nil)
|
701
|
-
_op(:pad, tensor,
|
716
|
+
_op(:pad, tensor, paddings, mode: mode, name: name)
|
702
717
|
end
|
703
718
|
|
704
719
|
##
|
705
720
|
# Checks a tensor for NaN and Inf values.
|
706
721
|
# When run, reports an InvalidArgument error if tensor has any values that are not a number (NaN) or infinity (Inf). Otherwise, passes tensor as-is.
|
707
722
|
def check_numerics(tensor, message, name: nil)
|
708
|
-
_op(:check_numerics, tensor,
|
723
|
+
_op(:check_numerics, tensor, message: message, name: name)
|
709
724
|
end
|
710
725
|
|
711
726
|
def size(tensor, name: nil, out_type: :int32)
|
@@ -730,7 +745,6 @@ module TensorStream
|
|
730
745
|
_op(:gather, params, indices, validate_indices: validate_indices, name: name, axis: axis)
|
731
746
|
end
|
732
747
|
|
733
|
-
|
734
748
|
##
|
735
749
|
# Stacks a list of rank-R tensors into one rank-(R+1) tensor.
|
736
750
|
#
|
@@ -753,6 +767,7 @@ module TensorStream
|
|
753
767
|
new_shape[0]
|
754
768
|
else
|
755
769
|
raise TensorStream::ValueError, "num is unspecified and cannot be inferred." if num.nil?
|
770
|
+
|
756
771
|
num
|
757
772
|
end
|
758
773
|
|