tensor_stream 0.9.8 → 0.9.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +31 -14
  3. data/lib/tensor_stream.rb +4 -0
  4. data/lib/tensor_stream/constant.rb +41 -0
  5. data/lib/tensor_stream/control_flow.rb +2 -1
  6. data/lib/tensor_stream/dynamic_stitch.rb +3 -1
  7. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -4
  8. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +74 -23
  9. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +45 -43
  10. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +31 -30
  11. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +6 -6
  12. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +46 -111
  13. data/lib/tensor_stream/graph.rb +61 -12
  14. data/lib/tensor_stream/graph_builder.rb +3 -3
  15. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +38 -0
  16. data/lib/tensor_stream/graph_serializers/packer.rb +8 -0
  17. data/lib/tensor_stream/graph_serializers/pbtext.rb +62 -27
  18. data/lib/tensor_stream/graph_serializers/serializer.rb +2 -2
  19. data/lib/tensor_stream/graph_serializers/yaml.rb +27 -0
  20. data/lib/tensor_stream/helpers/infer_shape.rb +15 -9
  21. data/lib/tensor_stream/helpers/op_helper.rb +17 -6
  22. data/lib/tensor_stream/helpers/string_helper.rb +32 -1
  23. data/lib/tensor_stream/helpers/tensor_mixins.rb +135 -0
  24. data/lib/tensor_stream/math_gradients.rb +19 -12
  25. data/lib/tensor_stream/monkey_patches/float.rb +7 -0
  26. data/lib/tensor_stream/monkey_patches/integer.rb +7 -0
  27. data/lib/tensor_stream/monkey_patches/patch.rb +8 -8
  28. data/lib/tensor_stream/nn/nn_ops.rb +1 -1
  29. data/lib/tensor_stream/operation.rb +98 -36
  30. data/lib/tensor_stream/ops.rb +65 -13
  31. data/lib/tensor_stream/placeholder.rb +2 -2
  32. data/lib/tensor_stream/session.rb +15 -3
  33. data/lib/tensor_stream/tensor.rb +15 -172
  34. data/lib/tensor_stream/tensor_shape.rb +3 -1
  35. data/lib/tensor_stream/train/saver.rb +12 -10
  36. data/lib/tensor_stream/trainer.rb +7 -2
  37. data/lib/tensor_stream/utils.rb +13 -11
  38. data/lib/tensor_stream/utils/freezer.rb +37 -0
  39. data/lib/tensor_stream/variable.rb +17 -11
  40. data/lib/tensor_stream/variable_scope.rb +3 -1
  41. data/lib/tensor_stream/version.rb +1 -1
  42. data/samples/iris.rb +3 -4
  43. data/samples/linear_regression.rb +9 -5
  44. data/samples/logistic_regression.rb +11 -9
  45. data/samples/mnist_data.rb +8 -10
  46. metadata +8 -4
@@ -27,11 +27,11 @@ module TensorStream
27
27
  rank = dimension.size
28
28
  options[:value] = value
29
29
  options[:const] = true
30
- TensorStream::Tensor.new(options[:dtype] || options[:T], rank, dimension, options)
30
+ TensorStream::Constant.new(options[:dtype] || options[:T], rank, dimension, options)
31
31
  when 'VariableV2'
32
32
  # evaluate options
33
33
  shape = options[:shape]
34
- TensorStream::Variable.new(options[:dtype] || options[:T], nil, shape, nil, options)
34
+ i_var(options[:dtype] || options[:T], nil, shape, nil, options)
35
35
  when 'Placeholder'
36
36
  shape = options[:shape]
37
37
  TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
@@ -57,7 +57,7 @@ module TensorStream
57
57
  end
58
58
 
59
59
  options[:data_type] = options.delete(:T)
60
- TensorStream::Operation.new(op, *inputs, options)
60
+ Graph.get_default_graph.add_op!(op, *inputs, options)
61
61
  end
62
62
  end
63
63
 
