tensor_stream 0.9.8 → 0.9.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +31 -14
- data/lib/tensor_stream.rb +4 -0
- data/lib/tensor_stream/constant.rb +41 -0
- data/lib/tensor_stream/control_flow.rb +2 -1
- data/lib/tensor_stream/dynamic_stitch.rb +3 -1
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -4
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +74 -23
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +45 -43
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +31 -30
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +6 -6
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +46 -111
- data/lib/tensor_stream/graph.rb +61 -12
- data/lib/tensor_stream/graph_builder.rb +3 -3
- data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +38 -0
- data/lib/tensor_stream/graph_serializers/packer.rb +8 -0
- data/lib/tensor_stream/graph_serializers/pbtext.rb +62 -27
- data/lib/tensor_stream/graph_serializers/serializer.rb +2 -2
- data/lib/tensor_stream/graph_serializers/yaml.rb +27 -0
- data/lib/tensor_stream/helpers/infer_shape.rb +15 -9
- data/lib/tensor_stream/helpers/op_helper.rb +17 -6
- data/lib/tensor_stream/helpers/string_helper.rb +32 -1
- data/lib/tensor_stream/helpers/tensor_mixins.rb +135 -0
- data/lib/tensor_stream/math_gradients.rb +19 -12
- data/lib/tensor_stream/monkey_patches/float.rb +7 -0
- data/lib/tensor_stream/monkey_patches/integer.rb +7 -0
- data/lib/tensor_stream/monkey_patches/patch.rb +8 -8
- data/lib/tensor_stream/nn/nn_ops.rb +1 -1
- data/lib/tensor_stream/operation.rb +98 -36
- data/lib/tensor_stream/ops.rb +65 -13
- data/lib/tensor_stream/placeholder.rb +2 -2
- data/lib/tensor_stream/session.rb +15 -3
- data/lib/tensor_stream/tensor.rb +15 -172
- data/lib/tensor_stream/tensor_shape.rb +3 -1
- data/lib/tensor_stream/train/saver.rb +12 -10
- data/lib/tensor_stream/trainer.rb +7 -2
- data/lib/tensor_stream/utils.rb +13 -11
- data/lib/tensor_stream/utils/freezer.rb +37 -0
- data/lib/tensor_stream/variable.rb +17 -11
- data/lib/tensor_stream/variable_scope.rb +3 -1
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +3 -4
- data/samples/linear_regression.rb +9 -5
- data/samples/logistic_regression.rb +11 -9
- data/samples/mnist_data.rb +8 -10
- metadata +8 -4
@@ -40,7 +40,10 @@ module TensorStream
|
|
40
40
|
return nil if grads.empty?
|
41
41
|
grads.size > 1 ? ts.add_n(grads) : grads[0]
|
42
42
|
else
|
43
|
-
|
43
|
+
|
44
|
+
if computed_op.nil?
|
45
|
+
return nil
|
46
|
+
end
|
44
47
|
_propagate(computed_op, tensor.inputs[0], stop_tensor, nodes_to_compute, stop_gradients)
|
45
48
|
end
|
46
49
|
end
|
@@ -50,6 +53,7 @@ module TensorStream
|
|
50
53
|
node.graph.name_scope("#{node.name}_grad") do
|
51
54
|
x = node.inputs[0] if node.inputs[0]
|
52
55
|
y = node.inputs[1] if node.inputs[1]
|
56
|
+
z = node.inputs[2] if node.inputs[2]
|
53
57
|
|
54
58
|
case node.operation
|
55
59
|
when :add_n
|
@@ -135,7 +139,6 @@ module TensorStream
|
|
135
139
|
reduced = ts.cast(reduction_indices, :int32)
|
136
140
|
idx = ts.range(0, rank)
|
137
141
|
other, = ts.setdiff1d(idx, reduced)
|
138
|
-
|
139
142
|
[ts.concat([reduced, other], 0),
|
140
143
|
ts.reduce_prod(ts.gather(input_shape, reduced)),
|
141
144
|
ts.reduce_prod(ts.gather(input_shape, other))]
|
@@ -239,13 +242,17 @@ module TensorStream
|
|
239
242
|
y = ts.constant(2.0, dtype: x.dtype)
|
240
243
|
ts.multiply(grad, ts.multiply(x, y))
|
241
244
|
when :where
|
242
|
-
x_mask = i_op(:where, i_op(:ones_like,
|
243
|
-
y_mask = i_op(:where, i_op(:zeros_like,
|
244
|
-
[x_mask * grad, y_mask * grad]
|
245
|
-
when :
|
246
|
-
|
247
|
-
|
248
|
-
|
245
|
+
x_mask = i_op(:where, x, i_op(:ones_like, y), i_op(:zeros_like, z))
|
246
|
+
y_mask = i_op(:where, x, i_op(:zeros_like, y), i_op(:ones_like, z))
|
247
|
+
[nil, x_mask * grad, y_mask * grad]
|
248
|
+
when :case
|
249
|
+
n_preds = node.inputs.size - 2
|
250
|
+
|
251
|
+
case_grads = Array.new(n_preds) do |index|
|
252
|
+
i_op(:case_grad, index, node.inputs[0], node.inputs[2 + index], grad)
|
253
|
+
end
|
254
|
+
|
255
|
+
[nil, i_op(:case_grad, -1, node.inputs[0], node.inputs[1], grad)] + case_grads
|
249
256
|
when :mean
|
250
257
|
sum_grad = _sum_grad(x, y, grad)[0]
|
251
258
|
input_shape = ts.shape(x)
|
@@ -283,12 +290,12 @@ module TensorStream
|
|
283
290
|
# hack!! not sure how to fix this yet
|
284
291
|
return grad if %i[softmax_cross_entropy_with_logits_v2 sparse_softmax_cross_entropy_with_logits].include?(node.inputs[0].operation)
|
285
292
|
|
286
|
-
if node.inputs[0].shape.known? && node.inputs[1].
|
293
|
+
if node.inputs[0].shape.known? && node.inputs[1].const_value
|
287
294
|
multiplier = node.inputs[0].shape.shape[0]
|
288
295
|
filler = ts.zeros_like(grad)
|
289
296
|
|
290
297
|
res = Array.new(multiplier) do |index|
|
291
|
-
index == node.inputs[1].
|
298
|
+
index == node.inputs[1].const_value ? grad : filler
|
292
299
|
end
|
293
300
|
[res]
|
294
301
|
end
|
@@ -346,7 +353,7 @@ module TensorStream
|
|
346
353
|
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
347
354
|
new_grad = _op(:reshape, grad, output_shape_kept_dims)
|
348
355
|
|
349
|
-
grad = _op(:
|
356
|
+
grad = _op(:case, [_op(:rank, grad).zero?], _op(:tile, new_grad, tile_scaling), _op(:fill, input_shape, grad))
|
350
357
|
|
351
358
|
[grad, nil]
|
352
359
|
end
|
@@ -1,3 +1,10 @@
|
|
1
1
|
class Float
|
2
2
|
include TensorStream::MonkeyPatch
|
3
|
+
|
4
|
+
def self.placeholder(name: nil, width: 32, shape: nil)
|
5
|
+
raise "invalid width passed #{width}" unless [16, 32, 64].include?(width)
|
6
|
+
|
7
|
+
data_type = :"float#{width}"
|
8
|
+
TensorStream.placeholder(data_type, name: name, shape: shape)
|
9
|
+
end
|
3
10
|
end
|
@@ -1,3 +1,10 @@
|
|
1
1
|
class Integer
|
2
2
|
include TensorStream::MonkeyPatch
|
3
|
+
|
4
|
+
def self.placeholder(name: nil, width: 32, shape: nil)
|
5
|
+
raise "invalid width passed #{width}" unless [16, 32, 64].include?(width)
|
6
|
+
|
7
|
+
data_type = :"int#{width}"
|
8
|
+
TensorStream.placeholder(data_type, name: name, shape: shape)
|
9
|
+
end
|
3
10
|
end
|
@@ -19,13 +19,13 @@ module TensorStream
|
|
19
19
|
TensorStream.shape_eval(self)
|
20
20
|
end
|
21
21
|
|
22
|
-
def t(name = nil)
|
23
|
-
TensorStream.convert_to_tensor(self, name: name)
|
22
|
+
def t(name = nil, dtype: nil)
|
23
|
+
TensorStream.convert_to_tensor(self, name: name, dtype: dtype)
|
24
24
|
end
|
25
25
|
|
26
26
|
def +(other)
|
27
27
|
if other.is_a?(TensorStream::Tensor)
|
28
|
-
TensorStream.convert_to_tensor(self) + other
|
28
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type) + other
|
29
29
|
else
|
30
30
|
_tensor_stream_add_orig(other)
|
31
31
|
end
|
@@ -33,7 +33,7 @@ module TensorStream
|
|
33
33
|
|
34
34
|
def -(other)
|
35
35
|
if other.is_a?(TensorStream::Tensor)
|
36
|
-
TensorStream.convert_to_tensor(self) - other
|
36
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type) - other
|
37
37
|
else
|
38
38
|
_tensor_stream_sub_orig(other)
|
39
39
|
end
|
@@ -41,7 +41,7 @@ module TensorStream
|
|
41
41
|
|
42
42
|
def *(other)
|
43
43
|
if other.is_a?(TensorStream::Tensor)
|
44
|
-
TensorStream.convert_to_tensor(self) * other
|
44
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type) * other
|
45
45
|
else
|
46
46
|
_tensor_stream_mul_orig(other)
|
47
47
|
end
|
@@ -49,7 +49,7 @@ module TensorStream
|
|
49
49
|
|
50
50
|
def /(other)
|
51
51
|
if other.is_a?(TensorStream::Tensor)
|
52
|
-
TensorStream.convert_to_tensor(self) * other
|
52
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type) * other
|
53
53
|
else
|
54
54
|
_tensor_stream_div_orig(other)
|
55
55
|
end
|
@@ -57,7 +57,7 @@ module TensorStream
|
|
57
57
|
|
58
58
|
def %(other)
|
59
59
|
if other.is_a?(TensorStream::Tensor)
|
60
|
-
TensorStream.convert_to_tensor(self) % other
|
60
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type) % other
|
61
61
|
else
|
62
62
|
_tensor_stream_mod_orig(other)
|
63
63
|
end
|
@@ -65,7 +65,7 @@ module TensorStream
|
|
65
65
|
|
66
66
|
def **(other)
|
67
67
|
if other.is_a?(TensorStream::Tensor)
|
68
|
-
TensorStream.convert_to_tensor(self)**other
|
68
|
+
TensorStream.convert_to_tensor(self, dtype: other.data_type)**other
|
69
69
|
else
|
70
70
|
_tensor_stream_pow_orig(other)
|
71
71
|
end
|
@@ -2,33 +2,18 @@ require 'tensor_stream/helpers/infer_shape'
|
|
2
2
|
module TensorStream
|
3
3
|
# TensorStream class that defines an operation
|
4
4
|
class Operation < Tensor
|
5
|
-
|
6
|
-
attr_reader :outputs
|
5
|
+
include OpHelper
|
7
6
|
|
8
|
-
|
9
|
-
|
10
|
-
args.pop
|
11
|
-
else
|
12
|
-
{}
|
13
|
-
end
|
14
|
-
|
15
|
-
inputs = args
|
16
|
-
|
17
|
-
setup_initial_state(options)
|
18
|
-
|
19
|
-
@operation = operation
|
20
|
-
@rank = options[:rank] || 0
|
21
|
-
@name = [@graph.get_name_scope, options[:name] || set_name].compact.reject(&:empty?).join('/')
|
22
|
-
@internal = options[:internal]
|
23
|
-
@given_name = @name
|
7
|
+
attr_accessor :name, :operation, :inputs, :rank, :device, :consumers, :breakpoint
|
8
|
+
attr_reader :outputs, :options, :is_const, :data_type, :shape
|
24
9
|
|
10
|
+
def initialize(graph, inputs:, options:)
|
11
|
+
@consumers = Set.new
|
12
|
+
@outputs = []
|
13
|
+
@op = self
|
14
|
+
@graph = graph
|
15
|
+
@inputs = inputs
|
25
16
|
@options = options
|
26
|
-
|
27
|
-
@inputs = inputs.map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
|
28
|
-
@data_type = set_data_type(options[:data_type])
|
29
|
-
@is_const = infer_const
|
30
|
-
@shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
|
31
|
-
@graph.add_node(self)
|
32
17
|
end
|
33
18
|
|
34
19
|
def to_s
|
@@ -37,25 +22,74 @@ module TensorStream
|
|
37
22
|
|
38
23
|
def to_h
|
39
24
|
{
|
40
|
-
op: operation,
|
41
|
-
name: name,
|
42
|
-
|
25
|
+
op: operation.to_s,
|
26
|
+
name: name.to_s,
|
27
|
+
data_type: @data_type,
|
28
|
+
inputs: @inputs.map(&:name),
|
29
|
+
attrs: serialize_options
|
43
30
|
}
|
44
31
|
end
|
45
32
|
|
33
|
+
def const_value
|
34
|
+
@options ? @options[:value] : nil
|
35
|
+
end
|
36
|
+
|
37
|
+
def container_buffer
|
38
|
+
@options[:container] ? @options[:container].buffer : nil
|
39
|
+
end
|
40
|
+
|
41
|
+
def container
|
42
|
+
@options[:container].read_value
|
43
|
+
end
|
44
|
+
|
45
|
+
def container=(value)
|
46
|
+
@options[:container].value = value
|
47
|
+
end
|
48
|
+
|
49
|
+
def set_input(index, value)
|
50
|
+
@inputs[index] = value
|
51
|
+
@shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
|
52
|
+
@rank = @shape.rank
|
53
|
+
@is_const = infer_const
|
54
|
+
@data_type = set_data_type(@options[:data_type])
|
55
|
+
end
|
56
|
+
|
46
57
|
def infer_const
|
47
58
|
return false if breakpoint
|
59
|
+
|
48
60
|
case operation
|
49
61
|
when :random_standard_normal, :random_uniform, :truncated_normal, :glorot_uniform, :print, :check_numerics
|
50
62
|
false
|
63
|
+
when :const
|
64
|
+
true
|
65
|
+
when :placeholder
|
66
|
+
false
|
67
|
+
when :variable_v2
|
68
|
+
false
|
51
69
|
else
|
52
70
|
non_const = @inputs.compact.find { |input| !input.is_const }
|
53
71
|
non_const ? false : true
|
54
72
|
end
|
55
73
|
end
|
56
74
|
|
75
|
+
def set_name
|
76
|
+
@operation.to_s
|
77
|
+
end
|
78
|
+
|
57
79
|
def set_data_type(passed_data_type)
|
58
80
|
case operation
|
81
|
+
when :where
|
82
|
+
@inputs[1].data_type
|
83
|
+
when :case
|
84
|
+
if @inputs[2]
|
85
|
+
@inputs[2].data_type
|
86
|
+
else
|
87
|
+
@inputs[1].data_type
|
88
|
+
end
|
89
|
+
when :case_grad
|
90
|
+
@inputs[2].data_type
|
91
|
+
when :placeholder, :variable_v2, :const
|
92
|
+
options[:data_type]
|
59
93
|
when :fill
|
60
94
|
@inputs[1].data_type
|
61
95
|
when :greater, :less, :equal, :not_equal, :greater_equal, :less_equal, :logical_and
|
@@ -70,9 +104,8 @@ module TensorStream
|
|
70
104
|
@inputs[1].data_type
|
71
105
|
when :index
|
72
106
|
if @inputs[0].is_a?(ControlFlow)
|
73
|
-
|
74
107
|
if @inputs[1].is_const
|
75
|
-
@inputs[0].inputs[@inputs[1].
|
108
|
+
@inputs[0].inputs[@inputs[1].const_value].data_type
|
76
109
|
else
|
77
110
|
:unknown
|
78
111
|
end
|
@@ -163,8 +196,6 @@ module TensorStream
|
|
163
196
|
"reshape(#{sub_input},#{sub_input2})"
|
164
197
|
when :rank
|
165
198
|
"#{sub_input}.rank"
|
166
|
-
when :cond
|
167
|
-
"(#{auto_math(options[:pred], name_only, max_depth - 1, cur_depth)} ? #{sub_input} : #{sub_input2})"
|
168
199
|
when :less
|
169
200
|
"#{sub_input} < #{sub_input2}"
|
170
201
|
when :less_equal
|
@@ -222,12 +253,41 @@ module TensorStream
|
|
222
253
|
|
223
254
|
private
|
224
255
|
|
256
|
+
def serialize_options
|
257
|
+
excludes = %i[internal_name source]
|
258
|
+
|
259
|
+
@options.reject { |k, v| excludes.include?(k) || v.nil? }.map do |k, v|
|
260
|
+
v = case v.class.to_s
|
261
|
+
when 'TensorStream::TensorShape'
|
262
|
+
v.shape
|
263
|
+
when 'Array'
|
264
|
+
v
|
265
|
+
when 'String', 'Integer', 'Float', 'Symbol', 'FalseClass', "TrueClass"
|
266
|
+
v
|
267
|
+
when 'TensorStream::Variable'
|
268
|
+
{ name: v.name, options: v.options, shape: v.shape.shape.dup }
|
269
|
+
else
|
270
|
+
raise "unknown type #{v.class}"
|
271
|
+
end
|
272
|
+
[k.to_sym, v]
|
273
|
+
end.to_h
|
274
|
+
end
|
275
|
+
|
276
|
+
def add_consumer(consumer)
|
277
|
+
@consumers << consumer.name if consumer.name != name
|
278
|
+
end
|
279
|
+
|
280
|
+
def setup_output(consumer)
|
281
|
+
@outputs << consumer.name unless @outputs.include?(consumer.name)
|
282
|
+
end
|
283
|
+
|
225
284
|
def propagate_consumer(consumer)
|
226
|
-
|
285
|
+
add_consumer(consumer)
|
227
286
|
@inputs.compact.each do |input|
|
228
287
|
if input.is_a?(Array)
|
229
|
-
input.flatten.compact.select { |t| t.is_a?(Tensor) }.each do |t|
|
288
|
+
input.flatten.compact.map(&:op).select { |t| t.is_a?(Tensor) }.each do |t|
|
230
289
|
next if t.consumers.include?(consumer.name)
|
290
|
+
|
231
291
|
t.send(:propagate_consumer, consumer)
|
232
292
|
end
|
233
293
|
elsif input.name != name && !input.consumers.include?(consumer.name)
|
@@ -239,7 +299,7 @@ module TensorStream
|
|
239
299
|
def propagate_outputs
|
240
300
|
@inputs.compact.each do |input|
|
241
301
|
if input.is_a?(Array)
|
242
|
-
input.flatten.compact.each do |t|
|
302
|
+
input.flatten.compact.map(&:op).each do |t|
|
243
303
|
t.send(:setup_output, self) if t.is_a?(Tensor)
|
244
304
|
end
|
245
305
|
elsif input.is_a?(Tensor) && (input.name != name)
|
@@ -248,8 +308,10 @@ module TensorStream
|
|
248
308
|
end
|
249
309
|
end
|
250
310
|
|
251
|
-
def
|
252
|
-
|
311
|
+
def setup_initial_state(options)
|
312
|
+
@outputs = []
|
313
|
+
@graph = options[:__graph] || TensorStream.get_default_graph
|
314
|
+
@source = format_source(caller_locations)
|
253
315
|
end
|
254
316
|
end
|
255
317
|
end
|
data/lib/tensor_stream/ops.rb
CHANGED
@@ -58,7 +58,8 @@ module TensorStream
|
|
58
58
|
# +wrt_xs+ : A Tensor or list of tensors to be used for differentiation.
|
59
59
|
# +stop_gradients+ : Optional. A Tensor or list of tensors not to differentiate through
|
60
60
|
def gradients(tensor_ys, wrt_xs, name: 'gradients', stop_gradients: nil)
|
61
|
-
|
61
|
+
tensor_ys = tensor_ys.op
|
62
|
+
gs = wrt_xs.map(&:op).collect do |x|
|
62
63
|
stops = stop_gradients ? stop_gradients.map(&:name).join('_') : ''
|
63
64
|
gradient_program_name = "grad_#{tensor_ys.name}_#{x.name}_#{stops}".to_sym
|
64
65
|
tensor_graph = tensor_ys.graph
|
@@ -82,21 +83,21 @@ module TensorStream
|
|
82
83
|
# Outputs random values from a uniform distribution.
|
83
84
|
def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
|
84
85
|
options = { dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name }
|
85
|
-
_op(:random_uniform, shape,
|
86
|
+
_op(:random_uniform, shape, options)
|
86
87
|
end
|
87
88
|
|
88
89
|
##
|
89
90
|
# Outputs random values from a normal distribution.
|
90
91
|
def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
|
91
92
|
options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
|
92
|
-
_op(:random_standard_normal, shape,
|
93
|
+
_op(:random_standard_normal, shape, options)
|
93
94
|
end
|
94
95
|
|
95
96
|
##
|
96
97
|
# Outputs random values from a truncated normal distribution.
|
97
98
|
def truncated_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
|
98
99
|
options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
|
99
|
-
_op(:truncated_normal, shape,
|
100
|
+
_op(:truncated_normal, shape, options)
|
100
101
|
end
|
101
102
|
|
102
103
|
##
|
@@ -172,14 +173,14 @@ module TensorStream
|
|
172
173
|
# initializer that generates tensors initialized to 0.
|
173
174
|
#
|
174
175
|
def zeros_initializer(dtype: :float32)
|
175
|
-
TensorStream::Initializer.new(-> { _op(:zeros,
|
176
|
+
TensorStream::Initializer.new(-> { _op(:zeros, data_type: dtype) })
|
176
177
|
end
|
177
178
|
|
178
179
|
##
|
179
180
|
# initializer that generates tensors initialized to 1.
|
180
181
|
#
|
181
182
|
def ones_initializer(dtype: :float32)
|
182
|
-
TensorStream::Initializer.new(-> { _op(:ones,
|
183
|
+
TensorStream::Initializer.new(-> { _op(:ones, data_type: dtype) })
|
183
184
|
end
|
184
185
|
|
185
186
|
##
|
@@ -189,13 +190,13 @@ module TensorStream
|
|
189
190
|
# where limit is sqrt(6 / (fan_in + fan_out)) where fan_in is the number
|
190
191
|
# of input units in the weight tensor and fan_out is the number of output units in the weight tensor.
|
191
192
|
def glorot_uniform_initializer(seed: nil, dtype: nil)
|
192
|
-
TensorStream::Initializer.new(-> { _op(:glorot_uniform,
|
193
|
+
TensorStream::Initializer.new(-> { _op(:glorot_uniform, seed: seed, data_type: dtype) })
|
193
194
|
end
|
194
195
|
|
195
196
|
##
|
196
197
|
# Initializer that generates tensors with a uniform distribution.
|
197
198
|
def random_uniform_initializer(minval: 0, maxval: 1, seed: nil, dtype: nil)
|
198
|
-
TensorStream::Initializer.new(-> { _op(:random_uniform,
|
199
|
+
TensorStream::Initializer.new(-> { _op(:random_uniform, minval: 0, maxval: 1, seed: seed, data_type: dtype) })
|
199
200
|
end
|
200
201
|
|
201
202
|
##
|
@@ -276,7 +277,7 @@ module TensorStream
|
|
276
277
|
##
|
277
278
|
# Computes the mean of elements across dimensions of a tensor.
|
278
279
|
def reduce_mean(input_tensor, axis = nil, keepdims: false, name: nil)
|
279
|
-
|
280
|
+
reduce(:mean, input_tensor, axis, keepdims: keepdims, name: name)
|
280
281
|
end
|
281
282
|
|
282
283
|
##
|
@@ -288,7 +289,7 @@ module TensorStream
|
|
288
289
|
# If axis has no entries, all dimensions are reduced, and a tensor with a single element
|
289
290
|
# is returned.
|
290
291
|
def reduce_sum(input_tensor, axis = nil, keepdims: false, name: nil)
|
291
|
-
|
292
|
+
reduce(:sum, input_tensor, axis, keepdims: keepdims, name: name)
|
292
293
|
end
|
293
294
|
|
294
295
|
##
|
@@ -300,7 +301,22 @@ module TensorStream
|
|
300
301
|
#
|
301
302
|
# If axis has no entries, all dimensions are reduced, and a tensor with a single element is returned.
|
302
303
|
def reduce_prod(input, axis = nil, keepdims: false, name: nil)
|
303
|
-
|
304
|
+
reduce(:prod, input, axis, keepdims: keepdims, name: name)
|
305
|
+
end
|
306
|
+
|
307
|
+
def reduce(op, input, axis = nil, keepdims: false, name: nil)
|
308
|
+
input = TensorStream.convert_to_tensor(input)
|
309
|
+
axis = if !axis.nil?
|
310
|
+
axis
|
311
|
+
elsif input.shape.scalar?
|
312
|
+
op
|
313
|
+
elsif input.shape.known?
|
314
|
+
(0...input.shape.ndims).to_a
|
315
|
+
else
|
316
|
+
range(0, rank(input))
|
317
|
+
end
|
318
|
+
|
319
|
+
_op(op, input, axis, keepdims: keepdims, name: name)
|
304
320
|
end
|
305
321
|
|
306
322
|
##
|
@@ -395,13 +411,13 @@ module TensorStream
|
|
395
411
|
##
|
396
412
|
# Return true_fn() if the predicate pred is true else false_fn().
|
397
413
|
def cond(pred, true_fn, false_fn, name: nil)
|
398
|
-
_op(:
|
414
|
+
_op(:case, [pred], false_fn, true_fn, name: name)
|
399
415
|
end
|
400
416
|
|
401
417
|
##
|
402
418
|
# Return the elements, either from x or y, depending on the condition.
|
403
419
|
def where(condition, true_t = nil, false_t = nil, name: nil)
|
404
|
-
_op(:where, true_t, false_t,
|
420
|
+
_op(:where, condition, true_t, false_t, name: name)
|
405
421
|
end
|
406
422
|
|
407
423
|
##
|
@@ -486,6 +502,7 @@ module TensorStream
|
|
486
502
|
def max(input_a, input_b, name: nil)
|
487
503
|
check_allowed_types(input_a, NUMERIC_TYPES)
|
488
504
|
check_allowed_types(input_b, NUMERIC_TYPES)
|
505
|
+
|
489
506
|
input_a, input_b = check_data_types(input_a, input_b)
|
490
507
|
_op(:max, input_a, input_b, name: name)
|
491
508
|
end
|
@@ -652,6 +669,13 @@ module TensorStream
|
|
652
669
|
_op(:tanh, input, name: name)
|
653
670
|
end
|
654
671
|
|
672
|
+
##
|
673
|
+
# Computes sec of input element-wise.
|
674
|
+
def sec(input, name: nil)
|
675
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
676
|
+
_op(:sec, input, name: name)
|
677
|
+
end
|
678
|
+
|
655
679
|
##
|
656
680
|
# Computes sqrt of input element-wise.
|
657
681
|
def sqrt(input, name: nil)
|
@@ -819,6 +843,34 @@ module TensorStream
|
|
819
843
|
[result[0], result[1]]
|
820
844
|
end
|
821
845
|
|
846
|
+
##
|
847
|
+
# Create a case operation.
|
848
|
+
#
|
849
|
+
# The pred_fn_pairs parameter is a dict or list of pairs of size N.
|
850
|
+
# Each pair contains a boolean scalar tensor and a proc that creates the tensors to be returned if the boolean evaluates to true.
|
851
|
+
# default is a proc generating a list of tensors. All the proc in pred_fn_pairs as well as default (if provided) should return the
|
852
|
+
# same number and types of tensors.
|
853
|
+
#
|
854
|
+
def case(args = {})
|
855
|
+
args = args.dup
|
856
|
+
default = args.delete(:default)
|
857
|
+
exclusive = args.delete(:exclusive)
|
858
|
+
strict = args.delete(:strict)
|
859
|
+
name = args.delete(:name)
|
860
|
+
|
861
|
+
predicates = []
|
862
|
+
functions = []
|
863
|
+
|
864
|
+
args.each do |k, v|
|
865
|
+
raise "Invalid argment or option #{k}" unless k.is_a?(Tensor)
|
866
|
+
|
867
|
+
predicates << k
|
868
|
+
functions << (v.is_a?(Proc) ? v.call : v)
|
869
|
+
end
|
870
|
+
|
871
|
+
_op(:case, predicates, default, *functions, exclusive: exclusive, strict: strict, name: name)
|
872
|
+
end
|
873
|
+
|
822
874
|
def cumprod(x, axis: 0, exclusive: false, reverse: false, name: nil)
|
823
875
|
_op(:cumprod, x, axis: axis, exclusive: exclusive, reverse: reverse, name: name)
|
824
876
|
end
|