tensor_stream 0.9.8 → 0.9.9

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +31 -14
  3. data/lib/tensor_stream.rb +4 -0
  4. data/lib/tensor_stream/constant.rb +41 -0
  5. data/lib/tensor_stream/control_flow.rb +2 -1
  6. data/lib/tensor_stream/dynamic_stitch.rb +3 -1
  7. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -4
  8. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +74 -23
  9. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +45 -43
  10. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +31 -30
  11. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +6 -6
  12. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +46 -111
  13. data/lib/tensor_stream/graph.rb +61 -12
  14. data/lib/tensor_stream/graph_builder.rb +3 -3
  15. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +38 -0
  16. data/lib/tensor_stream/graph_serializers/packer.rb +8 -0
  17. data/lib/tensor_stream/graph_serializers/pbtext.rb +62 -27
  18. data/lib/tensor_stream/graph_serializers/serializer.rb +2 -2
  19. data/lib/tensor_stream/graph_serializers/yaml.rb +27 -0
  20. data/lib/tensor_stream/helpers/infer_shape.rb +15 -9
  21. data/lib/tensor_stream/helpers/op_helper.rb +17 -6
  22. data/lib/tensor_stream/helpers/string_helper.rb +32 -1
  23. data/lib/tensor_stream/helpers/tensor_mixins.rb +135 -0
  24. data/lib/tensor_stream/math_gradients.rb +19 -12
  25. data/lib/tensor_stream/monkey_patches/float.rb +7 -0
  26. data/lib/tensor_stream/monkey_patches/integer.rb +7 -0
  27. data/lib/tensor_stream/monkey_patches/patch.rb +8 -8
  28. data/lib/tensor_stream/nn/nn_ops.rb +1 -1
  29. data/lib/tensor_stream/operation.rb +98 -36
  30. data/lib/tensor_stream/ops.rb +65 -13
  31. data/lib/tensor_stream/placeholder.rb +2 -2
  32. data/lib/tensor_stream/session.rb +15 -3
  33. data/lib/tensor_stream/tensor.rb +15 -172
  34. data/lib/tensor_stream/tensor_shape.rb +3 -1
  35. data/lib/tensor_stream/train/saver.rb +12 -10
  36. data/lib/tensor_stream/trainer.rb +7 -2
  37. data/lib/tensor_stream/utils.rb +13 -11
  38. data/lib/tensor_stream/utils/freezer.rb +37 -0
  39. data/lib/tensor_stream/variable.rb +17 -11
  40. data/lib/tensor_stream/variable_scope.rb +3 -1
  41. data/lib/tensor_stream/version.rb +1 -1
  42. data/samples/iris.rb +3 -4
  43. data/samples/linear_regression.rb +9 -5
  44. data/samples/logistic_regression.rb +11 -9
  45. data/samples/mnist_data.rb +8 -10
  46. metadata +8 -4
@@ -27,11 +27,11 @@ module TensorStream
27
27
  rank = dimension.size
28
28
  options[:value] = value
29
29
  options[:const] = true
30
- TensorStream::Tensor.new(options[:dtype] || options[:T], rank, dimension, options)
30
+ TensorStream::Constant.new(options[:dtype] || options[:T], rank, dimension, options)
31
31
  when 'VariableV2'
32
32
  # evaluate options
33
33
  shape = options[:shape]
34
- TensorStream::Variable.new(options[:dtype] || options[:T], nil, shape, nil, options)
34
+ i_var(options[:dtype] || options[:T], nil, shape, nil, options)
35
35
  when 'Placeholder'
36
36
  shape = options[:shape]
37
37
  TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
@@ -57,7 +57,7 @@ module TensorStream
57
57
  end
58
58
 
59
59
  options[:data_type] = options.delete(:T)
60
- TensorStream::Operation.new(op, *inputs, options)
60
+ Graph.get_default_graph.add_op!(op, *inputs, options)
61
61
  end
62
62
  end
63
63
 