@@ -0,0 +1,38 @@
1
+ module TensorStream
2
+ class YamlLoader
3
+ def initialize(graph = nil)
4
+ @graph = graph || TensorStream.get_default_graph
5
+ end
6
+
7
+ def load_from_string(buffer)
8
+ serialized_ops = YAML.safe_load(buffer, [Symbol])
9
+ serialized_ops.each do |op_def|
10
+ inputs = op_def[:inputs].map { |i| @graph.get_tensor_by_name(i) }
11
+ options = {}
12
+
13
+ if op_def.dig(:attrs, :container)
14
+ new_var = Variable.new(op_def.dig(:attrs, :data_type))
15
+ var_shape = op_def.dig(:attrs, :container, :shape)
16
+ var_options = op_def.dig(:attrs, :container, :options)
17
+ var_options[:name] = op_def[:name]
18
+
19
+ new_var.prepare(var_shape.size, var_shape, TensorStream.get_variable_scope, var_options)
20
+ options[:container] = new_var
21
+
22
+ @graph.add_variable(new_var, var_options)
23
+ end
24
+
25
+ new_op = Operation.new(@graph, inputs: inputs, options: op_def[:attrs].merge(options))
26
+ new_op.operation = op_def[:op].to_sym
27
+ new_op.name = op_def[:name]
28
+ new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
29
+ new_op.rank = new_op.shape.rank
30
+ new_op.data_type = new_op.set_data_type(op_def.dig(:attrs, :data_type))
31
+ new_op.is_const = new_op.infer_const
32
+ new_op.given_name = new_op.name
33
+
34
+ @graph.add_node(new_op)
35
+ end
36
+ end
37
+ end
38
+ end
@@ -22,6 +22,14 @@ module TensorStream
22
22
  value.pack('C*')
23
23
  when :boolean
24
24
  value.map { |v| v ? 1 : 0 }.pack('C*')
25
+ when :string
26
+ if value.is_a?(Array)
27
+ value.to_yaml
28
+ else
29
+ value
30
+ end
31
+ else
32
+ raise "unknown type #{data_type}"
25
33
  end
26
34
 
27
35
  byte_value
@@ -4,11 +4,19 @@ module TensorStream
4
4
  include TensorStream::StringHelper
5
5
  include TensorStream::OpHelper
6
6
 
7
- def get_string(tensor_or_graph, session = nil)
7
+ def get_string(tensor_or_graph, session = nil, graph_keys = nil)
8
8
  graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
9
9
  @lines = []
10
- graph.node_keys.each do |k|
11
- node = graph.get_tensor_by_name(k)
10
+
11
+ node_keys = graph_keys.nil? ? graph.node_keys : graph.node_keys.select { |k| graph_keys.include?(k) }
12
+
13
+ node_keys.each do |k|
14
+ node = if block_given?
15
+ yield graph, k
16
+ else
17
+ graph.get_tensor_by_name(k)
18
+ end
19
+
12
20
  @lines << "node {"
13
21
  @lines << " name: #{node.name.to_json}"
14
22
  if node.is_a?(TensorStream::Operation)
@@ -20,16 +28,13 @@ module TensorStream
20
28
  end
21
29
  # type
22
30
  pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
23
- process_options(node)
24
- elsif node.is_a?(TensorStream::Tensor) && node.is_const
25
- @lines << " op: \"Const\""
26
- # type
27
- pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
28
- pb_attr('value', tensor_value(node))
29
- elsif node.is_a?(TensorStream::Variable)
30
- @lines << " op: \"VariableV2\""
31
- pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
32
- pb_attr('shape', shape_buf(node, 'shape'))
31
+
32
+ case node.operation.to_s
33
+ when 'const'
34
+ pb_attr('value', tensor_value(node))
35
+ when 'variable_v2'
36
+ pb_attr('shape', shape_buf(node, 'shape'))
37
+ end
33
38
  process_options(node)
34
39
  end
35
40
  @lines << "}"
@@ -44,23 +49,53 @@ module TensorStream
44
49
 
45
50
  def process_options(node)
46
51
  return if node.options.nil?
47
- node.options.each do |k, v|
48
- next if %w[name].include?(k.to_s) || k.to_s.start_with?('__')
52
+ node.options.reject { |_k, v| v.nil? }.each do |k, v|
53
+ next if %w[name internal_name data_type].include?(k.to_s) || k.to_s.start_with?('__')
49
54
  @lines << " attr {"
50
55
  @lines << " key: \"#{k}\""
