tensor_stream 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +12 -0
- data/.rake_tasks~ +0 -0
- data/.rspec +2 -0
- data/.travis.yml +5 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +21 -0
- data/README.md +123 -0
- data/Rakefile +6 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/lib/tensor_stream.rb +138 -0
- data/lib/tensor_stream/control_flow.rb +23 -0
- data/lib/tensor_stream/evaluator/evaluator.rb +7 -0
- data/lib/tensor_stream/evaluator/operation_helpers/random_gaussian.rb +32 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +749 -0
- data/lib/tensor_stream/graph.rb +98 -0
- data/lib/tensor_stream/graph_keys.rb +5 -0
- data/lib/tensor_stream/helpers/op_helper.rb +58 -0
- data/lib/tensor_stream/math_gradients.rb +161 -0
- data/lib/tensor_stream/monkey_patches/integer.rb +0 -0
- data/lib/tensor_stream/nn/nn_ops.rb +17 -0
- data/lib/tensor_stream/operation.rb +195 -0
- data/lib/tensor_stream/ops.rb +225 -0
- data/lib/tensor_stream/placeholder.rb +21 -0
- data/lib/tensor_stream/session.rb +66 -0
- data/lib/tensor_stream/tensor.rb +317 -0
- data/lib/tensor_stream/tensor_shape.rb +25 -0
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +23 -0
- data/lib/tensor_stream/train/saver.rb +61 -0
- data/lib/tensor_stream/trainer.rb +7 -0
- data/lib/tensor_stream/types.rb +17 -0
- data/lib/tensor_stream/variable.rb +52 -0
- data/lib/tensor_stream/version.rb +7 -0
- data/samples/iris.data +150 -0
- data/samples/iris.rb +117 -0
- data/samples/linear_regression.rb +55 -0
- data/samples/raw_neural_net_sample.rb +54 -0
- data/tensor_stream.gemspec +40 -0
- metadata +185 -0
@@ -0,0 +1,32 @@
|
|
1
|
+
# http://creativecommons.org/publicdomain/zero/1.0/
|
2
|
+
class RandomGaussian
|
3
|
+
def initialize(mean, stddev, rand_helper = lambda { Kernel.rand })
|
4
|
+
@rand_helper = rand_helper
|
5
|
+
@mean = mean
|
6
|
+
@stddev = stddev
|
7
|
+
@valid = false
|
8
|
+
@next = 0
|
9
|
+
end
|
10
|
+
|
11
|
+
def rand
|
12
|
+
if @valid then
|
13
|
+
@valid = false
|
14
|
+
return @next
|
15
|
+
else
|
16
|
+
@valid = true
|
17
|
+
x, y = self.class.gaussian(@mean, @stddev, @rand_helper)
|
18
|
+
@next = y
|
19
|
+
return x
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
private
|
24
|
+
def self.gaussian(mean, stddev, rand)
|
25
|
+
theta = 2 * Math::PI * rand.call
|
26
|
+
rho = Math.sqrt(-2 * Math.log(1 - rand.call))
|
27
|
+
scale = stddev * rho
|
28
|
+
x = mean + scale * Math.cos(theta)
|
29
|
+
y = mean + scale * Math.sin(theta)
|
30
|
+
return x, y
|
31
|
+
end
|
32
|
+
end
|
@@ -0,0 +1,749 @@
|
|
1
|
+
require "tensor_stream/evaluator/operation_helpers/random_gaussian"
|
2
|
+
require 'tensor_stream/math_gradients'
|
3
|
+
|
4
|
+
module TensorStream
|
5
|
+
module Evaluator
|
6
|
+
class FullEvalNotPossible < RuntimeError
|
7
|
+
end
|
8
|
+
|
9
|
+
# Errors during graph evaluation
|
10
|
+
class EvaluatorExcecutionException < RuntimeError
|
11
|
+
attr_reader :tensor
|
12
|
+
|
13
|
+
def initialize(exception, tensor)
|
14
|
+
@exception = exception
|
15
|
+
@tensor = tensor
|
16
|
+
end
|
17
|
+
|
18
|
+
def wrapped_exception
|
19
|
+
@exception
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
## PURE ruby evaluator used for testing and development
|
24
|
+
class RubyEvaluator
|
25
|
+
attr_accessor :retain
|
26
|
+
|
27
|
+
include TensorStream::OpHelper
|
28
|
+
|
29
|
+
def initialize(session, context, thread_pool: nil)
|
30
|
+
@session = session
|
31
|
+
@context = context
|
32
|
+
@retain = context[:retain] || []
|
33
|
+
@thread_pool = thread_pool || Concurrent::ImmediateExecutor.new
|
34
|
+
end
|
35
|
+
|
36
|
+
def run(tensor, execution_context)
|
37
|
+
return tensor.map { |t| run(t, execution_context) } if tensor.is_a?(Array)
|
38
|
+
|
39
|
+
return tensor if retain.include?(tensor) # if var is in retain don't eval to value
|
40
|
+
|
41
|
+
child_context = execution_context.dup
|
42
|
+
res = if tensor.is_a?(Operation)
|
43
|
+
eval_operation(tensor, child_context)
|
44
|
+
elsif tensor.is_a?(Variable)
|
45
|
+
eval_variable(tensor, child_context)
|
46
|
+
elsif tensor.is_a?(Placeholder)
|
47
|
+
resolve_placeholder(tensor, child_context)
|
48
|
+
else
|
49
|
+
eval_tensor(tensor, child_context)
|
50
|
+
end
|
51
|
+
execution_context.deep_merge!(returns: child_context[:returns])
|
52
|
+
res
|
53
|
+
end
|
54
|
+
|
55
|
+
def complete_eval(tensor, context)
|
56
|
+
Kernel.loop do
|
57
|
+
old_tensor = tensor
|
58
|
+
tensor = run(tensor, context)
|
59
|
+
|
60
|
+
if tensor.is_a?(Array) && !tensor.empty? && tensor[0].is_a?(Tensor)
|
61
|
+
tensor = tensor.map { |t| complete_eval(t, context) }
|
62
|
+
end
|
63
|
+
|
64
|
+
return tensor if old_tensor.equal?(tensor)
|
65
|
+
return tensor unless tensor.is_a?(Tensor)
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
protected
|
70
|
+
|
71
|
+
def eval_variable(tensor, child_context)
|
72
|
+
raise "variable #{tensor.name} not initalized" if tensor.value.nil?
|
73
|
+
eval_tensor(tensor.value, child_context).tap do |val|
|
74
|
+
child_context[:returns] ||= {}
|
75
|
+
child_context[:returns][:vars] ||= []
|
76
|
+
child_context[:returns][:vars] << { name: tensor.name, val: val }
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def eval_operation(tensor, child_context)
|
81
|
+
return @context[tensor.name] if @context.key?(tensor.name)
|
82
|
+
|
83
|
+
a = resolve_placeholder(tensor.items[0], child_context) if tensor.items && tensor.items[0]
|
84
|
+
b = resolve_placeholder(tensor.items[1], child_context) if tensor.items && tensor.items[1]
|
85
|
+
|
86
|
+
case tensor.operation
|
87
|
+
when :argmax
|
88
|
+
a = complete_eval(a, child_context)
|
89
|
+
axis = tensor.options[:axis] || 0
|
90
|
+
|
91
|
+
get_max_with_axis(a, axis, 0, tensor.data_type)
|
92
|
+
when :cast
|
93
|
+
a = complete_eval(a, child_context)
|
94
|
+
|
95
|
+
call_op(:cast, a, child_context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
|
96
|
+
when :sign
|
97
|
+
a = complete_eval(a, child_context)
|
98
|
+
|
99
|
+
func = lambda { |x, _b|
|
100
|
+
if x.zero? || (x.is_a?(Float) && x.nan?)
|
101
|
+
0
|
102
|
+
elsif x < 0
|
103
|
+
-1
|
104
|
+
elsif x > 0
|
105
|
+
1
|
106
|
+
else
|
107
|
+
fail 'assert: cannot be here'
|
108
|
+
end
|
109
|
+
}
|
110
|
+
|
111
|
+
call_op(:sign, a, child_context, func)
|
112
|
+
when :equal
|
113
|
+
a = complete_eval(a, child_context)
|
114
|
+
b = complete_eval(b, child_context)
|
115
|
+
|
116
|
+
call_vector_op(:greater, a, b, child_context, ->(t, u) { t == u })
|
117
|
+
when :not_equal
|
118
|
+
a = complete_eval(a, child_context)
|
119
|
+
b = complete_eval(b, child_context)
|
120
|
+
|
121
|
+
call_vector_op(:not_equal, a, b, child_context, ->(t, u) { t != u })
|
122
|
+
when :index
|
123
|
+
f = run(a, child_context)
|
124
|
+
index = run(b, child_context)
|
125
|
+
|
126
|
+
f[index]
|
127
|
+
when :slice
|
128
|
+
input = complete_eval(a, child_context)
|
129
|
+
start = complete_eval(b, child_context)
|
130
|
+
size = complete_eval(tensor.options[:size], child_context)
|
131
|
+
fail "start index and size not of the same shape #{start.size} != #{size.size}" if start.size != size.size
|
132
|
+
slice_tensor(input, start, size)
|
133
|
+
when :negate
|
134
|
+
call_vector_op(:negate, a, nil, child_context, ->(t, _u) { -t })
|
135
|
+
when :add
|
136
|
+
call_vector_op(:add, a, b, child_context, ->(t, u) { t + u })
|
137
|
+
when :sub
|
138
|
+
call_vector_op(:sub, a, b, child_context, ->(t, u) { t - u })
|
139
|
+
when :mul
|
140
|
+
call_vector_op(:mul, a, b, child_context, ->(t, u) { binding.pry if t.nil? || u.nil?; t * u })
|
141
|
+
when :pow
|
142
|
+
call_vector_op(:pow, a, b, child_context, ->(t, u) { t**u })
|
143
|
+
when :concat
|
144
|
+
values = complete_eval(a, child_context)
|
145
|
+
concat_array(values, tensor.options[:axis])
|
146
|
+
when :abs
|
147
|
+
call_op(:abs, a, child_context, ->(t, _b) { t.abs })
|
148
|
+
when :tanh
|
149
|
+
call_op(:tanh, a, child_context, ->(t, _b) { Math.tanh(t) })
|
150
|
+
when :tan
|
151
|
+
call_op(:tan, a, child_context, ->(t, _b) { Math.tan(t) })
|
152
|
+
when :sec
|
153
|
+
call_op(:sec, a, child_context, ->(t, _b) { Math.sec(t) })
|
154
|
+
when :sin
|
155
|
+
call_op(:sin, a, child_context, ->(t, _b) { Math.sin(t) })
|
156
|
+
when :cos
|
157
|
+
call_op(:cos, a, child_context, ->(t, _b) { Math.cos(t) })
|
158
|
+
when :log
|
159
|
+
call_op(:log, a, child_context, ->(t, _b) { t < 0 ? Float::NAN : Math.log(t)} )
|
160
|
+
when :exp
|
161
|
+
call_op(:exp, a, child_context, ->(t, _b) { Math.exp(t) } )
|
162
|
+
when :sqrt
|
163
|
+
call_op(:exp, a, child_context, ->(t, _b) { Math.sqrt(t) } )
|
164
|
+
when :square
|
165
|
+
call_op(:square, a, child_context, ->(t, _b) { t * t } )
|
166
|
+
when :stop_gradient
|
167
|
+
run(a, child_context)
|
168
|
+
when :random_uniform
|
169
|
+
maxval = tensor.options.fetch(:maxval, 1)
|
170
|
+
minval = tensor.options.fetch(:minval, 0)
|
171
|
+
|
172
|
+
generator = -> { rand * (maxval - minval) + minval }
|
173
|
+
generate_vector(tensor.options[:shape], generator: generator)
|
174
|
+
when :random_normal
|
175
|
+
r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev))
|
176
|
+
generator = -> { r.rand }
|
177
|
+
|
178
|
+
generate_vector(tensor.options[:shape], generator: generator)
|
179
|
+
when :flow_group
|
180
|
+
threads = tensor.items.collect { |item| Concurrent::Future.execute(executor: @thread_pool) { run(item, child_context) } }
|
181
|
+
threads.collect(&:value)
|
182
|
+
when :assign
|
183
|
+
assign = tensor.items[0] || tensor
|
184
|
+
assign.value = complete_eval(tensor.items[1], child_context)
|
185
|
+
assign.value
|
186
|
+
when :assign_add
|
187
|
+
tensor.items[0].value = process_vector_math_op(tensor.items[0], tensor.items[1], child_context, ->(a,b) { a + b })
|
188
|
+
tensor.items[0].value
|
189
|
+
when :assign_sub
|
190
|
+
tensor.items[0].value = process_vector_math_op(tensor.items[0], tensor.items[1], child_context, ->(a,b) { a - b })
|
191
|
+
tensor.items[0].value
|
192
|
+
when :reduce_mean
|
193
|
+
c = tensor.data_type == :float ? 0.0 : 0
|
194
|
+
func = lambda { |v|
|
195
|
+
if v.is_a?(Array)
|
196
|
+
v.empty? ? c : (v.reduce(:+) / v.size)
|
197
|
+
else
|
198
|
+
v
|
199
|
+
end
|
200
|
+
}
|
201
|
+
|
202
|
+
reduction(child_context, tensor, func)
|
203
|
+
when :reduce_sum
|
204
|
+
c = tensor.data_type == :float ? 0.0 : 0
|
205
|
+
func = ->(v) {
|
206
|
+
if v.kind_of?(Array)
|
207
|
+
v.empty? ? c : v.reduce(:+)
|
208
|
+
else
|
209
|
+
v
|
210
|
+
end
|
211
|
+
}
|
212
|
+
|
213
|
+
reduction(child_context, tensor, func)
|
214
|
+
when :reduce_prod
|
215
|
+
c = tensor.data_type == :float ? 1.0 : 1
|
216
|
+
func = ->(v) {
|
217
|
+
if v.kind_of?(Array)
|
218
|
+
v.empty? ? c : v.reduce(:*)
|
219
|
+
else
|
220
|
+
v
|
221
|
+
end
|
222
|
+
}
|
223
|
+
|
224
|
+
reduction(child_context, tensor, func)
|
225
|
+
when :transpose
|
226
|
+
matrix_a = complete_eval(a, child_context)
|
227
|
+
matrix_a.transpose
|
228
|
+
when :eye
|
229
|
+
rows = complete_eval(a, child_context)
|
230
|
+
columns = complete_eval(b, child_context)
|
231
|
+
|
232
|
+
Array.new(rows) do |i|
|
233
|
+
Array.new(columns) do |col|
|
234
|
+
if tensor.data_type == :float32
|
235
|
+
i == col ? 1.0 : 0.0
|
236
|
+
else
|
237
|
+
i == col ? 1 : 0
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
when :cond
|
242
|
+
pred = complete_eval(tensor.options[:pred], child_context)
|
243
|
+
|
244
|
+
if is_all_true(pred)
|
245
|
+
complete_eval(a, child_context)
|
246
|
+
else
|
247
|
+
complete_eval(b, child_context)
|
248
|
+
end
|
249
|
+
when :where
|
250
|
+
pred = complete_eval(tensor.options[:pred], child_context)
|
251
|
+
a = complete_eval(a, child_context)
|
252
|
+
b = complete_eval(b, child_context)
|
253
|
+
|
254
|
+
call_3way_vector_op(pred, a, b, child_context, ->(t, u, v) { t ? u : v })
|
255
|
+
when :less
|
256
|
+
a = complete_eval(a, child_context)
|
257
|
+
b = complete_eval(b, child_context)
|
258
|
+
|
259
|
+
call_vector_op(:greater, a, b, child_context, ->(t, u) { t < u })
|
260
|
+
when :greater
|
261
|
+
a = complete_eval(a, child_context)
|
262
|
+
b = complete_eval(b, child_context)
|
263
|
+
|
264
|
+
call_vector_op(:greater, a, b, child_context, ->(t, u) { t > u })
|
265
|
+
when :zeros, :ones, :zeros_like, :ones_like
|
266
|
+
|
267
|
+
shape = if %i[zeros_like ones_like].include?(tensor.operation)
|
268
|
+
a = complete_eval(a, child_context)
|
269
|
+
shape_eval(a)
|
270
|
+
else
|
271
|
+
complete_eval(a, child_context) || tensor.shape.shape
|
272
|
+
end
|
273
|
+
|
274
|
+
func = if %i[zeros zeros_like].include?(tensor.operation)
|
275
|
+
-> { tensor.data_type == :int32 ? 0 : 0.0 }
|
276
|
+
else
|
277
|
+
-> { tensor.data_type == :int32 ? 1 : 1.0 }
|
278
|
+
end
|
279
|
+
|
280
|
+
if shape.is_a?(Array) && shape.size.zero?
|
281
|
+
func.call()
|
282
|
+
else
|
283
|
+
shape = [shape.to_i] unless shape.is_a?(Array)
|
284
|
+
generate_vector(shape, generator: func)
|
285
|
+
end
|
286
|
+
when :shape
|
287
|
+
input = complete_eval(a, child_context)
|
288
|
+
|
289
|
+
shape_eval(input)
|
290
|
+
when :matmul
|
291
|
+
matrix_a = complete_eval(a, child_context)
|
292
|
+
matrix_b = complete_eval(b, child_context)
|
293
|
+
|
294
|
+
rank_a = get_rank(matrix_a)
|
295
|
+
rank_b = get_rank(matrix_b)
|
296
|
+
|
297
|
+
raise "#{a.name} rank must be greater than 1" if rank_a < 2
|
298
|
+
raise "#{b.name} rank must be greater than 1" if rank_b < 2
|
299
|
+
|
300
|
+
matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
|
301
|
+
matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
|
302
|
+
|
303
|
+
# handle matrix multiplication with constants like 1 or 0
|
304
|
+
matrix_a = matmul_const_transform(matrix_a, matrix_b, tensor)
|
305
|
+
matrix_b = matmul_const_transform(matrix_b, matrix_a, tensor)
|
306
|
+
|
307
|
+
# check matrix dimensions
|
308
|
+
raise "incompatible shape sizes for matrix multiplication (#{matrix_a[0].size} != #{matrix_b.size}) #{shape_eval(matrix_a)} vs #{shape_eval(matrix_b)}" if matrix_a[0].size != matrix_b.size
|
309
|
+
|
310
|
+
(Matrix[*matrix_a] * Matrix[*matrix_b]).to_a
|
311
|
+
when :gradients
|
312
|
+
b.collect do |xs|
|
313
|
+
fail "#{xs} passed is not a tensor object" unless xs.is_a?(Tensor)
|
314
|
+
xs_val = complete_eval(xs, child_context)
|
315
|
+
target_shape = shape_eval(xs_val)
|
316
|
+
|
317
|
+
stops = tensor.options[:stop_gradients] ? tensor.options[:stop_gradients].map(&:name).join('_') : ''
|
318
|
+
gradient_program_name = "grad_#{tensor.name}_#{xs.name}_#{stops}".to_sym
|
319
|
+
|
320
|
+
tensor_program = if tensor.graph.node_added?(gradient_program_name)
|
321
|
+
tensor.graph.get_node(gradient_program_name)
|
322
|
+
else
|
323
|
+
derivative_ops = TensorStream::MathGradients.derivative(a, xs, graph: tensor.graph, stop_gradients: tensor.options[:stop_gradients], target_shape: target_shape)
|
324
|
+
unit_matrix = op(:ones_like, xs)
|
325
|
+
tensor.graph.add_node!(gradient_program_name, unit_matrix * derivative_ops)
|
326
|
+
end
|
327
|
+
|
328
|
+
complete_eval(tensor_program, child_context)
|
329
|
+
end
|
330
|
+
when :identity
|
331
|
+
complete_eval(a, child_context)
|
332
|
+
when :print
|
333
|
+
a = complete_eval(a, child_context)
|
334
|
+
b = complete_eval(b, child_context)
|
335
|
+
puts "#{tensor.options.fetch(:message, '')} #{b}"
|
336
|
+
a
|
337
|
+
when :rank
|
338
|
+
a = complete_eval(a, child_context)
|
339
|
+
get_rank(a)
|
340
|
+
when :div
|
341
|
+
process_vector_math_op(a, b, child_context, ->(a,b) { a/b })
|
342
|
+
when :reshape
|
343
|
+
arr = complete_eval(a, child_context)
|
344
|
+
new_shape = complete_eval(b, child_context)
|
345
|
+
|
346
|
+
flat_arr = arr.flatten
|
347
|
+
return flat_arr[0] if new_shape.size == 0 && flat_arr.size == 1
|
348
|
+
|
349
|
+
new_shape = fix_inferred_elements(new_shape, flat_arr.size)
|
350
|
+
|
351
|
+
reshape(flat_arr, new_shape)
|
352
|
+
when :pad
|
353
|
+
a = complete_eval(a, child_context)
|
354
|
+
p = complete_eval(tensor.options[:paddings], child_context)
|
355
|
+
|
356
|
+
arr_pad(a, p, tensor.data_type)
|
357
|
+
when :max
|
358
|
+
a = complete_eval(a, child_context)
|
359
|
+
b = complete_eval(b, child_context)
|
360
|
+
|
361
|
+
call_vector_op(:max, a, b, child_context, ->(t, u) { [t, u].max })
|
362
|
+
else
|
363
|
+
fail "unknown op #{tensor.operation}"
|
364
|
+
end.tap do |result|
|
365
|
+
if tensor.breakpoint
|
366
|
+
a = complete_eval(a, child_context)
|
367
|
+
b = complete_eval(b, child_context)
|
368
|
+
|
369
|
+
tensor.breakpoint.call(tensor, a, b, complete_eval(result, child_context))
|
370
|
+
end
|
371
|
+
@context[tensor.name] = result
|
372
|
+
end
|
373
|
+
rescue EvaluatorExcecutionException => e
|
374
|
+
raise e
|
375
|
+
rescue StandardError => e
|
376
|
+
# a = complete_eval(a, child_context)
|
377
|
+
# b = complete_eval(b, child_context)
|
378
|
+
# puts "name: #{tensor.given_name}"
|
379
|
+
# puts "op: #{tensor.to_math(true, 1)}"
|
380
|
+
# puts "A: #{a}" if a
|
381
|
+
# puts "B: #{b}" if b
|
382
|
+
# binding.pry
|
383
|
+
puts e.backtrace.join("\n")
|
384
|
+
raise EvaluatorExcecutionException.new(e, tensor), "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math} defined at #{tensor.source}"
|
385
|
+
end
|
386
|
+
|
387
|
+
def eval_tensor(tensor, child_context)
|
388
|
+
return tensor unless tensor.is_a?(Tensor)
|
389
|
+
return @context[tensor.name] if @context.key?(tensor.name)
|
390
|
+
|
391
|
+
if tensor.value.is_a?(Array)
|
392
|
+
tensor.value.collect do |item|
|
393
|
+
item.is_a?(Tensor) ? run(item, child_context) : item
|
394
|
+
end
|
395
|
+
else
|
396
|
+
tensor.value.is_a?(Tensor) ? run(tensor.value, child_context) : tensor.value
|
397
|
+
end.tap do |result|
|
398
|
+
@context[tensor.name] = result
|
399
|
+
end
|
400
|
+
end
|
401
|
+
|
402
|
+
private
|
403
|
+
|
404
|
+
def get_max_with_axis(a, target_axis, current_axis, output_type)
|
405
|
+
if target_axis == current_axis
|
406
|
+
if a[0].is_a?(Array)
|
407
|
+
(0...a[0].size).each.collect do |column_index|
|
408
|
+
max = nil
|
409
|
+
max_index = 0
|
410
|
+
a.each_with_index do |row, row_index|
|
411
|
+
if max.nil? || row[column_index] > max
|
412
|
+
max = row[column_index]
|
413
|
+
max_index = row_index
|
414
|
+
end
|
415
|
+
end
|
416
|
+
|
417
|
+
Tensor.cast_dtype(max_index, output_type)
|
418
|
+
end
|
419
|
+
else
|
420
|
+
max = nil
|
421
|
+
max_index = 0
|
422
|
+
a.each_with_index do |a, index|
|
423
|
+
if max.nil? || a > max
|
424
|
+
max = a
|
425
|
+
max_index = index
|
426
|
+
end
|
427
|
+
end
|
428
|
+
Tensor.cast_dtype(max_index, output_type)
|
429
|
+
end
|
430
|
+
else
|
431
|
+
a.collect do |row|
|
432
|
+
get_max_with_axis(row, target_axis, current_axis + 1, output_type)
|
433
|
+
end
|
434
|
+
end
|
435
|
+
end
|
436
|
+
|
437
|
+
def reduction(child_context, tensor, func)
|
438
|
+
val = complete_eval(tensor.items[0], child_context)
|
439
|
+
axis = tensor.options[:axis]
|
440
|
+
keep_dims = tensor.options[:keepdims]
|
441
|
+
|
442
|
+
res = if axis.is_a?(Array)
|
443
|
+
axis.each do |x|
|
444
|
+
val = reduce_axis(x, val, keep_dims, child_context, func)
|
445
|
+
end
|
446
|
+
|
447
|
+
func.call(val.flatten)
|
448
|
+
else
|
449
|
+
reduce_axis(axis, val, keep_dims, child_context, func)
|
450
|
+
end
|
451
|
+
res
|
452
|
+
end
|
453
|
+
|
454
|
+
def arr_pad(arr, paddings, data_type = :float32, rank = 0)
|
455
|
+
fail "padding #{paddings[rank]} needs to have to elements [before, after]" if paddings[rank].size != 2
|
456
|
+
|
457
|
+
before = paddings[rank][0]
|
458
|
+
after = paddings[rank][1]
|
459
|
+
|
460
|
+
if arr[0].is_a?(Array)
|
461
|
+
next_dim_elem = arr.collect { |a| arr_pad(a, paddings, data_type, rank + 1) }
|
462
|
+
padding = deep_dup_array(next_dim_elem[0], data_type == :float32 ? 0.0 : 0)
|
463
|
+
before.times.map { padding } + next_dim_elem + after.times.map { padding }
|
464
|
+
else
|
465
|
+
before.times.map { data_type == :float32 ? 0.0 : 0 } + arr + after.times.map { data_type == :float32 ? 0.0 : 0 }
|
466
|
+
end
|
467
|
+
end
|
468
|
+
|
469
|
+
def deep_dup_array(arr, value = nil)
|
470
|
+
if arr.is_a?(Array)
|
471
|
+
arr.dup.collect do |a|
|
472
|
+
deep_dup_array(a, value)
|
473
|
+
end
|
474
|
+
else
|
475
|
+
value.nil? ? arr : value
|
476
|
+
end
|
477
|
+
end
|
478
|
+
|
479
|
+
def slice_tensor(input, start, size)
|
480
|
+
start_index = start.shift
|
481
|
+
dimen_size = start_index + size.shift
|
482
|
+
|
483
|
+
input[start_index...dimen_size].collect do |item|
|
484
|
+
if item.is_a?(Array)
|
485
|
+
slice_tensor(item, start.dup, size.dup)
|
486
|
+
else
|
487
|
+
item
|
488
|
+
end
|
489
|
+
end
|
490
|
+
end
|
491
|
+
|
492
|
+
def matmul_const_transform(mat, mat_b, tensor)
|
493
|
+
if !mat.is_a?(Array)
|
494
|
+
compat_shape = shape_eval(mat_b).reverse
|
495
|
+
func = ->() { tensor.data_type == :int32 ? mat.to_i : mat.to_f }
|
496
|
+
|
497
|
+
generate_vector(compat_shape, generator: func)
|
498
|
+
else
|
499
|
+
mat
|
500
|
+
end
|
501
|
+
end
|
502
|
+
|
503
|
+
def fix_inferred_elements(shape, total_size)
|
504
|
+
return shape if shape.empty?
|
505
|
+
|
506
|
+
current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product }
|
507
|
+
inferred_size = total_size / current_size
|
508
|
+
shape.map { |s| s == -1 ? inferred_size : s }
|
509
|
+
end
|
510
|
+
|
511
|
+
def reshape(arr, new_shape)
|
512
|
+
return arr if new_shape.empty?
|
513
|
+
|
514
|
+
s = new_shape.shift
|
515
|
+
|
516
|
+
if new_shape.size == 0
|
517
|
+
fail "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s
|
518
|
+
return arr
|
519
|
+
end
|
520
|
+
|
521
|
+
dim = (arr.size / s)
|
522
|
+
arr.each_slice(dim).collect do |slice|
|
523
|
+
reshape(slice, new_shape.dup)
|
524
|
+
end
|
525
|
+
end
|
526
|
+
|
527
|
+
def call_op(op, a, child_context, func)
|
528
|
+
a = complete_eval(a, child_context)
|
529
|
+
process_function_op(a, child_context, func)
|
530
|
+
rescue FullEvalNotPossible
|
531
|
+
TensorStream.send(op.to_sym, a)
|
532
|
+
end
|
533
|
+
|
534
|
+
def call_vector_op(op, a, b, child_context, func)
|
535
|
+
process_vector_math_op(a, b, child_context, func)
|
536
|
+
rescue FullEvalNotPossible
|
537
|
+
TensorStream.send(op.to_sym, a, b)
|
538
|
+
end
|
539
|
+
|
540
|
+
def process_vector_math_op(a, b, child_context, op)
|
541
|
+
eval_a = complete_eval(a, child_context) unless a.nil?
|
542
|
+
eval_b = complete_eval(b, child_context) unless b.nil?
|
543
|
+
|
544
|
+
fail FullEvalNotPossible.new, "full eval not possible for #{a.name}" if eval_a.is_a?(Tensor) || eval_b.kind_of?(Tensor)
|
545
|
+
|
546
|
+
# ruby scalar
|
547
|
+
if get_rank(eval_a) == 0
|
548
|
+
if (get_rank(eval_b)) == 0
|
549
|
+
op.call(eval_a,eval_b)
|
550
|
+
else
|
551
|
+
constant_op(eval_b, eval_a, child_context, op, true)
|
552
|
+
end
|
553
|
+
elsif get_rank(eval_a) > 0
|
554
|
+
if get_rank(eval_b) > 0
|
555
|
+
vector_op(eval_a, eval_b, child_context, op)
|
556
|
+
else
|
557
|
+
constant_op(eval_a, eval_b, child_context, op)
|
558
|
+
end
|
559
|
+
end
|
560
|
+
end
|
561
|
+
|
562
|
+
def get_rank(value, rank = 0)
|
563
|
+
return rank unless value.is_a?(Array)
|
564
|
+
return rank + 1 if value.size == 0
|
565
|
+
|
566
|
+
get_rank(value[0], rank + 1)
|
567
|
+
end
|
568
|
+
|
569
|
+
def concat_array(values, axis)
|
570
|
+
combined_array = values.shift
|
571
|
+
axis = get_rank(combined_array) - 1 if axis == -1
|
572
|
+
|
573
|
+
values.each do |v|
|
574
|
+
combined_array = concat(combined_array, v, axis)
|
575
|
+
end
|
576
|
+
combined_array
|
577
|
+
end
|
578
|
+
|
579
|
+
def concat(a, b, axis)
|
580
|
+
if axis == 0
|
581
|
+
a + b
|
582
|
+
else
|
583
|
+
a.each_with_index.collect do |i, index|
|
584
|
+
concat(i, b[index], axis - 1)
|
585
|
+
end
|
586
|
+
end
|
587
|
+
end
|
588
|
+
|
589
|
+
def process_function_op(a, child_context, op)
|
590
|
+
# ruby scalar
|
591
|
+
if (a.kind_of?(Tensor) && a.shape.rank > 0) || a.kind_of?(Array)
|
592
|
+
constant_op(a, 0, child_context, op)
|
593
|
+
elsif !a.kind_of?(Tensor) || a.shape.rank == 0
|
594
|
+
v = run(a, child_context)
|
595
|
+
fail FullEvalNotPossible.new, "full eval not possible for #{v.name}" if v.is_a?(Tensor) && !v.is_const
|
596
|
+
|
597
|
+
op.call(v, 0)
|
598
|
+
else
|
599
|
+
fail 'cannot be here'
|
600
|
+
end
|
601
|
+
end
|
602
|
+
|
603
|
+
def resolve_placeholder(placeholder, execution_context = {})
|
604
|
+
return nil if placeholder.nil?
|
605
|
+
return placeholder if retain.include?(placeholder)
|
606
|
+
|
607
|
+
var = if placeholder.kind_of?(Placeholder)
|
608
|
+
@context[placeholder.name.to_sym].tap do |c|
|
609
|
+
if c.nil?
|
610
|
+
raise "missing placeholder #{placeholder.name}"
|
611
|
+
end
|
612
|
+
end
|
613
|
+
else
|
614
|
+
placeholder
|
615
|
+
end
|
616
|
+
|
617
|
+
return var unless placeholder.kind_of?(Tensor)
|
618
|
+
Tensor.cast_dtype(var, placeholder.data_type)
|
619
|
+
end
|
620
|
+
|
621
|
+
def reduce_axis(axis, val, keep_dims, child_context, op = ->(v) { v.kind_of?(Array) ? v.reduce(:+) : v })
|
622
|
+
val = run(val, child_context)
|
623
|
+
return val.is_a?(Array) ? op.call(val.flatten) : val if axis.nil?
|
624
|
+
return val.transpose.collect { |v| keep_dims ? [op.call(v)] : op.call(v) } if axis == 0
|
625
|
+
return val.collect { |v| keep_dims ? [op.call(v)] : op.call(v) } if axis == 1
|
626
|
+
|
627
|
+
fail "can't handle with axis > 1 :("
|
628
|
+
end
|
629
|
+
|
630
|
+
def constant_add(vector, constant)
|
631
|
+
run(vector).collect do |item|
|
632
|
+
if item.is_a?(Array)
|
633
|
+
constant_add(item, constant)
|
634
|
+
else
|
635
|
+
if item.respond_to?(:value)
|
636
|
+
item.value + constant
|
637
|
+
else
|
638
|
+
item + constant
|
639
|
+
end
|
640
|
+
end
|
641
|
+
end
|
642
|
+
end
|
643
|
+
|
644
|
+
def constant_op(vector, constant, child_context, op = ->(a,b) { a + b }, switch = false)
|
645
|
+
eval_vector = complete_eval(vector, child_context)
|
646
|
+
constant = complete_eval(constant, child_context)
|
647
|
+
|
648
|
+
fail FullEvalNotPossible.new, "full eval not possible for #{eval_vector.name}" if eval_vector.kind_of?(Tensor) || constant.kind_of?(Tensor)
|
649
|
+
|
650
|
+
eval_vector.each_with_index.collect do |item, index|
|
651
|
+
c = constant.is_a?(Array) ? constant[index] : constant
|
652
|
+
if item.is_a?(Array)
|
653
|
+
constant_op(item, c, child_context, op, switch)
|
654
|
+
else
|
655
|
+
if item.respond_to?(:value)
|
656
|
+
switch ? op.(c, item.value) : op.(item.value, c)
|
657
|
+
else
|
658
|
+
switch ? op.(c, item) : op.(item, c)
|
659
|
+
end
|
660
|
+
end
|
661
|
+
end
|
662
|
+
end
|
663
|
+
|
664
|
+
def call_3way_vector_op(v_a, v_b, v_c, child_context, op = ->(a,b,c) { a + b + c})
|
665
|
+
return op.call(v_a, v_b, v_c) unless v_a.is_a?(Array)
|
666
|
+
|
667
|
+
v_a.each_with_index.collect do |v1, index|
|
668
|
+
v2 = v_b[index]
|
669
|
+
v3 = v_c[index]
|
670
|
+
if v1.is_a?(Array)
|
671
|
+
call_3way_vector_op(v1, v2, v3, child_context, op)
|
672
|
+
else
|
673
|
+
op.call(v1, v2, v3)
|
674
|
+
end
|
675
|
+
end
|
676
|
+
end
|
677
|
+
|
678
|
+
def vector_op(vector, vector2, child_context, op = ->(a,b) { a + b })
|
679
|
+
v_a = run(vector, child_context)
|
680
|
+
v_b = run(vector2, child_context)
|
681
|
+
|
682
|
+
if get_rank(v_a) < get_rank(v_b) # upgrade rank of A
|
683
|
+
duplicated = v_b.size.times.collect do
|
684
|
+
v_a
|
685
|
+
end
|
686
|
+
return vector_op(duplicated, v_b, child_context, op)
|
687
|
+
end
|
688
|
+
|
689
|
+
v_a.each_with_index.collect do |item, index|
|
690
|
+
next vector_op(item, v_b, child_context, op) if item.is_a?(Array) && get_rank(v_a) > get_rank(v_b)
|
691
|
+
|
692
|
+
z = index < v_b.size ? v_b[index] : v_b[0]
|
693
|
+
|
694
|
+
if item.is_a?(Array)
|
695
|
+
constant_op(item, z, child_context, op)
|
696
|
+
else
|
697
|
+
item.respond_to?(:value) ? op.call(item.value, z.value) : op.call(item, z)
|
698
|
+
end
|
699
|
+
end
|
700
|
+
end
|
701
|
+
|
702
|
+
def is_all_true(arr)
|
703
|
+
if arr.is_a?(Array)
|
704
|
+
arr.each do |a|
|
705
|
+
return false if !is_all_true(a)
|
706
|
+
end
|
707
|
+
return true
|
708
|
+
end
|
709
|
+
|
710
|
+
!!arr
|
711
|
+
end
|
712
|
+
|
713
|
+
def vector_add(vector, vector2, child_context)
|
714
|
+
v_a = run(vector, child_context)
|
715
|
+
v_b = run(vector2, child_context)
|
716
|
+
|
717
|
+
v_a.each_with_index.collect do |item, index|
|
718
|
+
if item.is_a?(Array)
|
719
|
+
constant_add(item, constant)
|
720
|
+
else
|
721
|
+
if item.respond_to?(:value)
|
722
|
+
item.value + v_b[index].value
|
723
|
+
else
|
724
|
+
item + v_b[index]
|
725
|
+
end
|
726
|
+
end
|
727
|
+
end
|
728
|
+
end
|
729
|
+
|
730
|
+
def generate_vector(shape, dtype: :float32, generator: )
|
731
|
+
if shape.is_a?(Integer)
|
732
|
+
shape.times.collect do
|
733
|
+
generator.call
|
734
|
+
end
|
735
|
+
elsif shape.size > 1
|
736
|
+
shape[0].times.collect do
|
737
|
+
generate_vector(shape[1..shape.size], generator: generator, dtype: dtype)
|
738
|
+
end
|
739
|
+
elsif shape.size == 1
|
740
|
+
shape[0].times.collect do
|
741
|
+
generator.call
|
742
|
+
end
|
743
|
+
elsif shape.size == 0
|
744
|
+
generator.call
|
745
|
+
end
|
746
|
+
end
|
747
|
+
end
|
748
|
+
end
|
749
|
+
end
|