@@ -0,0 +1,38 @@
1
+ module TensorStream
2
+ class YamlLoader
3
+ def initialize(graph = nil)
4
+ @graph = graph || TensorStream.get_default_graph
5
+ end
6
+
7
+ def load_from_string(buffer)
8
+ serialized_ops = YAML.safe_load(buffer, [Symbol])
9
+ serialized_ops.each do |op_def|
10
+ inputs = op_def[:inputs].map { |i| @graph.get_tensor_by_name(i) }
11
+ options = {}
12
+
13
+ if op_def.dig(:attrs, :container)
14
+ new_var = Variable.new(op_def.dig(:attrs, :data_type))
15
+ var_shape = op_def.dig(:attrs, :container, :shape)
16
+ var_options = op_def.dig(:attrs, :container, :options)
17
+ var_options[:name] = op_def[:name]
18
+
19
+ new_var.prepare(var_shape.size, var_shape, TensorStream.get_variable_scope, var_options)
20
+ options[:container] = new_var
21
+
22
+ @graph.add_variable(new_var, var_options)
23
+ end
24
+
25
+ new_op = Operation.new(@graph, inputs: inputs, options: op_def[:attrs].merge(options))
26
+ new_op.operation = op_def[:op].to_sym
27
+ new_op.name = op_def[:name]
28
+ new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
29
+ new_op.rank = new_op.shape.rank
30
+ new_op.data_type = new_op.set_data_type(op_def.dig(:attrs, :data_type))
31
+ new_op.is_const = new_op.infer_const
32
+ new_op.given_name = new_op.name
33
+
34
+ @graph.add_node(new_op)
35
+ end
36
+ end
37
+ end
38
+ end
@@ -22,6 +22,14 @@ module TensorStream
22
22
  value.pack('C*')
23
23
  when :boolean
24
24
  value.map { |v| v ? 1 : 0 }.pack('C*')
25
+ when :string
26
+ if value.is_a?(Array)
27
+ value.to_yaml
28
+ else
29
+ value
30
+ end
31
+ else
32
+ raise "unknown type #{data_type}"
25
33
  end
26
34
 
27
35
  byte_value
@@ -4,11 +4,19 @@ module TensorStream
4
4
  include TensorStream::StringHelper
5
5
  include TensorStream::OpHelper
6
6
 
7
- def get_string(tensor_or_graph, session = nil)
7
+ def get_string(tensor_or_graph, session = nil, graph_keys = nil)
8
8
  graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
9
9
  @lines = []
10
- graph.node_keys.each do |k|
11
- node = graph.get_tensor_by_name(k)
10
+
11
+ node_keys = graph_keys.nil? ? graph.node_keys : graph.node_keys.select { |k| graph_keys.include?(k) }
12
+
13
+ node_keys.each do |k|
14
+ node = if block_given?
15
+ yield graph, k
16
+ else
17
+ graph.get_tensor_by_name(k)
18
+ end
19
+
12
20
  @lines << "node {"
13
21
  @lines << " name: #{node.name.to_json}"
14
22
  if node.is_a?(TensorStream::Operation)
@@ -20,16 +28,13 @@ module TensorStream
20
28
  end
21
29
  # type
22
30
  pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
23
- process_options(node)
24
- elsif node.is_a?(TensorStream::Tensor) && node.is_const
25
- @lines << " op: \"Const\""
26
- # type
27
- pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
28
- pb_attr('value', tensor_value(node))
29
- elsif node.is_a?(TensorStream::Variable)
30
- @lines << " op: \"VariableV2\""
31
- pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
32
- pb_attr('shape', shape_buf(node, 'shape'))
31
+
32
+ case node.operation.to_s
33
+ when 'const'
34
+ pb_attr('value', tensor_value(node))
35
+ when 'variable_v2'
36
+ pb_attr('shape', shape_buf(node, 'shape'))
37
+ end
33
38
  process_options(node)
34
39
  end
35
40
  @lines << "}"
@@ -44,23 +49,53 @@ module TensorStream
44
49
 
45
50
  def process_options(node)
46
51
  return if node.options.nil?
47
- node.options.each do |k, v|
48
- next if %w[name].include?(k.to_s) || k.to_s.start_with?('__')
52
+ node.options.reject { |_k, v| v.nil? }.each do |k, v|
53
+ next if %w[name internal_name data_type].include?(k.to_s) || k.to_s.start_with?('__')
49
54
  @lines << " attr {"
50
55
  @lines << " key: \"#{k}\""
