tensor_stream 0.9.8 → 0.9.9
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +31 -14
- data/lib/tensor_stream.rb +4 -0
- data/lib/tensor_stream/constant.rb +41 -0
- data/lib/tensor_stream/control_flow.rb +2 -1
- data/lib/tensor_stream/dynamic_stitch.rb +3 -1
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -4
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +74 -23
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +45 -43
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +31 -30
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +6 -6
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +46 -111
- data/lib/tensor_stream/graph.rb +61 -12
- data/lib/tensor_stream/graph_builder.rb +3 -3
- data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +38 -0
- data/lib/tensor_stream/graph_serializers/packer.rb +8 -0
- data/lib/tensor_stream/graph_serializers/pbtext.rb +62 -27
- data/lib/tensor_stream/graph_serializers/serializer.rb +2 -2
- data/lib/tensor_stream/graph_serializers/yaml.rb +27 -0
- data/lib/tensor_stream/helpers/infer_shape.rb +15 -9
- data/lib/tensor_stream/helpers/op_helper.rb +17 -6
- data/lib/tensor_stream/helpers/string_helper.rb +32 -1
- data/lib/tensor_stream/helpers/tensor_mixins.rb +135 -0
- data/lib/tensor_stream/math_gradients.rb +19 -12
- data/lib/tensor_stream/monkey_patches/float.rb +7 -0
- data/lib/tensor_stream/monkey_patches/integer.rb +7 -0
- data/lib/tensor_stream/monkey_patches/patch.rb +8 -8
- data/lib/tensor_stream/nn/nn_ops.rb +1 -1
- data/lib/tensor_stream/operation.rb +98 -36
- data/lib/tensor_stream/ops.rb +65 -13
- data/lib/tensor_stream/placeholder.rb +2 -2
- data/lib/tensor_stream/session.rb +15 -3
- data/lib/tensor_stream/tensor.rb +15 -172
- data/lib/tensor_stream/tensor_shape.rb +3 -1
- data/lib/tensor_stream/train/saver.rb +12 -10
- data/lib/tensor_stream/trainer.rb +7 -2
- data/lib/tensor_stream/utils.rb +13 -11
- data/lib/tensor_stream/utils/freezer.rb +37 -0
- data/lib/tensor_stream/variable.rb +17 -11
- data/lib/tensor_stream/variable_scope.rb +3 -1
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +3 -4
- data/samples/linear_regression.rb +9 -5
- data/samples/logistic_regression.rb +11 -9
- data/samples/mnist_data.rb +8 -10
- metadata +8 -4
@@ -27,11 +27,11 @@ module TensorStream
|
|
27
27
|
rank = dimension.size
|
28
28
|
options[:value] = value
|
29
29
|
options[:const] = true
|
30
|
-
TensorStream::
|
30
|
+
TensorStream::Constant.new(options[:dtype] || options[:T], rank, dimension, options)
|
31
31
|
when 'VariableV2'
|
32
32
|
# evaluate options
|
33
33
|
shape = options[:shape]
|
34
|
-
|
34
|
+
i_var(options[:dtype] || options[:T], nil, shape, nil, options)
|
35
35
|
when 'Placeholder'
|
36
36
|
shape = options[:shape]
|
37
37
|
TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
|
@@ -57,7 +57,7 @@ module TensorStream
|
|
57
57
|
end
|
58
58
|
|
59
59
|
options[:data_type] = options.delete(:T)
|
60
|
-
|
60
|
+
Graph.get_default_graph.add_op!(op, *inputs, options)
|
61
61
|
end
|
62
62
|
end
|
63
63
|
|
@@ -0,0 +1,38 @@
|
|
1
|
+
module TensorStream
|
2
|
+
class YamlLoader
|
3
|
+
def initialize(graph = nil)
|
4
|
+
@graph = graph || TensorStream.get_default_graph
|
5
|
+
end
|
6
|
+
|
7
|
+
def load_from_string(buffer)
|
8
|
+
serialized_ops = YAML.safe_load(buffer, [Symbol])
|
9
|
+
serialized_ops.each do |op_def|
|
10
|
+
inputs = op_def[:inputs].map { |i| @graph.get_tensor_by_name(i) }
|
11
|
+
options = {}
|
12
|
+
|
13
|
+
if op_def.dig(:attrs, :container)
|
14
|
+
new_var = Variable.new(op_def.dig(:attrs, :data_type))
|
15
|
+
var_shape = op_def.dig(:attrs, :container, :shape)
|
16
|
+
var_options = op_def.dig(:attrs, :container, :options)
|
17
|
+
var_options[:name] = op_def[:name]
|
18
|
+
|
19
|
+
new_var.prepare(var_shape.size, var_shape, TensorStream.get_variable_scope, var_options)
|
20
|
+
options[:container] = new_var
|
21
|
+
|
22
|
+
@graph.add_variable(new_var, var_options)
|
23
|
+
end
|
24
|
+
|
25
|
+
new_op = Operation.new(@graph, inputs: inputs, options: op_def[:attrs].merge(options))
|
26
|
+
new_op.operation = op_def[:op].to_sym
|
27
|
+
new_op.name = op_def[:name]
|
28
|
+
new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
|
29
|
+
new_op.rank = new_op.shape.rank
|
30
|
+
new_op.data_type = new_op.set_data_type(op_def.dig(:attrs, :data_type))
|
31
|
+
new_op.is_const = new_op.infer_const
|
32
|
+
new_op.given_name = new_op.name
|
33
|
+
|
34
|
+
@graph.add_node(new_op)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
@@ -22,6 +22,14 @@ module TensorStream
|
|
22
22
|
value.pack('C*')
|
23
23
|
when :boolean
|
24
24
|
value.map { |v| v ? 1 : 0 }.pack('C*')
|
25
|
+
when :string
|
26
|
+
if value.is_a?(Array)
|
27
|
+
value.to_yaml
|
28
|
+
else
|
29
|
+
value
|
30
|
+
end
|
31
|
+
else
|
32
|
+
raise "unknown type #{data_type}"
|
25
33
|
end
|
26
34
|
|
27
35
|
byte_value
|
@@ -4,11 +4,19 @@ module TensorStream
|
|
4
4
|
include TensorStream::StringHelper
|
5
5
|
include TensorStream::OpHelper
|
6
6
|
|
7
|
-
def get_string(tensor_or_graph, session = nil)
|
7
|
+
def get_string(tensor_or_graph, session = nil, graph_keys = nil)
|
8
8
|
graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
|
9
9
|
@lines = []
|
10
|
-
|
11
|
-
|
10
|
+
|
11
|
+
node_keys = graph_keys.nil? ? graph.node_keys : graph.node_keys.select { |k| graph_keys.include?(k) }
|
12
|
+
|
13
|
+
node_keys.each do |k|
|
14
|
+
node = if block_given?
|
15
|
+
yield graph, k
|
16
|
+
else
|
17
|
+
graph.get_tensor_by_name(k)
|
18
|
+
end
|
19
|
+
|
12
20
|
@lines << "node {"
|
13
21
|
@lines << " name: #{node.name.to_json}"
|
14
22
|
if node.is_a?(TensorStream::Operation)
|
@@ -20,16 +28,13 @@ module TensorStream
|
|
20
28
|
end
|
21
29
|
# type
|
22
30
|
pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
@lines << " op: \"VariableV2\""
|
31
|
-
pb_attr('T', "type: #{sym_to_protobuf_type(node.data_type)}")
|
32
|
-
pb_attr('shape', shape_buf(node, 'shape'))
|
31
|
+
|
32
|
+
case node.operation.to_s
|
33
|
+
when 'const'
|
34
|
+
pb_attr('value', tensor_value(node))
|
35
|
+
when 'variable_v2'
|
36
|
+
pb_attr('shape', shape_buf(node, 'shape'))
|
37
|
+
end
|
33
38
|
process_options(node)
|
34
39
|
end
|
35
40
|
@lines << "}"
|
@@ -44,23 +49,53 @@ module TensorStream
|
|
44
49
|
|
45
50
|
def process_options(node)
|
46
51
|
return if node.options.nil?
|
47
|
-
node.options.each do |k, v|
|
48
|
-
next if %w[name].include?(k.to_s) || k.to_s.start_with?('__')
|
52
|
+
node.options.reject { |_k, v| v.nil? }.each do |k, v|
|
53
|
+
next if %w[name internal_name data_type].include?(k.to_s) || k.to_s.start_with?('__')
|
49
54
|
@lines << " attr {"
|
50
55
|
@lines << " key: \"#{k}\""
|
51
56
|
@lines << " value {"
|
52
|
-
|
53
|
-
@lines << " b: #{v}"
|
54
|
-
elsif v.is_a?(Integer)
|
55
|
-
@lines << " int_val: #{v}"
|
56
|
-
elsif v.is_a?(Float)
|
57
|
-
@lines << " float_val: #{v}"
|
58
|
-
end
|
57
|
+
attr_value(v, 6)
|
59
58
|
@lines << " }"
|
60
59
|
@lines << " }"
|
61
60
|
end
|
62
61
|
end
|
63
62
|
|
63
|
+
def attr_value(val, indent = 0)
|
64
|
+
spaces = " " * indent
|
65
|
+
case val.class.to_s
|
66
|
+
when 'TrueClass', 'FalseClass'
|
67
|
+
@lines << "#{spaces}b: #{val}"
|
68
|
+
when 'Integer'
|
69
|
+
@lines << "#{spaces}i: #{val}"
|
70
|
+
when 'String',
|
71
|
+
@lines << "#{spaces}s: #{val}"
|
72
|
+
when 'Float'
|
73
|
+
@lines << "#{spaces}f: #{val}"
|
74
|
+
when 'Symbol'
|
75
|
+
@lines << "#{spaces}sym: #{val}"
|
76
|
+
when 'Array'
|
77
|
+
@lines << "#{spaces}list {"
|
78
|
+
val.each do |v_item|
|
79
|
+
attr_value(v_item, indent + 2)
|
80
|
+
end
|
81
|
+
@lines << "#{spaces}}"
|
82
|
+
when 'TensorStream::TensorShape'
|
83
|
+
@lines << "#{spaces}shape {"
|
84
|
+
if val.shape
|
85
|
+
val.shape.each do |dim|
|
86
|
+
@lines << "#{spaces} dim {"
|
87
|
+
@lines << "#{spaces} size: #{dim}"
|
88
|
+
@lines << "#{spaces} }"
|
89
|
+
end
|
90
|
+
end
|
91
|
+
@lines << "#{spaces}}"
|
92
|
+
when 'TensorStream::Variable'
|
93
|
+
else
|
94
|
+
binding.pry
|
95
|
+
raise "unknown type #{val.class}"
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
64
99
|
def pack_arr_float(float_arr)
|
65
100
|
float_arr.flatten.pack('f*').bytes.map { |b| b.chr =~ /[^[:print:]]/ ? "\\#{sprintf("%o", b).rjust(3, '0')}" : b.chr }.join
|
66
101
|
end
|
@@ -92,17 +127,17 @@ module TensorStream
|
|
92
127
|
|
93
128
|
if tensor.rank > 0
|
94
129
|
if TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type)
|
95
|
-
packed = pack_arr_float(tensor.
|
130
|
+
packed = pack_arr_float(tensor.const_value)
|
96
131
|
arr << " tensor_content: \"#{packed}\""
|
97
132
|
elsif TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
|
98
|
-
packed = pack_arr_int(tensor.
|
133
|
+
packed = pack_arr_int(tensor.const_value)
|
99
134
|
arr << " tensor_content: \"#{packed}\""
|
100
135
|
elsif tensor.data_type == :string
|
101
|
-
tensor.
|
136
|
+
tensor.const_value.each do |v|
|
102
137
|
arr << " string_val: #{v.to_json}"
|
103
138
|
end
|
104
139
|
else
|
105
|
-
arr << " tensor_content: #{tensor.
|
140
|
+
arr << " tensor_content: #{tensor.const_value.flatten}"
|
106
141
|
end
|
107
142
|
else
|
108
143
|
val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type)
|
@@ -114,7 +149,7 @@ module TensorStream
|
|
114
149
|
else
|
115
150
|
"val"
|
116
151
|
end
|
117
|
-
arr << " #{val_type}: #{tensor.
|
152
|
+
arr << " #{val_type}: #{tensor.const_value.to_json}"
|
118
153
|
end
|
119
154
|
arr << "}"
|
120
155
|
arr
|
@@ -1,7 +1,7 @@
|
|
1
1
|
module TensorStream
|
2
2
|
class Serializer
|
3
|
-
def serialize(filename, tensor, session = nil)
|
4
|
-
File.write(filename, get_string(tensor, session))
|
3
|
+
def serialize(filename, tensor, session = nil, graph_keys = nil)
|
4
|
+
File.write(filename, get_string(tensor, session, graph_keys = nil))
|
5
5
|
end
|
6
6
|
|
7
7
|
def get_string(tensor, session = nil); end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module TensorStream
|
2
|
+
# Parses pbtext files and loads it as a graph
|
3
|
+
class Yaml < TensorStream::Serializer
|
4
|
+
include TensorStream::StringHelper
|
5
|
+
include TensorStream::OpHelper
|
6
|
+
|
7
|
+
def get_string(tensor_or_graph, session = nil, graph_keys = nil)
|
8
|
+
graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
|
9
|
+
serialized_arr = []
|
10
|
+
|
11
|
+
node_keys = graph_keys.nil? ? graph.node_keys : graph.node_keys.select { |k| graph_keys.include?(k) }
|
12
|
+
|
13
|
+
node_keys.each do |k|
|
14
|
+
node = if block_given?
|
15
|
+
yield graph, k
|
16
|
+
else
|
17
|
+
graph.get_tensor_by_name(k)
|
18
|
+
end
|
19
|
+
next unless node.is_a?(Operation)
|
20
|
+
|
21
|
+
serialized_arr << node.to_h
|
22
|
+
end
|
23
|
+
|
24
|
+
serialized_arr.to_yaml
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -9,6 +9,12 @@ module TensorStream
|
|
9
9
|
|
10
10
|
def self.infer_shape(tensor)
|
11
11
|
case tensor.operation
|
12
|
+
when :case, :case_grad
|
13
|
+
tensor.inputs[2].shape.shape if tensor.inputs[2]
|
14
|
+
when :const
|
15
|
+
shape_eval(tensor.options[:value])
|
16
|
+
when :variable_v2
|
17
|
+
tensor.shape ? tensor.shape.shape : nil
|
12
18
|
when :assign
|
13
19
|
possible_shape = if tensor.inputs[0] && tensor.inputs[0].shape.shape
|
14
20
|
tensor.inputs[0].shape.shape
|
@@ -29,9 +35,9 @@ module TensorStream
|
|
29
35
|
s
|
30
36
|
when :arg_min, :argmax, :argmin
|
31
37
|
return nil unless tensor.inputs[0].shape.known?
|
32
|
-
return nil if tensor.inputs[1] && tensor.inputs[1].
|
38
|
+
return nil if tensor.inputs[1] && tensor.inputs[1].const_value.nil?
|
33
39
|
|
34
|
-
axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].
|
40
|
+
axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].const_value
|
35
41
|
new_shape = tensor.inputs[0].shape.shape
|
36
42
|
new_shape.each_with_index.collect do |shape, index|
|
37
43
|
next nil if index == axis
|
@@ -61,7 +67,7 @@ module TensorStream
|
|
61
67
|
item
|
62
68
|
end.compact
|
63
69
|
when :reshape
|
64
|
-
new_shape = tensor.inputs[1] && tensor.inputs[1].
|
70
|
+
new_shape = tensor.inputs[1] && tensor.inputs[1].const_value ? tensor.inputs[1].const_value : nil
|
65
71
|
return nil if new_shape.nil?
|
66
72
|
return nil if tensor.inputs[0].shape.nil?
|
67
73
|
|
@@ -83,11 +89,11 @@ module TensorStream
|
|
83
89
|
tensor.inputs[0].shape.shape ? [tensor.inputs[0].shape.shape.size] : nil
|
84
90
|
when :pad
|
85
91
|
return nil unless tensor.inputs[0].shape.known?
|
86
|
-
return nil unless tensor.inputs[1].
|
92
|
+
return nil unless tensor.inputs[1].const_value
|
87
93
|
|
88
94
|
size = tensor.inputs[0].shape.shape.reduce(:*) || 1
|
89
95
|
dummy_tensor_for_shape = TensorShape.reshape(Array.new(size), tensor.inputs[0].shape)
|
90
|
-
shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].
|
96
|
+
shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].const_value))
|
91
97
|
when :mat_mul
|
92
98
|
return nil if tensor.inputs[0].shape.shape.nil? || tensor.inputs[1].shape.shape.nil?
|
93
99
|
return [] if tensor.inputs[0].shape.shape.empty? || tensor.inputs[1].shape.shape.empty?
|
@@ -128,9 +134,9 @@ module TensorStream
|
|
128
134
|
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
129
135
|
rotated_shape.rotate! + new_shape
|
130
136
|
when :concat
|
131
|
-
return nil if tensor.inputs[0].
|
137
|
+
return nil if tensor.inputs[0].const_value.nil?
|
132
138
|
|
133
|
-
axis = tensor.inputs[0].
|
139
|
+
axis = tensor.inputs[0].const_value # get axis
|
134
140
|
|
135
141
|
axis_size = 0
|
136
142
|
|
@@ -196,9 +202,9 @@ module TensorStream
|
|
196
202
|
|
197
203
|
new_shape
|
198
204
|
when :conv2d_backprop_input
|
199
|
-
return nil unless tensor.inputs[0].
|
205
|
+
return nil unless tensor.inputs[0].const_value
|
200
206
|
|
201
|
-
tensor.inputs[0].
|
207
|
+
tensor.inputs[0].const_value
|
202
208
|
else
|
203
209
|
return nil if tensor.inputs[0].nil?
|
204
210
|
return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
|
@@ -3,9 +3,11 @@ module TensorStream
|
|
3
3
|
# module that contains helper functions useful for ops
|
4
4
|
module OpHelper
|
5
5
|
def _op(code, *args)
|
6
|
-
|
7
|
-
|
8
|
-
|
6
|
+
default_graph = Graph.get_default_graph
|
7
|
+
|
8
|
+
op = default_graph.add_op!(code.to_sym, *args)
|
9
|
+
if !default_graph.get_dependency_scope.nil?
|
10
|
+
i_op(:identity, op, default_graph.get_dependency_scope, name: [op.name, 'tuple', 'control_dependency'].join('/'))
|
9
11
|
else
|
10
12
|
op
|
11
13
|
end
|
@@ -20,7 +22,15 @@ module TensorStream
|
|
20
22
|
end
|
21
23
|
|
22
24
|
args << options.merge(internal: true)
|
23
|
-
|
25
|
+
Graph.get_default_graph.add_op!(code.to_sym, *args)
|
26
|
+
end
|
27
|
+
|
28
|
+
def i_var(data_type, rank, shape, variable_scope, options = {})
|
29
|
+
new_var = Variable.new(data_type)
|
30
|
+
new_var.prepare(rank, shape, variable_scope, options)
|
31
|
+
new_var.op = new_var.graph.add_variable!(new_var, options.merge(shape: @shape, data_type: data_type))
|
32
|
+
|
33
|
+
new_var
|
24
34
|
end
|
25
35
|
|
26
36
|
def cons(value, options = {})
|
@@ -55,8 +65,8 @@ module TensorStream
|
|
55
65
|
end
|
56
66
|
|
57
67
|
def format_source(trace)
|
58
|
-
grad_source = trace.
|
59
|
-
|
68
|
+
grad_source = trace.detect { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }
|
69
|
+
source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
|
60
70
|
[grad_source, trace].compact.join("\n")
|
61
71
|
end
|
62
72
|
|
@@ -82,6 +92,7 @@ module TensorStream
|
|
82
92
|
axes = TensorStream.range(0, input_rank) if axes.nil?
|
83
93
|
axes = (axes + input_rank) % input_rank
|
84
94
|
axes_shape = i_op(:shape, axes)
|
95
|
+
|
85
96
|
TensorStream.dynamic_stitch([TensorStream.range(0, input_rank), axes],
|
86
97
|
[input_shape, i_op(:fill, axes_shape, 1)])
|
87
98
|
end
|
@@ -1,6 +1,6 @@
|
|
1
1
|
module TensorStream
|
2
2
|
# helper string methods usually found in ActiveSupport but
|
3
|
-
# need to replicate here
|
3
|
+
# need to replicate here since we don't want to use ActiveSupport
|
4
4
|
module StringHelper
|
5
5
|
def camelize(string, uppercase_first_letter = true)
|
6
6
|
string = if uppercase_first_letter
|
@@ -23,5 +23,36 @@ module TensorStream
|
|
23
23
|
[k.to_sym, v]
|
24
24
|
end.to_h
|
25
25
|
end
|
26
|
+
|
27
|
+
def constantize(camel_cased_word)
|
28
|
+
names = camel_cased_word.split('::')
|
29
|
+
|
30
|
+
# Trigger a built-in NameError exception including the ill-formed constant in the message.
|
31
|
+
Object.const_get(camel_cased_word) if names.empty?
|
32
|
+
|
33
|
+
# Remove the first blank element in case of '::ClassName' notation.
|
34
|
+
names.shift if names.size > 1 && names.first.empty?
|
35
|
+
|
36
|
+
names.inject(Object) do |constant, name|
|
37
|
+
if constant == Object
|
38
|
+
constant.const_get(name)
|
39
|
+
else
|
40
|
+
candidate = constant.const_get(name)
|
41
|
+
next candidate if constant.const_defined?(name, false)
|
42
|
+
next candidate unless Object.const_defined?(name)
|
43
|
+
|
44
|
+
# Go down the ancestors to check if it is owned directly. The check
|
45
|
+
# stops when we reach Object or the end of ancestors tree.
|
46
|
+
constant = constant.ancestors.inject do |const, ancestor|
|
47
|
+
break const if ancestor == Object
|
48
|
+
break ancestor if ancestor.const_defined?(name, false)
|
49
|
+
const
|
50
|
+
end
|
51
|
+
|
52
|
+
# owner is in Object, so raise
|
53
|
+
constant.const_get(name, false)
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
26
57
|
end
|
27
58
|
end
|
@@ -0,0 +1,135 @@
|
|
1
|
+
module TensorStream
|
2
|
+
module TensorMixins
|
3
|
+
def +(other)
|
4
|
+
_a, other = TensorStream.check_data_types(self, other)
|
5
|
+
_op(:add, self, other)
|
6
|
+
end
|
7
|
+
|
8
|
+
def [](index)
|
9
|
+
_op(:index, self, index)
|
10
|
+
end
|
11
|
+
|
12
|
+
def *(other)
|
13
|
+
_a, other = TensorStream.check_data_types(self, other)
|
14
|
+
_op(:mul, self, TensorStream.convert_to_tensor(other, dtype: data_type))
|
15
|
+
end
|
16
|
+
|
17
|
+
def **(other)
|
18
|
+
_a, other = TensorStream.check_data_types(self, other)
|
19
|
+
_op(:pow, self, TensorStream.convert_to_tensor(other, dtype: data_type))
|
20
|
+
end
|
21
|
+
|
22
|
+
def /(other)
|
23
|
+
_a, other = TensorStream.check_data_types(self, other)
|
24
|
+
_op(:div, self, TensorStream.convert_to_tensor(other, dtype: data_type))
|
25
|
+
end
|
26
|
+
|
27
|
+
def -(other)
|
28
|
+
_a, other = TensorStream.check_data_types(self, other)
|
29
|
+
_op(:sub, self, TensorStream.convert_to_tensor(other, dtype: data_type))
|
30
|
+
end
|
31
|
+
|
32
|
+
def -@
|
33
|
+
_op(:negate, self)
|
34
|
+
end
|
35
|
+
|
36
|
+
def %(other)
|
37
|
+
TensorStream.mod(self, other)
|
38
|
+
end
|
39
|
+
|
40
|
+
def floor(name: nil)
|
41
|
+
TensorStream.floor(self, name: name)
|
42
|
+
end
|
43
|
+
|
44
|
+
def ceil(name: nil)
|
45
|
+
TensorStream.ceil(self, name: name)
|
46
|
+
end
|
47
|
+
|
48
|
+
def round(name: nil)
|
49
|
+
TensorStream.round(self, name: name)
|
50
|
+
end
|
51
|
+
|
52
|
+
def log(name: nil)
|
53
|
+
TensorStream.log(self, name: name)
|
54
|
+
end
|
55
|
+
|
56
|
+
def reshape(shape, name: nil)
|
57
|
+
TensorStream.reshape(self, shape, name: name)
|
58
|
+
end
|
59
|
+
|
60
|
+
def zero?
|
61
|
+
_op(:equal, self, TensorStream.constant(0, dtype: data_type, name: 'equal/is_zero?'))
|
62
|
+
end
|
63
|
+
|
64
|
+
def ==(other)
|
65
|
+
_a, other = TensorStream.check_data_types(self, other)
|
66
|
+
_op(:equal, self, other)
|
67
|
+
end
|
68
|
+
|
69
|
+
def <(other)
|
70
|
+
_a, other = TensorStream.check_data_types(self, other)
|
71
|
+
_op(:less, self, other)
|
72
|
+
end
|
73
|
+
|
74
|
+
def !=(other)
|
75
|
+
_a, other = TensorStream.check_data_types(self, other)
|
76
|
+
_op(:not_equal, self, other)
|
77
|
+
end
|
78
|
+
|
79
|
+
def >(other)
|
80
|
+
_a, other = TensorStream.check_data_types(self, other)
|
81
|
+
_op(:greater, self, other)
|
82
|
+
end
|
83
|
+
|
84
|
+
def >=(other)
|
85
|
+
_a, other = TensorStream.check_data_types(self, other)
|
86
|
+
_op(:greater_equal, self, other)
|
87
|
+
end
|
88
|
+
|
89
|
+
def <=(other)
|
90
|
+
_a, other = TensorStream.check_data_types(self, other)
|
91
|
+
_op(:less_equal, self, other)
|
92
|
+
end
|
93
|
+
|
94
|
+
def and(other)
|
95
|
+
_a, other = TensorStream.check_data_types(self, other)
|
96
|
+
_op(:logical_and, self, other)
|
97
|
+
end
|
98
|
+
|
99
|
+
def matmul(other)
|
100
|
+
_a, other = TensorStream.check_data_types(self, other)
|
101
|
+
_op(:mat_mul, self, other)
|
102
|
+
end
|
103
|
+
|
104
|
+
def dot(other)
|
105
|
+
_a, other = TensorStream.check_data_types(self, other)
|
106
|
+
_op(:mat_mul, self, other)
|
107
|
+
end
|
108
|
+
|
109
|
+
def cast(data_type = :float32, name: nil)
|
110
|
+
TensorStream.cast(self, data_type, name: name)
|
111
|
+
end
|
112
|
+
|
113
|
+
def var(name: nil)
|
114
|
+
TensorStream.variable(self, name: name)
|
115
|
+
end
|
116
|
+
|
117
|
+
##
|
118
|
+
# Apply a reduction to tensor
|
119
|
+
def reduce(op_type = :+, axis: nil, keepdims: false, name: nil)
|
120
|
+
reduce_op = case op_type.to_sym
|
121
|
+
when :+
|
122
|
+
:sum
|
123
|
+
when :*
|
124
|
+
:prod
|
125
|
+
when :mean
|
126
|
+
:mean
|
127
|
+
else
|
128
|
+
raise "unsupported reduce op type #{op_type} valid values are :+, :*, :prod, :mean"
|
129
|
+
end
|
130
|
+
raise "blocks are not supported for tensors" if block_given?
|
131
|
+
|
132
|
+
TensorStream.reduce(reduce_op, self, axis, keepdims: keepdims, name: name)
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|