tensor_stream 0.9.8 → 0.9.9

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +31 -14
  3. data/lib/tensor_stream.rb +4 -0
  4. data/lib/tensor_stream/constant.rb +41 -0
  5. data/lib/tensor_stream/control_flow.rb +2 -1
  6. data/lib/tensor_stream/dynamic_stitch.rb +3 -1
  7. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -4
  8. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +74 -23
  9. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +45 -43
  10. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +31 -30
  11. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +6 -6
  12. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +46 -111
  13. data/lib/tensor_stream/graph.rb +61 -12
  14. data/lib/tensor_stream/graph_builder.rb +3 -3
  15. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +38 -0
  16. data/lib/tensor_stream/graph_serializers/packer.rb +8 -0
  17. data/lib/tensor_stream/graph_serializers/pbtext.rb +62 -27
  18. data/lib/tensor_stream/graph_serializers/serializer.rb +2 -2
  19. data/lib/tensor_stream/graph_serializers/yaml.rb +27 -0
  20. data/lib/tensor_stream/helpers/infer_shape.rb +15 -9
  21. data/lib/tensor_stream/helpers/op_helper.rb +17 -6
  22. data/lib/tensor_stream/helpers/string_helper.rb +32 -1
  23. data/lib/tensor_stream/helpers/tensor_mixins.rb +135 -0
  24. data/lib/tensor_stream/math_gradients.rb +19 -12
  25. data/lib/tensor_stream/monkey_patches/float.rb +7 -0
  26. data/lib/tensor_stream/monkey_patches/integer.rb +7 -0
  27. data/lib/tensor_stream/monkey_patches/patch.rb +8 -8
  28. data/lib/tensor_stream/nn/nn_ops.rb +1 -1
  29. data/lib/tensor_stream/operation.rb +98 -36
  30. data/lib/tensor_stream/ops.rb +65 -13
  31. data/lib/tensor_stream/placeholder.rb +2 -2
  32. data/lib/tensor_stream/session.rb +15 -3
  33. data/lib/tensor_stream/tensor.rb +15 -172
  34. data/lib/tensor_stream/tensor_shape.rb +3 -1
  35. data/lib/tensor_stream/train/saver.rb +12 -10
  36. data/lib/tensor_stream/trainer.rb +7 -2
  37. data/lib/tensor_stream/utils.rb +13 -11
  38. data/lib/tensor_stream/utils/freezer.rb +37 -0
  39. data/lib/tensor_stream/variable.rb +17 -11
  40. data/lib/tensor_stream/variable_scope.rb +3 -1
  41. data/lib/tensor_stream/version.rb +1 -1
  42. data/samples/iris.rb +3 -4
  43. data/samples/linear_regression.rb +9 -5
  44. data/samples/logistic_regression.rb +11 -9
  45. data/samples/mnist_data.rb +8 -10
  46. metadata +8 -4
@@ -40,7 +40,10 @@ module TensorStream
40
40
  return nil if grads.empty?
41
41
  grads.size > 1 ? ts.add_n(grads) : grads[0]
42
42
  else
43
- return nil if computed_op.nil?
43
+
44
+ if computed_op.nil?
45
+ return nil
46
+ end
44
47
  _propagate(computed_op, tensor.inputs[0], stop_tensor, nodes_to_compute, stop_gradients)
45
48
  end
46
49
  end
@@ -50,6 +53,7 @@ module TensorStream
50
53
  node.graph.name_scope("#{node.name}_grad") do
51
54
  x = node.inputs[0] if node.inputs[0]
52
55
  y = node.inputs[1] if node.inputs[1]
56
+ z = node.inputs[2] if node.inputs[2]
53
57
 
54
58
  case node.operation
55
59
  when :add_n
@@ -135,7 +139,6 @@ module TensorStream
135
139
  reduced = ts.cast(reduction_indices, :int32)
136
140
  idx = ts.range(0, rank)
