tensor_stream 0.8.5 → 0.8.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +9 -7
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +17 -2
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +92 -10
- data/lib/tensor_stream/evaluator/ruby/check_ops.rb +9 -0
- data/lib/tensor_stream/evaluator/ruby/images_ops.rb +1 -1
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +38 -38
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +87 -12
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +16 -13
- data/lib/tensor_stream/graph.rb +2 -0
- data/lib/tensor_stream/helpers/op_helper.rb +1 -0
- data/lib/tensor_stream/math_gradients.rb +86 -5
- data/lib/tensor_stream/nn/nn_ops.rb +47 -0
- data/lib/tensor_stream/operation.rb +25 -4
- data/lib/tensor_stream/ops.rb +160 -6
- data/lib/tensor_stream/session.rb +1 -0
- data/lib/tensor_stream/tensor.rb +4 -7
- data/lib/tensor_stream/tensor_shape.rb +10 -1
- data/lib/tensor_stream/train/adagrad_optimizer.rb +46 -0
- data/lib/tensor_stream/train/optimizer.rb +12 -0
- data/lib/tensor_stream/train/rmsprop_optimizer.rb +84 -0
- data/lib/tensor_stream/train/slot_creator.rb +14 -9
- data/lib/tensor_stream/trainer.rb +2 -0
- data/lib/tensor_stream/utils.rb +6 -4
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +4 -3
- data/samples/linear_regression.rb +3 -0
- data/samples/rnn.rb +105 -0
- metadata +6 -2
@@ -7,6 +7,7 @@ require 'tensor_stream/evaluator/ruby/nn_ops'
|
|
7
7
|
require 'tensor_stream/evaluator/ruby/array_ops'
|
8
8
|
require 'tensor_stream/evaluator/ruby/random_ops'
|
9
9
|
require 'tensor_stream/evaluator/ruby/images_ops'
|
10
|
+
require 'tensor_stream/evaluator/ruby/check_ops'
|
10
11
|
|
11
12
|
module TensorStream
|
12
13
|
module Evaluator
|
@@ -39,6 +40,7 @@ module TensorStream
|
|
39
40
|
include TensorStream::ArrayOps
|
40
41
|
include TensorStream::RandomOps
|
41
42
|
include TensorStream::ImagesOps
|
43
|
+
include TensorStream::CheckOps
|
42
44
|
|
43
45
|
def run(tensor, execution_context)
|
44
46
|
return tensor.map { |t| run(t, execution_context) } if tensor.is_a?(Array) && !tensor.empty? && tensor[0].is_a?(Tensor)
|
@@ -116,10 +118,10 @@ module TensorStream
|
|
116
118
|
end
|
117
119
|
|
118
120
|
register_op(:cast) do |context, tensor, inputs|
|
119
|
-
call_op(
|
121
|
+
call_op(tensor, inputs[0], context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
|
120
122
|
end
|
121
123
|
|
122
|
-
register_op(:sign) do |context,
|
124
|
+
register_op(:sign) do |context, tensor, inputs|
|
123
125
|
func = lambda { |x, _b|
|
124
126
|
if x.zero? || (x.is_a?(Float) && x.nan?)
|
125
127
|
0
|
@@ -132,7 +134,7 @@ module TensorStream
|
|
132
134
|
end
|
133
135
|
}
|
134
136
|
|
135
|
-
call_op(
|
137
|
+
call_op(tensor, inputs[0], context, func)
|
136
138
|
end
|
137
139
|
|
138
140
|
register_op(:logical_and) do |context, tensor, inputs|
|
@@ -217,10 +219,6 @@ module TensorStream
|
|
217
219
|
call_vector_op(tensor, :greater_equal, a, b, context, ->(t, u) { t <= u })
|
218
220
|
end
|
219
221
|
|
220
|
-
register_op :shape do |_context, tensor, inputs|
|
221
|
-
shape_eval(inputs[0], tensor.options[:out_type])
|
222
|
-
end
|
223
|
-
|
224
222
|
register_op :broadcast_transform do |_context, _tensor, inputs|
|
225
223
|
broadcast(inputs[0], inputs[1])
|
226
224
|
end
|
@@ -266,14 +264,13 @@ module TensorStream
|
|
266
264
|
raise TensorStream::InvalidArgumentError, "#{message} Invalid argument" if t.nan? || t.infinite?
|
267
265
|
t
|
268
266
|
}
|
269
|
-
call_op(
|
267
|
+
call_op(tensor, inputs[0], context, f)
|
270
268
|
end
|
271
269
|
|
272
270
|
def eval_operation(tensor, child_context)
|
273
271
|
return @context[tensor.name] if @context.key?(tensor.name)
|
274
|
-
# puts "ruby: #{tensor.name}"
|
275
272
|
invoke(tensor, child_context).tap do |result|
|
276
|
-
|
273
|
+
# puts "ruby: #{tensor.name}"
|
277
274
|
if tensor.breakpoint
|
278
275
|
a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
|
279
276
|
b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
|
@@ -412,8 +409,6 @@ module TensorStream
|
|
412
409
|
def call_op(op, a, child_context, func)
|
413
410
|
a = complete_eval(a, child_context)
|
414
411
|
process_function_op(a, func)
|
415
|
-
rescue FullEvalNotPossible
|
416
|
-
TensorStream.send(op.to_sym, a)
|
417
412
|
end
|
418
413
|
|
419
414
|
def call_vector_op(tensor, op, a, b, child_context, func)
|
@@ -484,7 +479,15 @@ module TensorStream
|
|
484
479
|
|
485
480
|
var = if placeholder.is_a?(Placeholder)
|
486
481
|
@context[placeholder.name.to_sym].tap do |c|
|
487
|
-
raise "missing placeholder #{placeholder.name}" if c.nil?
|
482
|
+
raise TensorStream::ValueError, "missing placeholder #{placeholder.name}" if c.nil?
|
483
|
+
if placeholder.shape.shape
|
484
|
+
value_shape = shape_eval(c)
|
485
|
+
placeholder_shape = placeholder.shape.shape
|
486
|
+
placeholder_shape.zip(value_shape).each do |p_shape, v_shape|
|
487
|
+
next if p_shape.nil?
|
488
|
+
raise TensorStream::ValueError, "placeholder expects #{placeholder_shape}, got #{value_shape}" if p_shape != v_shape
|
489
|
+
end
|
490
|
+
end
|
488
491
|
end
|
489
492
|
else
|
490
493
|
placeholder
|
data/lib/tensor_stream/graph.rb
CHANGED
@@ -90,8 +90,10 @@ module TensorStream
|
|
90
90
|
@node_keys << node.name
|
91
91
|
@nodes[node.name] = node
|
92
92
|
@constants[node.name] = node if node.is_const
|
93
|
+
# puts "adding node"
|
93
94
|
node.send(:propagate_outputs)
|
94
95
|
node.send(:propagate_consumer, node)
|
96
|
+
# puts "#{node.name}"
|
95
97
|
node.value = node.eval if @eager_execution
|
96
98
|
end
|
97
99
|
|
@@ -10,7 +10,7 @@ module TensorStream
|
|
10
10
|
|
11
11
|
def self.derivative(tensor, wrt_dx, options = {})
|
12
12
|
return i_op(:ones_like, tensor) if tensor.equal?(wrt_dx)
|
13
|
-
return i_op(:zeros_like,
|
13
|
+
return i_op(:zeros_like, wrt_dx) unless wrt_dx.consumers.include?(tensor.name)
|
14
14
|
|
15
15
|
nodes_to_compute = wrt_dx.consumers.select do |t|
|
16
16
|
node = tensor.graph.nodes[t]
|
@@ -30,12 +30,15 @@ module TensorStream
|
|
30
30
|
computed_op = _compute_derivative(tensor, grad)
|
31
31
|
|
32
32
|
if computed_op.is_a?(Array)
|
33
|
-
computed_op.each_with_index.collect do |op_grad, index|
|
33
|
+
grads = computed_op.each_with_index.collect do |op_grad, index|
|
34
34
|
next if op_grad.nil?
|
35
35
|
next unless nodes_to_compute.include?(tensor.inputs[index].name)
|
36
36
|
|
37
37
|
_propagate(op_grad, tensor.inputs[index], stop_tensor, nodes_to_compute, stop_gradients)
|
38
|
-
end.compact
|
38
|
+
end.compact
|
39
|
+
|
40
|
+
return nil if grads.empty?
|
41
|
+
grads.size > 1 ? ts.add_n(grads) : grads[0]
|
39
42
|
else
|
40
43
|
return nil if computed_op.nil?
|
41
44
|
_propagate(computed_op, tensor.inputs[0], stop_tensor, nodes_to_compute, stop_gradients)
|
@@ -260,7 +263,11 @@ module TensorStream
|
|
260
263
|
i_op(:softmax_grad, x, grad)
|
261
264
|
when :softmax_cross_entropy_with_logits_v2
|
262
265
|
output = node
|
263
|
-
|
266
|
+
logits = node.inputs[0]
|
267
|
+
[_broadcast_mul(grad, output[1]), -ts.nn.log_softmax(logits)]
|
268
|
+
when :sparse_softmax_cross_entropy_with_logits
|
269
|
+
output = node
|
270
|
+
[_broadcast_mul(grad, output[1]), nil]
|
264
271
|
when :floor, :ceil
|
265
272
|
# non differentiable
|
266
273
|
nil
|
@@ -273,13 +280,51 @@ module TensorStream
|
|
273
280
|
when :transpose
|
274
281
|
return [ts.transpose(grad, ts.invert_permutation(y)), nil]
|
275
282
|
when :index
|
276
|
-
|
283
|
+
#hack!! not sure how to fix this yet
|
284
|
+
return grad if %i[softmax_cross_entropy_with_logits_v2 sparse_softmax_cross_entropy_with_logits].include?(node.inputs[0].operation)
|
285
|
+
|
286
|
+
if node.inputs[0].shape.known? && node.inputs[1].value
|
287
|
+
multiplier = node.inputs[0].shape.shape[0]
|
288
|
+
filler = ts.zeros_like(grad)
|
289
|
+
|
290
|
+
res = Array.new(multiplier) { |index|
|
291
|
+
index == node.inputs[1].value ? grad : filler
|
292
|
+
}
|
293
|
+
[res]
|
294
|
+
end
|
295
|
+
when :squeeze
|
296
|
+
_reshape_to_input(node, grad)
|
297
|
+
when :expand_dims
|
298
|
+
[_reshape_to_input(node, grad), nil]
|
299
|
+
when :concat
|
300
|
+
_concat_grad_helper(node, grad, 1, node.inputs.size, 0)
|
301
|
+
when :reshape
|
302
|
+
[ts.reshape(grad, ts.shape(node.inputs[0])), nil]
|
303
|
+
when :stack
|
304
|
+
res = ts.unstack(grad, num: node.inputs.size, axis: node.options[:axis])
|
305
|
+
Array.new(node.inputs.size) { |i| res[i] }
|
306
|
+
when :unstack
|
307
|
+
ts.stack(grad, axis: node.options[:axis])
|
308
|
+
when :cast
|
309
|
+
t = %i[float16 float32 float64]
|
310
|
+
src_type = node.inputs[0].data_type
|
311
|
+
dst_type = grad.data_type
|
312
|
+
|
313
|
+
if t.key?(src_type) && t.key?(dst_type)
|
314
|
+
ts.cast(grad, src_type)
|
315
|
+
else
|
316
|
+
nil
|
317
|
+
end
|
277
318
|
else
|
278
319
|
raise "no derivative op for #{node.operation}"
|
279
320
|
end
|
280
321
|
end
|
281
322
|
end
|
282
323
|
|
324
|
+
def self._reshape_to_input(node, grad)
|
325
|
+
ts.reshape(grad, ts.shape(node.inputs[0]))
|
326
|
+
end
|
327
|
+
|
283
328
|
def self._broadcast_gradient_args(input_a, input_b)
|
284
329
|
res = _op(:broadcast_gradient_args, input_a, input_b)
|
285
330
|
[res[0], res[1]]
|
@@ -335,5 +380,41 @@ module TensorStream
|
|
335
380
|
arr.each { |a| return true if a.equal?(obj) }
|
336
381
|
false
|
337
382
|
end
|
383
|
+
|
384
|
+
def self._extract_input_shapes(inputs)
|
385
|
+
sizes = []
|
386
|
+
fully_known = true
|
387
|
+
inputs.each do |x|
|
388
|
+
input_shape = ts.shape(x)
|
389
|
+
unless input_shape.is_const
|
390
|
+
fully_known = false
|
391
|
+
break
|
392
|
+
end
|
393
|
+
sizes << input_shape.value
|
394
|
+
end
|
395
|
+
|
396
|
+
if fully_known
|
397
|
+
sizes
|
398
|
+
else
|
399
|
+
ts.shape_n(inputs)
|
400
|
+
end
|
401
|
+
end
|
402
|
+
|
403
|
+
def self._concat_grad_helper(op, grad, start_value_index, end_value_index, dim_index)
|
404
|
+
# Degenerate concatenation, just return grad.
|
405
|
+
if op.inputs.size == 2
|
406
|
+
return end_value_index <= dim_index ? [grad] + [nil] : [nil] + [grad]
|
407
|
+
end
|
408
|
+
concat_dim = op.inputs[dim_index]
|
409
|
+
input_values = op.inputs[start_value_index..end_value_index]
|
410
|
+
non_neg_concat_dim = concat_dim % ts.rank(input_values[0])
|
411
|
+
sizes = _extract_input_shapes(input_values)
|
412
|
+
|
413
|
+
slicer = ts.slice(ts.stack(sizes, axis: 1), [non_neg_concat_dim, 0], [1, -1])
|
414
|
+
sizes = ts.squeeze(slicer)
|
415
|
+
|
416
|
+
out_grads = ts.split(grad, sizes, axis: non_neg_concat_dim, num: op.inputs.size - 1)
|
417
|
+
end_value_index <= dim_index ? out_grads + [nil] : [nil] + out_grads
|
418
|
+
end
|
338
419
|
end
|
339
420
|
end
|
@@ -25,11 +25,58 @@ module TensorStream
|
|
25
25
|
logits = tf.convert_to_tensor(logits, name: 'logits')
|
26
26
|
labels = tf.convert_to_tensor(labels, name: 'labels')
|
27
27
|
labels = tf.cast(labels, logits.dtype)
|
28
|
+
|
28
29
|
output = _op(:softmax_cross_entropy_with_logits_v2, logits, labels)
|
29
30
|
output[0]
|
30
31
|
end
|
31
32
|
end
|
32
33
|
|
34
|
+
def self.sparse_softmax_cross_entropy_with_logits(labels: nil, logits: nil, name: nil)
|
35
|
+
TensorStream.name_scope(name, default: "SparseSoftmaxCrossEntropyWithLogits", values: [logits, labels]) do
|
36
|
+
tf = TensorStream
|
37
|
+
labels = tf.convert_to_tensor(labels)
|
38
|
+
logits = tf.convert_to_tensor(logits)
|
39
|
+
precise_logits = logits.data_type == :float16 ? tf.cast(logits, :float32) : logits
|
40
|
+
|
41
|
+
labels_static_shape = labels.shape
|
42
|
+
labels_shape = tf.shape(labels)
|
43
|
+
static_shapes_fully_defined = labels_static_shape.known? && logits.shape.known?
|
44
|
+
|
45
|
+
raise TensorStream::ValueError, "Logits cannot be scalars - received shape #{logits.shape.shape}." if logits.shape.known? && logits.shape.scalar?
|
46
|
+
raise TensorStream::ValueError, "Rank mismatch: Rank of labels (received #{labels_static_shape.ndims}) " +
|
47
|
+
"should equal rank of logits minus 1 (received #{logits.shape.ndims})." if logits.shape.known? && (labels_static_shape.known? && labels_static_shape.ndims != logits.shape.ndims - 1)
|
48
|
+
|
49
|
+
if logits.shape.ndims == 2
|
50
|
+
cost = _op(:sparse_softmax_cross_entropy_with_logits,
|
51
|
+
precise_logits, labels, name: name)
|
52
|
+
if logits.data_type == :float16
|
53
|
+
return tf.cast(cost[0], :float16)
|
54
|
+
else
|
55
|
+
return cost[0]
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
shape_checks = []
|
60
|
+
if !static_shapes_fully_defined
|
61
|
+
shape_checks << tf.append(tf.assert_equal(tf.rank(labels), tf.rank(logits) - 1))
|
62
|
+
end
|
63
|
+
|
64
|
+
tf.control_dependencies(shape_checks) do
|
65
|
+
num_classes = tf.shape(logits)[tf.rank(logits) - 1]
|
66
|
+
precise_logits = tf.reshape(precise_logits, [-1, num_classes])
|
67
|
+
labels = array_ops.reshape(labels, [-1])
|
68
|
+
cost = _op(:sparse_softmax_cross_entropy_with_logits, precise_logits, labels, name: name)
|
69
|
+
cost = tf.reshape(cost[0], labels_shape)
|
70
|
+
|
71
|
+
if logits.data_type == :float16
|
72
|
+
tf.cast(cost, :float16)
|
73
|
+
else
|
74
|
+
cost
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
33
80
|
# Computes log softmax activations.
|
34
81
|
def self.log_softmax(logits, axis: -1, name: nil)
|
35
82
|
_op(:log_softmax, logits, axis: axis, name: name)
|
@@ -45,7 +45,7 @@ module TensorStream
|
|
45
45
|
def infer_const
|
46
46
|
return false if breakpoint
|
47
47
|
case operation
|
48
|
-
when :random_standard_normal, :random_uniform, :glorot_uniform, :print
|
48
|
+
when :random_standard_normal, :random_uniform, :glorot_uniform, :print, :check_numerics
|
49
49
|
false
|
50
50
|
else
|
51
51
|
non_const = @inputs.compact.find { |input| !input.is_const }
|
@@ -59,10 +59,12 @@ module TensorStream
|
|
59
59
|
@inputs[1].data_type
|
60
60
|
when :greater, :less, :equal, :not_equal, :greater_equal, :less_equal, :logical_and
|
61
61
|
:boolean
|
62
|
-
when :shape, :rank
|
62
|
+
when :shape, :rank, :shape_n
|
63
63
|
options[:out_type] || :int32
|
64
64
|
when :random_standard_normal, :random_uniform, :glorot_uniform
|
65
65
|
passed_data_type || :float32
|
66
|
+
when :concat
|
67
|
+
@inputs[1].data_type
|
66
68
|
when :index
|
67
69
|
if @inputs[0].is_a?(ControlFlow)
|
68
70
|
|
@@ -274,9 +276,28 @@ module TensorStream
|
|
274
276
|
rank = inputs[0].shape.shape.size + 1
|
275
277
|
axis = rank + axis if axis < 0
|
276
278
|
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
277
|
-
rotated_shape.rotate! + new_shape
|
279
|
+
return rotated_shape.rotate! + new_shape
|
280
|
+
when :concat
|
281
|
+
return nil if inputs[0].value.nil?
|
282
|
+
|
283
|
+
axis = inputs[0].value # get axis
|
284
|
+
|
285
|
+
axis_size = 0
|
286
|
+
|
287
|
+
inputs[1..inputs.size].each do |input_item|
|
288
|
+
return nil if input_item.shape.shape.nil?
|
289
|
+
return nil if input_item.shape.shape[axis].nil?
|
290
|
+
|
291
|
+
axis_size += input_item.shape.shape[axis]
|
292
|
+
end
|
293
|
+
|
294
|
+
new_shape = inputs[1].shape.shape.dup
|
295
|
+
new_shape[axis] = axis_size
|
296
|
+
return new_shape
|
297
|
+
when :slice, :squeeze
|
298
|
+
return nil
|
278
299
|
when :tile
|
279
|
-
nil
|
300
|
+
return nil
|
280
301
|
else
|
281
302
|
return nil if inputs[0].nil?
|
282
303
|
return inputs[0].shape.shape if inputs.size == 1
|
data/lib/tensor_stream/ops.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# Class that defines all available ops supported by TensorStream
|
3
3
|
module Ops
|
4
|
+
class OutputHolder
|
5
|
+
def initialize(op)
|
6
|
+
@op = op
|
7
|
+
end
|
8
|
+
end
|
4
9
|
FLOATING_POINT_TYPES = %i[float32 float64 float].freeze
|
5
10
|
INTEGER_TYPES = %i[uint8 int32 int int64].freeze
|
6
11
|
NUMERIC_TYPES = FLOATING_POINT_TYPES + INTEGER_TYPES
|
@@ -94,10 +99,29 @@ module TensorStream
|
|
94
99
|
##
|
95
100
|
# This operation returns a 1-D integer tensor representing the shape of input
|
96
101
|
def shape(input, name: nil, out_type: :int32)
|
97
|
-
return constant(shape_eval(input, out_type), dtype: out_type, name: name) if input.is_a?(Array)
|
102
|
+
return constant(shape_eval(input, out_type), dtype: out_type, name: name) if input.is_a?(Array) && !input[0].is_a?(Tensor)
|
98
103
|
return constant(input.shape.shape, dtype: out_type, name: "Shape/#{input.name}") if shape_full_specified(input)
|
99
104
|
|
100
|
-
_op(:shape, input,
|
105
|
+
_op(:shape, input, name: name, out_type: out_type)
|
106
|
+
end
|
107
|
+
|
108
|
+
def shape_n(inputs, name: nil, out_type: :int32)
|
109
|
+
shapes_known = true
|
110
|
+
inputs.each do |input|
|
111
|
+
unless input.shape.known?
|
112
|
+
shapes_known = false
|
113
|
+
break
|
114
|
+
end
|
115
|
+
end
|
116
|
+
|
117
|
+
if shapes_known
|
118
|
+
inputs.collect { |input| cons(input.shape.shape dtype: out_type) }
|
119
|
+
else
|
120
|
+
res = _op(:shape_n, *inputs, out_type: out_type, name: name)
|
121
|
+
Array.new(inputs.size) do |index|
|
122
|
+
res[index]
|
123
|
+
end
|
124
|
+
end
|
101
125
|
end
|
102
126
|
|
103
127
|
##
|
@@ -113,15 +137,30 @@ module TensorStream
|
|
113
137
|
##
|
114
138
|
# Returns the rank of a tensor.
|
115
139
|
def rank(input, name: nil)
|
140
|
+
input = convert_to_tensor(input)
|
141
|
+
return cons(input.shape.ndims) if input.shape.known?
|
142
|
+
|
116
143
|
_op(:rank, input, name: name)
|
117
144
|
end
|
118
145
|
|
146
|
+
def constant_initializer(value, dtype: nil, verify_shape: false)
|
147
|
+
TensorStream::Initializer.new(-> { convert_to_tensor(value, dtype: dtype) })
|
148
|
+
end
|
149
|
+
|
119
150
|
##
|
120
151
|
# initializer that generates tensors initialized to 0.
|
121
|
-
|
152
|
+
#
|
153
|
+
def zeros_initializer(dtype: :float32)
|
122
154
|
TensorStream::Initializer.new(-> { _op(:zeros, nil, nil, data_type: dtype) })
|
123
155
|
end
|
124
156
|
|
157
|
+
##
|
158
|
+
# initializer that generates tensors initialized to 1.
|
159
|
+
#
|
160
|
+
def ones_initializer(dtype: :float32)
|
161
|
+
TensorStream::Initializer.new(-> { _op(:ones, nil, nil, data_type: dtype) })
|
162
|
+
end
|
163
|
+
|
125
164
|
##
|
126
165
|
# The Glorot uniform initializer, also called Xavier uniform initializer.
|
127
166
|
#
|
@@ -246,7 +285,62 @@ module TensorStream
|
|
246
285
|
##
|
247
286
|
# Concatenates tensors along one dimension.
|
248
287
|
def concat(values, axis, name: 'concat')
|
249
|
-
|
288
|
+
if values.is_a?(Array)
|
289
|
+
_op(:concat, axis, *values, name: name)
|
290
|
+
else
|
291
|
+
_op(:concat, axis, values, name: name)
|
292
|
+
end
|
293
|
+
end
|
294
|
+
|
295
|
+
def split(value, num_or_size_splits, axis: 0, num: nil, name: 'split')
|
296
|
+
value = convert_to_tensor(value)
|
297
|
+
num_or_size_splits = convert_to_tensor(num_or_size_splits)
|
298
|
+
axis = convert_to_tensor(axis)
|
299
|
+
|
300
|
+
raise TensorStream::ValueError, "num_or_size_splits must be integer dtype" unless INTEGER_TYPES.include?(num_or_size_splits.data_type)
|
301
|
+
|
302
|
+
res = _op(:split, value, num_or_size_splits, axis, name: name)
|
303
|
+
|
304
|
+
pieces = if value.shape.known? && num_or_size_splits.is_const && num_or_size_splits.value && axis.is_const
|
305
|
+
if num_or_size_splits.shape.scalar?
|
306
|
+
raise TensorStream::ValueError, "num_or_size_splits must divide dimension #{value.shape.shape[axis.value]} evenly" unless value.shape.shape[axis.value] % num_or_size_splits.value == 0
|
307
|
+
div = num_or_size_splits.value
|
308
|
+
n = value.shape.shape[axis.value] / div
|
309
|
+
|
310
|
+
Array.new(div) { |i|
|
311
|
+
new_shape = value.shape.shape.dup
|
312
|
+
new_shape[axis.value] = n
|
313
|
+
new_shape
|
314
|
+
}
|
315
|
+
elsif num_or_size_splits.shape.ndims == 1
|
316
|
+
raise TensorStream::ValueError, "Sum of splits do not match total dimen in axis #{value.shape.shape[axis.value]} != #{ num_or_size_splits.value.reduce(:+)}" if value.shape.shape[axis.value] != num_or_size_splits.value.reduce(:+)
|
317
|
+
num_or_size_splits.value.collect do |v|
|
318
|
+
new_shape = value.shape.shape.dup
|
319
|
+
new_shape[axis.value] = v
|
320
|
+
new_shape
|
321
|
+
end
|
322
|
+
else
|
323
|
+
raise TensorStream::ValueError, "Scalar or 1D Tensor expected for num_or_size_splits"
|
324
|
+
end
|
325
|
+
else
|
326
|
+
raise TensorStream::ValueError, "Cannot automatically determine num, please specify num: in options" if num.nil?
|
327
|
+
|
328
|
+
Array.new(num) { nil }
|
329
|
+
end
|
330
|
+
|
331
|
+
pieces.collect.with_index do |shape, i|
|
332
|
+
op = index(res, i, name: "split/index:#{i}")
|
333
|
+
if shape
|
334
|
+
op.shape = TensorShape.new(shape)
|
335
|
+
end
|
336
|
+
op
|
337
|
+
end
|
338
|
+
end
|
339
|
+
|
340
|
+
##
|
341
|
+
# select an index in an array or a set of tensor outputs
|
342
|
+
def index(tensor, sel, name: nil)
|
343
|
+
_op(:index, tensor, sel, name: name)
|
250
344
|
end
|
251
345
|
|
252
346
|
##
|
@@ -338,8 +432,11 @@ module TensorStream
|
|
338
432
|
##
|
339
433
|
# Returns element-wise remainder of division.
|
340
434
|
def mod(input_a, input_b, name: nil)
|
435
|
+
input_a = convert_to_tensor(input_a)
|
436
|
+
input_b = convert_to_tensor(input_b)
|
437
|
+
|
341
438
|
input_a, input_b = check_data_types(input_a, input_b)
|
342
|
-
|
439
|
+
_op(:mod, input_a, input_b, name: name)
|
343
440
|
end
|
344
441
|
|
345
442
|
##
|
@@ -393,8 +490,11 @@ module TensorStream
|
|
393
490
|
end
|
394
491
|
|
395
492
|
##
|
396
|
-
# Casts a tensor to a new type
|
493
|
+
# Casts a tensor to a new type, if needed
|
397
494
|
def cast(input, dtype, name: nil)
|
495
|
+
input = convert_to_tensor(input)
|
496
|
+
return input if input.data_type == dtype
|
497
|
+
|
398
498
|
_op(:cast, input, nil, data_type: dtype, name: name)
|
399
499
|
end
|
400
500
|
|
@@ -630,14 +730,68 @@ module TensorStream
|
|
630
730
|
_op(:gather, params, indices, validate_indices: validate_indices, name: name, axis: axis)
|
631
731
|
end
|
632
732
|
|
733
|
+
|
734
|
+
##
|
735
|
+
# Stacks a list of rank-R tensors into one rank-(R+1) tensor.
|
736
|
+
#
|
633
737
|
def stack(values, axis: 0, name: 'stack')
|
634
738
|
_op(:stack, *values, axis: axis, name: name)
|
635
739
|
end
|
636
740
|
|
741
|
+
##
|
742
|
+
# Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
|
743
|
+
#
|
744
|
+
def unstack(value, num: nil, axis: 0, name: 'unstack')
|
745
|
+
res = _op(:unstack, value, num: num, axis: axis, name: name)
|
746
|
+
|
747
|
+
num_vars = if value.shape.known?
|
748
|
+
new_shape = value.shape.shape.dup
|
749
|
+
rank = new_shape.size - 1
|
750
|
+
axis = rank + axis if axis < 0
|
751
|
+
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
752
|
+
new_shape = rotated_shape.rotate!(-1) + new_shape
|
753
|
+
new_shape[0]
|
754
|
+
else
|
755
|
+
raise TensorStream::ValueError, "num is unspecified and cannot be inferred." if num.nil?
|
756
|
+
num
|
757
|
+
end
|
758
|
+
|
759
|
+
return res[0] if num_vars == 1
|
760
|
+
|
761
|
+
Array.new(num_vars) do |i|
|
762
|
+
index(res, i, name: "unstack/index:#{i}")
|
763
|
+
end
|
764
|
+
end
|
765
|
+
|
766
|
+
##
|
767
|
+
# Same as stack
|
768
|
+
def pack(values, axis: 0, name: 'pack')
|
769
|
+
_op(:stack, *values, axis: axis, name: name)
|
770
|
+
end
|
771
|
+
|
772
|
+
##
|
773
|
+
# Same as unstack
|
774
|
+
#
|
775
|
+
def unpack(value, num: nil, axis: 0, name: 'unpack')
|
776
|
+
unstack(value, num: num, axis: axis, name: name)
|
777
|
+
end
|
778
|
+
|
779
|
+
##
|
780
|
+
# Removes dimensions of size 1 from the shape of a tensor.
|
781
|
+
#
|
782
|
+
# Given a tensor input, this operation returns a tensor of the same type with all dimensions of size 1 removed.
|
783
|
+
# If you don't want to remove all size 1 dimensions, you can remove specific size 1 dimensions by specifying axis.
|
637
784
|
def squeeze(value, axis: [], name: nil)
|
638
785
|
_op(:squeeze, value, axis: axis, name: nil)
|
639
786
|
end
|
640
787
|
|
788
|
+
##
|
789
|
+
# Computes the difference between two lists of numbers or strings.
|
790
|
+
# Given a list x and a list y, this operation returns a list out that represents all values
|
791
|
+
# that are in x but not in y. The returned list out is sorted in the same order that the numbers appear
|
792
|
+
# in x (duplicates are preserved). This operation also returns a list idx that represents the position of
|
793
|
+
# each out element in x. In other words:
|
794
|
+
#
|
641
795
|
def setdiff1d(x, y, index_dtype: :int32, name: nil)
|
642
796
|
result = _op(:setdiff1d, x, y, index_dtype: index_dtype, name: name)
|
643
797
|
[result[0], result[1]]
|