51
56
  @lines << " value {"
52
- if v.is_a?(TrueClass) || v.is_a?(FalseClass)
53
- @lines << " b: #{v}"
54
- elsif v.is_a?(Integer)
55
- @lines << " int_val: #{v}"
56
- elsif v.is_a?(Float)
57
- @lines << " float_val: #{v}"
58
- end
57
+ attr_value(v, 6)
59
58
  @lines << " }"
60
59
  @lines << " }"
61
60
  end
62
61
  end
63
62
 
63
+ def attr_value(val, indent = 0)
64
+ spaces = " " * indent
65
+ case val.class.to_s
66
+ when 'TrueClass', 'FalseClass'
67
+ @lines << "#{spaces}b: #{val}"
68
+ when 'Integer'
69
+ @lines << "#{spaces}i: #{val}"
70
+ when 'String',
71
+ @lines << "#{spaces}s: #{val}"
72
+ when 'Float'
73
+ @lines << "#{spaces}f: #{val}"
74
+ when 'Symbol'
75
+ @lines << "#{spaces}sym: #{val}"
76
+ when 'Array'
77
+ @lines << "#{spaces}list {"
78
+ val.each do |v_item|
79
+ attr_value(v_item, indent + 2)
80
+ end
81
+ @lines << "#{spaces}}"
82
+ when 'TensorStream::TensorShape'
83
+ @lines << "#{spaces}shape {"
84
+ if val.shape
85
+ val.shape.each do |dim|
86
+ @lines << "#{spaces} dim {"
87
+ @lines << "#{spaces} size: #{dim}"
88
+ @lines << "#{spaces} }"
89
+ end
90
+ end
91
+ @lines << "#{spaces}}"
92
+ when 'TensorStream::Variable'
93
+ else
94
+ binding.pry
95
+ raise "unknown type #{val.class}"
96
+ end
97
+ end
98
+
64
99
  def pack_arr_float(float_arr)
65
100
  float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
66
101
  end
@@ -92,17 +127,17 @@ module TensorStream
92
127
 
93
128
  if tensor.rank > 0
94
129
  if TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
95
- packed = pack_arr_float(tensor.value)
130
+ packed = pack_arr_float(tensor.const_value)
96
131
  arr << " tensor_content: \"#{packed}\""
97
132
  elsif TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
98
- packed = pack_arr_int(tensor.value)
133
+ packed = pack_arr_int(tensor.const_value)
99
134
  arr << " tensor_content: \"#{packed}\""
100
135
  elsif tensor.data_type == :string
101
- tensor.value.each do |v|
136
+ tensor.const_value.each do |v|
102
137
  arr << " string_val: #{v.to_json}"
103
138
  end
104
139
  else
105
- arr << " tensor_content: #{tensor.value.flatten}"
140
+ arr << " tensor_content: #{tensor.const_value.flatten}"
106
141
  end
107
142
  else
108
143
  val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
@@ -114,7 +149,7 @@ module TensorStream
114
149
  else
115
150
  "val"
116
151
  end
117
- arr << " #{val_type}: #{tensor.value.to_json}"
152
+ arr << " #{val_type}: #{tensor.const_value.to_json}"
118
153
  end
119
154
  arr << "}"
120
155
  arr
@@ -1,7 +1,7 @@
1
1
  module TensorStream
2
2
  class Serializer
3
- def serialize(filename, tensor, session = nil)
4
- File.write(filename, get_string(tensor, session))
3
+ def serialize(filename, tensor, session = nil, graph_keys = nil)
4
+ File.write(filename, get_string(tensor, session, graph_keys = nil))
5
5
  end
6
6
 
7
7
  def get_string(tensor, session = nil); end
@@ -0,0 +1,27 @@
1
+ module TensorStream
2
+ # Parses pbtext files and loads it as a graph
3
+ class Yaml < TensorStream::Serializer
4
+ include TensorStream::StringHelper
5
+ include TensorStream::OpHelper
6
+
7
+ def get_string(tensor_or_graph, session = nil, graph_keys = nil)
8
+ graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
9
+ serialized_arr = []
10
+
11
+ node_keys = graph_keys.nil? ? graph.node_keys : graph.node_keys.select { |k| graph_keys.include?(k) }
12
+
13
+ node_keys.each do |k|
14
+ node = if block_given?
15
+ yield graph, k
16
+ else
17
+ graph.get_tensor_by_name(k)
18
+ end
19
+ next unless node.is_a?(Operation)
20
+
21
+ serialized_arr << node.to_h
22
+ end
23
+
24
+ serialized_arr.to_yaml
25
+ end
26
+ end
27
+ end
@@ -9,6 +9,12 @@ module TensorStream
9
9
 