137
141
  other, = ts.setdiff1d(idx, reduced)
138
-
139
142
  [ts.concat([reduced, other], 0),
140
143
  ts.reduce_prod(ts.gather(input_shape, reduced)),
141
144
  ts.reduce_prod(ts.gather(input_shape, other))]
@@ -239,13 +242,17 @@ module TensorStream
239
242
  y = ts.constant(2.0, dtype: x.dtype)
240
243
  ts.multiply(grad, ts.multiply(x, y))
241
244
  when :where
242
- x_mask = i_op(:where, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
243
- y_mask = i_op(:where, i_op(:zeros_like, x), i_op(:ones_like, y), pred: node.options[:pred])
244
- [x_mask * grad, y_mask * grad]
245
- when :cond
246
- x_cond = i_op(:cond, i_op(:ones_like, x), i_op(:zeros_like, y), pred: node.options[:pred])
247
- y_cond = i_op(:cond, i_op(:zeros_like, x), i_op(:ones_like, x), pred: node.options[:pred])
248
- [x_cond * grad, y_cond * grad]
245
+ x_mask = i_op(:where, x, i_op(:ones_like, y), i_op(:zeros_like, z))
246
+ y_mask = i_op(:where, x, i_op(:zeros_like, y), i_op(:ones_like, z))
247
+ [nil, x_mask * grad, y_mask * grad]
248
+ when :case
249
+ n_preds = node.inputs.size - 2
250
+
251
+ case_grads = Array.new(n_preds) do |index|
252
+ i_op(:case_grad, index, node.inputs[0], node.inputs[2 + index], grad)
253
+ end
254
+
255
+ [nil, i_op(:case_grad, -1, node.inputs[0], node.inputs[1], grad)] + case_grads
249
256
  when :mean
250
257
  sum_grad = _sum_grad(x, y, grad)[0]
251
258
  input_shape = ts.shape(x)
@@ -283,12 +290,12 @@ module TensorStream
283
290
  # hack!! not sure how to fix this yet
284
291
  return grad if %i[softmax_cross_entropy_with_logits_v2 sparse_softmax_cross_entropy_with_logits].include?(node.inputs[0].operation)
285
292
 
286
- if node.inputs[0].shape.known? && node.inputs[1].value
293
+ if node.inputs[0].shape.known? && node.inputs[1].const_value
287
294
  multiplier = node.inputs[0].shape.shape[0]
288
295
  filler = ts.zeros_like(grad)
289
296
 
290
297
  res = Array.new(multiplier) do |index|
291
- index == node.inputs[1].value ? grad : filler
298
+ index == node.inputs[1].const_value ? grad : filler
292
299
  end
293
300
  [res]
294
301
  end
@@ -346,7 +353,7 @@ module TensorStream
346
353
  tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
347
354
  new_grad = _op(:reshape, grad, output_shape_kept_dims)
348
355
 
349
- grad = _op(:cond, _op(:fill, input_shape, grad), _op(:tile, new_grad, tile_scaling), pred: _op(:rank, grad).zero?)
356
+ grad = _op(:case, [_op(:rank, grad).zero?], _op(:tile, new_grad, tile_scaling), _op(:fill, input_shape, grad))
350
357
 
351
358
  [grad, nil]
352
359
  end
@@ -1,3 +1,10 @@
1
1
  class Float
2
2
  include TensorStream::MonkeyPatch
3
+
4
+ def self.placeholder(name: nil, width: 32, shape: nil)
5
+ raise "invalid width passed #{width}" unless [16, 32, 64].include?(width)
6
+
7
+ data_type = :"float#{width}"
8
+ TensorStream.placeholder(data_type, name: name, shape: shape)
9
+ end
3
10
  end
@@ -1,3 +1,10 @@
1
1
  class Integer
2
2
  include TensorStream::MonkeyPatch
3
+
4
+ def self.placeholder(name: nil, width: 32, shape: nil)
5
+ raise "invalid width passed #{width}" unless [16, 32, 64].include?(width)
6
+
7
+ data_type = :"int#{width}"
8
+ TensorStream.placeholder(data_type, name: name, shape: shape)
9
+ end
3
10
  end
@@ -19,13 +19,13 @@ module TensorStream
19
19
  TensorStream.shape_eval(self)
20
20
  end
21
21
 
22
- def t(name = nil)
23
- TensorStream.convert_to_tensor(self, name: name)
22
+ def t(name = nil, dtype: nil)
23
+ TensorStream.convert_to_tensor(self, name: name, dtype: dtype)
24
24
  end
25
25
 
26
26
  def +(other)
27
27
  if other.is_a?(TensorStream::Tensor)
28
- TensorStream.convert_to_tensor(self) + other
28
+ TensorStream.convert_to_tensor(self, dtype: other.data_type) + other
29
29
  else
30
30
  _tensor_stream_add_orig(other)
31
31
  end
@@ -33,7 +33,7 @@ module TensorStream
33
33
 
34
34
  def -(other)
35
35
  if other.is_a?(TensorStream::Tensor)
36
- TensorStream.convert_to_tensor(self) - other
36
+ TensorStream.convert_to_tensor(self, dtype: other.data_type) - other
37
37
  else
38
38
  _tensor_stream_sub_orig(other)
39
39
  end
@@ -41,7 +41,7 @@ module TensorStream
41
41
 
42
42
  def *(other)
43
43
  if other.is_a?(TensorStream::Tensor)
44
- TensorStream.convert_to_tensor(self) * other
44
+ TensorStream.convert_to_tensor(self, dtype: other.data_type) * other
45
45
  else
46
46
  _tensor_stream_mul_orig(other)
47
47
  end
@@ -49,7 +49,7 @@ module TensorStream
49
49
 
50
50
  def /(other)
51
51
  if other.is_a?(TensorStream::Tensor)
52
- TensorStream.convert_to_tensor(self) * other
52
+ TensorStream.convert_to_tensor(self, dtype: other.data_type) * other
53
53
  else
54
54
  _tensor_stream_div_orig(other)
55
55
  end
@@ -57,7 +57,7 @@ module TensorStream
57
57
 
58
58
  def %(other)
59
59
  if other.is_a?(TensorStream::Tensor)
60
- TensorStream.convert_to_tensor(self) % other
60
+ TensorStream.convert_to_tensor(self, dtype: other.data_type) % other
61
61
  else
62
62
  _tensor_stream_mod_orig(other)
63
63
  end
@@ -65,7 +65,7 @@ module TensorStream
65
65
 
66
66
  def **(other)
67
67
  if other.is_a?(TensorStream::Tensor)
68
- TensorStream.convert_to_tensor(self)**other
68
+ TensorStream.convert_to_tensor(self, dtype: other.data_type)**other
69
69
  else
70
70
  _tensor_stream_pow_orig(other)
71
71
  end
@@ -5,7 +5,7 @@ module TensorStream
5
5
 
6
6
  class << self
7
7
  def softmax(logits, axis: nil, name: nil)
8
- _op(:softmax, logits, nil, axis: axis, name: name)
8
+ _op(:softmax, logits, axis: axis, name: name)
9
9
  end
10
10
 
11
11
  def relu(features, name: nil)
@@ -2,33 +2,18 @@ require 'tensor_stream/helpers/infer_shape'
2
2
  module TensorStream
3
3
  # TensorStream class that defines an operation
4
4
  class Operation < Tensor
5
- attr_accessor :name, :operation, :inputs, :rank, :options
6
- attr_reader :outputs
5
+ include OpHelper
7
6
 
8
- def initialize(operation, *args)
9
- options = if args.last.is_a?(Hash)
10
- args.pop
11
- else
12
- {}
13
- end
14
-
15
- inputs = args
16
-
17
- setup_initial_state(options)
18
-
19
- @operation = operation
20
- @rank = options[:rank] || 0
21
- @name = [@graph.get_name_scope, options[:name] || set_name].compact.reject(&:empty?).join('/')
22
- @internal = options[:internal]
23
- @given_name = @name
7
+ attr_accessor :name, :operation, :inputs, :rank, :device, :consumers, :breakpoint
8
+ attr_reader :outputs, :options, :is_const, :data_type, :shape
24
9
 
10
+ def initialize(graph, inputs:, options:)
11
+ @consumers = Set.new
12
+ @outputs = []
13
+ @op = self
14
+ @graph = graph
15
+ @inputs = inputs
25
16
  @options = options
26
-
27
- @inputs = inputs.map { |i| options[:preserve_params_type] ? i : TensorStream.convert_to_tensor(i) }
28
- @data_type = set_data_type(options[:data_type])
29
- @is_const = infer_const
30
- @shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
31
- @graph.add_node(self)
32
17
  end
33
18
 
34
19
  def to_s
@@ -37,25 +22,74 @@ module TensorStream
37
22
 
38
23
  def to_h
39
24
  {
40
- op: operation,
41
- name: name,
42
- operands: hashify_tensor(inputs)
25
+ op: operation.to_s,
26
+ name: name.to_s,
27
+ data_type: @data_type,
28
+ inputs: @inputs.map(&:name),
29
+ attrs: serialize_options
43
30
  }
44
31
  end
45
32
 
33
+ def const_value
34
+ @options ? @options[:value] : nil
35
+ end
36
+
37
+ def container_buffer
38
+ @options[:container] ? @options[:container].buffer : nil
39
+ end
40
+
41
+ def container
42
+ @options[:container].read_value
43
+ end
44
+
45
+ def container=(value)
46
+ @options[:container].value = value
47
+ end
48
+
49
+ def set_input(index, value)
50
+ @inputs[index] = value
51
+ @shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
52
+ @rank = @shape.rank
53
+ @is_const = infer_const
54
+ @data_type = set_data_type(@options[:data_type])
55
+ end
56
+
46
57
  def infer_const
47
58
  return false if breakpoint
59
+
48
60
  case operation
49
61
  when :random_standard_normal, :random_uniform, :truncated_normal, :glorot_uniform, :print, :check_numerics
50
62
  false
63
+ when :const
64
+ true
65
+ when :placeholder
66
+ false
67
+ when :variable_v2
68
+ false
51
69
  else
52
70
  non_const = @inputs.compact.find { |input| !input.is_const }
53
71
  non_const ? false : true
54
72
  end
55
73
  end
56
74
 
75
+ def set_name
76
+ @operation.to_s
77
+ end
78
+
57
79
  def set_data_type(passed_data_type)
58
80
  case operation
81
+ when :where
82
+ @inputs[1].data_type
83
+ when :case
84
+ if @inputs[2]
85
+ @inputs[2].data_type
86
+ else
87
+ @inputs[1].data_type
88
+ end
89
+ when :case_grad
90
+ @inputs[2].data_type
91
+ when :placeholder, :variable_v2, :const
92
+ options[:data_type]
59
93
  when :fill
60
94
  @inputs[1].data_type
61
95
  when :greater, :less, :equal, :not_equal, :greater_equal, :less_equal, :logical_and
@@ -70,9 +104,8 @@ module TensorStream
70
104
  @inputs[1].data_type
71
105
  when :index
72
106
  if @inputs[0].is_a?(ControlFlow)
73
-
74
107
  if @inputs[1].is_const
75
- @inputs[0].inputs[@inputs[1].value].data_type
108
+ @inputs[0].inputs[@inputs[1].const_value].data_type
76
109
  else
77
110
  :unknown
78
111
  end
@@ -163,8 +196,6 @@ module TensorStream
163
196
  "reshape(#{sub_input},#{sub_input2})"
164
197
  when :rank
165
198
  "#{sub_input}.rank"
166
- when :cond
167
- "(#{auto_math(options[:pred], name_only, max_depth - 1, cur_depth)} ? #{sub_input} : #{sub_input2})"
168
199
  when :less
169
200
  "#{sub_input} < #{sub_input2}"
170
201
  when :less_equal
@@ -222,12 +253,41 @@ module TensorStream
222
253
 
223
254
  private
224
255
 
256
+ def serialize_options
257
+ excludes = %i[internal_name source]
258
+
259
+ @options.reject { |k, v| excludes.include?(k) || v.nil? }.map do |k, v|
260
+ v = case v.class.to_s
261
+ when 'TensorStream::TensorShape'
262
+ v.shape
263
+ when 'Array'
264
+ v
265
+ when 'String', 'Integer', 'Float', 'Symbol', 'FalseClass', "TrueClass"
266
+ v
267
+ when 'TensorStream::Variable'
268
+ { name: v.name, options: v.options, shape: v.shape.shape.dup }
269
+ else
270
+ raise "unknown type #{v.class}"
271
+ end
272
+ [k.to_sym, v]
273
+ end.to_h
274
+ end
275
+
276
+ def add_consumer(consumer)
277
+ @consumers << consumer.name if consumer.name != name
278
+ end
279
+
280
+ def setup_output(consumer)
281
+ @outputs << consumer.name unless @outputs.include?(consumer.name)
282
+ end
283
+
225
284
  def propagate_consumer(consumer)
226
- super
285
+ add_consumer(consumer)
227
286
  @inputs.compact.each do |input|
228
287
  if input.is_a?(Array)
229
- input.flatten.compact.select { |t| t.is_a?(Tensor) }.each do |t|
288
+ input.flatten.compact.map(&:op).select { |t| t.is_a?(Tensor) }.each do |t|
230
289
  next if t.consumers.include?(consumer.name)
290
+
231
291
  t.send(:propagate_consumer, consumer)
232
292
  end
233
293
  elsif input.name != name && !input.consumers.include?(consumer.name)
@@ -239,7 +299,7 @@ module TensorStream
239
299
  def propagate_outputs
240
300
  @inputs.compact.each do |input|
241
301
  if input.is_a?(Array)
242
- input.flatten.compact.each do |t|
302
+ input.flatten.compact.map(&:op).each do |t|
243
303
  t.send(:setup_output, self) if t.is_a?(Tensor)
244
304
  end
245
305
  elsif input.is_a?(Tensor) && (input.name != name)
@@ -248,8 +308,10 @@ module TensorStream
248
308
  end
249
309
  end
250
310
 
251
- def set_name
252
- "#{@operation}#{graph.get_operation_counter}:#{@rank}"
311
+ def setup_initial_state(options)
312
+ @outputs = []
313
+ @graph = options[:__graph] || TensorStream.get_default_graph
314
+ @source = format_source(caller_locations)
253
315
  end
254
316
  end
255
317
  end
@@ -58,7 +58,8 @@ module TensorStream
58
58
  # +wrt_xs+ : A Tensor or list of tensors to be used for differentiation.
59
59
  # +stop_gradients+ : Optional. A Tensor or list of tensors not to differentiate through
60
60
  def gradients(tensor_ys, wrt_xs, name: 'gradients', stop_gradients: nil)
61
- gs = wrt_xs.collect do |x|
61
+ tensor_ys = tensor_ys.op
62
+ gs = wrt_xs.map(&:op).collect do |x|
62
63
  stops = stop_gradients ? stop_gradients.map(&:name).join('_') : ''
63
64
  gradient_program_name = "grad_#{tensor_ys.name}_#{x.name}_#{stops}".to_sym
64
65
  tensor_graph = tensor_ys.graph
@@ -82,21 +83,21 @@ module TensorStream
82
83
  # Outputs random values from a uniform distribution.
83
84
  def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
84
85
  options = { dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name }
85
- _op(:random_uniform, shape, nil, options)
86
+ _op(:random_uniform, shape, options)
86
87
  end
87
88
 
88
89
  ##
89
90
  # Outputs random values from a normal distribution.
90
91
  def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
91
92
  options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
92
- _op(:random_standard_normal, shape, nil, options)
93
+ _op(:random_standard_normal, shape, options)
93
94
  end
94
95
 
95
96
  ##
96
97
  # Outputs random values from a truncated normal distribution.
97
98
  def truncated_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
98
99
  options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
99
- _op(:truncated_normal, shape, nil, options)
100
+ _op(:truncated_normal, shape, options)
100
101
  end
101
102
 
102
103
  ##
@@ -172,14 +173,14 @@ module TensorStream
172
173
  # initializer that generates tensors initialized to 0.
173
174
  #
174
175
  def zeros_initializer(dtype: :float32)
175
- TensorStream::Initializer.new(-> { _op(:zeros, nil, nil, data_type: dtype) })
176
+ TensorStream::Initializer.new(-> { _op(:zeros, data_type: dtype) })
176
177
  end
177
178
 
178
179
  ##
179
180
  # initializer that generates tensors initialized to 1.
180
181
  #
181
182
  def ones_initializer(dtype: :float32)
182
- TensorStream::Initializer.new(-> { _op(:ones, nil, nil, data_type: dtype) })
183
+ TensorStream::Initializer.new(-> { _op(:ones, data_type: dtype) })
183
184
  end
184
185
 
185
186
  ##
@@ -189,13 +190,13 @@ module TensorStream
189
190
  # where limit is sqrt(6 / (fan_in + fan_out)) where fan_in is the number
190
191
  # of input units in the weight tensor and fan_out is the number of output units in the weight tensor.
191
192
  def glorot_uniform_initializer(seed: nil, dtype: nil)
192
- TensorStream::Initializer.new(-> { _op(:glorot_uniform, nil, nil, seed: seed, data_type: dtype) })
193
+ TensorStream::Initializer.new(-> { _op(:glorot_uniform, seed: seed, data_type: dtype) })
193
194
  end
194
195
 
195
196
  ##
196
197
  # Initializer that generates tensors with a uniform distribution.
197
198
  def random_uniform_initializer(minval: 0, maxval: 1, seed: nil, dtype: nil)
198
- TensorStream::Initializer.new(-> { _op(:random_uniform, nil, nil, minval: 0, maxval: 1, seed: seed, data_type: dtype) })
199
+ TensorStream::Initializer.new(-> { _op(:random_uniform, minval: 0, maxval: 1, seed: seed, data_type: dtype) })
199
200
  end
200
201
 
201
202
  ##
@@ -276,7 +277,7 @@ module TensorStream
276
277
  ##
277
278
  # Computes the mean of elements across dimensions of a tensor.
278
279
  def reduce_mean(input_tensor, axis = nil, keepdims: false, name: nil)
279
- _op(:mean, input_tensor, axis, keepdims: keepdims, name: name)
280
+ reduce(:mean, input_tensor, axis, keepdims: keepdims, name: name)
280
281
  end
281
282
 
282
283
  ##
@@ -288,7 +289,7 @@ module TensorStream
288
289
  # If axis has no entries, all dimensions are reduced, and a tensor with a single element
289
290
  # is returned.
290
291
  def reduce_sum(input_tensor, axis = nil, keepdims: false, name: nil)
291
- _op(:sum, input_tensor, axis, keepdims: keepdims, name: name)
292
+ reduce(:sum, input_tensor, axis, keepdims: keepdims, name: name)
292
293
  end
293
294
 
294
295
  ##
@@ -300,7 +301,22 @@ module TensorStream
300
301
  #
301
302
  # If axis has no entries, all dimensions are reduced, and a tensor with a single element is returned.
302
303
  def reduce_prod(input, axis = nil, keepdims: false, name: nil)
303
- _op(:prod, input, axis, keepdims: keepdims, name: name)
304
+ reduce(:prod, input, axis, keepdims: keepdims, name: name)
305
+ end
306
+
307
+ def reduce(op, input, axis = nil, keepdims: false, name: nil)
308
+ input = TensorStream.convert_to_tensor(input)
309
+ axis = if !axis.nil?
310
+ axis
311
+ elsif input.shape.scalar?
312
+ op
313
+ elsif input.shape.known?
314
+ (0...input.shape.ndims).to_a
315
+ else
316
+ range(0, rank(input))
317
+ end
318
+
319
+ _op(op, input, axis, keepdims: keepdims, name: name)
304
320
  end
305
321
 
306
322
  ##
@@ -395,13 +411,13 @@ module TensorStream
395
411
  ##
396
412
  # Return true_fn() if the predicate pred is true else false_fn().
397
413
  def cond(pred, true_fn, false_fn, name: nil)
398
- _op(:cond, true_fn, false_fn, pred: pred, name: name)
414
+ _op(:case, [pred], false_fn, true_fn, name: name)
399
415
  end
400
416
 
401
417
  ##
402
418
  # Return the elements, either from x or y, depending on the condition.
403
419
  def where(condition, true_t = nil, false_t = nil, name: nil)
404
- _op(:where, true_t, false_t, pred: condition, name: name)
420
+ _op(:where, condition, true_t, false_t, name: name)
405
421
  end
406
422
 
407
423
  ##
@@ -486,6 +502,7 @@ module TensorStream
486
502
  def max(input_a, input_b, name: nil)
487
503
  check_allowed_types(input_a, NUMERIC_TYPES)
488
504
  check_allowed_types(input_b, NUMERIC_TYPES)
505
+
489
506
  input_a, input_b = check_data_types(input_a, input_b)
490
507
  _op(:max, input_a, input_b, name: name)
491
508
  end
@@ -652,6 +669,13 @@ module TensorStream
652
669
  _op(:tanh, input, name: name)
653
670
  end
654
671
 
672
+ ##
673
+ # Computes sec of input element-wise.
674
+ def sec(input, name: nil)
675
+ check_allowed_types(input, FLOATING_POINT_TYPES)
676
+ _op(:sec, input, name: name)
677
+ end
678
+
655
679
  ##
656
680
  # Computes sqrt of input element-wise.
657
681
  def sqrt(input, name: nil)
@@ -819,6 +843,34 @@ module TensorStream
819
843
  [result[0], result[1]]
820
844
  end
821
845
 
846
+ ##
847
+ # Create a case operation.
848
+ #
849
+ # The pred_fn_pairs parameter is a dict or list of pairs of size N.
850
+ # Each pair contains a boolean scalar tensor and a proc that creates the tensors to be returned if the boolean evaluates to true.
851
+ # default is a proc generating a list of tensors. All the proc in pred_fn_pairs as well as default (if provided) should return the
852
+ # same number and types of tensors.
853
+ #
854
+ def case(args = {})
855
+ args = args.dup
856
+ default = args.delete(:default)
857
+ exclusive = args.delete(:exclusive)
858
+ strict = args.delete(:strict)
859
+ name = args.delete(:name)
860
+
861
+ predicates = []
862
+ functions = []
863
+
864
+ args.each do |k, v|
865
+ raise "Invalid argment or option #{k}" unless k.is_a?(Tensor)
866
+
867
+ predicates << k
868
+ functions << (v.is_a?(Proc) ? v.call : v)
869
+ end
870
+
871
+ _op(:case, predicates, default, *functions, exclusive: exclusive, strict: strict, name: name)
872
+ end
873
+
822
874
  def cumprod(x, axis: 0, exclusive: false, reverse: false, name: nil)
823
875
  _op(:cumprod, x, axis: axis, exclusive: exclusive, reverse: reverse, name: name)
824
876
  end