tensor_stream 0.9.8 → 0.9.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +31 -14
- data/lib/tensor_stream.rb +4 -0
- data/lib/tensor_stream/constant.rb +41 -0
- data/lib/tensor_stream/control_flow.rb +2 -1
- data/lib/tensor_stream/dynamic_stitch.rb +3 -1
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -4
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +74 -23
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +45 -43
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +31 -30
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +6 -6
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +46 -111
- data/lib/tensor_stream/graph.rb +61 -12
- data/lib/tensor_stream/graph_builder.rb +3 -3
- data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +38 -0
- data/lib/tensor_stream/graph_serializers/packer.rb +8 -0
- data/lib/tensor_stream/graph_serializers/pbtext.rb +62 -27
- data/lib/tensor_stream/graph_serializers/serializer.rb +2 -2
- data/lib/tensor_stream/graph_serializers/yaml.rb +27 -0
- data/lib/tensor_stream/helpers/infer_shape.rb +15 -9
- data/lib/tensor_stream/helpers/op_helper.rb +17 -6
- data/lib/tensor_stream/helpers/string_helper.rb +32 -1
- data/lib/tensor_stream/helpers/tensor_mixins.rb +135 -0
- data/lib/tensor_stream/math_gradients.rb +19 -12
- data/lib/tensor_stream/monkey_patches/float.rb +7 -0
- data/lib/tensor_stream/monkey_patches/integer.rb +7 -0
- data/lib/tensor_stream/monkey_patches/patch.rb +8 -8
- data/lib/tensor_stream/nn/nn_ops.rb +1 -1
- data/lib/tensor_stream/operation.rb +98 -36
- data/lib/tensor_stream/ops.rb +65 -13
- data/lib/tensor_stream/placeholder.rb +2 -2
- data/lib/tensor_stream/session.rb +15 -3
- data/lib/tensor_stream/tensor.rb +15 -172
- data/lib/tensor_stream/tensor_shape.rb +3 -1
- data/lib/tensor_stream/train/saver.rb +12 -10
- data/lib/tensor_stream/trainer.rb +7 -2
- data/lib/tensor_stream/utils.rb +13 -11
- data/lib/tensor_stream/utils/freezer.rb +37 -0
- data/lib/tensor_stream/variable.rb +17 -11
- data/lib/tensor_stream/variable_scope.rb +3 -1
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +3 -4
- data/samples/linear_regression.rb +9 -5
- data/samples/logistic_regression.rb +11 -9
- data/samples/mnist_data.rb +8 -10
- metadata +8 -4
@@ -27,11 +27,11 @@ module TensorStream
|
|
27
27
|
rank = dimension.size
|
28
28
|
options[:value] = value
|
29
29
|
options[:const] = true
|
30
|
-
TensorStream::
|
30
|
+
TensorStream::Constant.new(options[:dtype] || options[:T], rank, dimension, options)
|
31
31
|
when 'VariableV2'
|
32
32
|
# evaluate options
|
33
33
|
shape = options[:shape]
|
34
|
-
|
34
|
+
i_var(options[:dtype] || options[:T], nil, shape, nil, options)
|
35
35
|
when 'Placeholder'
|
36
36
|
shape = options[:shape]
|
37
37
|
TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
|
@@ -57,7 +57,7 @@ module TensorStream
|
|
57
57
|
end
|
58
58
|
|
59
59
|
options[:data_type] = options.delete(:T)
|
60
|
-
|
60
|
+
Graph.get_default_graph.add_op!(op, *inputs, options)
|
61
61
|
end
|
62
62
|
end
|
63
63
|
|
@@ -0,0 +1,38 @@
|
|
1
|
+
module TensorStream
|
2
|
+
class YamlLoader
|
3
|
+
def initialize(graph = nil)
|
4
|
+
@graph = graph || TensorStream.get_default_graph
|
5
|
+
end
|
6
|
+
|
7
|
+
def load_from_string(buffer)
|
8
|
+
serialized_ops = YAML.safe_load(buffer, [Symbol])
|
9
|
+
serialized_ops.each do |op_def|
|
10
|
+
inputs = op_def[:inputs].map { |i| @graph.get_tensor_by_name(i) }
|
11
|
+
options = {}
|
12
|
+
|
13
|
+
if op_def.dig(:attrs, :container)
|
14
|
+
new_var = Variable.new(op_def.dig(:attrs, :data_type))
|
15
|
+
var_shape = op_def.dig(:attrs, :container, :shape)
|
16
|
+
var_options = op_def.dig(:attrs, :container, :options)
|
17
|
+
var_options[:name] = op_def[:name]
|
18
|
+
|
19
|
+
new_var.prepare(var_shape.size, var_shape, TensorStream.get_variable_scope, var_options)
|
20
|
+
options[:container] = new_var
|
21
|
+
|
22
|
+
@graph.add_variable(new_var, var_options)
|
23
|
+
end
|
24
|
+
|
25
|
+
new_op = Operation.new(@graph, inputs: inputs, options: op_def[:attrs].merge(options))
|
26
|
+
new_op.operation = op_def[:op].to_sym
|
27
|
+
new_op.name = op_def[:name]
|
28
|
+
new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
|
29
|
+
new_op.rank = new_op.shape.rank
|
30
|
+
new_op.data_type = new_op.set_data_type(op_def.dig(:attrs, :data_type))
|
31
|
+
new_op.is_const = new_op.infer_const
|
32
|
+
new_op.given_name = new_op.name
|
33
|
+
|
34
|
+
@graph.add_node(new_op)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
@@ -22,6 +22,14 @@ module TensorStream
|
|
22
22
|
value.pack('C*')
|
23
23
|
when :boolean
|
24
24
|
value.map { |v| v ? 1 : 0 }.pack('C*')
|
25
|
+
when :string
|
26
|
+
if value.is_a?(Array)
|
27
|
+
value.to_yaml
|
28
|
+
else
|
29
|
+
value
|
30
|
+
end
|
31
|
+
else
|
32
|
+
raise "unknown type #{data_type}"
|
25
33
|
end
|
26
34
|
|
27
35
|
byte_value
|
@@ -4,11 +4,19 @@ module TensorStream
|
|
4
4
|
include TensorStream::StringHelper
|
5
5
|
include TensorStream::OpHelper
|
6
6
|
|
7
|
-
def get_string(tensor_or_graph, session = nil)
|
7
|
+
def get_string(tensor_or_graph, session = nil, graph_keys = nil)
|
8
8
|
graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
|
9
9
|
@lines = []
|
10
|
-
|
11
|
-
|
10
|
+
|
11
|
+
node_keys = graph_keys.nil? ? graph.node_keys : graph.node_keys.select { |k| graph_keys.include?(k) }
|
12
|
+
|
13
|
+
node_keys.each do |k|
|
14
|
+
node = if block_given?
|
15
|
+
yield graph, k
|
16
|
+
else
|
17
|
+
graph.get_tensor_by_name(k)
|
18
|
+
end
|
19
|
+
|
12
20
|
@lines << "node {"
|
13
21
|
@lines << " name: #{node.name.to_json}"
|
14
22
|
if node.is_a?(TensorStream::Operation)
|
@@ -20,16 +28,13 @@ module TensorStream
|
|
20
28
|
end
|
21
29
|
# type
|
22
30
|
pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
@lines << " op: \"VariableV2\""
|
31
|
-
pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
|
32
|
-
pb_attr('shape', shape_buf(node, 'shape'))
|
31
|
+
|
32
|
+
case node.operation.to_s
|
33
|
+
when 'const'
|
34
|
+
pb_attr('value', tensor_value(node))
|
35
|
+
when 'variable_v2'
|
36
|
+
pb_attr('shape', shape_buf(node, 'shape'))
|
37
|
+
end
|
33
38
|
process_options(node)
|
34
39
|
end
|
35
40
|
@lines << "}"
|
@@ -44,23 +49,53 @@ module TensorStream
|
|
44
49
|
|
45
50
|
def process_options(node)
|
46
51
|
return if node.options.nil?
|
47
|
-
node.options.each do |k, v|
|
48
|
-
next if %w[name].include?(k.to_s) || k.to_s.start_with?('__')
|
52
|
+
node.options.reject { |_k, v| v.nil? }.each do |k, v|
|
53
|
+
next if %w[name internal_name data_type].include?(k.to_s) || k.to_s.start_with?('__')
|
49
54
|
@lines << " attr {"
|
50
55
|
@lines << " key: \"#{k}\""
|
51
56
|
@lines << " value {"
|
52
|
-
|
53
|
-
@lines << " b: #{v}"
|
54
|
-
elsif v.is_a?(Integer)
|
55
|
-
@lines << " int_val: #{v}"
|
56
|
-
elsif v.is_a?(Float)
|
57
|
-
@lines << " float_val: #{v}"
|
58
|
-
end
|
57
|
+
attr_value(v, 6)
|
59
58
|
@lines << " }"
|
60
59
|
@lines << " }"
|
61
60
|
end
|
62
61
|
end
|
63
62
|
|
63
|
+
def attr_value(val, indent = 0)
|
64
|
+
spaces = " " * indent
|
65
|
+
case val.class.to_s
|
66
|
+
when 'TrueClass', 'FalseClass'
|
67
|
+
@lines << "#{spaces}b: #{val}"
|
68
|
+
when 'Integer'
|
69
|
+
@lines << "#{spaces}i: #{val}"
|
70
|
+
when 'String',
|
71
|
+
@lines << "#{spaces}s: #{val}"
|
72
|
+
when 'Float'
|
73
|
+
@lines << "#{spaces}f: #{val}"
|
74
|
+
when 'Symbol'
|
75
|
+
@lines << "#{spaces}sym: #{val}"
|
76
|
+
when 'Array'
|
77
|
+
@lines << "#{spaces}list {"
|
78
|
+
val.each do |v_item|
|
79
|
+
attr_value(v_item, indent + 2)
|
80
|
+
end
|
81
|
+
@lines << "#{spaces}}"
|
82
|
+
when 'TensorStream::TensorShape'
|
83
|
+
@lines << "#{spaces}shape {"
|
84
|
+
if val.shape
|
85
|
+
val.shape.each do |dim|
|
86
|
+
@lines << "#{spaces} dim {"
|
87
|
+
@lines << "#{spaces} size: #{dim}"
|
88
|
+
@lines << "#{spaces} }"
|
89
|
+
end
|
90
|
+
end
|
91
|
+
@lines << "#{spaces}}"
|
92
|
+
when 'TensorStream::Variable'
|
93
|
+
else
|
94
|
+
binding.pry
|
95
|
+
raise "unknown type #{val.class}"
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
64
99
|
def pack_arr_float(float_arr)
|
65
100
|
float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
|
66
101
|
end
|
@@ -92,17 +127,17 @@ module TensorStream
|
|
92
127
|
|
93
128
|
if tensor.rank > 0
|
94
129
|
if TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
|
95
|
-
packed = pack_arr_float(tensor.
|
130
|
+
packed = pack_arr_float(tensor.const_value)
|
96
131
|
arr << " tensor_content: \"#{packed}\""
|
97
132
|
elsif TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
|
98
|
-
packed = pack_arr_int(tensor.
|
133
|
+
packed = pack_arr_int(tensor.const_value)
|
99
134
|
arr << " tensor_content: \"#{packed}\""
|
100
135
|
elsif tensor.data_type == :string
|
101
|
-
tensor.
|
136
|
+
tensor.const_value.each do |v|
|
102
137
|
arr << " string_val: #{v.to_json}"
|
103
138
|
end
|
104
139
|
else
|
105
|
-
arr << " tensor_content: #{tensor.
|
140
|
+
arr << " tensor_content: #{tensor.const_value.flatten}"
|
106
141
|
end
|
107
142
|
else
|
108
143
|
val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
|
@@ -114,7 +149,7 @@ module TensorStream
|
|
114
149
|
else
|
115
150
|
"val"
|
116
151
|
end
|
117
|
-
arr << " #{val_type}: #{tensor.
|
152
|
+
arr << " #{val_type}: #{tensor.const_value.to_json}"
|
118
153
|
end
|
119
154
|
arr << "}"
|
120
155
|
arr
|
@@ -1,7 +1,7 @@
|
|
1
1
|
module TensorStream
|
2
2
|
class Serializer
|
3
|
-
def serialize(filename, tensor, session = nil)
|
4
|
-
File.write(filename, get_string(tensor, session))
|
3
|
+
def serialize(filename, tensor, session = nil, graph_keys = nil)
|
4
|
+
File.write(filename, get_string(tensor, session, graph_keys = nil))
|
5
5
|
end
|
6
6
|
|
7
7
|
def get_string(tensor, session = nil); end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module TensorStream
|
2
|
+
# Parses pbtext files and loads it as a graph
|
3
|
+
class Yaml < TensorStream::Serializer
|
4
|
+
include TensorStream::StringHelper
|
5
|
+
include TensorStream::OpHelper
|
6
|
+
|
7
|
+
def get_string(tensor_or_graph, session = nil, graph_keys = nil)
|
8
|
+
graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
|
9
|
+
serialized_arr = []
|
10
|
+
|
11
|
+
node_keys = graph_keys.nil? ? graph.node_keys : graph.node_keys.select { |k| graph_keys.include?(k) }
|
12
|
+
|
13
|
+
node_keys.each do |k|
|
14
|
+
node = if block_given?
|
15
|
+
yield graph, k
|
16
|
+
else
|
17
|
+
graph.get_tensor_by_name(k)
|
18
|
+
end
|
19
|
+
next unless node.is_a?(Operation)
|
20
|
+
|
21
|
+
serialized_arr << node.to_h
|
22
|
+
end
|
23
|
+
|
24
|
+
serialized_arr.to_yaml
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -9,6 +9,12 @@ module TensorStream
|
|
9
9
|
|
10
10
|
def self.infer_shape(tensor)
|
11
11
|
case tensor.operation
|
12
|
+
when :case, :case_grad
|
13
|
+
tensor.inputs[2].shape.shape if tensor.inputs[2]
|
14
|
+
when :const
|
15
|
+
shape_eval(tensor.options[:value])
|
16
|
+
when :variable_v2
|
17
|
+
tensor.shape ? tensor.shape.shape : nil
|
12
18
|
when :assign
|
13
19
|
possible_shape = if tensor.inputs[0] && tensor.inputs[0].shape.shape
|
14
20
|
tensor.inputs[0].shape.shape
|
@@ -29,9 +35,9 @@ module TensorStream
|
|
29
35
|
s
|
30
36
|
when :arg_min, :argmax, :argmin
|
31
37
|
return nil unless tensor.inputs[0].shape.known?
|
32
|
-
return nil if tensor.inputs[1] && tensor.inputs[1].
|
38
|
+
return nil if tensor.inputs[1] && tensor.inputs[1].const_value.nil?
|
33
39
|
|
34
|
-
axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].
|
40
|
+
axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].const_value
|
35
41
|
new_shape = tensor.inputs[0].shape.shape
|
36
42
|
new_shape.each_with_index.collect do |shape, index|
|
37
43
|
next nil if index == axis
|
@@ -61,7 +67,7 @@ module TensorStream
|
|
61
67
|
item
|
62
68
|
end.compact
|
63
69
|
when :reshape
|
64
|
-
new_shape = tensor.inputs[1] && tensor.inputs[1].
|
70
|
+
new_shape = tensor.inputs[1] && tensor.inputs[1].const_value ? tensor.inputs[1].const_value : nil
|
65
71
|
return nil if new_shape.nil?
|
66
72
|
return nil if tensor.inputs[0].shape.nil?
|
67
73
|
|
@@ -83,11 +89,11 @@ module TensorStream
|
|
83
89
|
tensor.inputs[0].shape.shape ? [tensor.inputs[0].shape.shape.size] : nil
|
84
90
|
when :pad
|
85
91
|
return nil unless tensor.inputs[0].shape.known?
|
86
|
-
return nil unless tensor.inputs[1].
|
92
|
+
return nil unless tensor.inputs[1].const_value
|
87
93
|
|
88
94
|
size = tensor.inputs[0].shape.shape.reduce(:*) || 1
|
89
95
|
dummy_tensor_for_shape = TensorShape.reshape(Array.new(size), tensor.inputs[0].shape)
|
90
|
-
shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].
|
96
|
+
shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].const_value))
|
91
97
|
when :mat_mul
|
92
98
|
return nil if tensor.inputs[0].shape.shape.nil? || tensor.inputs[1].shape.shape.nil?
|
93
99
|
return [] if tensor.inputs[0].shape.shape.empty? || tensor.inputs[1].shape.shape.empty?
|
@@ -128,9 +134,9 @@ module TensorStream
|
|
128
134
|
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
129
135
|
rotated_shape.rotate! + new_shape
|
130
136
|
when :concat
|
131
|
-
return nil if tensor.inputs[0].
|
137
|
+
return nil if tensor.inputs[0].const_value.nil?
|
132
138
|
|
133
|
-
axis = tensor.inputs[0].
|
139
|
+
axis = tensor.inputs[0].const_value # get axis
|
134
140
|
|
135
141
|
axis_size = 0
|
136
142
|
|
@@ -196,9 +202,9 @@ module TensorStream
|
|
196
202
|
|
197
203
|
new_shape
|
198
204
|
when :conv2d_backprop_input
|
199
|
-
return nil unless tensor.inputs[0].
|
205
|
+
return nil unless tensor.inputs[0].const_value
|
200
206
|
|
201
|
-
tensor.inputs[0].
|
207
|
+
tensor.inputs[0].const_value
|
202
208
|
else
|
203
209
|
return nil if tensor.inputs[0].nil?
|
204
210
|
return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
|
@@ -3,9 +3,11 @@ module TensorStream
|
|
3
3
|
# module that contains helper functions useful for ops
|
4
4
|
module OpHelper
|
5
5
|
def _op(code, *args)
|
6
|
-
|
7
|
-
|
8
|
-
|
6
|
+
default_graph = Graph.get_default_graph
|
7
|
+
|
8
|
+
op = default_graph.add_op!(code.to_sym, *args)
|
9
|
+
if !default_graph.get_dependency_scope.nil?
|
10
|
+
i_op(:identity, op, default_graph.get_dependency_scope, name: [op.name, 'tuple', 'control_dependency'].join('/'))
|
9
11
|
else
|
10
12
|
op
|
11
13
|
end
|
@@ -20,7 +22,15 @@ module TensorStream
|
|
20
22
|
end
|
21
23
|
|
22
24
|
args << options.merge(internal: true)
|
23
|
-
|
25
|
+
Graph.get_default_graph.add_op!(code.to_sym, *args)
|
26
|
+
end
|
27
|
+
|
28
|
+
def i_var(data_type, rank, shape, variable_scope, options = {})
|
29
|
+
new_var = Variable.new(data_type)
|
30
|
+
new_var.prepare(rank, shape, variable_scope, options)
|
31
|
+
new_var.op = new_var.graph.add_variable!(new_var, options.merge(shape: @shape, data_type: data_type))
|
32
|
+
|
33
|
+
new_var
|
24
34
|
end
|
25
35
|
|
26
36
|
def cons(value, options = {})
|
@@ -55,8 +65,8 @@ module TensorStream
|
|
55
65
|
end
|
56
66
|
|
57
67
|
def format_source(trace)
|
58
|
-
grad_source = trace.
|
59
|
-
|
68
|
+
grad_source = trace.detect { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }
|
69
|
+
source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
|
60
70
|
[grad_source, trace].compact.join("\n")
|
61
71
|
end
|
62
72
|
|
@@ -82,6 +92,7 @@ module TensorStream
|
|
82
92
|
axes = TensorStream.range(0, input_rank) if axes.nil?
|
83
93
|
axes = (axes + input_rank) % input_rank
|
84
94
|
axes_shape = i_op(:shape, axes)
|
95
|
+
|
85
96
|
TensorStream.dynamic_stitch([TensorStream.range(0, input_rank), axes],
|
86
97
|
[input_shape, i_op(:fill, axes_shape, 1)])
|
87
98
|
end
|
@@ -1,6 +1,6 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# helper string methods usually found in ActiveSupport but
|
3
|
-
# need to replicate here
|
3
|
+
# need to replicate here since we don't want to use ActiveSupport
|
4
4
|
module StringHelper
|
5
5
|
def camelize(string, uppercase_first_letter = true)
|
6
6
|
string = if uppercase_first_letter
|
@@ -23,5 +23,36 @@ module TensorStream
|
|
23
23
|
[k.to_sym, v]
|
24
24
|
end.to_h
|
25
25
|
end
|
26
|
+
|
27
|
+
def constantize(camel_cased_word)
|
28
|
+
names = camel_cased_word.split('::')
|
29
|
+
|
30
|
+
# Trigger a built-in NameError exception including the ill-formed constant in the message.
|
31
|
+
Object.const_get(camel_cased_word) if names.empty?
|
32
|
+
|
33
|
+
# Remove the first blank element in case of '::ClassName' notation.
|
34
|
+
names.shift if names.size > 1 && names.first.empty?
|
35
|
+
|
36
|
+
names.inject(Object) do |constant, name|
|
37
|
+
if constant == Object
|
38
|
+
constant.const_get(name)
|
39
|
+
else
|
40
|
+
candidate = constant.const_get(name)
|
41
|
+
next candidate if constant.const_defined?(name, false)
|
42
|
+
next candidate unless Object.const_defined?(name)
|
43
|
+
|
44
|
+
# Go down the ancestors to check if it is owned directly. The check
|
45
|
+
# stops when we reach Object or the end of ancestors tree.
|
46
|
+
constant = constant.ancestors.inject do |const, ancestor|
|
47
|
+
break const if ancestor == Object
|
48
|
+
break ancestor if ancestor.const_defined?(name, false)
|
49
|
+
const
|
50
|
+
end
|
51
|
+
|
52
|
+
# owner is in Object, so raise
|
53
|
+
constant.const_get(name, false)
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
26
57
|
end
|
27
58
|
end
|
@@ -0,0 +1,135 @@
|
|
1
|
+
module TensorStream
|
2
|
+
module TensorMixins
|
3
|
+
def +(other)
|
4
|
+
_a, other = TensorStream.check_data_types(self, other)
|
5
|
+
_op(:add, self, other)
|
6
|
+
end
|
7
|
+
|
8
|
+
def [](index)
|
9
|
+
_op(:index, self, index)
|
10
|
+
end
|
11
|
+
|
12
|
+
def *(other)
|
13
|
+
_a, other = TensorStream.check_data_types(self, other)
|
14
|
+
_op(:mul, self, TensorStream.convert_to_tensor(other, dtype: data_type))
|
15
|
+
end
|
16
|
+
|
17
|
+
def **(other)
|
18
|
+
_a, other = TensorStream.check_data_types(self, other)
|
19
|
+
_op(:pow, self, TensorStream.convert_to_tensor(other, dtype: data_type))
|
20
|
+
end
|
21
|
+
|
22
|
+
def /(other)
|
23
|
+
_a, other = TensorStream.check_data_types(self, other)
|
24
|
+
_op(:div, self, TensorStream.convert_to_tensor(other, dtype: data_type))
|
25
|
+
end
|
26
|
+
|
27
|
+
def -(other)
|
28
|
+
_a, other = TensorStream.check_data_types(self, other)
|
29
|
+
_op(:sub, self, TensorStream.convert_to_tensor(other, dtype: data_type))
|
30
|
+
end
|
31
|
+
|
32
|
+
def -@
|
33
|
+
_op(:negate, self)
|
34
|
+
end
|
35
|
+
|
36
|
+
def %(other)
|
37
|
+
TensorStream.mod(self, other)
|
38
|
+
end
|
39
|
+
|
40
|
+
def floor(name: nil)
|
41
|
+
TensorStream.floor(self, name: name)
|
42
|
+
end
|
43
|
+
|
44
|
+
def ceil(name: nil)
|
45
|
+
TensorStream.ceil(self, name: name)
|
46
|
+
end
|
47
|
+
|
48
|
+
def round(name: nil)
|
49
|
+
TensorStream.round(self, name: name)
|
50
|
+
end
|
51
|
+
|
52
|
+
def log(name: nil)
|
53
|
+
TensorStream.log(self, name: name)
|
54
|
+
end
|
55
|
+
|
56
|
+
def reshape(shape, name: nil)
|
57
|
+
TensorStream.reshape(self, shape, name: name)
|
58
|
+
end
|
59
|
+
|
60
|
+
def zero?
|
61
|
+
_op(:equal, self, TensorStream.constant(0, dtype: data_type, name: 'equal/is_zero?'))
|
62
|
+
end
|
63
|
+
|
64
|
+
def ==(other)
|
65
|
+
_a, other = TensorStream.check_data_types(self, other)
|
66
|
+
_op(:equal, self, other)
|
67
|
+
end
|
68
|
+
|
69
|
+
def <(other)
|
70
|
+
_a, other = TensorStream.check_data_types(self, other)
|
71
|
+
_op(:less, self, other)
|
72
|
+
end
|
73
|
+
|
74
|
+
def !=(other)
|
75
|
+
_a, other = TensorStream.check_data_types(self, other)
|
76
|
+
_op(:not_equal, self, other)
|
77
|
+
end
|
78
|
+
|
79
|
+
def >(other)
|
80
|
+
_a, other = TensorStream.check_data_types(self, other)
|
81
|
+
_op(:greater, self, other)
|
82
|
+
end
|
83
|
+
|
84
|
+
def >=(other)
|
85
|
+
_a, other = TensorStream.check_data_types(self, other)
|
86
|
+
_op(:greater_equal, self, other)
|
87
|
+
end
|
88
|
+
|
89
|
+
def <=(other)
|
90
|
+
_a, other = TensorStream.check_data_types(self, other)
|
91
|
+
_op(:less_equal, self, other)
|
92
|
+
end
|
93
|
+
|
94
|
+
def and(other)
|
95
|
+
_a, other = TensorStream.check_data_types(self, other)
|
96
|
+
_op(:logical_and, self, other)
|
97
|
+
end
|
98
|
+
|
99
|
+
def matmul(other)
|
100
|
+
_a, other = TensorStream.check_data_types(self, other)
|
101
|
+
_op(:mat_mul, self, other)
|
102
|
+
end
|
103
|
+
|
104
|
+
def dot(other)
|
105
|
+
_a, other = TensorStream.check_data_types(self, other)
|
106
|
+
_op(:mat_mul, self, other)
|
107
|
+
end
|
108
|
+
|
109
|
+
def cast(data_type = :float32, name: nil)
|
110
|
+
TensorStream.cast(self, data_type, name: name)
|
111
|
+
end
|
112
|
+
|
113
|
+
def var(name: nil)
|
114
|
+
TensorStream.variable(self, name: name)
|
115
|
+
end
|
116
|
+
|
117
|
+
##
|
118
|
+
# Apply a reduction to tensor
|
119
|
+
def reduce(op_type = :+, axis: nil, keepdims: false, name: nil)
|
120
|
+
reduce_op = case op_type.to_sym
|
121
|
+
when :+
|
122
|
+
:sum
|
123
|
+
when :*
|
124
|
+
:prod
|
125
|
+
when :mean
|
126
|
+
:mean
|
127
|
+
else
|
128
|
+
raise "unsupported reduce op type #{op_type} valid values are :+, :*, :prod, :mean"
|
129
|
+
end
|
130
|
+
raise "blocks are not supported for tensors" if block_given?
|
131
|
+
|
132
|
+
TensorStream.reduce(reduce_op, self, axis, keepdims: keepdims, name: name)
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|