10
10
  def self.infer_shape(tensor)
11
11
  case tensor.operation
12
+ when :case, :case_grad
13
+ tensor.inputs[2].shape.shape if tensor.inputs[2]
14
+ when :const
15
+ shape_eval(tensor.options[:value])
16
+ when :variable_v2
17
+ tensor.shape ? tensor.shape.shape : nil
12
18
  when :assign
13
19
  possible_shape = if tensor.inputs[0] && tensor.inputs[0].shape.shape
14
20
  tensor.inputs[0].shape.shape
@@ -29,9 +35,9 @@ module TensorStream
29
35
  s
30
36
  when :arg_min, :argmax, :argmin
31
37
  return nil unless tensor.inputs[0].shape.known?
32
- return nil if tensor.inputs[1] && tensor.inputs[1].value.nil?
38
+ return nil if tensor.inputs[1] && tensor.inputs[1].const_value.nil?
33
39
 
34
- axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].value
40
+ axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].const_value
35
41
  new_shape = tensor.inputs[0].shape.shape
36
42
  new_shape.each_with_index.collect do |shape, index|
37
43
  next nil if index == axis
@@ -61,7 +67,7 @@ module TensorStream
61
67
  item
62
68
  end.compact
63
69
  when :reshape
64
- new_shape = tensor.inputs[1] && tensor.inputs[1].value ? tensor.inputs[1].value : nil
70
+ new_shape = tensor.inputs[1] && tensor.inputs[1].const_value ? tensor.inputs[1].const_value : nil
65
71
  return nil if new_shape.nil?
66
72
  return nil if tensor.inputs[0].shape.nil?
67
73
 
@@ -83,11 +89,11 @@ module TensorStream
83
89
  tensor.inputs[0].shape.shape ? [tensor.inputs[0].shape.shape.size] : nil
84
90
  when :pad
85
91
  return nil unless tensor.inputs[0].shape.known?
86
- return nil unless tensor.inputs[1].value
92
+ return nil unless tensor.inputs[1].const_value
87
93
 
88
94
  size = tensor.inputs[0].shape.shape.reduce(:*) || 1
89
95
  dummy_tensor_for_shape = TensorShape.reshape(Array.new(size), tensor.inputs[0].shape)
90
- shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].value))
96
+ shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].const_value))
91
97
  when :mat_mul
92
98
  return nil if tensor.inputs[0].shape.shape.nil? || tensor.inputs[1].shape.shape.nil?
93
99
  return [] if tensor.inputs[0].shape.shape.empty? || tensor.inputs[1].shape.shape.empty?
@@ -128,9 +134,9 @@ module TensorStream
128
134
  rotated_shape = Array.new(axis + 1) { new_shape.shift }
129
135
  rotated_shape.rotate! + new_shape
130
136
  when :concat
131
- return nil if tensor.inputs[0].value.nil?
137
+ return nil if tensor.inputs[0].const_value.nil?
132
138
 
133
- axis = tensor.inputs[0].value # get axis
139
+ axis = tensor.inputs[0].const_value # get axis
134
140
 
135
141
  axis_size = 0
136
142
 
@@ -196,9 +202,9 @@ module TensorStream
196
202
 
197
203
  new_shape
198
204
  when :conv2d_backprop_input
199
- return nil unless tensor.inputs[0].value
205
+ return nil unless tensor.inputs[0].const_value
200
206
 
201
- tensor.inputs[0].value
207
+ tensor.inputs[0].const_value
202
208
  else
203
209
  return nil if tensor.inputs[0].nil?
204
210
  return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
@@ -3,9 +3,11 @@ module TensorStream
3
3
  # module that contains helper functions useful for ops
4
4
  module OpHelper