51
56
  @lines << " value {"
52
- if v.is_a?(TrueClass) || v.is_a?(FalseClass)
53
- @lines << " b: #{v}"
54
- elsif v.is_a?(Integer)
55
- @lines << " int_val: #{v}"
56
- elsif v.is_a?(Float)
57
- @lines << " float_val: #{v}"
58
- end
57
+ attr_value(v, 6)
59
58
  @lines << " }"
60
59
  @lines << " }"
61
60
  end
62
61
  end
63
62
 
63
+ def attr_value(val, indent = 0)
64
+ spaces = " " * indent
65
+ case val.class.to_s
66
+ when 'TrueClass', 'FalseClass'
67
+ @lines << "#{spaces}b: #{val}"
68
+ when 'Integer'
69
+ @lines << "#{spaces}i: #{val}"
70
+ when 'String',
71
+ @lines << "#{spaces}s: #{val}"
72
+ when 'Float'
73
+ @lines << "#{spaces}f: #{val}"
74
+ when 'Symbol'
75
+ @lines << "#{spaces}sym: #{val}"
76
+ when 'Array'
77
+ @lines << "#{spaces}list {"
78
+ val.each do |v_item|
79
+ attr_value(v_item, indent + 2)
80
+ end
81
+ @lines << "#{spaces}}"
82
+ when 'TensorStream::TensorShape'
83
+ @lines << "#{spaces}shape {"
84
+ if val.shape
85
+ val.shape.each do |dim|
86
+ @lines << "#{spaces} dim {"
87
+ @lines << "#{spaces} size: #{dim}"
88
+ @lines << "#{spaces} }"
89
+ end
90
+ end
91
+ @lines << "#{spaces}}"
92
+ when 'TensorStream::Variable'
93
+ else
94
+ binding.pry
95
+ raise "unknown type #{val.class}"
96
+ end
97
+ end
98
+
64
99
  def pack_arr_float(float_arr)
65
100
  float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
66
101
  end
@@ -92,17 +127,17 @@ module TensorStream
92
127
 
93
128
  if tensor.rank > 0
94
129
  if TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
95
- packed = pack_arr_float(tensor.value)
130
+ packed = pack_arr_float(tensor.const_value)
96
131
  arr << " tensor_content: \"#{packed}\""
97
132
  elsif TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
98
- packed = pack_arr_int(tensor.value)
133
+ packed = pack_arr_int(tensor.const_value)
99
134
  arr << " tensor_content: \"#{packed}\""
100
135
  elsif tensor.data_type == :string
101
- tensor.value.each do |v|
136
+ tensor.const_value.each do |v|
102
137
  arr << " string_val: #{v.to_json}"
103
138
  end
104
139
  else
105
- arr << " tensor_content: #{tensor.value.flatten}"
140
+ arr << " tensor_content: #{tensor.const_value.flatten}"
106
141
  end
107
142
  else
108
143
  val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
@@ -114,7 +149,7 @@ module TensorStream
114
149
  else
115
150
  "val"
116
151
  end
117
- arr << " #{val_type}: #{tensor.value.to_json}"
152
+ arr << " #{val_type}: #{tensor.const_value.to_json}"
118
153
  end
119
154
  arr << "}"
120
155
  arr
@@ -1,7 +1,7 @@
1
1
  module TensorStream
2
2
  class Serializer
3
- def serialize(filename, tensor, session = nil)
4
- File.write(filename, get_string(tensor, session))
3
+ def serialize(filename, tensor, session = nil, graph_keys = nil)
4
+ File.write(filename, get_string(tensor, session, graph_keys = nil))
5
5
  end
6
6
 
7
7
  def get_string(tensor, session = nil); end
@@ -0,0 +1,27 @@
1
+ module TensorStream
2
+ # Parses pbtext files and loads it as a graph
3
+ class Yaml < TensorStream::Serializer
4
+ include TensorStream::StringHelper
5
+ include TensorStream::OpHelper
6
+
7
+ def get_string(tensor_or_graph, session = nil, graph_keys = nil)
8
+ graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
9
+ serialized_arr = []
10
+
11
+ node_keys = graph_keys.nil? ? graph.node_keys : graph.node_keys.select { |k| graph_keys.include?(k) }
12
+
13
+ node_keys.each do |k|
14
+ node = if block_given?
15
+ yield graph, k
16
+ else
17
+ graph.get_tensor_by_name(k)
18
+ end
19
+ next unless node.is_a?(Operation)
20
+
21
+ serialized_arr << node.to_h
22
+ end
23
+
24
+ serialized_arr.to_yaml
25
+ end
26
+ end
27
+ end
@@ -9,6 +9,12 @@ module TensorStream
9
9
 
