tensor_stream 0.9.8 → 0.9.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +31 -14
  3. data/lib/tensor_stream.rb +4 -0
  4. data/lib/tensor_stream/constant.rb +41 -0
  5. data/lib/tensor_stream/control_flow.rb +2 -1
  6. data/lib/tensor_stream/dynamic_stitch.rb +3 -1
  7. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -4
  8. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +74 -23
  9. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +45 -43
  10. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +31 -30
  11. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +6 -6
  12. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +46 -111
  13. data/lib/tensor_stream/graph.rb +61 -12
  14. data/lib/tensor_stream/graph_builder.rb +3 -3
  15. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +38 -0
  16. data/lib/tensor_stream/graph_serializers/packer.rb +8 -0
  17. data/lib/tensor_stream/graph_serializers/pbtext.rb +62 -27
  18. data/lib/tensor_stream/graph_serializers/serializer.rb +2 -2
  19. data/lib/tensor_stream/graph_serializers/yaml.rb +27 -0
  20. data/lib/tensor_stream/helpers/infer_shape.rb +15 -9
  21. data/lib/tensor_stream/helpers/op_helper.rb +17 -6
  22. data/lib/tensor_stream/helpers/string_helper.rb +32 -1
  23. data/lib/tensor_stream/helpers/tensor_mixins.rb +135 -0
  24. data/lib/tensor_stream/math_gradients.rb +19 -12
  25. data/lib/tensor_stream/monkey_patches/float.rb +7 -0
  26. data/lib/tensor_stream/monkey_patches/integer.rb +7 -0
  27. data/lib/tensor_stream/monkey_patches/patch.rb +8 -8
  28. data/lib/tensor_stream/nn/nn_ops.rb +1 -1
  29. data/lib/tensor_stream/operation.rb +98 -36
  30. data/lib/tensor_stream/ops.rb +65 -13
  31. data/lib/tensor_stream/placeholder.rb +2 -2
  32. data/lib/tensor_stream/session.rb +15 -3
  33. data/lib/tensor_stream/tensor.rb +15 -172
  34. data/lib/tensor_stream/tensor_shape.rb +3 -1
  35. data/lib/tensor_stream/train/saver.rb +12 -10
  36. data/lib/tensor_stream/trainer.rb +7 -2
  37. data/lib/tensor_stream/utils.rb +13 -11
  38. data/lib/tensor_stream/utils/freezer.rb +37 -0
  39. data/lib/tensor_stream/variable.rb +17 -11
  40. data/lib/tensor_stream/variable_scope.rb +3 -1
  41. data/lib/tensor_stream/version.rb +1 -1
  42. data/samples/iris.rb +3 -4
  43. data/samples/linear_regression.rb +9 -5
  44. data/samples/logistic_regression.rb +11 -9
  45. data/samples/mnist_data.rb +8 -10
  46. metadata +8 -4
@@ -40,7 +40,10 @@ module TensorStream
40
40
  return nil if grads.empty?
41
41
  grads.size > 1 ? ts.add_n(grads) : grads[0]
42
42
  else
43
- return nil if computed_op.nil?
43
+
44
+ if computed_op.nil?
45
+ return nil
46
+ end
44
47
  _propagate(computed_op, tensor.inputs[0], stop_tensor, nodes_to_compute, stop_gradients)
45
48
  end
46
49
  end
@@ -50,6 +53,7 @@ module TensorStream
50
53
  node.graph.name_scope("#{node.name}_grad") do
51
54
  x = node.inputs[0] if node.inputs[0]
52
55
  y = node.inputs[1] if node.inputs[1]
56
+ z = node.inputs[2] if node.inputs[2]
53
57
 
54
58
  case node.operation
55
59
  when :add_n
@@ -135,7 +139,6 @@ module TensorStream
135
139
  reduced = ts.cast(reduction_indices, :int32)
136
140
  idx = ts.range(0, rank)