5
5
  def _op(code, *args)
6
- op = Operation.new(code.to_sym, *args)
7
- if !TensorStream.get_default_graph.get_dependency_scope.nil?
8
- i_op(:identity, op, TensorStream.get_default_graph.get_dependency_scope, name: [op.name, 'tuple', 'control_dependency'].join('/'))
6
+ default_graph = Graph.get_default_graph
7
+
8
+ op = default_graph.add_op!(code.to_sym, *args)
9
+ if !default_graph.get_dependency_scope.nil?
10
+ i_op(:identity, op, default_graph.get_dependency_scope, name: [op.name, 'tuple', 'control_dependency'].join('/'))
9
11
  else
10
12
  op
11
13
  end
@@ -20,7 +22,15 @@ module TensorStream
20
22
  end
21
23
 
22
24
  args << options.merge(internal: true)
23
- Operation.new(code.to_sym, *args)
25
+ Graph.get_default_graph.add_op!(code.to_sym, *args)
26
+ end
27
+
28
+ def i_var(data_type, rank, shape, variable_scope, options = {})
29
+ new_var = Variable.new(data_type)
30
+ new_var.prepare(rank, shape, variable_scope, options)
31
+ new_var.op = new_var.graph.add_variable!(new_var, options.merge(shape: @shape, data_type: data_type))
32
+
33
+ new_var
24
34
  end
25
35
 
26
36
  def cons(value, options = {})
@@ -55,8 +65,8 @@ module TensorStream
55
65
  end
56
66
 
57
67
  def format_source(trace)
58
- grad_source = trace.select { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }.first
59
- # source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
68
+ grad_source = trace.detect { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }
69
+ source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
60
70
  [grad_source, trace].compact.join("\n")
61
71
  end
62
72
 
@@ -82,6 +92,7 @@ module TensorStream
82
92
  axes = TensorStream.range(0, input_rank) if axes.nil?
83
93
  axes = (axes + input_rank) % input_rank
84
94
  axes_shape = i_op(:shape, axes)
95
+
85
96
  TensorStream.dynamic_stitch([TensorStream.range(0, input_rank), axes],
86
97
  [input_shape, i_op(:fill, axes_shape, 1)])
87
98
  end
@@ -1,6 +1,6 @@
1
1
  module TensorStream
2
2
  # helper string methods usually found in ActiveSupport but
3
- # need to replicate here
3
+ # need to replicate here since we don't want to use ActiveSupport
4
4
  module StringHelper
5
5
  def camelize(string, uppercase_first_letter = true)
6
6
  string = if uppercase_first_letter
@@ -23,5 +23,36 @@ module TensorStream
23
23
  [k.to_sym, v]
24
24
  end.to_h
25
25
  end
26
+
27
+ def constantize(camel_cased_word)
28
+ names = camel_cased_word.split('::')
29
+
30
+ # Trigger a built-in NameError exception including the ill-formed constant in the message.
31
+ Object.const_get(camel_cased_word) if names.empty?
32
+
33
+ # Remove the first blank element in case of '::ClassName' notation.
34
+ names.shift if names.size > 1 && names.first.empty?
35
+
36
+ names.inject(Object) do |constant, name|
37
+ if constant == Object
38
+ constant.const_get(name)
39
+ else
40
+ candidate = constant.const_get(name)
41
+ next candidate if constant.const_defined?(name, false)
42
+ next candidate unless Object.const_defined?(name)
43
+
44
+ # Go down the ancestors to check if it is owned directly. The check
45
+ # stops when we reach Object or the end of ancestors tree.
46
+ constant = constant.ancestors.inject do |const, ancestor|
47
+ break const if ancestor == Object
48
+ break ancestor if ancestor.const_defined?(name, false)
49
+ const
50
+ end
51
+
52
+ # owner is in Object, so raise
53
+ constant.const_get(name, false)
54
+ end
55
+ end
56
+ end
26
57
  end
27
58
  end