10
10
  def self.infer_shape(tensor)
11
11
  case tensor.operation
12
+ when :case, :case_grad
13
+ tensor.inputs[2].shape.shape if tensor.inputs[2]
14
+ when :const
15
+ shape_eval(tensor.options[:value])
16
+ when :variable_v2
17
+ tensor.shape ? tensor.shape.shape : nil
12
18
  when :assign
13
19
  possible_shape = if tensor.inputs[0] && tensor.inputs[0].shape.shape
14
20
  tensor.inputs[0].shape.shape
@@ -29,9 +35,9 @@ module TensorStream
29
35
  s
30
36
  when :arg_min, :argmax, :argmin
31
37
  return nil unless tensor.inputs[0].shape.known?
32
- return nil if tensor.inputs[1] && tensor.inputs[1].value.nil?
38
+ return nil if tensor.inputs[1] && tensor.inputs[1].const_value.nil?
33
39
 
34
- axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].value
40
+ axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].const_value
35
41
  new_shape = tensor.inputs[0].shape.shape
36
42
  new_shape.each_with_index.collect do |shape, index|
37
43
  next nil if index == axis
@@ -61,7 +67,7 @@ module TensorStream
61
67
  item
62
68
  end.compact
63
69
  when :reshape
64
- new_shape = tensor.inputs[1] && tensor.inputs[1].value ? tensor.inputs[1].value : nil
70
+ new_shape = tensor.inputs[1] && tensor.inputs[1].const_value ? tensor.inputs[1].const_value : nil
65
71
  return nil if new_shape.nil?
66
72
  return nil if tensor.inputs[0].shape.nil?
67
73
 
@@ -83,11 +89,11 @@ module TensorStream
83
89
  tensor.inputs[0].shape.shape ? [tensor.inputs[0].shape.shape.size] : nil
84
90
  when :pad
85
91
  return nil unless tensor.inputs[0].shape.known?
86
- return nil unless tensor.inputs[1].value
92
+ return nil unless tensor.inputs[1].const_value
87
93
 
88
94
  size = tensor.inputs[0].shape.shape.reduce(:*) || 1
89
95
  dummy_tensor_for_shape = TensorShape.reshape(Array.new(size), tensor.inputs[0].shape)
90
- shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].value))
96
+ shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].const_value))
91
97
  when :mat_mul
92
98
  return nil if tensor.inputs[0].shape.shape.nil? || tensor.inputs[1].shape.shape.nil?
93
99
  return [] if tensor.inputs[0].shape.shape.empty? || tensor.inputs[1].shape.shape.empty?
@@ -128,9 +134,9 @@ module TensorStream
128
134
  rotated_shape = Array.new(axis + 1) { new_shape.shift }
129
135
  rotated_shape.rotate! + new_shape
130
136
  when :concat
131
- return nil if tensor.inputs[0].value.nil?
137
+ return nil if tensor.inputs[0].const_value.nil?
132
138
 
133
- axis = tensor.inputs[0].value # get axis
139
+ axis = tensor.inputs[0].const_value # get axis
134
140
 
135
141
  axis_size = 0
136
142
 
@@ -196,9 +202,9 @@ module TensorStream
196
202
 
197
203
  new_shape
198
204
  when :conv2d_backprop_input
199
- return nil unless tensor.inputs[0].value
205
+ return nil unless tensor.inputs[0].const_value
200
206
 
201
- tensor.inputs[0].value
207
+ tensor.inputs[0].const_value
202
208
  else
203
209
  return nil if tensor.inputs[0].nil?
204
210
  return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
@@ -3,9 +3,11 @@ module TensorStream
3
3
  # module that contains helper functions useful for ops
4
4
  module OpHelper