137
141
  other, = ts.setdiff1d(idx, reduced)
138
-
139
142
  [ts.concat([reduced, other], 0),
140
143
  ts.reduce_prod(ts.gather(input_shape, reduced)),
141
144
  ts.reduce_prod(ts.gather(input_shape, other))]
@@ -239,13 +242,17 @@ module TensorStream
239
242
  y = ts.constant(2.0, dtype: x.dtype)
240
243
  ts.multiply(grad, ts.multiply(x, y))
241
244
  when :where
242
- x_mask = i_op(:where, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
243
- y_mask = i_op(:where, i_op(:zeros_like, x), i_op(:ones_like, y), pred: node.options[:pred])
244
- [x_mask * grad, y_mask * grad]
245
- when :cond
246
- x_cond = i_op(:cond, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
247
- y_cond = i_op(:cond, i_op(:zeros_like, x), i_op(:ones_like, x), pred: node.options[:pred])
248
- [x_cond * grad, y_cond * grad]
245
+ x_mask = i_op(:where, x, i_op(:ones_like, y), i_op(:zeros_like, z))
246
+ y_mask = i_op(:where, x, i_op(:zeros_like, y), i_op(:ones_like, z))
247
+ [nil, x_mask * grad, y_mask * grad]
248
+ when :case
249
+ n_preds = node.inputs.size - 2
250
+
251
+ case_grads = Array.new(n_preds) do |index|
252
+ i_op(:case_grad, index, node.inputs[0], node.inputs[2 + index], grad)
253
+ end
254
+
255
+ [nil, i_op(:case_grad, -1, node.inputs[0], node.inputs[1], grad)] + case_grads
249
256
  when :mean
250
257
  sum_grad = _sum_grad(x, y, grad)[0]
251
258
  input_shape = ts.shape(x)
@@ -283,12 +290,12 @@ module TensorStream
283
290
  # hack!! not sure how to fix this yet
284
291
  return grad if %i[softmax_cross_entropy_with_logits_v2 sparse_softmax_cross_entropy_with_logits].include?(node.inputs[0].operation)
285
292
 
286
- if node.inputs[0].shape.known? && node.inputs[1].value
293
+ if node.inputs[0].shape.known? && node.inputs[1].const_value
287
294
  multiplier = node.inputs[0].shape.shape[0]
288
295
  filler = ts.zeros_like(grad)
289
296
 
290
297
  res = Array.new(multiplier) do |index|
291
- index == node.inputs[1].value ? grad : filler
298
+ index == node.inputs[1].const_value ? grad : filler
292
299
  end
293
300
  [res]
294
301
  end
@@ -346,7 +353,7 @@ module TensorStream
346
353
  tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
347
354
  new_grad = _op(:reshape, grad, output_shape_kept_dims)
348
355
 
349
- grad = _op(:cond, _op(:fill, input_shape, grad), _op(:tile, new_grad, tile_scaling), pred: _op(:rank, grad).zero?)
356
+ grad = _op(:case, [_op(:rank, grad).zero?], _op(:tile, new_grad, tile_scaling), _op(:fill, input_shape, grad))
350
357
 
351
358
  [grad, nil]
352
359
  end
@@ -1,3 +1,10 @@
1
1
  class Float
2
2
  include TensorStream::MonkeyPatch
3
+
4
+ def self.placeholder(name: nil, width: 32, shape: nil)
5
+ raise "invalid width passed #{width}" unless [16, 32, 64].include?(width)
6
+
7
+ data_type = :"float#{width}"
8
+ TensorStream.placeholder(data_type, name: name, shape: shape)
9
+ end
3
10
  end
@@ -1,3 +1,10 @@
1
1
  class Integer
2
2
  include TensorStream::MonkeyPatch
3
+
4
+ def self.placeholder(name: nil, width: 32, shape: nil)
5
+ raise "invalid width passed #{width}" unless [16, 32, 64].include?(width)
6
+
7
+ data_type = :"int#{width}"
8
+ TensorStream.placeholder(data_type, name: name, shape: shape)
9
+ end
3
10
  end
@@ -19,13 +19,13 @@ module TensorStream
19
19
  TensorStream.shape_eval(self)
20
20
  end
21
21
 
22
- def t(name = nil)
23
- TensorStream.convert_to_tensor(self, name: name)
22
+ def t(name = nil, dtype: nil)
23
+ TensorStream.convert_to_tensor(self, name: name, dtype: dtype)
24
24
  end
25
25
 
26
26
  def +(other)
27
27
  if other.is_a?(TensorStream::Tensor)
28
- TensorStream.convert_to_tensor(self) + other
28
+ TensorStream.convert_to_tensor(self, dtype: other.data_type) + other
29
29
  else
30
30
  _tensor_stream_add_orig(other)
31
31
  end
@@ -33,7 +33,7 @@ module TensorStream
33
33
 
34
34
  def -(other)
35
35
  if other.is_a?(TensorStream::Tensor)
36
- TensorStream.convert_to_tensor(self) - other
36
+ TensorStream.convert_to_tensor(self, dtype: other.data_type) - other
37
37
  else
38
38
  _tensor_stream_sub_orig(other)
39
39
  end
@@ -41,7 +41,7 @@ module TensorStream
41
41
 
42
42
  def *(other)
43
43
  if other.is_a?(TensorStream::Tensor)
44
- TensorStream.convert_to_tensor(self) * other
44
+ TensorStream.convert_to_tensor(self, dtype: other.data_type) * other
45
45
  else
46
46
  _tensor_stream_mul_orig(other)
47
47
  end
@@ -49,7 +49,7 @@ module TensorStream
49
49
 
50
50
  def /(other)
51
51
  if other.is_a?(TensorStream::Tensor)
52
- TensorStream.convert_to_tensor(self) * other
52
+ TensorStream.convert_to_tensor(self, dtype: other.data_type) * other
53
53
  else
54
54
  _tensor_stream_div_orig(other)
55
55
  end
@@ -57,7 +57,7 @@ module TensorStream
57
57
 
58
58
  def %(other)
59
59
  if other.is_a?(TensorStream::Tensor)
60
- TensorStream.convert_to_tensor(self) % other
60
+ TensorStream.convert_to_tensor(self, dtype: other.data_type) % other
61
61
  else
62
62
  _tensor_stream_mod_orig(other)
63
63
  end
@@ -65,7 +65,7 @@ module TensorStream
65
65
 
66
66
  def **(other)
67
67
  if other.is_a?(TensorStream::Tensor)
68
- TensorStream.convert_to_tensor(self)**other
68
+ TensorStream.convert_to_tensor(self, dtype: other.data_type)**other
69
69
  else
70
70
  _tensor_stream_pow_orig(other)
71
71
  end
@@ -5,7 +5,7 @@ module TensorStream
5
5
 
6
6
  class << self
7
7
  def softmax(logits, axis: nil, name: nil)
8
- _op(:softmax, logits, nil, axis: axis, name: name)
8
+ _op(:softmax, logits, axis: axis, name: name)
9
9
  end
10
10
 
11
11
  def relu(features, name: nil)
@@ -2,33 +2,18 @@ require 'tensor_stream/helpers/infer_shape'
2
2
  module TensorStream
3
3
  # TensorStream class that defines an operation
4
4
  class Operation < Tensor
5
- attr_accessor :name, :operation, :inputs, :rank, :options
6
- attr_reader :outputs
5
+ include OpHelper
7
6
 
8
- def initialize(operation, *args)
9
- options = if args.last.is_a?(Hash)
10
- args.pop
11
- else
12
- {}
13
- end
14
-
15
- inputs = args
16
-
17
- setup_initial_state(options)
18
-
19
- @operation = operation
20
- @rank = options[:rank] || 0
21
- @name = [@graph.get_name_scope, options[:name] || set_name].compact.reject(&:empty?).join('/')
22
- @internal = options[:internal]
23
- @given_name = @name
7
+ attr_accessor :name, :operation, :inputs, :rank, :device, :consumers, :breakpoint
8
+ attr_reader :outputs, :options, :is_const, :data_type, :shape
24
9
 
10
+ def initialize(graph, inputs:, options:)
11
+ @consumers = Set.new
12
+ @outputs = []
13
+ @op = self
14
+ @graph = graph
15
+ @inputs = inputs
25
16
  @options = options
26
-
27
- @inputs = inputs.map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
28
- @data_type = set_data_type(options[:data_type])
29
- @is_const = infer_const
30
- @shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
31
- @graph.add_node(self)
32
17
  end
33
18
 
34
19
  def to_s
@@ -37,25 +22,74 @@ module TensorStream
37
22
 
38
23
  def to_h
39
24
  {
40
- op: operation,
41
- name: name,
42
- operands: hashify_tensor(inputs)
25
+ op: operation.to_s,
26
+ name: name.to_s,
27
+ data_type: @data_type,
28
+ inputs: @inputs.map(&:name),
29
+ attrs: serialize_options
43
30
  }
44
31
  end
45
32
 
33
+ def const_value
34
+ @options ? @options[:value] : nil
35
+ end
36
+
37
+ def container_buffer
38
+ @options[:container] ? @options[:container].buffer : nil
39
+ end
40
+
41
+ def container
42
+ @options[:container].read_value
43
+ end
44
+
45
+ def container=(value)
46
+ @options[:container].value = value
47
+ end
48
+
49
+ def set_input(index, value)
50
+ @inputs[index] = value
51
+ @shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
52
+ @rank = @shape.rank
53
+ @is_const = infer_const
54
+ @data_type = set_data_type(@options[:data_type])
55
+ end
56
+
46
57
  def infer_const
47
58
  return false if breakpoint
59
+
48
60
  case operation
49
61
  when :random_standard_normal, :random_uniform, :truncated_normal, :glorot_uniform, :print, :check_numerics
50
62
  false
63
+ when :const
64
+ true
65
+ when :placeholder
66
+ false
67
+ when :variable_v2
68
+ false
51
69
  else
52
70
  non_const = @inputs.compact.find { |input| !input.is_const }
53
71
  non_const ? false : true
54
72
  end
55
73
  end
56
74
 
75
+ def set_name
76
+ @operation.to_s
77
+ end
78
+
57
79
  def set_data_type(passed_data_type)
58
80
  case operation
81
+ when :where
82
+ @inputs[1].data_type
83
+ when :case
84
+ if @inputs[2]
85
+ @inputs[2].data_type
86
+ else
87
+ @inputs[1].data_type
88
+ end
89
+ when :case_grad
90
+ @inputs[2].data_type
91
+ when :placeholder, :variable_v2, :const
92
+ options[:data_type]
59
93
  when :fill
60
94
  @inputs[1].data_type
61
95
  when :greater, :less, :equal, :not_equal, :greater_equal, :less_equal, :logical_and
@@ -70,9 +104,8 @@ module TensorStream
70
104
  @inputs[1].data_type
71
105
  when :index
72
106
  if @inputs[0].is_a?(ControlFlow)
73
-
74
107
  if @inputs[1].is_const
75
- @inputs[0].inputs[@inputs[1].value].data_type
108
+ @inputs[0].inputs[@inputs[1].const_value].data_type
76
109
  else
77
110
  :unknown
78
111
  end
@@ -163,8 +196,6 @@ module TensorStream
163
196
  "reshape(#{sub_input},#{sub_input2})"
164
197
  when :rank
165
198
  "#{sub_input}.rank"
166
- when :cond
167
- "(#{auto_math(options[:pred], name_only, max_depth - 1, cur_depth)} ? #{sub_input} : #{sub_input2})"
168
199
  when :less
169
200
  "#{sub_input} < #{sub_input2}"
170
201
  when :less_equal
@@ -222,12 +253,41 @@ module TensorStream
222
253
 
223
254
  private
224
255
 
256
+ def serialize_options
257
+ excludes = %i[internal_name source]
258
+
259
+ @options.reject { |k, v| excludes.include?(k) || v.nil? }.map do |k, v|
260
+ v = case v.class.to_s
261
+ when 'TensorStream::TensorShape'
262
+ v.shape
263
+ when 'Array'
264
+ v
265
+ when 'String', 'Integer', 'Float', 'Symbol', 'FalseClass', "TrueClass"
266
+ v
267
+ when 'TensorStream::Variable'
268
+ { name: v.name, options: v.options, shape: v.shape.shape.dup }
269
+ else
270
+ raise "unknown type #{v.class}"
271
+ end
272
+ [k.to_sym, v]
273
+ end.to_h
274
+ end
275
+
276
+ def add_consumer(consumer)
277
+ @consumers << consumer.name if consumer.name != name
278
+ end
279
+
280
+ def setup_output(consumer)
281
+ @outputs << consumer.name unless @outputs.include?(consumer.name)
282
+ end
283
+
225
284
  def propagate_consumer(consumer)
226
- super
285
+ add_consumer(consumer)
227
286
  @inputs.compact.each do |input|
228
287
  if input.is_a?(Array)
229
- input.flatten.compact.select { |t| t.is_a?(Tensor) }.each do |t|
288
+ input.flatten.compact.map(&:op).select { |t| t.is_a?(Tensor) }.each do |t|
230
289
  next if t.consumers.include?(consumer.name)
290
+
231
291
  t.send(:propagate_consumer, consumer)
232
292
  end
233
293
  elsif input.name != name && !input.consumers.include?(consumer.name)
@@ -239,7 +299,7 @@ module TensorStream
239
299
  def propagate_outputs
240
300
  @inputs.compact.each do |input|
241
301
  if input.is_a?(Array)
242
- input.flatten.compact.each do |t|
302
+ input.flatten.compact.map(&:op).each do |t|
243
303
  t.send(:setup_output, self) if t.is_a?(Tensor)
244
304
  end
245
305
  elsif input.is_a?(Tensor) && (input.name != name)
@@ -248,8 +308,10 @@ module TensorStream
248
308
  end
249
309
  end
250
310
 
251
- def set_name
252
- "#{@operation}#{graph.get_operation_counter}:#{@rank}"
311
+ def setup_initial_state(options)
312
+ @outputs = []
313
+ @graph = options[:__graph] || TensorStream.get_default_graph
314
+ @source = format_source(caller_locations)
253
315
  end
254
316
  end
255
317
  end
@@ -58,7 +58,8 @@ module TensorStream
58
58
  # +wrt_xs+ : A Tensor or list of tensors to be used for differentiation.
59
59
  # +stop_gradients+ : Optional. A Tensor or list of tensors not to differentiate through
60
60
  def gradients(tensor_ys, wrt_xs, name: 'gradients', stop_gradients: nil)
61
- gs = wrt_xs.collect do |x|
61
+ tensor_ys = tensor_ys.op
62
+ gs = wrt_xs.map(&:op).collect do |x|
62
63
  stops = stop_gradients ? stop_gradients.map(&:name).join('_') : ''
63
64
  gradient_program_name = "grad_#{tensor_ys.name}_#{x.name}_#{stops}".to_sym
64
65
  tensor_graph = tensor_ys.graph
@@ -82,21 +83,21 @@ module TensorStream
82
83
  # Outputs random values from a uniform distribution.
83
84
  def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
84
85
  options = { dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name }
85
- _op(:random_uniform, shape, nil, options)
86
+ _op(:random_uniform, shape, options)
86
87
  end
87
88
 
88
89
  ##
89
90
  # Outputs random values from a normal distribution.
90
91
  def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
91
92
  options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
92
- _op(:random_standard_normal, shape, nil, options)
93
+ _op(:random_standard_normal, shape, options)
93
94
  end
94
95
 
95
96
  ##
96
97
  # Outputs random values from a truncated normal distribution.
97
98
  def truncated_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
98
99
  options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
99
- _op(:truncated_normal, shape, nil, options)
100
+ _op(:truncated_normal, shape, options)
100
101
  end
101
102
 
102
103
  ##
@@ -172,14 +173,14 @@ module TensorStream
172
173
  # initializer that generates tensors initialized to 0.
173
174
  #
174
175
  def zeros_initializer(dtype: :float32)
175
- TensorStream::Initializer.new(-> { _op(:zeros, nil, nil, data_type: dtype) })
176
+ TensorStream::Initializer.new(-> { _op(:zeros, data_type: dtype) })
176
177
  end
177
178
 
178
179
  ##
179
180
  # initializer that generates tensors initialized to 1.
180
181
  #
181
182
  def ones_initializer(dtype: :float32)
182
- TensorStream::Initializer.new(-> { _op(:ones, nil, nil, data_type: dtype) })
183
+ TensorStream::Initializer.new(-> { _op(:ones, data_type: dtype) })
183
184
  end
184
185
 
185
186
  ##
@@ -189,13 +190,13 @@ module TensorStream
189
190
  # where limit is sqrt(6 / (fan_in + fan_out)) where fan_in is the number
190
191
  # of input units in the weight tensor and fan_out is the number of output units in the weight tensor.
191
192
  def glorot_uniform_initializer(seed: nil, dtype: nil)
192
- TensorStream::Initializer.new(-> { _op(:glorot_uniform, nil, nil, seed: seed, data_type: dtype) })
193
+ TensorStream::Initializer.new(-> { _op(:glorot_uniform, seed: seed, data_type: dtype) })
193
194
  end
194
195
 
195
196
  ##
196
197
  # Initializer that generates tensors with a uniform distribution.
197
198
  def random_uniform_initializer(minval: 0, maxval: 1, seed: nil, dtype: nil)
198
- TensorStream::Initializer.new(-> { _op(:random_uniform, nil, nil, minval: 0, maxval: 1, seed: seed, data_type: dtype) })
199
+ TensorStream::Initializer.new(-> { _op(:random_uniform, minval: 0, maxval: 1, seed: seed, data_type: dtype) })
199
200
  end
200
201
 
201
202
  ##
@@ -276,7 +277,7 @@ module TensorStream
276
277
  ##
277
278
  # Computes the mean of elements across dimensions of a tensor.
278
279
  def reduce_mean(input_tensor, axis = nil, keepdims: false, name: nil)
279
- _op(:mean, input_tensor, axis, keepdims: keepdims, name: name)
280
+ reduce(:mean, input_tensor, axis, keepdims: keepdims, name: name)
280
281
  end
281
282
 
282
283
  ##
@@ -288,7 +289,7 @@ module TensorStream
288
289
  # If axis has no entries, all dimensions are reduced, and a tensor with a single element
289
290
  # is returned.
290
291
  def reduce_sum(input_tensor, axis = nil, keepdims: false, name: nil)
291
- _op(:sum, input_tensor, axis, keepdims: keepdims, name: name)
292
+ reduce(:sum, input_tensor, axis, keepdims: keepdims, name: name)
292
293
  end
293
294
 
294
295
  ##
@@ -300,7 +301,22 @@ module TensorStream
300
301
  #
301
302
  # If axis has no entries, all dimensions are reduced, and a tensor with a single element is returned.
302
303
  def reduce_prod(input, axis = nil, keepdims: false, name: nil)
303
- _op(:prod, input, axis, keepdims: keepdims, name: name)
304
+ reduce(:prod, input, axis, keepdims: keepdims, name: name)
305
+ end
306
+
307
+ def reduce(op, input, axis = nil, keepdims: false, name: nil)
308
+ input = TensorStream.convert_to_tensor(input)
309
+ axis = if !axis.nil?
310
+ axis
311
+ elsif input.shape.scalar?
312
+ op
313
+ elsif input.shape.known?
314
+ (0...input.shape.ndims).to_a
315
+ else
316
+ range(0, rank(input))
317
+ end
318
+
319
+ _op(op, input, axis, keepdims: keepdims, name: name)
304
320
  end
305
321
 
306
322
  ##
@@ -395,13 +411,13 @@ module TensorStream
395
411
  ##
396
412
  # Return true_fn() if the predicate pred is true else false_fn().
397
413
  def cond(pred, true_fn, false_fn, name: nil)
398
- _op(:cond, true_fn, false_fn, pred: pred, name: name)
414
+ _op(:case, [pred], false_fn, true_fn, name: name)
399
415
  end
400
416
 
401
417
  ##
402
418
  # Return the elements, either from x or y, depending on the condition.
403
419
  def where(condition, true_t = nil, false_t = nil, name: nil)
404
- _op(:where, true_t, false_t, pred: condition, name: name)
420
+ _op(:where, condition, true_t, false_t, name: name)
405
421
  end
406
422
 
407
423
  ##
@@ -486,6 +502,7 @@ module TensorStream
486
502
  def max(input_a, input_b, name: nil)
487
503
  check_allowed_types(input_a, NUMERIC_TYPES)
488
504
  check_allowed_types(input_b, NUMERIC_TYPES)
505
+
489
506
  input_a, input_b = check_data_types(input_a, input_b)
490
507
  _op(:max, input_a, input_b, name: name)
491
508
  end
@@ -652,6 +669,13 @@ module TensorStream
652
669
  _op(:tanh, input, name: name)
653
670
  end
654
671
 
672
+ ##
673
+ # Computes sec of input element-wise.
674
+ def sec(input, name: nil)
675
+ check_allowed_types(input, FLOATING_POINT_TYPES)
676
+ _op(:sec, input, name: name)
677
+ end
678
+
655
679
  ##
656
680
  # Computes sqrt of input element-wise.
657
681
  def sqrt(input, name: nil)
@@ -819,6 +843,34 @@ module TensorStream
819
843
  [result[0], result[1]]
820
844
  end
821
845
 
846
+ ##
847
+ # Create a case operation.
848
+ #
849
+ # The pred_fn_pairs parameter is a dict or list of pairs of size N.
850
+ # Each pair contains a boolean scalar tensor and a proc that creates the tensors to be returned if the boolean evaluates to true.
851
+ # default is a proc generating a list of tensors. All the proc in pred_fn_pairs as well as default (if provided) should return the
852
+ # same number and types of tensors.
853
+ #
854
+ def case(args = {})
855
+ args = args.dup
856
+ default = args.delete(:default)
857
+ exclusive = args.delete(:exclusive)
858
+ strict = args.delete(:strict)
859
+ name = args.delete(:name)
860
+
861
+ predicates = []
862
+ functions = []
863
+
864
+ args.each do |k, v|
865
+ raise "Invalid argment or option #{k}" unless k.is_a?(Tensor)
866
+
867
+ predicates << k
868
+ functions << (v.is_a?(Proc) ? v.call : v)
869
+ end
870
+
871
+ _op(:case, predicates, default, *functions, exclusive: exclusive, strict: strict, name: name)
872
+ end
873
+
822
874
  def cumprod(x, axis: 0, exclusive: false, reverse: false, name: nil)
823
875
  _op(:cumprod, x, axis: axis, exclusive: exclusive, reverse: reverse, name: name)
824
876
  end