tensor_stream 0.8.5 → 0.8.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +9 -7
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +17 -2
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +92 -10
- data/lib/tensor_stream/evaluator/ruby/check_ops.rb +9 -0
- data/lib/tensor_stream/evaluator/ruby/images_ops.rb +1 -1
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +38 -38
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +87 -12
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +16 -13
- data/lib/tensor_stream/graph.rb +2 -0
- data/lib/tensor_stream/helpers/op_helper.rb +1 -0
- data/lib/tensor_stream/math_gradients.rb +86 -5
- data/lib/tensor_stream/nn/nn_ops.rb +47 -0
- data/lib/tensor_stream/operation.rb +25 -4
- data/lib/tensor_stream/ops.rb +160 -6
- data/lib/tensor_stream/session.rb +1 -0
- data/lib/tensor_stream/tensor.rb +4 -7
- data/lib/tensor_stream/tensor_shape.rb +10 -1
- data/lib/tensor_stream/train/adagrad_optimizer.rb +46 -0
- data/lib/tensor_stream/train/optimizer.rb +12 -0
- data/lib/tensor_stream/train/rmsprop_optimizer.rb +84 -0
- data/lib/tensor_stream/train/slot_creator.rb +14 -9
- data/lib/tensor_stream/trainer.rb +2 -0
- data/lib/tensor_stream/utils.rb +6 -4
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +4 -3
- data/samples/linear_regression.rb +3 -0
- data/samples/rnn.rb +105 -0
- metadata +6 -2
@@ -7,6 +7,7 @@ require 'tensor_stream/evaluator/ruby/nn_ops'
|
|
7
7
|
require 'tensor_stream/evaluator/ruby/array_ops'
|
8
8
|
require 'tensor_stream/evaluator/ruby/random_ops'
|
9
9
|
require 'tensor_stream/evaluator/ruby/images_ops'
|
10
|
+
require 'tensor_stream/evaluator/ruby/check_ops'
|
10
11
|
|
11
12
|
module TensorStream
|
12
13
|
module Evaluator
|
@@ -39,6 +40,7 @@ module TensorStream
|
|
39
40
|
include TensorStream::ArrayOps
|
40
41
|
include TensorStream::RandomOps
|
41
42
|
include TensorStream::ImagesOps
|
43
|
+
include TensorStream::CheckOps
|
42
44
|
|
43
45
|
def run(tensor, execution_context)
|
44
46
|
return tensor.map { |t| run(t, execution_context) } if tensor.is_a?(Array) && !tensor.empty? && tensor[0].is_a?(Tensor)
|
@@ -116,10 +118,10 @@ module TensorStream
|
|
116
118
|
end
|
117
119
|
|
118
120
|
register_op(:cast) do |context, tensor, inputs|
|
119
|
-
call_op(
|
121
|
+
call_op(tensor, inputs[0], context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
|
120
122
|
end
|
121
123
|
|
122
|
-
register_op(:sign) do |context,
|
124
|
+
register_op(:sign) do |context, tensor, inputs|
|
123
125
|
func = lambda { |x, _b|
|
124
126
|
if x.zero? || (x.is_a?(Float) && x.nan?)
|
125
127
|
0
|
@@ -132,7 +134,7 @@ module TensorStream
|
|
132
134
|
end
|
133
135
|
}
|
134
136
|
|
135
|
-
call_op(
|
137
|
+
call_op(tensor, inputs[0], context, func)
|
136
138
|
end
|
137
139
|
|
138
140
|
register_op(:logical_and) do |context, tensor, inputs|
|
@@ -217,10 +219,6 @@ module TensorStream
|
|
217
219
|
call_vector_op(tensor, :greater_equal, a, b, context, ->(t, u) { t <= u })
|
218
220
|
end
|
219
221
|
|
220
|
-
register_op :shape do |_context, tensor, inputs|
|
221
|
-
shape_eval(inputs[0], tensor.options[:out_type])
|
222
|
-
end
|
223
|
-
|
224
222
|
register_op :broadcast_transform do |_context, _tensor, inputs|
|
225
223
|
broadcast(inputs[0], inputs[1])
|
226
224
|
end
|
@@ -266,14 +264,13 @@ module TensorStream
|
|
266
264
|
raise TensorStream::InvalidArgumentError, "#{message} Invalid argument" if t.nan? || t.infinite?
|
267
265
|
t
|
268
266
|
}
|
269
|
-
call_op(
|
267
|
+
call_op(tensor, inputs[0], context, f)
|
270
268
|
end
|
271
269
|
|
272
270
|
def eval_operation(tensor, child_context)
|
273
271
|
return @context[tensor.name] if @context.key?(tensor.name)
|
274
|
-
# puts "ruby: #{tensor.name}"
|
275
272
|
invoke(tensor, child_context).tap do |result|
|
276
|
-
|
273
|
+
# puts "ruby: #{tensor.name}"
|
277
274
|
if tensor.breakpoint
|
278
275
|
a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
|
279
276
|
b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
|
@@ -412,8 +409,6 @@ module TensorStream
|
|
412
409
|
def call_op(op, a, child_context, func)
|
413
410
|
a = complete_eval(a, child_context)
|
414
411
|
process_function_op(a, func)
|
415
|
-
rescue FullEvalNotPossible
|
416
|
-
TensorStream.send(op.to_sym, a)
|
417
412
|
end
|
418
413
|
|
419
414
|
def call_vector_op(tensor, op, a, b, child_context, func)
|
@@ -484,7 +479,15 @@ module TensorStream
|
|
484
479
|
|
485
480
|
var = if placeholder.is_a?(Placeholder)
|
486
481
|
@context[placeholder.name.to_sym].tap do |c|
|
487
|
-
raise "missing placeholder #{placeholder.name}" if c.nil?
|
482
|
+
raise TensorStream::ValueError, "missing placeholder #{placeholder.name}" if c.nil?
|
483
|
+
if placeholder.shape.shape
|
484
|
+
value_shape = shape_eval(c)
|
485
|
+
placeholder_shape = placeholder.shape.shape
|
486
|
+
placeholder_shape.zip(value_shape).each do |p_shape, v_shape|
|
487
|
+
next if p_shape.nil?
|
488
|
+
raise TensorStream::ValueError, "placeholder expects #{placeholder_shape}, got #{value_shape}" if p_shape != v_shape
|
489
|
+
end
|
490
|
+
end
|
488
491
|
end
|
489
492
|
else
|
490
493
|
placeholder
|
data/lib/tensor_stream/graph.rb
CHANGED
@@ -90,8 +90,10 @@ module TensorStream
|
|
90
90
|
@node_keys << node.name
|
91
91
|
@nodes[node.name] = node
|
92
92
|
@constants[node.name] = node if node.is_const
|
93
|
+
# puts "adding node"
|
93
94
|
node.send(:propagate_outputs)
|
94
95
|
node.send(:propagate_consumer, node)
|
96
|
+
# puts "#{node.name}"
|
95
97
|
node.value = node.eval if @eager_execution
|
96
98
|
end
|
97
99
|
|
@@ -10,7 +10,7 @@ module TensorStream
|
|
10
10
|
|
11
11
|
def self.derivative(tensor, wrt_dx, options = {})
|
12
12
|
return i_op(:ones_like, tensor) if tensor.equal?(wrt_dx)
|
13
|
-
return i_op(:zeros_like,
|
13
|
+
return i_op(:zeros_like, wrt_dx) unless wrt_dx.consumers.include?(tensor.name)
|
14
14
|
|
15
15
|
nodes_to_compute = wrt_dx.consumers.select do |t|
|
16
16
|
node = tensor.graph.nodes[t]
|
@@ -30,12 +30,15 @@ module TensorStream
|
|
30
30
|
computed_op = _compute_derivative(tensor, grad)
|
31
31
|
|
32
32
|
if computed_op.is_a?(Array)
|
33
|
-
computed_op.each_with_index.collect do |op_grad, index|
|
33
|
+
grads = computed_op.each_with_index.collect do |op_grad, index|
|
34
34
|
next if op_grad.nil?
|
35
35
|
next unless nodes_to_compute.include?(tensor.inputs[index].name)
|
36
36
|
|
37
37
|
_propagate(op_grad, tensor.inputs[index], stop_tensor, nodes_to_compute, stop_gradients)
|
38
|
-
end.compact
|
38
|
+
end.compact
|
39
|
+
|
40
|
+
return nil if grads.empty?
|
41
|
+
grads.size > 1 ? ts.add_n(grads) : grads[0]
|
39
42
|
else
|
40
43
|
return nil if computed_op.nil?
|
41
44
|
_propagate(computed_op, tensor.inputs[0], stop_tensor, nodes_to_compute, stop_gradients)
|
@@ -260,7 +263,11 @@ module TensorStream
|
|
260
263
|
i_op(:softmax_grad, x, grad)
|
261
264
|
when :softmax_cross_entropy_with_logits_v2
|
262
265
|
output = node
|
263
|
-
|
266
|
+
logits = node.inputs[0]
|
267
|
+
[_broadcast_mul(grad, output[1]), -ts.nn.log_softmax(logits)]
|
268
|
+
when :sparse_softmax_cross_entropy_with_logits
|
269
|
+
output = node
|
270
|
+
[_broadcast_mul(grad, output[1]), nil]
|
264
271
|
when :floor, :ceil
|
265
272
|
# non differentiable
|
266
273
|
nil
|
@@ -273,13 +280,51 @@ module TensorStream
|
|
273
280
|
when :transpose
|
274
281
|
return [ts.transpose(grad, ts.invert_permutation(y)), nil]
|
275
282
|
when :index
|
276
|
-
|
283
|
+
#hack!! not sure how to fix this yet
|
284
|
+
return grad if %i[softmax_cross_entropy_with_logits_v2 sparse_softmax_cross_entropy_with_logits].include?(node.inputs[0].operation)
|
285
|
+
|
286
|
+
if node.inputs[0].shape.known? && node.inputs[1].value
|
287
|
+
multiplier = node.inputs[0].shape.shape[0]
|
288
|
+
filler = ts.zeros_like(grad)
|
289
|
+
|
290
|
+
res = Array.new(multiplier) { |index|
|
291
|
+
index == node.inputs[1].value ? grad : filler
|
292
|
+
}
|
293
|
+
[res]
|
294
|
+
end
|
295
|
+
when :squeeze
|
296
|
+
_reshape_to_input(node, grad)
|
297
|
+
when :expand_dims
|
298
|
+
[_reshape_to_input(node, grad), nil]
|
299
|
+
when :concat
|
300
|
+
_concat_grad_helper(node, grad, 1, node.inputs.size, 0)
|
301
|
+
when :reshape
|
302
|
+
[ts.reshape(grad, ts.shape(node.inputs[0])), nil]
|
303
|
+
when :stack
|
304
|
+
res = ts.unstack(grad, num: node.inputs.size, axis: node.options[:axis])
|
305
|
+
Array.new(node.inputs.size) { |i| res[i] }
|
306
|
+
when :unstack
|
307
|
+
ts.stack(grad, axis: node.options[:axis])
|
308
|
+
when :cast
|
309
|
+
t = %i[float16 float32 float64]
|
310
|
+
src_type = node.inputs[0].data_type
|
311
|
+
dst_type = grad.data_type
|
312
|
+
|
313
|
+
if t.key?(src_type) && t.key?(dst_type)
|
314
|
+
ts.cast(grad, src_type)
|
315
|
+
else
|
316
|
+
nil
|
317
|
+
end
|
277
318
|
else
|
278
319
|
raise "no derivative op for #{node.operation}"
|
279
320
|
end
|
280
321
|
end
|
281
322
|
end
|
282
323
|
|
324
|
+
def self._reshape_to_input(node, grad)
|
325
|
+
ts.reshape(grad, ts.shape(node.inputs[0]))
|
326
|
+
end
|
327
|
+
|
283
328
|
def self._broadcast_gradient_args(input_a, input_b)
|
284
329
|
res = _op(:broadcast_gradient_args, input_a, input_b)
|
285
330
|
[res[0], res[1]]
|
@@ -335,5 +380,41 @@ module TensorStream
|
|
335
380
|
arr.each { |a| return true if a.equal?(obj) }
|
336
381
|
false
|
337
382
|
end
|
383
|
+
|
384
|
+
def self._extract_input_shapes(inputs)
|
385
|
+
sizes = []
|
386
|
+
fully_known = true
|
387
|
+
inputs.each do |x|
|
388
|
+
input_shape = ts.shape(x)
|
389
|
+
unless input_shape.is_const
|
390
|
+
fully_known = false
|
391
|
+
break
|
392
|
+
end
|
393
|
+
sizes << input_shape.value
|
394
|
+
end
|
395
|
+
|
396
|
+
if fully_known
|
397
|
+
sizes
|
398
|
+
else
|
399
|
+
ts.shape_n(inputs)
|
400
|
+
end
|
401
|
+
end
|
402
|
+
|
403
|
+
def self._concat_grad_helper(op, grad, start_value_index, end_value_index, dim_index)
|
404
|
+
# Degenerate concatenation, just return grad.
|
405
|
+
if op.inputs.size == 2
|
406
|
+
return end_value_index <= dim_index ? [grad] + [nil] : [nil] + [grad]
|
407
|
+
end
|
408
|
+
concat_dim = op.inputs[dim_index]
|
409
|
+
input_values = op.inputs[start_value_index..end_value_index]
|
410
|
+
non_neg_concat_dim = concat_dim % ts.rank(input_values[0])
|
411
|
+
sizes = _extract_input_shapes(input_values)
|
412
|
+
|
413
|
+
slicer = ts.slice(ts.stack(sizes, axis: 1), [non_neg_concat_dim, 0], [1, -1])
|
414
|
+
sizes = ts.squeeze(slicer)
|
415
|
+
|
416
|
+
out_grads = ts.split(grad, sizes, axis: non_neg_concat_dim, num: op.inputs.size - 1)
|
417
|
+
end_value_index <= dim_index ? out_grads + [nil] : [nil] + out_grads
|
418
|
+
end
|
338
419
|
end
|
339
420
|
end
|
@@ -25,11 +25,58 @@ module TensorStream
|
|
25
25
|
logits = tf.convert_to_tensor(logits, name: 'logits')
|
26
26
|
labels = tf.convert_to_tensor(labels, name: 'labels')
|
27
27
|
labels = tf.cast(labels, logits.dtype)
|
28
|
+
|
28
29
|
output = _op(:softmax_cross_entropy_with_logits_v2, logits, labels)
|
29
30
|
output[0]
|
30
31
|
end
|
31
32
|
end
|
32
33
|
|
34
|
+
def self.sparse_softmax_cross_entropy_with_logits(labels: nil, logits: nil, name: nil)
|
35
|
+
TensorStream.name_scope(name, default: "SparseSoftmaxCrossEntropyWithLogits", values: [logits, labels]) do
|
36
|
+
tf = TensorStream
|
37
|
+
labels = tf.convert_to_tensor(labels)
|
38
|
+
logits = tf.convert_to_tensor(logits)
|
39
|
+
precise_logits = logits.data_type == :float16 ? tf.cast(logits, :float32) : logits
|
40
|
+
|
41
|
+
labels_static_shape = labels.shape
|
42
|
+
labels_shape = tf.shape(labels)
|
43
|
+
static_shapes_fully_defined = labels_static_shape.known? && logits.shape.known?
|
44
|
+
|
45
|
+
raise TensorStream::ValueError, "Logits cannot be scalars - received shape #{logits.shape.shape}." if logits.shape.known? && logits.shape.scalar?
|
46
|
+
raise TensorStream::ValueError, "Rank mismatch: Rank of labels (received #{labels_static_shape.ndims}) " +
|
47
|
+
"should equal rank of logits minus 1 (received #{logits.shape.ndims})." if logits.shape.known? && (labels_static_shape.known? && labels_static_shape.ndims != logits.shape.ndims - 1)
|
48
|
+
|
49
|
+
if logits.shape.ndims == 2
|
50
|
+
cost = _op(:sparse_softmax_cross_entropy_with_logits,
|
51
|
+
precise_logits, labels, name: name)
|
52
|
+
if logits.data_type == :float16
|
53
|
+
return tf.cast(cost[0], :float16)
|
54
|
+
else
|
55
|
+
return cost[0]
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
shape_checks = []
|
60
|
+
if !static_shapes_fully_defined
|
61
|
+
shape_checks << tf.append(tf.assert_equal(tf.rank(labels), tf.rank(logits) - 1))
|
62
|
+
end
|
63
|
+
|
64
|
+
tf.control_dependencies(shape_checks) do
|
65
|
+
num_classes = tf.shape(logits)[tf.rank(logits) - 1]
|
66
|
+
precise_logits = tf.reshape(precise_logits, [-1, num_classes])
|
67
|
+
labels = array_ops.reshape(labels, [-1])
|
68
|
+
cost = _op(:sparse_softmax_cross_entropy_with_logits, precise_logits, labels, name: name)
|
69
|
+
cost = tf.reshape(cost[0], labels_shape)
|
70
|
+
|
71
|
+
if logits.data_type == :float16
|
72
|
+
tf.cast(cost, :float16)
|
73
|
+
else
|
74
|
+
cost
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
33
80
|
# Computes log softmax activations.
|
34
81
|
def self.log_softmax(logits, axis: -1, name: nil)
|
35
82
|
_op(:log_softmax, logits, axis: axis, name: name)
|
@@ -45,7 +45,7 @@ module TensorStream
|
|
45
45
|
def infer_const
|
46
46
|
return false if breakpoint
|
47
47
|
case operation
|
48
|
-
when :random_standard_normal, :random_uniform, :glorot_uniform, :print
|
48
|
+
when :random_standard_normal, :random_uniform, :glorot_uniform, :print, :check_numerics
|
49
49
|
false
|
50
50
|
else
|
51
51
|
non_const = @inputs.compact.find { |input| !input.is_const }
|
@@ -59,10 +59,12 @@ module TensorStream
|
|
59
59
|
@inputs[1].data_type
|
60
60
|
when :greater, :less, :equal, :not_equal, :greater_equal, :less_equal, :logical_and
|
61
61
|
:boolean
|
62
|
-
when :shape, :rank
|
62
|
+
when :shape, :rank, :shape_n
|
63
63
|
options[:out_type] || :int32
|
64
64
|
when :random_standard_normal, :random_uniform, :glorot_uniform
|
65
65
|
passed_data_type || :float32
|
66
|
+
when :concat
|
67
|
+
@inputs[1].data_type
|
66
68
|
when :index
|
67
69
|
if @inputs[0].is_a?(ControlFlow)
|
68
70
|
|
@@ -274,9 +276,28 @@ module TensorStream
|
|
274
276
|
rank = inputs[0].shape.shape.size + 1
|
275
277
|
axis = rank + axis if axis < 0
|
276
278
|
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
277
|
-
rotated_shape.rotate! + new_shape
|
279
|
+
return rotated_shape.rotate! + new_shape
|
280
|
+
when :concat
|
281
|
+
return nil if inputs[0].value.nil?
|
282
|
+
|
283
|
+
axis = inputs[0].value # get axis
|
284
|
+
|
285
|
+
axis_size = 0
|
286
|
+
|
287
|
+
inputs[1..inputs.size].each do |input_item|
|
288
|
+
return nil if input_item.shape.shape.nil?
|
289
|
+
return nil if input_item.shape.shape[axis].nil?
|
290
|
+
|
291
|
+
axis_size += input_item.shape.shape[axis]
|
292
|
+
end
|
293
|
+
|
294
|
+
new_shape = inputs[1].shape.shape.dup
|
295
|
+
new_shape[axis] = axis_size
|
296
|
+
return new_shape
|
297
|
+
when :slice, :squeeze
|
298
|
+
return nil
|
278
299
|
when :tile
|
279
|
-
nil
|
300
|
+
return nil
|
280
301
|
else
|
281
302
|
return nil if inputs[0].nil?
|
282
303
|
return inputs[0].shape.shape if inputs.size == 1
|
data/lib/tensor_stream/ops.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# Class that defines all available ops supported by TensorStream
|
3
3
|
module Ops
|
4
|
+
class OutputHolder
|
5
|
+
def initialize(op)
|
6
|
+
@op = op
|
7
|
+
end
|
8
|
+
end
|
4
9
|
FLOATING_POINT_TYPES = %i[float32 float64 float].freeze
|
5
10
|
INTEGER_TYPES = %i[uint8 int32 int int64].freeze
|
6
11
|
NUMERIC_TYPES = FLOATING_POINT_TYPES + INTEGER_TYPES
|
@@ -94,10 +99,29 @@ module TensorStream
|
|
94
99
|
##
|
95
100
|
# This operation returns a 1-D integer tensor representing the shape of input
|
96
101
|
def shape(input, name: nil, out_type: :int32)
|
97
|
-
return constant(shape_eval(input, out_type), dtype: out_type, name: name) if input.is_a?(Array)
|
102
|
+
return constant(shape_eval(input, out_type), dtype: out_type, name: name) if input.is_a?(Array) && !input[0].is_a?(Tensor)
|
98
103
|
return constant(input.shape.shape, dtype: out_type, name: "Shape/#{input.name}") if shape_full_specified(input)
|
99
104
|
|
100
|
-
_op(:shape, input,
|
105
|
+
_op(:shape, input, name: name, out_type: out_type)
|
106
|
+
end
|
107
|
+
|
108
|
+
def shape_n(inputs, name: nil, out_type: :int32)
|
109
|
+
shapes_known = true
|
110
|
+
inputs.each do |input|
|
111
|
+
unless input.shape.known?
|
112
|
+
shapes_known = false
|
113
|
+
break
|
114
|
+
end
|
115
|
+
end
|
116
|
+
|
117
|
+
if shapes_known
|
118
|
+
inputs.collect { |input| cons(input.shape.shape dtype: out_type) }
|
119
|
+
else
|
120
|
+
res = _op(:shape_n, *inputs, out_type: out_type, name: name)
|
121
|
+
Array.new(inputs.size) do |index|
|
122
|
+
res[index]
|
123
|
+
end
|
124
|
+
end
|
101
125
|
end
|
102
126
|
|
103
127
|
##
|
@@ -113,15 +137,30 @@ module TensorStream
|
|
113
137
|
##
|
114
138
|
# Returns the rank of a tensor.
|
115
139
|
def rank(input, name: nil)
|
140
|
+
input = convert_to_tensor(input)
|
141
|
+
return cons(input.shape.ndims) if input.shape.known?
|
142
|
+
|
116
143
|
_op(:rank, input, name: name)
|
117
144
|
end
|
118
145
|
|
146
|
+
def constant_initializer(value, dtype: nil, verify_shape: false)
|
147
|
+
TensorStream::Initializer.new(-> { convert_to_tensor(value, dtype: dtype) })
|
148
|
+
end
|
149
|
+
|
119
150
|
##
|
120
151
|
# initializer that generates tensors initialized to 0.
|
121
|
-
|
152
|
+
#
|
153
|
+
def zeros_initializer(dtype: :float32)
|
122
154
|
TensorStream::Initializer.new(-> { _op(:zeros, nil, nil, data_type: dtype) })
|
123
155
|
end
|
124
156
|
|
157
|
+
##
|
158
|
+
# initializer that generates tensors initialized to 1.
|
159
|
+
#
|
160
|
+
def ones_initializer(dtype: :float32)
|
161
|
+
TensorStream::Initializer.new(-> { _op(:ones, nil, nil, data_type: dtype) })
|
162
|
+
end
|
163
|
+
|
125
164
|
##
|
126
165
|
# The Glorot uniform initializer, also called Xavier uniform initializer.
|
127
166
|
#
|
@@ -246,7 +285,62 @@ module TensorStream
|
|
246
285
|
##
|
247
286
|
# Concatenates tensors along one dimension.
|
248
287
|
def concat(values, axis, name: 'concat')
|
249
|
-
|
288
|
+
if values.is_a?(Array)
|
289
|
+
_op(:concat, axis, *values, name: name)
|
290
|
+
else
|
291
|
+
_op(:concat, axis, values, name: name)
|
292
|
+
end
|
293
|
+
end
|
294
|
+
|
295
|
+
def split(value, num_or_size_splits, axis: 0, num: nil, name: 'split')
|
296
|
+
value = convert_to_tensor(value)
|
297
|
+
num_or_size_splits = convert_to_tensor(num_or_size_splits)
|
298
|
+
axis = convert_to_tensor(axis)
|
299
|
+
|
300
|
+
raise TensorStream::ValueError, "num_or_size_splits must be integer dtype" unless INTEGER_TYPES.include?(num_or_size_splits.data_type)
|
301
|
+
|
302
|
+
res = _op(:split, value, num_or_size_splits, axis, name: name)
|
303
|
+
|
304
|
+
pieces = if value.shape.known? && num_or_size_splits.is_const && num_or_size_splits.value && axis.is_const
|
305
|
+
if num_or_size_splits.shape.scalar?
|
306
|
+
raise TensorStream::ValueError, "num_or_size_splits must divide dimension #{value.shape.shape[axis.value]} evenly" unless value.shape.shape[axis.value] % num_or_size_splits.value == 0
|
307
|
+
div = num_or_size_splits.value
|
308
|
+
n = value.shape.shape[axis.value] / div
|
309
|
+
|
310
|
+
Array.new(div) { |i|
|
311
|
+
new_shape = value.shape.shape.dup
|
312
|
+
new_shape[axis.value] = n
|
313
|
+
new_shape
|
314
|
+
}
|
315
|
+
elsif num_or_size_splits.shape.ndims == 1
|
316
|
+
raise TensorStream::ValueError, "Sum of splits do not match total dimen in axis #{value.shape.shape[axis.value]} != #{ num_or_size_splits.value.reduce(:+)}" if value.shape.shape[axis.value] != num_or_size_splits.value.reduce(:+)
|
317
|
+
num_or_size_splits.value.collect do |v|
|
318
|
+
new_shape = value.shape.shape.dup
|
319
|
+
new_shape[axis.value] = v
|
320
|
+
new_shape
|
321
|
+
end
|
322
|
+
else
|
323
|
+
raise TensorStream::ValueError, "Scalar or 1D Tensor expected for num_or_size_splits"
|
324
|
+
end
|
325
|
+
else
|
326
|
+
raise TensorStream::ValueError, "Cannot automatically determine num, please specify num: in options" if num.nil?
|
327
|
+
|
328
|
+
Array.new(num) { nil }
|
329
|
+
end
|
330
|
+
|
331
|
+
pieces.collect.with_index do |shape, i|
|
332
|
+
op = index(res, i, name: "split/index:#{i}")
|
333
|
+
if shape
|
334
|
+
op.shape = TensorShape.new(shape)
|
335
|
+
end
|
336
|
+
op
|
337
|
+
end
|
338
|
+
end
|
339
|
+
|
340
|
+
##
|
341
|
+
# select an index in an array or a set of tensor outputs
|
342
|
+
def index(tensor, sel, name: nil)
|
343
|
+
_op(:index, tensor, sel, name: name)
|
250
344
|
end
|
251
345
|
|
252
346
|
##
|
@@ -338,8 +432,11 @@ module TensorStream
|
|
338
432
|
##
|
339
433
|
# Returns element-wise remainder of division.
|
340
434
|
def mod(input_a, input_b, name: nil)
|
435
|
+
input_a = convert_to_tensor(input_a)
|
436
|
+
input_b = convert_to_tensor(input_b)
|
437
|
+
|
341
438
|
input_a, input_b = check_data_types(input_a, input_b)
|
342
|
-
|
439
|
+
_op(:mod, input_a, input_b, name: name)
|
343
440
|
end
|
344
441
|
|
345
442
|
##
|
@@ -393,8 +490,11 @@ module TensorStream
|
|
393
490
|
end
|
394
491
|
|
395
492
|
##
|
396
|
-
# Casts a tensor to a new type
|
493
|
+
# Casts a tensor to a new type, if needed
|
397
494
|
def cast(input, dtype, name: nil)
|
495
|
+
input = convert_to_tensor(input)
|
496
|
+
return input if input.data_type == dtype
|
497
|
+
|
398
498
|
_op(:cast, input, nil, data_type: dtype, name: name)
|
399
499
|
end
|
400
500
|
|
@@ -630,14 +730,68 @@ module TensorStream
|
|
630
730
|
_op(:gather, params, indices, validate_indices: validate_indices, name: name, axis: axis)
|
631
731
|
end
|
632
732
|
|
733
|
+
|
734
|
+
##
|
735
|
+
# Stacks a list of rank-R tensors into one rank-(R+1) tensor.
|
736
|
+
#
|
633
737
|
def stack(values, axis: 0, name: 'stack')
|
634
738
|
_op(:stack, *values, axis: axis, name: name)
|
635
739
|
end
|
636
740
|
|
741
|
+
##
|
742
|
+
# Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
|
743
|
+
#
|
744
|
+
def unstack(value, num: nil, axis: 0, name: 'unstack')
|
745
|
+
res = _op(:unstack, value, num: num, axis: axis, name: name)
|
746
|
+
|
747
|
+
num_vars = if value.shape.known?
|
748
|
+
new_shape = value.shape.shape.dup
|
749
|
+
rank = new_shape.size - 1
|
750
|
+
axis = rank + axis if axis < 0
|
751
|
+
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
752
|
+
new_shape = rotated_shape.rotate!(-1) + new_shape
|
753
|
+
new_shape[0]
|
754
|
+
else
|
755
|
+
raise TensorStream::ValueError, "num is unspecified and cannot be inferred." if num.nil?
|
756
|
+
num
|
757
|
+
end
|
758
|
+
|
759
|
+
return res[0] if num_vars == 1
|
760
|
+
|
761
|
+
Array.new(num_vars) do |i|
|
762
|
+
index(res, i, name: "unstack/index:#{i}")
|
763
|
+
end
|
764
|
+
end
|
765
|
+
|
766
|
+
##
|
767
|
+
# Same as stack
|
768
|
+
def pack(values, axis: 0, name: 'pack')
|
769
|
+
_op(:stack, *values, axis: axis, name: name)
|
770
|
+
end
|
771
|
+
|
772
|
+
##
|
773
|
+
# Same as unstack
|
774
|
+
#
|
775
|
+
def unpack(value, num: nil, axis: 0, name: 'unpack')
|
776
|
+
unstack(value, num: num, axis: axis, name: name)
|
777
|
+
end
|
778
|
+
|
779
|
+
##
|
780
|
+
# Removes dimensions of size 1 from the shape of a tensor.
|
781
|
+
#
|
782
|
+
# Given a tensor input, this operation returns a tensor of the same type with all dimensions of size 1 removed.
|
783
|
+
# If you don't want to remove all size 1 dimensions, you can remove specific size 1 dimensions by specifying axis.
|
637
784
|
def squeeze(value, axis: [], name: nil)
|
638
785
|
_op(:squeeze, value, axis: axis, name: nil)
|
639
786
|
end
|
640
787
|
|
788
|
+
##
|
789
|
+
# Computes the difference between two lists of numbers or strings.
|
790
|
+
# Given a list x and a list y, this operation returns a list out that represents all values
|
791
|
+
# that are in x but not in y. The returned list out is sorted in the same order that the numbers appear
|
792
|
+
# in x (duplicates are preserved). This operation also returns a list idx that represents the position of
|
793
|
+
# each out element in x. In other words:
|
794
|
+
#
|
641
795
|
def setdiff1d(x, y, index_dtype: :int32, name: nil)
|
642
796
|
result = _op(:setdiff1d, x, y, index_dtype: index_dtype, name: name)
|
643
797
|
[result[0], result[1]]
|