tensor_stream 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +12 -0
- data/.rake_tasks~ +0 -0
- data/.rspec +2 -0
- data/.travis.yml +5 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +21 -0
- data/README.md +123 -0
- data/Rakefile +6 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/lib/tensor_stream.rb +138 -0
- data/lib/tensor_stream/control_flow.rb +23 -0
- data/lib/tensor_stream/evaluator/evaluator.rb +7 -0
- data/lib/tensor_stream/evaluator/operation_helpers/random_gaussian.rb +32 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +749 -0
- data/lib/tensor_stream/graph.rb +98 -0
- data/lib/tensor_stream/graph_keys.rb +5 -0
- data/lib/tensor_stream/helpers/op_helper.rb +58 -0
- data/lib/tensor_stream/math_gradients.rb +161 -0
- data/lib/tensor_stream/monkey_patches/integer.rb +0 -0
- data/lib/tensor_stream/nn/nn_ops.rb +17 -0
- data/lib/tensor_stream/operation.rb +195 -0
- data/lib/tensor_stream/ops.rb +225 -0
- data/lib/tensor_stream/placeholder.rb +21 -0
- data/lib/tensor_stream/session.rb +66 -0
- data/lib/tensor_stream/tensor.rb +317 -0
- data/lib/tensor_stream/tensor_shape.rb +25 -0
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +23 -0
- data/lib/tensor_stream/train/saver.rb +61 -0
- data/lib/tensor_stream/trainer.rb +7 -0
- data/lib/tensor_stream/types.rb +17 -0
- data/lib/tensor_stream/variable.rb +52 -0
- data/lib/tensor_stream/version.rb +7 -0
- data/samples/iris.data +150 -0
- data/samples/iris.rb +117 -0
- data/samples/linear_regression.rb +55 -0
- data/samples/raw_neural_net_sample.rb +54 -0
- data/tensor_stream.gemspec +40 -0
- metadata +185 -0
@@ -0,0 +1,225 @@
|
|
1
|
+
module TensorStream
|
2
|
+
module Ops
|
3
|
+
FLOATING_POINT_TYPES = %w[float32 float64].map(&:to_sym)
|
4
|
+
NUMERIC_TYPES = %w[int32 int64 float32 float64].map(&:to_sym)
|
5
|
+
|
6
|
+
def argmax(input, axis = nil, name: nil, dimension: nil, output_type: :int32)
|
7
|
+
op(:argmax, input, nil, axis: axis, name: name, dimension: dimension, data_type: output_type)
|
8
|
+
end
|
9
|
+
|
10
|
+
def gradients(ys, xs, grad_ys: nil,
|
11
|
+
name: 'gradients',
|
12
|
+
colocate_gradients_with_ops: false,
|
13
|
+
gate_gradients: false,
|
14
|
+
aggregation_method: nil,
|
15
|
+
stop_gradients: nil
|
16
|
+
)
|
17
|
+
options = { stop_gradients: stop_gradients}
|
18
|
+
op(:gradients, ys, xs, options)
|
19
|
+
end
|
20
|
+
|
21
|
+
def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
|
22
|
+
options = {shape: shape, dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name}
|
23
|
+
op(:random_uniform, nil, nil, options)
|
24
|
+
end
|
25
|
+
|
26
|
+
def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
|
27
|
+
options = {shape: shape, dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name}
|
28
|
+
op(:random_normal, nil, nil, options)
|
29
|
+
end
|
30
|
+
|
31
|
+
def stop_gradient(tensor, options = {})
|
32
|
+
op(:stop_gradient, tensor, nil, options)
|
33
|
+
end
|
34
|
+
|
35
|
+
def eye(num_rows, num_columns: nil, dtype: :float32, name: nil)
|
36
|
+
op(:eye, num_rows, num_columns || num_rows, data_type: dtype, name: name, preserve_params_type: true)
|
37
|
+
end
|
38
|
+
|
39
|
+
def shape(input, name: nil, out_type: :int32)
|
40
|
+
op(:shape, input, nil, name: name)
|
41
|
+
end
|
42
|
+
|
43
|
+
def rank(input, name: nil)
|
44
|
+
op(:rank, input, name: name)
|
45
|
+
end
|
46
|
+
|
47
|
+
def zeros_initializer(options = {})
|
48
|
+
op(:zeros, nil, nil, options)
|
49
|
+
end
|
50
|
+
|
51
|
+
def slice(input, start, size, name: nil)
|
52
|
+
op(:slice, input, start, size: size, name: name)
|
53
|
+
end
|
54
|
+
|
55
|
+
def zeros(shape, dtype: :float32, name: nil)
|
56
|
+
op(:zeros, shape, nil, data_type: dtype, name: name)
|
57
|
+
end
|
58
|
+
|
59
|
+
def ones(shape, dtype: :float32, name: nil)
|
60
|
+
op(:ones, shape, nil, data_type: dtype, name: name)
|
61
|
+
end
|
62
|
+
|
63
|
+
def less(a, b, name: nil)
|
64
|
+
op(:less, a, b, name: name)
|
65
|
+
end
|
66
|
+
|
67
|
+
def greater(a, b, name: nil)
|
68
|
+
op(:greater, a, b, name: name)
|
69
|
+
end
|
70
|
+
|
71
|
+
def reduce_mean(input_tensor, axis = nil, keepdims: false, name: nil)
|
72
|
+
op(:reduce_mean, input_tensor, nil, axis: axis, keepdims: keepdims, name: name)
|
73
|
+
end
|
74
|
+
|
75
|
+
def reduce_sum(input_tensor, axis = nil, keepdims: false, name: nil)
|
76
|
+
op(:reduce_sum, input_tensor, nil, axis: axis, keepdims: keepdims, name: name)
|
77
|
+
end
|
78
|
+
|
79
|
+
def reduce_prod(input, axis = nil, keepdims: false, name: nil)
|
80
|
+
op(:reduce_prod, input, nil, axis: axis, keepdims: keepdims, name: name)
|
81
|
+
end
|
82
|
+
|
83
|
+
def concat(values, axis, name: 'concat')
|
84
|
+
op(:concat, values, nil, axis: axis, name: name)
|
85
|
+
end
|
86
|
+
|
87
|
+
def reshape(tensor, shape, name: nil)
|
88
|
+
op(:reshape, tensor, shape, name: name)
|
89
|
+
end
|
90
|
+
|
91
|
+
def square(tensor, name: nil)
|
92
|
+
op(:square, tensor, nil, name: name)
|
93
|
+
end
|
94
|
+
|
95
|
+
def cond(pred, true_fn, false_fn, name: nil)
|
96
|
+
op(:cond, true_fn, false_fn, pred: pred, name: name)
|
97
|
+
end
|
98
|
+
|
99
|
+
def where(condition, x = nil, y = nil, name: nil)
|
100
|
+
op(:where, x, y, pred: condition)
|
101
|
+
end
|
102
|
+
|
103
|
+
def add(a, b, name: nil)
|
104
|
+
op(:add, a, b, name: name)
|
105
|
+
end
|
106
|
+
|
107
|
+
def sub(a, b, name: nil)
|
108
|
+
op(:sub, a, b, name: name)
|
109
|
+
end
|
110
|
+
|
111
|
+
def max(a, b, name: nil)
|
112
|
+
check_allowed_types(a, NUMERIC_TYPES)
|
113
|
+
check_allowed_types(b, NUMERIC_TYPES)
|
114
|
+
|
115
|
+
op(:max, a, b, name: name)
|
116
|
+
end
|
117
|
+
|
118
|
+
def cast(input, dtype, name: nil)
|
119
|
+
op(:cast, input, nil, data_type: dtype, name: name)
|
120
|
+
end
|
121
|
+
|
122
|
+
def print(input, data, message: nil, name: nil)
|
123
|
+
op(:print, input, data, message: message, name: name)
|
124
|
+
end
|
125
|
+
|
126
|
+
def negate(a, options = {})
|
127
|
+
op(:negate, a, nil, options)
|
128
|
+
end
|
129
|
+
|
130
|
+
def equal(a, b, name: nil)
|
131
|
+
op(:equal, a, b, name: name)
|
132
|
+
end
|
133
|
+
|
134
|
+
def not_equal(a, b, name: nil)
|
135
|
+
op(:not_equal, a, b, name: name)
|
136
|
+
end
|
137
|
+
|
138
|
+
def zeros_like(tensor, dtype: nil, name: nil)
|
139
|
+
op(:zeros_like, tensor, nil, data_type: dtype, name: name)
|
140
|
+
end
|
141
|
+
|
142
|
+
def ones_like(tensor, dtype: nil, name: nil)
|
143
|
+
op(:ones_like, tensor, nil, data_type: dtype, name: name)
|
144
|
+
end
|
145
|
+
|
146
|
+
def identity(input, name: nil)
|
147
|
+
op(:identity, input, nil, name: name)
|
148
|
+
end
|
149
|
+
|
150
|
+
def multiply(a, b, name: nil)
|
151
|
+
op(:mul, a, b, name: name)
|
152
|
+
end
|
153
|
+
|
154
|
+
def pow(a, e, name: nil)
|
155
|
+
op(:pow, a, e, name: name)
|
156
|
+
end
|
157
|
+
|
158
|
+
def abs(x, name: nil)
|
159
|
+
op(:abs, x, nil, name: name)
|
160
|
+
end
|
161
|
+
|
162
|
+
def sign(x, name: nil)
|
163
|
+
op(:sign, x, nil, name: name)
|
164
|
+
end
|
165
|
+
|
166
|
+
def sin(a, options = {})
|
167
|
+
options[:data_type] ||= :float32
|
168
|
+
check_allowed_types(a, FLOATING_POINT_TYPES)
|
169
|
+
op(:sin, a, nil, options)
|
170
|
+
end
|
171
|
+
|
172
|
+
def cos(a, options = {})
|
173
|
+
options[:data_type] ||= :float32
|
174
|
+
check_allowed_types(a, FLOATING_POINT_TYPES)
|
175
|
+
op(:cos, a, nil, options)
|
176
|
+
end
|
177
|
+
|
178
|
+
def tan(a, options = {})
|
179
|
+
options[:data_type] ||= :float32
|
180
|
+
check_allowed_types(a, FLOATING_POINT_TYPES)
|
181
|
+
op(:tan, a, nil, options)
|
182
|
+
end
|
183
|
+
|
184
|
+
def tanh(a, options = {})
|
185
|
+
options[:data_type] ||= :float32
|
186
|
+
check_allowed_types(a, FLOATING_POINT_TYPES)
|
187
|
+
op(:tanh, a, nil, options)
|
188
|
+
end
|
189
|
+
|
190
|
+
def sqrt(a, name: nil)
|
191
|
+
options = {
|
192
|
+
data_type: a.data_type,
|
193
|
+
name: name
|
194
|
+
}
|
195
|
+
check_allowed_types(a, FLOATING_POINT_TYPES)
|
196
|
+
op(:sqrt, a, nil, options)
|
197
|
+
end
|
198
|
+
|
199
|
+
def log(input, options= {})
|
200
|
+
options[:data_type] ||= :float32
|
201
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
202
|
+
op(:log, input, nil, options)
|
203
|
+
end
|
204
|
+
|
205
|
+
def exp(a, options = {})
|
206
|
+
options[:data_type] ||= :float32
|
207
|
+
check_allowed_types(a, FLOATING_POINT_TYPES)
|
208
|
+
op(:exp, a, nil, options)
|
209
|
+
end
|
210
|
+
|
211
|
+
def matmul(input_a, input_b, transpose_a: false,
|
212
|
+
transpose_b: false,
|
213
|
+
name: nil)
|
214
|
+
op(:matmul, input_a, input_b, transpose_a: transpose_a, transpose_b: transpose_b, name: name)
|
215
|
+
end
|
216
|
+
|
217
|
+
def transpose(tensor, perm: nil, name: 'transpose')
|
218
|
+
op(:transpose, tensor, nil, perm: perm, name: name)
|
219
|
+
end
|
220
|
+
|
221
|
+
def pad(tensor, paddings, mode: 'CONSTANT', name: nil)
|
222
|
+
op(:pad, tensor, nil, paddings: paddings, mode: mode, name: name)
|
223
|
+
end
|
224
|
+
end
|
225
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
module TensorStream
|
2
|
+
class Placeholder < Tensor
|
3
|
+
def initialize(data_type, rank, shape, options = {})
|
4
|
+
@data_type = data_type
|
5
|
+
@rank = rank
|
6
|
+
@shape = TensorShape.new(shape, rank)
|
7
|
+
@value = nil
|
8
|
+
@is_const = false
|
9
|
+
@source = set_source(caller_locations)
|
10
|
+
@graph = options[:graph] || TensorStream.get_default_graph
|
11
|
+
@name = options[:name] || build_name
|
12
|
+
@graph.add_node(self)
|
13
|
+
end
|
14
|
+
|
15
|
+
private
|
16
|
+
|
17
|
+
def build_name
|
18
|
+
"Placeholder#{Tensor.placeholder_name}:#{@rank}"
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,66 @@
|
|
1
|
+
module TensorStream
|
2
|
+
class Session
|
3
|
+
def initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor)
|
4
|
+
@evaluator_class = Object.const_get("TensorStream::Evaluator::#{camelize(evaluator.to_s)}")
|
5
|
+
@thread_pool = thread_pool_class.new
|
6
|
+
end
|
7
|
+
|
8
|
+
def self.default_session
|
9
|
+
@session ||= Session.new
|
10
|
+
end
|
11
|
+
|
12
|
+
def last_session_context
|
13
|
+
@last_session_context
|
14
|
+
end
|
15
|
+
|
16
|
+
def run(*args)
|
17
|
+
options = if args.last.is_a?(Hash)
|
18
|
+
args.pop
|
19
|
+
else
|
20
|
+
{}
|
21
|
+
end
|
22
|
+
context = {}
|
23
|
+
|
24
|
+
# scan for placeholders and assign value
|
25
|
+
options[:feed_dict].keys.each do |k|
|
26
|
+
if k.is_a?(Placeholder)
|
27
|
+
context[k.name.to_sym] = options[:feed_dict][k]
|
28
|
+
end
|
29
|
+
end if options[:feed_dict]
|
30
|
+
|
31
|
+
evaluator = @evaluator_class.new(self, context.merge!(retain: options[:retain]), thread_pool: @thread_pool)
|
32
|
+
|
33
|
+
execution_context = {}
|
34
|
+
result = args.collect { |e| evaluator.run(e, execution_context) }
|
35
|
+
@last_session_context = context
|
36
|
+
result.size == 1 ? result.first : result
|
37
|
+
end
|
38
|
+
|
39
|
+
def dump_internal_ops(tensor)
|
40
|
+
dump_ops(tensor, ->(k, n) { n.is_a?(Tensor) && n.internal? } )
|
41
|
+
end
|
42
|
+
|
43
|
+
def dump_user_ops(tensor)
|
44
|
+
dump_ops(tensor, ->(k, n) { n.is_a?(Tensor) && !n.internal? } )
|
45
|
+
end
|
46
|
+
|
47
|
+
def dump_ops(tensor, selector)
|
48
|
+
graph = tensor.graph
|
49
|
+
graph.nodes.select { |k,v| selector.call(k, v) }.collect do |k, node|
|
50
|
+
next unless @last_session_context[node.name]
|
51
|
+
"#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}"
|
52
|
+
end.compact
|
53
|
+
end
|
54
|
+
|
55
|
+
private
|
56
|
+
|
57
|
+
def camelize(string, uppercase_first_letter = true)
|
58
|
+
if uppercase_first_letter
|
59
|
+
string = string.sub(/^[a-z\d]*/) { $&.capitalize }
|
60
|
+
else
|
61
|
+
string = string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
|
62
|
+
end
|
63
|
+
string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::')
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
@@ -0,0 +1,317 @@
|
|
1
|
+
require 'ostruct'
|
2
|
+
|
3
|
+
module TensorStream
|
4
|
+
class Tensor
|
5
|
+
include OpHelper
|
6
|
+
|
7
|
+
attr_accessor :name, :data_type, :shape, :rank, :native_buffer, :is_const, :value, :breakpoint, :internal, :source, :given_name, :graph
|
8
|
+
|
9
|
+
def self.const_name
|
10
|
+
@const_counter ||= 0
|
11
|
+
|
12
|
+
name = if @const_counter == 0
|
13
|
+
""
|
14
|
+
else
|
15
|
+
"_#{@const_counter}"
|
16
|
+
end
|
17
|
+
|
18
|
+
@const_counter += 1
|
19
|
+
|
20
|
+
name
|
21
|
+
end
|
22
|
+
|
23
|
+
def self.var_name
|
24
|
+
@var_counter ||= 0
|
25
|
+
@var_counter += 1
|
26
|
+
|
27
|
+
return "" if @var_counter == 1
|
28
|
+
return "_#{@var_counter}"
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.placeholder_name
|
32
|
+
@placeholder_counter ||= 0
|
33
|
+
@placeholder_counter += 1
|
34
|
+
|
35
|
+
return "" if @placeholder_counter == 1
|
36
|
+
return "_#{@placeholder_counter}"
|
37
|
+
end
|
38
|
+
|
39
|
+
def initialize(data_type, rank, shape, options = {})
|
40
|
+
@data_type = data_type
|
41
|
+
@rank = rank
|
42
|
+
@breakpoint = false
|
43
|
+
@shape = TensorShape.new(shape, rank)
|
44
|
+
@value = nil
|
45
|
+
@source = set_source(caller_locations)
|
46
|
+
@is_const = options[:const] || false
|
47
|
+
@internal = options[:internal]
|
48
|
+
@graph = options[:graph] || TensorStream.get_default_graph
|
49
|
+
@name = options[:name] || build_name
|
50
|
+
@given_name = @name
|
51
|
+
|
52
|
+
if options[:value]
|
53
|
+
if options[:value].kind_of?(Array)
|
54
|
+
# check if single dimenstion array is passed
|
55
|
+
if shape.size >= 2 && options[:value].size > 0 && !options[:value][0].kind_of?(Array)
|
56
|
+
options[:value] = reshape(options[:value], shape.reverse.dup)
|
57
|
+
end
|
58
|
+
|
59
|
+
@value = options[:value].collect do |v|
|
60
|
+
v.kind_of?(Tensor) ? Tensor.cast_dtype(v, data_type) : v
|
61
|
+
end
|
62
|
+
elsif shape.size > 0
|
63
|
+
@value = reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
|
64
|
+
else
|
65
|
+
@value = Tensor.cast_dtype(options[:value], @data_type)
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
@graph.add_node(self)
|
70
|
+
end
|
71
|
+
|
72
|
+
def internal?
|
73
|
+
!!@internal
|
74
|
+
end
|
75
|
+
|
76
|
+
def dtype
|
77
|
+
@data_type
|
78
|
+
end
|
79
|
+
|
80
|
+
def self.reset_counters
|
81
|
+
@const_counter = 0
|
82
|
+
@var_counter = 0
|
83
|
+
@placeholder_counter = 0
|
84
|
+
end
|
85
|
+
|
86
|
+
def build_buffer
|
87
|
+
@native_buffer = if @data_type == :float32 && @rank == 2
|
88
|
+
NArray.sfloat(@shape.cols * @shape.rows)
|
89
|
+
elsif @data_type == :float32 && @rank == 0
|
90
|
+
NArray.sfloat(1)
|
91
|
+
else
|
92
|
+
raise "Invalid data type #{@data_type}"
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
def +(operand)
|
97
|
+
TensorStream::Operation.new(:add, self, auto_wrap(operand))
|
98
|
+
end
|
99
|
+
|
100
|
+
def [](index)
|
101
|
+
TensorStream::Operation.new(:index, self, index)
|
102
|
+
end
|
103
|
+
|
104
|
+
def *(operand)
|
105
|
+
TensorStream::Operation.new(:mul, self, auto_wrap(operand))
|
106
|
+
end
|
107
|
+
|
108
|
+
def **(operand)
|
109
|
+
TensorStream::Operation.new(:pow, self, auto_wrap(operand))
|
110
|
+
end
|
111
|
+
|
112
|
+
def /(operand)
|
113
|
+
TensorStream::Operation.new(:div, self, auto_wrap(operand))
|
114
|
+
end
|
115
|
+
|
116
|
+
def -(operand)
|
117
|
+
TensorStream::Operation.new(:sub, self, auto_wrap(operand))
|
118
|
+
end
|
119
|
+
|
120
|
+
def -@
|
121
|
+
TensorStream::Operation.new(:negate, self, nil)
|
122
|
+
end
|
123
|
+
|
124
|
+
def ==(operand)
|
125
|
+
op(:equal, self, operand)
|
126
|
+
end
|
127
|
+
|
128
|
+
def <(operand)
|
129
|
+
op(:less, self, operand)
|
130
|
+
end
|
131
|
+
|
132
|
+
def !=(operand)
|
133
|
+
op(:not_equal, self, operand)
|
134
|
+
end
|
135
|
+
|
136
|
+
def <=(operand)
|
137
|
+
op(:less_equal, self, operand)
|
138
|
+
end
|
139
|
+
|
140
|
+
def >(operand)
|
141
|
+
op(:greater, self, operand)
|
142
|
+
end
|
143
|
+
|
144
|
+
def >=(operand)
|
145
|
+
op(:greater_equal, self, operand)
|
146
|
+
end
|
147
|
+
|
148
|
+
def collect(&block)
|
149
|
+
@value.collect(&block)
|
150
|
+
end
|
151
|
+
|
152
|
+
def to_s
|
153
|
+
@name
|
154
|
+
end
|
155
|
+
|
156
|
+
# def to_ary
|
157
|
+
# if rank == 2
|
158
|
+
# @native_buffer.to_a.each_slice(shape.cols).collect { |slice| slice }
|
159
|
+
# else
|
160
|
+
# raise "Invalid rank"
|
161
|
+
# end
|
162
|
+
# end
|
163
|
+
|
164
|
+
# open cl methods
|
165
|
+
def open_cl_buffer(context)
|
166
|
+
@cl_buffer ||= context.create_buffer(@native_buffer.size * @native_buffer.element_size, :flags => OpenCL::Mem::COPY_HOST_PTR, :host_ptr => @native_buffer)
|
167
|
+
end
|
168
|
+
|
169
|
+
def sync_cl_buffer(queue, events = [])
|
170
|
+
queue.enqueue_read_buffer(@cl_buffer, @native_buffer, :event_wait_list => events)
|
171
|
+
end
|
172
|
+
|
173
|
+
def eval(options = {})
|
174
|
+
Session.default_session.run(self, options)
|
175
|
+
end
|
176
|
+
|
177
|
+
def to_h
|
178
|
+
{
|
179
|
+
name: @name,
|
180
|
+
value: hashify_tensor(@value),
|
181
|
+
dtype: @data_type,
|
182
|
+
shape: @shape,
|
183
|
+
const: !!is_const,
|
184
|
+
}
|
185
|
+
end
|
186
|
+
|
187
|
+
def to_i
|
188
|
+
@value
|
189
|
+
end
|
190
|
+
|
191
|
+
def to_a
|
192
|
+
@value
|
193
|
+
end
|
194
|
+
|
195
|
+
def to_f
|
196
|
+
@value
|
197
|
+
end
|
198
|
+
|
199
|
+
def first
|
200
|
+
op(:index, self, 0)
|
201
|
+
end
|
202
|
+
|
203
|
+
def to_math(name_only = false, max_depth = 99)
|
204
|
+
return @name if max_depth==0 || name_only || @value.nil?
|
205
|
+
|
206
|
+
if @value.kind_of?(Array)
|
207
|
+
@value.collect { |v| v.kind_of?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
|
208
|
+
else
|
209
|
+
is_const ? @value : @name
|
210
|
+
end
|
211
|
+
end
|
212
|
+
|
213
|
+
def auto_math(tensor, name_only = false, max_depth = 99)
|
214
|
+
tensor.kind_of?(Tensor) ? tensor.to_math(name_only, max_depth) : tensor
|
215
|
+
end
|
216
|
+
|
217
|
+
def self.detect_type(value)
|
218
|
+
dtype = if value.is_a?(String)
|
219
|
+
:string
|
220
|
+
elsif value.is_a?(Float)
|
221
|
+
:float32
|
222
|
+
elsif value.is_a?(Integer)
|
223
|
+
:int32
|
224
|
+
elsif value.is_a?(Array)
|
225
|
+
:array
|
226
|
+
else
|
227
|
+
:float32
|
228
|
+
end
|
229
|
+
end
|
230
|
+
|
231
|
+
def self.cast_dtype(val, dtype)
|
232
|
+
return val if val.is_a?(Tensor)
|
233
|
+
|
234
|
+
if val.is_a?(Array)
|
235
|
+
return val.collect do |v|
|
236
|
+
cast_dtype(v, dtype)
|
237
|
+
end
|
238
|
+
end
|
239
|
+
|
240
|
+
case dtype.to_sym
|
241
|
+
when :float32, :float
|
242
|
+
if !!val == val
|
243
|
+
val ? 1.0 : 0.0
|
244
|
+
else
|
245
|
+
val.to_f
|
246
|
+
end
|
247
|
+
when :string
|
248
|
+
val.to_s
|
249
|
+
when :int32, :int16
|
250
|
+
if !!val == val
|
251
|
+
val ? 1 : 0
|
252
|
+
else
|
253
|
+
val.to_i
|
254
|
+
end
|
255
|
+
when :boolean
|
256
|
+
!!val
|
257
|
+
when :unknown
|
258
|
+
val
|
259
|
+
else
|
260
|
+
fail "unknown data_type #{dtype} passed"
|
261
|
+
end
|
262
|
+
end
|
263
|
+
|
264
|
+
def breakpoint!(&block)
|
265
|
+
@breakpoint = block
|
266
|
+
self
|
267
|
+
end
|
268
|
+
|
269
|
+
protected
|
270
|
+
|
271
|
+
def set_source(trace)
|
272
|
+
trace.reject { |c| c.to_s.include?(File.join("lib","tensor_stream")) }.first
|
273
|
+
end
|
274
|
+
|
275
|
+
def hashify_tensor(tensor)
|
276
|
+
if tensor.kind_of?(Tensor)
|
277
|
+
tensor.to_h
|
278
|
+
elsif tensor.kind_of?(Array)
|
279
|
+
tensor.collect do |t| hashify_tensor(t) end
|
280
|
+
else
|
281
|
+
tensor
|
282
|
+
end
|
283
|
+
end
|
284
|
+
|
285
|
+
def reshape(arr, shape)
|
286
|
+
if arr.is_a?(Array)
|
287
|
+
return arr if shape.size < 2
|
288
|
+
slice = shape.shift
|
289
|
+
arr.each_slice(slice).collect do |s|
|
290
|
+
reshape(s, shape)
|
291
|
+
end
|
292
|
+
else
|
293
|
+
return arr if shape.size < 1
|
294
|
+
slice = shape.shift
|
295
|
+
return arr if slice.nil?
|
296
|
+
|
297
|
+
slice.times.collect do |s|
|
298
|
+
reshape(arr, shape.dup)
|
299
|
+
end
|
300
|
+
end
|
301
|
+
end
|
302
|
+
|
303
|
+
def auto_wrap(operand)
|
304
|
+
return auto_wrap(operand.call) if operand.is_a?(Proc)
|
305
|
+
|
306
|
+
if !operand.is_a?(Tensor)
|
307
|
+
i_cons(operand, dtype: @data_type || Tensor.detect_type(operand))
|
308
|
+
else
|
309
|
+
operand
|
310
|
+
end
|
311
|
+
end
|
312
|
+
|
313
|
+
def build_name
|
314
|
+
"#{@is_const ? "Const#{Tensor.const_name}:#{@rank}" : "Variable#{Tensor.var_name}:#{@rank}"}"
|
315
|
+
end
|
316
|
+
end
|
317
|
+
end
|