5
5
  def _op(code, *args)
6
- op = Operation.new(code.to_sym, *args)
7
- if !TensorStream.get_default_graph.get_dependency_scope.nil?
8
- i_op(:identity, op, TensorStream.get_default_graph.get_dependency_scope, name: [op.name, 'tuple', 'control_dependency'].join('/'))
6
+ default_graph = Graph.get_default_graph
7
+
8
+ op = default_graph.add_op!(code.to_sym, *args)
9
+ if !default_graph.get_dependency_scope.nil?
10
+ i_op(:identity, op, default_graph.get_dependency_scope, name: [op.name, 'tuple', 'control_dependency'].join('/'))
9
11
  else
10
12
  op
11
13
  end
@@ -20,7 +22,15 @@ module TensorStream
20
22
  end
21
23
 
22
24
  args << options.merge(internal: true)
23
- Operation.new(code.to_sym, *args)
25
+ Graph.get_default_graph.add_op!(code.to_sym, *args)
26
+ end
27
+
28
+ def i_var(data_type, rank, shape, variable_scope, options = {})
29
+ new_var = Variable.new(data_type)
30
+ new_var.prepare(rank, shape, variable_scope, options)
31
+ new_var.op = new_var.graph.add_variable!(new_var, options.merge(shape: @shape, data_type: data_type))
32
+
33
+ new_var
24
34
  end
25
35
 
26
36
  def cons(value, options = {})
@@ -55,8 +65,8 @@ module TensorStream
55
65
  end
56
66
 
57
67
  def format_source(trace)
58
- grad_source = trace.select { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }.first
59
- # source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
68
+ grad_source = trace.detect { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }
69
+ source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
60
70
  [grad_source, trace].compact.join("\n")
61
71
  end
62
72
 
@@ -82,6 +92,7 @@ module TensorStream
82
92
  axes = TensorStream.range(0, input_rank) if axes.nil?
83
93
  axes = (axes + input_rank) % input_rank
84
94
  axes_shape = i_op(:shape, axes)
95
+
85
96
  TensorStream.dynamic_stitch([TensorStream.range(0, input_rank), axes],
86
97
  [input_shape, i_op(:fill, axes_shape, 1)])
87
98
  end
@@ -1,6 +1,6 @@
1
1
  module TensorStream
2
2
  # helper string methods usually found in ActiveSupport but
3
- # need to replicate here
3
+ # need to replicate here since we don't want to use ActiveSupport
4
4
  module StringHelper
5
5
  def camelize(string, uppercase_first_letter = true)
6
6
  string = if uppercase_first_letter
@@ -23,5 +23,36 @@ module TensorStream
23
23
  [k.to_sym, v]
24
24
  end.to_h
25
25
  end
26
+
27
+ def constantize(camel_cased_word)
28
+ names = camel_cased_word.split('::')
29
+
30
+ # Trigger a built-in NameError exception including the ill-formed constant in the message.
31
+ Object.const_get(camel_cased_word) if names.empty?
32
+
33
+ # Remove the first blank element in case of '::ClassName' notation.
34
+ names.shift if names.size > 1 && names.first.empty?
35
+
36
+ names.inject(Object) do |constant, name|
37
+ if constant == Object
38
+ constant.const_get(name)
39
+ else
40
+ candidate = constant.const_get(name)
41
+ next candidate if constant.const_defined?(name, false)
42
+ next candidate unless Object.const_defined?(name)
43
+
44
+ # Go down the ancestors to check if it is owned directly. The check
45
+ # stops when we reach Object or the end of ancestors tree.
46
+ constant = constant.ancestors.inject do |const, ancestor|
47
+ break const if ancestor == Object
48
+ break ancestor if ancestor.const_defined?(name, false)
49
+ const
50
+ end
51
+
52
+ # owner is in Object, so raise
53
+ constant.const_get(name, false)
54
+ end
55
+ end
56
+ end
26
57
  end
27
58
  end
