tensor_stream 0.9.8 → 0.9.9
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +31 -14
- data/lib/tensor_stream.rb +4 -0
- data/lib/tensor_stream/constant.rb +41 -0
- data/lib/tensor_stream/control_flow.rb +2 -1
- data/lib/tensor_stream/dynamic_stitch.rb +3 -1
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -4
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +74 -23
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +45 -43
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +31 -30
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +6 -6
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +46 -111
- data/lib/tensor_stream/graph.rb +61 -12
- data/lib/tensor_stream/graph_builder.rb +3 -3
- data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +38 -0
- data/lib/tensor_stream/graph_serializers/packer.rb +8 -0
- data/lib/tensor_stream/graph_serializers/pbtext.rb +62 -27
- data/lib/tensor_stream/graph_serializers/serializer.rb +2 -2
- data/lib/tensor_stream/graph_serializers/yaml.rb +27 -0
- data/lib/tensor_stream/helpers/infer_shape.rb +15 -9
- data/lib/tensor_stream/helpers/op_helper.rb +17 -6
- data/lib/tensor_stream/helpers/string_helper.rb +32 -1
- data/lib/tensor_stream/helpers/tensor_mixins.rb +135 -0
- data/lib/tensor_stream/math_gradients.rb +19 -12
- data/lib/tensor_stream/monkey_patches/float.rb +7 -0
- data/lib/tensor_stream/monkey_patches/integer.rb +7 -0
- data/lib/tensor_stream/monkey_patches/patch.rb +8 -8
- data/lib/tensor_stream/nn/nn_ops.rb +1 -1
- data/lib/tensor_stream/operation.rb +98 -36
- data/lib/tensor_stream/ops.rb +65 -13
- data/lib/tensor_stream/placeholder.rb +2 -2
- data/lib/tensor_stream/session.rb +15 -3
- data/lib/tensor_stream/tensor.rb +15 -172
- data/lib/tensor_stream/tensor_shape.rb +3 -1
- data/lib/tensor_stream/train/saver.rb +12 -10
- data/lib/tensor_stream/trainer.rb +7 -2
- data/lib/tensor_stream/utils.rb +13 -11
- data/lib/tensor_stream/utils/freezer.rb +37 -0
- data/lib/tensor_stream/variable.rb +17 -11
- data/lib/tensor_stream/variable_scope.rb +3 -1
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +3 -4
- data/samples/linear_regression.rb +9 -5
- data/samples/logistic_regression.rb +11 -9
- data/samples/mnist_data.rb +8 -10
- metadata +8 -4
@@ -40,7 +40,10 @@ module TensorStream
|
|
40
40
|
return nil if grads.empty?
|
41
41
|
grads.size > 1 ? ts.add_n(grads) : grads[0]
|
42
42
|
else
|
43
|
-
|
43
|
+
|
44
|
+
if computed_op.nil?
|
45
|
+
return nil
|
46
|
+
end
|
44
47
|
_propagate(computed_op, tensor.inputs[0], stop_tensor, nodes_to_compute, stop_gradients)
|
45
48
|
end
|
46
49
|
end
|
@@ -50,6 +53,7 @@ module TensorStream
|
|
50
53
|
node.graph.name_scope("#{node.name}_grad") do
|
51
54
|
x = node.inputs[0] if node.inputs[0]
|
52
55
|
y = node.inputs[1] if node.inputs[1]
|
56
|
+
z = node.inputs[2] if node.inputs[2]
|
53
57
|
|
54
58
|
case node.operation
|
55
59
|
when :add_n
|
@@ -135,7 +139,6 @@ module TensorStream
|
|
135
139
|
reduced = ts.cast(reduction_indices, :int32)
|
136
140
|
idx = ts.range(0, rank)
|
137
141
|
other, = ts.setdiff1d(idx, reduced)
|
138
|
-
|
139
142
|
[ts.concat([reduced, other], 0),
|
140
143
|
ts.reduce_prod(ts.gather(input_shape, reduced)),
|
141
144
|
ts.reduce_prod(ts.gather(input_shape, other))]
|
@@ -239,13 +242,17 @@ module TensorStream
|
|
239
242
|
y = ts.constant(2.0, dtype: x.dtype)
|
240
243
|
ts.multiply(grad, ts.multiply(x, y))
|
241
244
|
when :where
|
242
|
-
x_mask = i_op(:where, i_op(:ones_like,
|
243
|
-
y_mask = i_op(:where, i_op(:zeros_like,
|
244
|
-
[x_mask * grad, y_mask * grad]
|
245
|
-
when :
|
246
|
-
|
247
|
-
|
248
|
-
|
245
|
+
x_mask = i_op(:where, x, i_op(:ones_like, y), i_op(:zeros_like, z))
|
246
|
+
y_mask = i_op(:where, x, i_op(:zeros_like, y), i_op(:ones_like, z))
|
247
|
+
[nil, x_mask * grad, y_mask * grad]
|
248
|
+
when :case
|
249
|
+
n_preds = node.inputs.size - 2
|
250
|
+
|
251
|
+
case_grads = Array.new(n_preds) do |index|
|
252
|
+
i_op(:case_grad, index, node.inputs[0], node.inputs[2 + index], grad)
|
253
|
+
end
|
254
|
+
|
255
|
+
[nil, i_op(:case_grad, -1, node.inputs[0], node.inputs[1], grad)] + case_grads
|
249
256
|
when :mean
|
250
257
|
sum_grad = _sum_grad(x, y, grad)[0]
|
251
258
|
input_shape = ts.shape(x)
|
@@ -283,12 +290,12 @@ module TensorStream
|
|
283
290
|
# hack!! not sure how to fix this yet
|
284
291
|
return grad if %i[softmax_cross_entropy_with_logits_v2 sparse_softmax_cross_entropy_with_logits].include?(node.inputs[0].operation)
|
285
292
|
|
286
|
-
if node.inputs[0].shape.known? && node.inputs[1].
|
293
|
+
if node.inputs[0].shape.known? && node.inputs[1].const_value
|
287
294
|
multiplier = node.inputs[0].shape.shape[0]
|
288
295
|
filler = ts.zeros_like(grad)
|
289
296
|
|
290
297
|
res = Array.new(multiplier) do |index|
|
291
|
-
index == node.inputs[1].
|
298
|
+
index == node.inputs[1].const_value ? grad : filler
|
292
299
|
end
|
293
300
|
[res]
|
294
301
|
end
|
@@ -346,7 +353,7 @@ module TensorStream
|
|
346
353
|
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
347
354
|
new_grad = _op(:reshape, grad, output_shape_kept_dims)
|
348
355
|
|
349
|
-
grad = _op(:
|
356
|
+
grad = _op(:case, [_op(:rank, grad).zero?], _op(:tile, new_grad, tile_scaling), _op(:fill, input_shape, grad))
|
350
357
|
|
351
358
|
[grad, nil]
|
352
359
|
end
|
@@ -1,3 +1,10 @@
|
|
1
1
|
class Float
|
2
2
|
include TensorStream::MonkeyPatch
|
3
|
+
|
4
|
+
def self.placeholder(name: nil, width: 32, shape: nil)
|
5
|
+
raise "invalid width passed #{width}" unless [16, 32, 64].include?(width)
|
6
|
+
|
7
|
+
data_type = :"float#{width}"
|
8
|
+
TensorStream.placeholder(data_type, name: name, shape: shape)
|
9
|
+
end
|
3
10
|
end
|
@@ -1,3 +1,10 @@
|
|
1
1
|
class Integer
|
2
2
|
include TensorStream::MonkeyPatch
|
3
|
+
|
4
|
+
def self.placeholder(name: nil, width: 32, shape: nil)
|
5
|
+
raise "invalid width passed #{width}" unless [16, 32, 64].include?(width)
|
6
|
+
|
7
|
+
data_type = :"int#{width}"
|
8
|
+
TensorStream.placeholder(data_type, name: name, shape: shape)
|
9
|
+
end
|
3
10
|
end
|
@@ -19,13 +19,13 @@ module TensorStream
|
|
19
19
|
TensorStream.shape_eval(self)
|
20
20
|
end
|
21
21
|
|
22
|
-
def t(name = nil)
|
23
|
-
TensorStream.convert_to_tensor(self, name: name)
|
22
|
+
def t(name = nil, dtype: nil)
|
23
|
+
TensorStream.convert_to_tensor(self, name: name, dtype: dtype)
|
24
24
|
end
|
25
25
|
|
26
26
|
def +(other)
|
27
27
|
if other.is_a?(TensorStream::Tensor)
|
28
|
-
TensorStream.convert_to_tensor(self) + other
|
28
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type) + other
|
29
29
|
else
|
30
30
|
_tensor_stream_add_orig(other)
|
31
31
|
end
|
@@ -33,7 +33,7 @@ module TensorStream
|
|
33
33
|
|
34
34
|
def -(other)
|
35
35
|
if other.is_a?(TensorStream::Tensor)
|
36
|
-
TensorStream.convert_to_tensor(self) - other
|
36
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type) - other
|
37
37
|
else
|
38
38
|
_tensor_stream_sub_orig(other)
|
39
39
|
end
|
@@ -41,7 +41,7 @@ module TensorStream
|
|
41
41
|
|
42
42
|
def *(other)
|
43
43
|
if other.is_a?(TensorStream::Tensor)
|
44
|
-
TensorStream.convert_to_tensor(self) * other
|
44
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type) * other
|
45
45
|
else
|
46
46
|
_tensor_stream_mul_orig(other)
|
47
47
|
end
|
@@ -49,7 +49,7 @@ module TensorStream
|
|
49
49
|
|
50
50
|
def /(other)
|
51
51
|
if other.is_a?(TensorStream::Tensor)
|
52
|
-
TensorStream.convert_to_tensor(self) * other
|
52
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type) * other
|
53
53
|
else
|
54
54
|
_tensor_stream_div_orig(other)
|
55
55
|
end
|
@@ -57,7 +57,7 @@ module TensorStream
|
|
57
57
|
|
58
58
|
def %(other)
|
59
59
|
if other.is_a?(TensorStream::Tensor)
|
60
|
-
TensorStream.convert_to_tensor(self) % other
|
60
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type) % other
|
61
61
|
else
|
62
62
|
_tensor_stream_mod_orig(other)
|
63
63
|
end
|
@@ -65,7 +65,7 @@ module TensorStream
|
|
65
65
|
|
66
66
|
def **(other)
|
67
67
|
if other.is_a?(TensorStream::Tensor)
|
68
|
-
TensorStream.convert_to_tensor(self)**other
|
68
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type)**other
|
69
69
|
else
|
70
70
|
_tensor_stream_pow_orig(other)
|
71
71
|
end
|
@@ -2,33 +2,18 @@ require 'tensor_stream/helpers/infer_shape'
|
|
2
2
|
module TensorStream
|
3
3
|
# TensorStream class that defines an operation
|
4
4
|
class Operation < Tensor
|
5
|
-
|
6
|
-
attr_reader :outputs
|
5
|
+
include OpHelper
|
7
6
|
|
8
|
-
|
9
|
-
|
10
|
-
args.pop
|
11
|
-
else
|
12
|
-
{}
|
13
|
-
end
|
14
|
-
|
15
|
-
inputs = args
|
16
|
-
|
17
|
-
setup_initial_state(options)
|
18
|
-
|
19
|
-
@operation = operation
|
20
|
-
@rank = options[:rank] || 0
|
21
|
-
@name = [@graph.get_name_scope, options[:name] || set_name].compact.reject(&:empty?).join('/')
|
22
|
-
@internal = options[:internal]
|
23
|
-
@given_name = @name
|
7
|
+
attr_accessor :name, :operation, :inputs, :rank, :device, :consumers, :breakpoint
|
8
|
+
attr_reader :outputs, :options, :is_const, :data_type, :shape
|
24
9
|
|
10
|
+
def initialize(graph, inputs:, options:)
|
11
|
+
@consumers = Set.new
|
12
|
+
@outputs = []
|
13
|
+
@op = self
|
14
|
+
@graph = graph
|
15
|
+
@inputs = inputs
|
25
16
|
@options = options
|
26
|
-
|
27
|
-
@inputs = inputs.map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
|
28
|
-
@data_type = set_data_type(options[:data_type])
|
29
|
-
@is_const = infer_const
|
30
|
-
@shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
|
31
|
-
@graph.add_node(self)
|
32
17
|
end
|
33
18
|
|
34
19
|
def to_s
|
@@ -37,25 +22,74 @@ module TensorStream
|
|
37
22
|
|
38
23
|
def to_h
|
39
24
|
{
|
40
|
-
op: operation,
|
41
|
-
name: name,
|
42
|
-
|
25
|
+
op: operation.to_s,
|
26
|
+
name: name.to_s,
|
27
|
+
data_type: @data_type,
|
28
|
+
inputs: @inputs.map(&:name),
|
29
|
+
attrs: serialize_options
|
43
30
|
}
|
44
31
|
end
|
45
32
|
|
33
|
+
def const_value
|
34
|
+
@options ? @options[:value] : nil
|
35
|
+
end
|
36
|
+
|
37
|
+
def container_buffer
|
38
|
+
@options[:container] ? @options[:container].buffer : nil
|
39
|
+
end
|
40
|
+
|
41
|
+
def container
|
42
|
+
@options[:container].read_value
|
43
|
+
end
|
44
|
+
|
45
|
+
def container=(value)
|
46
|
+
@options[:container].value = value
|
47
|
+
end
|
48
|
+
|
49
|
+
def set_input(index, value)
|
50
|
+
@inputs[index] = value
|
51
|
+
@shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
|
52
|
+
@rank = @shape.rank
|
53
|
+
@is_const = infer_const
|
54
|
+
@data_type = set_data_type(@options[:data_type])
|
55
|
+
end
|
56
|
+
|
46
57
|
def infer_const
|
47
58
|
return false if breakpoint
|
59
|
+
|
48
60
|
case operation
|
49
61
|
when :random_standard_normal, :random_uniform, :truncated_normal, :glorot_uniform, :print, :check_numerics
|
50
62
|
false
|
63
|
+
when :const
|
64
|
+
true
|
65
|
+
when :placeholder
|
66
|
+
false
|
67
|
+
when :variable_v2
|
68
|
+
false
|
51
69
|
else
|
52
70
|
non_const = @inputs.compact.find { |input| !input.is_const }
|
53
71
|
non_const ? false : true
|
54
72
|
end
|
55
73
|
end
|
56
74
|
|
75
|
+
def set_name
|
76
|
+
@operation.to_s
|
77
|
+
end
|
78
|
+
|
57
79
|
def set_data_type(passed_data_type)
|
58
80
|
case operation
|
81
|
+
when :where
|
82
|
+
@inputs[1].data_type
|
83
|
+
when :case
|
84
|
+
if @inputs[2]
|
85
|
+
@inputs[2].data_type
|
86
|
+
else
|
87
|
+
@inputs[1].data_type
|
88
|
+
end
|
89
|
+
when :case_grad
|
90
|
+
@inputs[2].data_type
|
91
|
+
when :placeholder, :variable_v2, :const
|
92
|
+
options[:data_type]
|
59
93
|
when :fill
|
60
94
|
@inputs[1].data_type
|
61
95
|
when :greater, :less, :equal, :not_equal, :greater_equal, :less_equal, :logical_and
|
@@ -70,9 +104,8 @@ module TensorStream
|
|
70
104
|
@inputs[1].data_type
|
71
105
|
when :index
|
72
106
|
if @inputs[0].is_a?(ControlFlow)
|
73
|
-
|
74
107
|
if @inputs[1].is_const
|
75
|
-
@inputs[0].inputs[@inputs[1].
|
108
|
+
@inputs[0].inputs[@inputs[1].const_value].data_type
|
76
109
|
else
|
77
110
|
:unknown
|
78
111
|
end
|
@@ -163,8 +196,6 @@ module TensorStream
|
|
163
196
|
"reshape(#{sub_input},#{sub_input2})"
|
164
197
|
when :rank
|
165
198
|
"#{sub_input}.rank"
|
166
|
-
when :cond
|
167
|
-
"(#{auto_math(options[:pred], name_only, max_depth - 1, cur_depth)} ? #{sub_input} : #{sub_input2})"
|
168
199
|
when :less
|
169
200
|
"#{sub_input} < #{sub_input2}"
|
170
201
|
when :less_equal
|
@@ -222,12 +253,41 @@ module TensorStream
|
|
222
253
|
|
223
254
|
private
|
224
255
|
|
256
|
+
def serialize_options
|
257
|
+
excludes = %i[internal_name source]
|
258
|
+
|
259
|
+
@options.reject { |k, v| excludes.include?(k) || v.nil? }.map do |k, v|
|
260
|
+
v = case v.class.to_s
|
261
|
+
when 'TensorStream::TensorShape'
|
262
|
+
v.shape
|
263
|
+
when 'Array'
|
264
|
+
v
|
265
|
+
when 'String', 'Integer', 'Float', 'Symbol', 'FalseClass', "TrueClass"
|
266
|
+
v
|
267
|
+
when 'TensorStream::Variable'
|
268
|
+
{ name: v.name, options: v.options, shape: v.shape.shape.dup }
|
269
|
+
else
|
270
|
+
raise "unknown type #{v.class}"
|
271
|
+
end
|
272
|
+
[k.to_sym, v]
|
273
|
+
end.to_h
|
274
|
+
end
|
275
|
+
|
276
|
+
def add_consumer(consumer)
|
277
|
+
@consumers << consumer.name if consumer.name != name
|
278
|
+
end
|
279
|
+
|
280
|
+
def setup_output(consumer)
|
281
|
+
@outputs << consumer.name unless @outputs.include?(consumer.name)
|
282
|
+
end
|
283
|
+
|
225
284
|
def propagate_consumer(consumer)
|
226
|
-
|
285
|
+
add_consumer(consumer)
|
227
286
|
@inputs.compact.each do |input|
|
228
287
|
if input.is_a?(Array)
|
229
|
-
input.flatten.compact.select { |t| t.is_a?(Tensor) }.each do |t|
|
288
|
+
input.flatten.compact.map(&:op).select { |t| t.is_a?(Tensor) }.each do |t|
|
230
289
|
next if t.consumers.include?(consumer.name)
|
290
|
+
|
231
291
|
t.send(:propagate_consumer, consumer)
|
232
292
|
end
|
233
293
|
elsif input.name != name && !input.consumers.include?(consumer.name)
|
@@ -239,7 +299,7 @@ module TensorStream
|
|
239
299
|
def propagate_outputs
|
240
300
|
@inputs.compact.each do |input|
|
241
301
|
if input.is_a?(Array)
|
242
|
-
input.flatten.compact.each do |t|
|
302
|
+
input.flatten.compact.map(&:op).each do |t|
|
243
303
|
t.send(:setup_output, self) if t.is_a?(Tensor)
|
244
304
|
end
|
245
305
|
elsif input.is_a?(Tensor) && (input.name != name)
|
@@ -248,8 +308,10 @@ module TensorStream
|
|
248
308
|
end
|
249
309
|
end
|
250
310
|
|
251
|
-
def
|
252
|
-
|
311
|
+
def setup_initial_state(options)
|
312
|
+
@outputs = []
|
313
|
+
@graph = options[:__graph] || TensorStream.get_default_graph
|
314
|
+
@source = format_source(caller_locations)
|
253
315
|
end
|
254
316
|
end
|
255
317
|
end
|
data/lib/tensor_stream/ops.rb
CHANGED
@@ -58,7 +58,8 @@ module TensorStream
|
|
58
58
|
# +wrt_xs+ : A Tensor or list of tensors to be used for differentiation.
|
59
59
|
# +stop_gradients+ : Optional. A Tensor or list of tensors not to differentiate through
|
60
60
|
def gradients(tensor_ys, wrt_xs, name: 'gradients', stop_gradients: nil)
|
61
|
-
|
61
|
+
tensor_ys = tensor_ys.op
|
62
|
+
gs = wrt_xs.map(&:op).collect do |x|
|
62
63
|
stops = stop_gradients ? stop_gradients.map(&:name).join('_') : ''
|
63
64
|
gradient_program_name = "grad_#{tensor_ys.name}_#{x.name}_#{stops}".to_sym
|
64
65
|
tensor_graph = tensor_ys.graph
|
@@ -82,21 +83,21 @@ module TensorStream
|
|
82
83
|
# Outputs random values from a uniform distribution.
|
83
84
|
def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
|
84
85
|
options = { dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name }
|
85
|
-
_op(:random_uniform, shape,
|
86
|
+
_op(:random_uniform, shape, options)
|
86
87
|
end
|
87
88
|
|
88
89
|
##
|
89
90
|
# Outputs random values from a normal distribution.
|
90
91
|
def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
|
91
92
|
options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
|
92
|
-
_op(:random_standard_normal, shape,
|
93
|
+
_op(:random_standard_normal, shape, options)
|
93
94
|
end
|
94
95
|
|
95
96
|
##
|
96
97
|
# Outputs random values from a truncated normal distribution.
|
97
98
|
def truncated_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
|
98
99
|
options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
|
99
|
-
_op(:truncated_normal, shape,
|
100
|
+
_op(:truncated_normal, shape, options)
|
100
101
|
end
|
101
102
|
|
102
103
|
##
|
@@ -172,14 +173,14 @@ module TensorStream
|
|
172
173
|
# initializer that generates tensors initialized to 0.
|
173
174
|
#
|
174
175
|
def zeros_initializer(dtype: :float32)
|
175
|
-
TensorStream::Initializer.new(-> { _op(:zeros,
|
176
|
+
TensorStream::Initializer.new(-> { _op(:zeros, data_type: dtype) })
|
176
177
|
end
|
177
178
|
|
178
179
|
##
|
179
180
|
# initializer that generates tensors initialized to 1.
|
180
181
|
#
|
181
182
|
def ones_initializer(dtype: :float32)
|
182
|
-
TensorStream::Initializer.new(-> { _op(:ones,
|
183
|
+
TensorStream::Initializer.new(-> { _op(:ones, data_type: dtype) })
|
183
184
|
end
|
184
185
|
|
185
186
|
##
|
@@ -189,13 +190,13 @@ module TensorStream
|
|
189
190
|
# where limit is sqrt(6 / (fan_in + fan_out)) where fan_in is the number
|
190
191
|
# of input units in the weight tensor and fan_out is the number of output units in the weight tensor.
|
191
192
|
def glorot_uniform_initializer(seed: nil, dtype: nil)
|
192
|
-
TensorStream::Initializer.new(-> { _op(:glorot_uniform,
|
193
|
+
TensorStream::Initializer.new(-> { _op(:glorot_uniform, seed: seed, data_type: dtype) })
|
193
194
|
end
|
194
195
|
|
195
196
|
##
|
196
197
|
# Initializer that generates tensors with a uniform distribution.
|
197
198
|
def random_uniform_initializer(minval: 0, maxval: 1, seed: nil, dtype: nil)
|
198
|
-
TensorStream::Initializer.new(-> { _op(:random_uniform,
|
199
|
+
TensorStream::Initializer.new(-> { _op(:random_uniform, minval: 0, maxval: 1, seed: seed, data_type: dtype) })
|
199
200
|
end
|
200
201
|
|
201
202
|
##
|
@@ -276,7 +277,7 @@ module TensorStream
|
|
276
277
|
##
|
277
278
|
# Computes the mean of elements across dimensions of a tensor.
|
278
279
|
def reduce_mean(input_tensor, axis = nil, keepdims: false, name: nil)
|
279
|
-
|
280
|
+
reduce(:mean, input_tensor, axis, keepdims: keepdims, name: name)
|
280
281
|
end
|
281
282
|
|
282
283
|
##
|
@@ -288,7 +289,7 @@ module TensorStream
|
|
288
289
|
# If axis has no entries, all dimensions are reduced, and a tensor with a single element
|
289
290
|
# is returned.
|
290
291
|
def reduce_sum(input_tensor, axis = nil, keepdims: false, name: nil)
|
291
|
-
|
292
|
+
reduce(:sum, input_tensor, axis, keepdims: keepdims, name: name)
|
292
293
|
end
|
293
294
|
|
294
295
|
##
|
@@ -300,7 +301,22 @@ module TensorStream
|
|
300
301
|
#
|
301
302
|
# If axis has no entries, all dimensions are reduced, and a tensor with a single element is returned.
|
302
303
|
def reduce_prod(input, axis = nil, keepdims: false, name: nil)
|
303
|
-
|
304
|
+
reduce(:prod, input, axis, keepdims: keepdims, name: name)
|
305
|
+
end
|
306
|
+
|
307
|
+
def reduce(op, input, axis = nil, keepdims: false, name: nil)
|
308
|
+
input = TensorStream.convert_to_tensor(input)
|
309
|
+
axis = if !axis.nil?
|
310
|
+
axis
|
311
|
+
elsif input.shape.scalar?
|
312
|
+
op
|
313
|
+
elsif input.shape.known?
|
314
|
+
(0...input.shape.ndims).to_a
|
315
|
+
else
|
316
|
+
range(0, rank(input))
|
317
|
+
end
|
318
|
+
|
319
|
+
_op(op, input, axis, keepdims: keepdims, name: name)
|
304
320
|
end
|
305
321
|
|
306
322
|
##
|
@@ -395,13 +411,13 @@ module TensorStream
|
|
395
411
|
##
|
396
412
|
# Return true_fn() if the predicate pred is true else false_fn().
|
397
413
|
def cond(pred, true_fn, false_fn, name: nil)
|
398
|
-
_op(:
|
414
|
+
_op(:case, [pred], false_fn, true_fn, name: name)
|
399
415
|
end
|
400
416
|
|
401
417
|
##
|
402
418
|
# Return the elements, either from x or y, depending on the condition.
|
403
419
|
def where(condition, true_t = nil, false_t = nil, name: nil)
|
404
|
-
_op(:where, true_t, false_t,
|
420
|
+
_op(:where, condition, true_t, false_t, name: name)
|
405
421
|
end
|
406
422
|
|
407
423
|
##
|
@@ -486,6 +502,7 @@ module TensorStream
|
|
486
502
|
def max(input_a, input_b, name: nil)
|
487
503
|
check_allowed_types(input_a, NUMERIC_TYPES)
|
488
504
|
check_allowed_types(input_b, NUMERIC_TYPES)
|
505
|
+
|
489
506
|
input_a, input_b = check_data_types(input_a, input_b)
|
490
507
|
_op(:max, input_a, input_b, name: name)
|
491
508
|
end
|
@@ -652,6 +669,13 @@ module TensorStream
|
|
652
669
|
_op(:tanh, input, name: name)
|
653
670
|
end
|
654
671
|
|
672
|
+
##
|
673
|
+
# Computes sec of input element-wise.
|
674
|
+
def sec(input, name: nil)
|
675
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
676
|
+
_op(:sec, input, name: name)
|
677
|
+
end
|
678
|
+
|
655
679
|
##
|
656
680
|
# Computes sqrt of input element-wise.
|
657
681
|
def sqrt(input, name: nil)
|
@@ -819,6 +843,34 @@ module TensorStream
|
|
819
843
|
[result[0], result[1]]
|
820
844
|
end
|
821
845
|
|
846
|
+
##
|
847
|
+
# Create a case operation.
|
848
|
+
#
|
849
|
+
# The pred_fn_pairs parameter is a dict or list of pairs of size N.
|
850
|
+
# Each pair contains a boolean scalar tensor and a proc that creates the tensors to be returned if the boolean evaluates to true.
|
851
|
+
# default is a proc generating a list of tensors. All the proc in pred_fn_pairs as well as default (if provided) should return the
|
852
|
+
# same number and types of tensors.
|
853
|
+
#
|
854
|
+
def case(args = {})
|
855
|
+
args = args.dup
|
856
|
+
default = args.delete(:default)
|
857
|
+
exclusive = args.delete(:exclusive)
|
858
|
+
strict = args.delete(:strict)
|
859
|
+
name = args.delete(:name)
|
860
|
+
|
861
|
+
predicates = []
|
862
|
+
functions = []
|
863
|
+
|
864
|
+
args.each do |k, v|
|
865
|
+
raise "Invalid argment or option #{k}" unless k.is_a?(Tensor)
|
866
|
+
|
867
|
+
predicates << k
|
868
|
+
functions << (v.is_a?(Proc) ? v.call : v)
|
869
|
+
end
|
870
|
+
|
871
|
+
_op(:case, predicates, default, *functions, exclusive: exclusive, strict: strict, name: name)
|
872
|
+
end
|
873
|
+
|
822
874
|
def cumprod(x, axis: 0, exclusive: false, reverse: false, name: nil)
|
823
875
|
_op(:cumprod, x, axis: axis, exclusive: exclusive, reverse: reverse, name: name)
|
824
876
|
end
|