@@ -0,0 +1,135 @@
1
+ module TensorStream
2
+ module TensorMixins
3
+ def +(other)
4
+ _a, other = TensorStream.check_data_types(self, other)
5
+ _op(:add, self, other)
6
+ end
7
+
8
+ def [](index)
9
+ _op(:index, self, index)
10
+ end
11
+
12
+ def *(other)
13
+ _a, other = TensorStream.check_data_types(self, other)
14
+ _op(:mul, self, TensorStream.convert_to_tensor(other, dtype: data_type))
15
+ end
16
+
17
+ def **(other)
18
+ _a, other = TensorStream.check_data_types(self, other)
19
+ _op(:pow, self, TensorStream.convert_to_tensor(other, dtype: data_type))
20
+ end
21
+
22
+ def /(other)
23
+ _a, other = TensorStream.check_data_types(self, other)
24
+ _op(:div, self, TensorStream.convert_to_tensor(other, dtype: data_type))
25
+ end
26
+
27
+ def -(other)
28
+ _a, other = TensorStream.check_data_types(self, other)
29
+ _op(:sub, self, TensorStream.convert_to_tensor(other, dtype: data_type))
30
+ end
31
+
32
+ def -@
33
+ _op(:negate, self)
34
+ end
35
+
36
+ def %(other)
37
+ TensorStream.mod(self, other)
38
+ end
39
+
40
+ def floor(name: nil)
41
+ TensorStream.floor(self, name: name)
42
+ end
43
+
44
+ def ceil(name: nil)
45
+ TensorStream.ceil(self, name: name)
46
+ end
47
+
48
+ def round(name: nil)
49
+ TensorStream.round(self, name: name)
50
+ end
51
+
52
+ def log(name: nil)
53
+ TensorStream.log(self, name: name)
54
+ end
55
+
56
+ def reshape(shape, name: nil)
57
+ TensorStream.reshape(self, shape, name: name)
58
+ end
59
+
60
+ def zero?
61
+ _op(:equal, self, TensorStream.constant(0, dtype: data_type, name: 'equal/is_zero?'))
62
+ end
63
+
64
+ def ==(other)
65
+ _a, other = TensorStream.check_data_types(self, other)
66
+ _op(:equal, self, other)
67
+ end
68
+
69
+ def <(other)
70
+ _a, other = TensorStream.check_data_types(self, other)
71
+ _op(:less, self, other)
72
+ end
73
+
74
+ def !=(other)
75
+ _a, other = TensorStream.check_data_types(self, other)
76
+ _op(:not_equal, self, other)
77
+ end
78
+
79
+ def >(other)
80
+ _a, other = TensorStream.check_data_types(self, other)
81
+ _op(:greater, self, other)
82
+ end
83
+
84
+ def >=(other)
85
+ _a, other = TensorStream.check_data_types(self, other)
86
+ _op(:greater_equal, self, other)
87
+ end
88
+
89
+ def <=(other)
90
+ _a, other = TensorStream.check_data_types(self, other)
91
+ _op(:less_equal, self, other)
92
+ end
93
+
94
+ def and(other)
95
+ _a, other = TensorStream.check_data_types(self, other)
96
+ _op(:logical_and, self, other)
97
+ end
98
+
99
+ def matmul(other)
100
+ _a, other = TensorStream.check_data_types(self, other)
101
+ _op(:mat_mul, self, other)
102
+ end
103
+
104
+ def dot(other)
105
+ _a, other = TensorStream.check_data_types(self, other)
106
+ _op(:mat_mul, self, other)
107
+ end
108
+
109
+ def cast(data_type = :float32, name: nil)
110
+ TensorStream.cast(self, data_type, name: name)
111
+ end
112
+
113
+ def var(name: nil)
114
+ TensorStream.variable(self, name: name)
115
+ end
116
+
117
+ ##
118
+ # Apply a reduction to tensor
119
+ def reduce(op_type = :+, axis: nil, keepdims: false, name: nil)
120
+ reduce_op = case op_type.to_sym
121
+ when :+
122
+ :sum
123
+ when :*
124
+ :prod
125
+ when :mean
126
+ :mean
127
+ else
128
+ raise "unsupported reduce op type #{op_type} valid values are :+, :*, :prod, :mean"
129
+ end
130
+ raise "blocks are not supported for tensors" if block_given?
131
+
132
+ TensorStream.reduce(reduce_op, self, axis, keepdims: keepdims, name: name)
133
+ end
134
+ end
135
+ end