@@ -0,0 +1,135 @@
1
+ module TensorStream
2
+ module TensorMixins
3
+ def +(other)
4
+ _a, other = TensorStream.check_data_types(self, other)
5
+ _op(:add, self, other)
6
+ end
7
+
8
+ def [](index)
9
+ _op(:index, self, index)
10
+ end
11
+
12
+ def *(other)
13
+ _a, other = TensorStream.check_data_types(self, other)
14
+ _op(:mul, self, TensorStream.convert_to_tensor(other, dtype: data_type))
15
+ end
16
+
17
+ def **(other)
18
+ _a, other = TensorStream.check_data_types(self, other)
19
+ _op(:pow, self, TensorStream.convert_to_tensor(other, dtype: data_type))
20
+ end
21
+
22
+ def /(other)
23
+ _a, other = TensorStream.check_data_types(self, other)
24
+ _op(:div, self, TensorStream.convert_to_tensor(other, dtype: data_type))
25
+ end
26
+
27
+ def -(other)
28
+ _a, other = TensorStream.check_data_types(self, other)
29
+ _op(:sub, self, TensorStream.convert_to_tensor(other, dtype: data_type))
30
+ end
31
+
32
+ def -@
33
+ _op(:negate, self)
34
+ end
35
+
36
+ def %(other)
37
+ TensorStream.mod(self, other)
38
+ end
39
+
40
+ def floor(name: nil)
41
+ TensorStream.floor(self, name: name)
42
+ end
43
+
44
+ def ceil(name: nil)
45
+ TensorStream.ceil(self, name: name)
46
+ end
47
+
48
+ def round(name: nil)
49
+ TensorStream.round(self, name: name)
50
+ end
51
+
52
+ def log(name: nil)
53
+ TensorStream.log(self, name: name)
54
+ end
55
+
56
+ def reshape(shape, name: nil)
57
+ TensorStream.reshape(self, shape, name: name)
58
+ end
59
+
60
+ def zero?
61
+ _op(:equal, self, TensorStream.constant(0, dtype: data_type, name: 'equal/is_zero?'))
62
+ end
63
+
64
+ def ==(other)
65
+ _a, other = TensorStream.check_data_types(self, other)
66
+ _op(:equal, self, other)
67
+ end
68
+
69
+ def <(other)
70
+ _a, other = TensorStream.check_data_types(self, other)
71
+ _op(:less, self, other)
72
+ end
73
+
74
+ def !=(other)
75
+ _a, other = TensorStream.check_data_types(self, other)
76
+ _op(:not_equal, self, other)
77
+ end
78
+
79
+ def >(other)
80
+ _a, other = TensorStream.check_data_types(self, other)
81
+ _op(:greater, self, other)
82
+ end
83
+
84
+ def >=(other)
85
+ _a, other = TensorStream.check_data_types(self, other)
86
+ _op(:greater_equal, self, other)
87
+ end
88
+
89
+ def <=(other)
90
+ _a, other = TensorStream.check_data_types(self, other)
91
+ _op(:less_equal, self, other)
92
+ end
93
+
94
+ def and(other)
95
+ _a, other = TensorStream.check_data_types(self, other)
96
+ _op(:logical_and, self, other)
97
+ end
98
+
99
+ def matmul(other)
100
+ _a, other = TensorStream.check_data_types(self, other)
101
+ _op(:mat_mul, self, other)
102
+ end
103
+
104
+ def dot(other)
105
+ _a, other = TensorStream.check_data_types(self, other)
106
+ _op(:mat_mul, self, other)
107
+ end
108
+
109
+ def cast(data_type = :float32, name: nil)
110
+ TensorStream.cast(self, data_type, name: name)
111
+ end
112
+
113
+ def var(name: nil)
114
+ TensorStream.variable(self, name: name)
115
+ end
116
+
117
+ ##
118
+ # Apply a reduction to tensor
119
+ def reduce(op_type = :+, axis: nil, keepdims: false, name: nil)
120
+ reduce_op = case op_type.to_sym
121
+ when :+
122
+ :sum
123
+ when :*
124
+ :prod
125
+ when :mean
126
+ :mean
127
+ else
128
+ raise "unsupported reduce op type #{op_type} valid values are :+, :*, :prod, :mean"
129
+ end
130
+ raise "blocks are not supported for tensors" if block_given?
131
+
132
+ TensorStream.reduce(reduce_op, self, axis, keepdims: keepdims, name: name)
133
+ end
134
+ end
135
+ end