tensor_stream 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rake_tasks~ +0 -0
  4. data/.rspec +2 -0
  5. data/.travis.yml +5 -0
  6. data/CODE_OF_CONDUCT.md +74 -0
  7. data/Gemfile +4 -0
  8. data/LICENSE.txt +21 -0
  9. data/README.md +123 -0
  10. data/Rakefile +6 -0
  11. data/bin/console +14 -0
  12. data/bin/setup +8 -0
  13. data/lib/tensor_stream.rb +138 -0
  14. data/lib/tensor_stream/control_flow.rb +23 -0
  15. data/lib/tensor_stream/evaluator/evaluator.rb +7 -0
  16. data/lib/tensor_stream/evaluator/operation_helpers/random_gaussian.rb +32 -0
  17. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +749 -0
  18. data/lib/tensor_stream/graph.rb +98 -0
  19. data/lib/tensor_stream/graph_keys.rb +5 -0
  20. data/lib/tensor_stream/helpers/op_helper.rb +58 -0
  21. data/lib/tensor_stream/math_gradients.rb +161 -0
  22. data/lib/tensor_stream/monkey_patches/integer.rb +0 -0
  23. data/lib/tensor_stream/nn/nn_ops.rb +17 -0
  24. data/lib/tensor_stream/operation.rb +195 -0
  25. data/lib/tensor_stream/ops.rb +225 -0
  26. data/lib/tensor_stream/placeholder.rb +21 -0
  27. data/lib/tensor_stream/session.rb +66 -0
  28. data/lib/tensor_stream/tensor.rb +317 -0
  29. data/lib/tensor_stream/tensor_shape.rb +25 -0
  30. data/lib/tensor_stream/train/gradient_descent_optimizer.rb +23 -0
  31. data/lib/tensor_stream/train/saver.rb +61 -0
  32. data/lib/tensor_stream/trainer.rb +7 -0
  33. data/lib/tensor_stream/types.rb +17 -0
  34. data/lib/tensor_stream/variable.rb +52 -0
  35. data/lib/tensor_stream/version.rb +7 -0
  36. data/samples/iris.data +150 -0
  37. data/samples/iris.rb +117 -0
  38. data/samples/linear_regression.rb +55 -0
  39. data/samples/raw_neural_net_sample.rb +54 -0
  40. data/tensor_stream.gemspec +40 -0
  41. metadata +185 -0
@@ -0,0 +1,7 @@
1
+
2
+ require 'tensor_stream/evaluator/ruby_evaluator'
3
+
4
+ module TensorStream
5
+ module Evaluator
6
+ end
7
+ end
@@ -0,0 +1,32 @@
1
+ # http://creativecommons.org/publicdomain/zero/1.0/
2
+ class RandomGaussian
3
+ def initialize(mean, stddev, rand_helper = lambda { Kernel.rand })
4
+ @rand_helper = rand_helper
5
+ @mean = mean
6
+ @stddev = stddev
7
+ @valid = false
8
+ @next = 0
9
+ end
10
+
11
+ def rand
12
+ if @valid then
13
+ @valid = false
14
+ return @next
15
+ else
16
+ @valid = true
17
+ x, y = self.class.gaussian(@mean, @stddev, @rand_helper)
18
+ @next = y
19
+ return x
20
+ end
21
+ end
22
+
23
+ private
24
+ def self.gaussian(mean, stddev, rand)
25
+ theta = 2 * Math::PI * rand.call
26
+ rho = Math.sqrt(-2 * Math.log(1 - rand.call))
27
+ scale = stddev * rho
28
+ x = mean + scale * Math.cos(theta)
29
+ y = mean + scale * Math.sin(theta)
30
+ return x, y
31
+ end
32
+ end
@@ -0,0 +1,749 @@
1
+ require "tensor_stream/evaluator/operation_helpers/random_gaussian"
2
+ require 'tensor_stream/math_gradients'
3
+
4
+ module TensorStream
5
+ module Evaluator
6
+ class FullEvalNotPossible < RuntimeError
7
+ end
8
+
9
+ # Errors during graph evaluation
10
+ class EvaluatorExcecutionException < RuntimeError
11
+ attr_reader :tensor
12
+
13
+ def initialize(exception, tensor)
14
+ @exception = exception
15
+ @tensor = tensor
16
+ end
17
+
18
+ def wrapped_exception
19
+ @exception
20
+ end
21
+ end
22
+
23
+ ## PURE ruby evaluator used for testing and development
24
+ class RubyEvaluator
25
+ attr_accessor :retain
26
+
27
+ include TensorStream::OpHelper
28
+
29
+ def initialize(session, context, thread_pool: nil)
30
+ @session = session
31
+ @context = context
32
+ @retain = context[:retain] || []
33
+ @thread_pool = thread_pool || Concurrent::ImmediateExecutor.new
34
+ end
35
+
36
+ def run(tensor, execution_context)
37
+ return tensor.map { |t| run(t, execution_context) } if tensor.is_a?(Array)
38
+
39
+ return tensor if retain.include?(tensor) # if var is in retain don't eval to value
40
+
41
+ child_context = execution_context.dup
42
+ res = if tensor.is_a?(Operation)
43
+ eval_operation(tensor, child_context)
44
+ elsif tensor.is_a?(Variable)
45
+ eval_variable(tensor, child_context)
46
+ elsif tensor.is_a?(Placeholder)
47
+ resolve_placeholder(tensor, child_context)
48
+ else
49
+ eval_tensor(tensor, child_context)
50
+ end
51
+ execution_context.deep_merge!(returns: child_context[:returns])
52
+ res
53
+ end
54
+
55
+ def complete_eval(tensor, context)
56
+ Kernel.loop do
57
+ old_tensor = tensor
58
+ tensor = run(tensor, context)
59
+
60
+ if tensor.is_a?(Array) && !tensor.empty? && tensor[0].is_a?(Tensor)
61
+ tensor = tensor.map { |t| complete_eval(t, context) }
62
+ end
63
+
64
+ return tensor if old_tensor.equal?(tensor)
65
+ return tensor unless tensor.is_a?(Tensor)
66
+ end
67
+ end
68
+
69
+ protected
70
+
71
+ def eval_variable(tensor, child_context)
72
+ raise "variable #{tensor.name} not initalized" if tensor.value.nil?
73
+ eval_tensor(tensor.value, child_context).tap do |val|
74
+ child_context[:returns] ||= {}
75
+ child_context[:returns][:vars] ||= []
76
+ child_context[:returns][:vars] << { name: tensor.name, val: val }
77
+ end
78
+ end
79
+
80
+ def eval_operation(tensor, child_context)
81
+ return @context[tensor.name] if @context.key?(tensor.name)
82
+
83
+ a = resolve_placeholder(tensor.items[0], child_context) if tensor.items && tensor.items[0]
84
+ b = resolve_placeholder(tensor.items[1], child_context) if tensor.items && tensor.items[1]
85
+
86
+ case tensor.operation
87
+ when :argmax
88
+ a = complete_eval(a, child_context)
89
+ axis = tensor.options[:axis] || 0
90
+
91
+ get_max_with_axis(a, axis, 0, tensor.data_type)
92
+ when :cast
93
+ a = complete_eval(a, child_context)
94
+
95
+ call_op(:cast, a, child_context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
96
+ when :sign
97
+ a = complete_eval(a, child_context)
98
+
99
+ func = lambda { |x, _b|
100
+ if x.zero? || (x.is_a?(Float) && x.nan?)
101
+ 0
102
+ elsif x < 0
103
+ -1
104
+ elsif x > 0
105
+ 1
106
+ else
107
+ fail 'assert: cannot be here'
108
+ end
109
+ }
110
+
111
+ call_op(:sign, a, child_context, func)
112
+ when :equal
113
+ a = complete_eval(a, child_context)
114
+ b = complete_eval(b, child_context)
115
+
116
+ call_vector_op(:greater, a, b, child_context, ->(t, u) { t == u })
117
+ when :not_equal
118
+ a = complete_eval(a, child_context)
119
+ b = complete_eval(b, child_context)
120
+
121
+ call_vector_op(:not_equal, a, b, child_context, ->(t, u) { t != u })
122
+ when :index
123
+ f = run(a, child_context)
124
+ index = run(b, child_context)
125
+
126
+ f[index]
127
+ when :slice
128
+ input = complete_eval(a, child_context)
129
+ start = complete_eval(b, child_context)
130
+ size = complete_eval(tensor.options[:size], child_context)
131
+ fail "start index and size not of the same shape #{start.size} != #{size.size}" if start.size != size.size
132
+ slice_tensor(input, start, size)
133
+ when :negate
134
+ call_vector_op(:negate, a, nil, child_context, ->(t, _u) { -t })
135
+ when :add
136
+ call_vector_op(:add, a, b, child_context, ->(t, u) { t + u })
137
+ when :sub
138
+ call_vector_op(:sub, a, b, child_context, ->(t, u) { t - u })
139
+ when :mul
140
+ call_vector_op(:mul, a, b, child_context, ->(t, u) { binding.pry if t.nil? || u.nil?; t * u })
141
+ when :pow
142
+ call_vector_op(:pow, a, b, child_context, ->(t, u) { t**u })
143
+ when :concat
144
+ values = complete_eval(a, child_context)
145
+ concat_array(values, tensor.options[:axis])
146
+ when :abs
147
+ call_op(:abs, a, child_context, ->(t, _b) { t.abs })
148
+ when :tanh
149
+ call_op(:tanh, a, child_context, ->(t, _b) { Math.tanh(t) })
150
+ when :tan
151
+ call_op(:tan, a, child_context, ->(t, _b) { Math.tan(t) })
152
+ when :sec
153
+ call_op(:sec, a, child_context, ->(t, _b) { Math.sec(t) })
154
+ when :sin
155
+ call_op(:sin, a, child_context, ->(t, _b) { Math.sin(t) })
156
+ when :cos
157
+ call_op(:cos, a, child_context, ->(t, _b) { Math.cos(t) })
158
+ when :log
159
+ call_op(:log, a, child_context, ->(t, _b) { t < 0 ? Float::NAN : Math.log(t)} )
160
+ when :exp
161
+ call_op(:exp, a, child_context, ->(t, _b) { Math.exp(t) } )
162
+ when :sqrt
163
+ call_op(:exp, a, child_context, ->(t, _b) { Math.sqrt(t) } )
164
+ when :square
165
+ call_op(:square, a, child_context, ->(t, _b) { t * t } )
166
+ when :stop_gradient
167
+ run(a, child_context)
168
+ when :random_uniform
169
+ maxval = tensor.options.fetch(:maxval, 1)
170
+ minval = tensor.options.fetch(:minval, 0)
171
+
172
+ generator = -> { rand * (maxval - minval) + minval }
173
+ generate_vector(tensor.options[:shape], generator: generator)
174
+ when :random_normal
175
+ r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev))
176
+ generator = -> { r.rand }
177
+
178
+ generate_vector(tensor.options[:shape], generator: generator)
179
+ when :flow_group
180
+ threads = tensor.items.collect { |item| Concurrent::Future.execute(executor: @thread_pool) { run(item, child_context) } }
181
+ threads.collect(&:value)
182
+ when :assign
183
+ assign = tensor.items[0] || tensor
184
+ assign.value = complete_eval(tensor.items[1], child_context)
185
+ assign.value
186
+ when :assign_add
187
+ tensor.items[0].value = process_vector_math_op(tensor.items[0], tensor.items[1], child_context, ->(a,b) { a + b })
188
+ tensor.items[0].value
189
+ when :assign_sub
190
+ tensor.items[0].value = process_vector_math_op(tensor.items[0], tensor.items[1], child_context, ->(a,b) { a - b })
191
+ tensor.items[0].value
192
+ when :reduce_mean
193
+ c = tensor.data_type == :float ? 0.0 : 0
194
+ func = lambda { |v|
195
+ if v.is_a?(Array)
196
+ v.empty? ? c : (v.reduce(:+) / v.size)
197
+ else
198
+ v
199
+ end
200
+ }
201
+
202
+ reduction(child_context, tensor, func)
203
+ when :reduce_sum
204
+ c = tensor.data_type == :float ? 0.0 : 0
205
+ func = ->(v) {
206
+ if v.kind_of?(Array)
207
+ v.empty? ? c : v.reduce(:+)
208
+ else
209
+ v
210
+ end
211
+ }
212
+
213
+ reduction(child_context, tensor, func)
214
+ when :reduce_prod
215
+ c = tensor.data_type == :float ? 1.0 : 1
216
+ func = ->(v) {
217
+ if v.kind_of?(Array)
218
+ v.empty? ? c : v.reduce(:*)
219
+ else
220
+ v
221
+ end
222
+ }
223
+
224
+ reduction(child_context, tensor, func)
225
+ when :transpose
226
+ matrix_a = complete_eval(a, child_context)
227
+ matrix_a.transpose
228
+ when :eye
229
+ rows = complete_eval(a, child_context)
230
+ columns = complete_eval(b, child_context)
231
+
232
+ Array.new(rows) do |i|
233
+ Array.new(columns) do |col|
234
+ if tensor.data_type == :float32
235
+ i == col ? 1.0 : 0.0
236
+ else
237
+ i == col ? 1 : 0
238
+ end
239
+ end
240
+ end
241
+ when :cond
242
+ pred = complete_eval(tensor.options[:pred], child_context)
243
+
244
+ if is_all_true(pred)
245
+ complete_eval(a, child_context)
246
+ else
247
+ complete_eval(b, child_context)
248
+ end
249
+ when :where
250
+ pred = complete_eval(tensor.options[:pred], child_context)
251
+ a = complete_eval(a, child_context)
252
+ b = complete_eval(b, child_context)
253
+
254
+ call_3way_vector_op(pred, a, b, child_context, ->(t, u, v) { t ? u : v })
255
+ when :less
256
+ a = complete_eval(a, child_context)
257
+ b = complete_eval(b, child_context)
258
+
259
+ call_vector_op(:greater, a, b, child_context, ->(t, u) { t < u })
260
+ when :greater
261
+ a = complete_eval(a, child_context)
262
+ b = complete_eval(b, child_context)
263
+
264
+ call_vector_op(:greater, a, b, child_context, ->(t, u) { t > u })
265
+ when :zeros, :ones, :zeros_like, :ones_like
266
+
267
+ shape = if %i[zeros_like ones_like].include?(tensor.operation)
268
+ a = complete_eval(a, child_context)
269
+ shape_eval(a)
270
+ else
271
+ complete_eval(a, child_context) || tensor.shape.shape
272
+ end
273
+
274
+ func = if %i[zeros zeros_like].include?(tensor.operation)
275
+ -> { tensor.data_type == :int32 ? 0 : 0.0 }
276
+ else
277
+ -> { tensor.data_type == :int32 ? 1 : 1.0 }
278
+ end
279
+
280
+ if shape.is_a?(Array) && shape.size.zero?
281
+ func.call()
282
+ else
283
+ shape = [shape.to_i] unless shape.is_a?(Array)
284
+ generate_vector(shape, generator: func)
285
+ end
286
+ when :shape
287
+ input = complete_eval(a, child_context)
288
+
289
+ shape_eval(input)
290
+ when :matmul
291
+ matrix_a = complete_eval(a, child_context)
292
+ matrix_b = complete_eval(b, child_context)
293
+
294
+ rank_a = get_rank(matrix_a)
295
+ rank_b = get_rank(matrix_b)
296
+
297
+ raise "#{a.name} rank must be greater than 1" if rank_a < 2
298
+ raise "#{b.name} rank must be greater than 1" if rank_b < 2
299
+
300
+ matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
301
+ matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
302
+
303
+ # handle matrix multiplication with constants like 1 or 0
304
+ matrix_a = matmul_const_transform(matrix_a, matrix_b, tensor)
305
+ matrix_b = matmul_const_transform(matrix_b, matrix_a, tensor)
306
+
307
+ # check matrix dimensions
308
+ raise "incompatible shape sizes for matrix multiplication (#{matrix_a[0].size} != #{matrix_b.size}) #{shape_eval(matrix_a)} vs #{shape_eval(matrix_b)}" if matrix_a[0].size != matrix_b.size
309
+
310
+ (Matrix[*matrix_a] * Matrix[*matrix_b]).to_a
311
+ when :gradients
312
+ b.collect do |xs|
313
+ fail "#{xs} passed is not a tensor object" unless xs.is_a?(Tensor)
314
+ xs_val = complete_eval(xs, child_context)
315
+ target_shape = shape_eval(xs_val)
316
+
317
+ stops = tensor.options[:stop_gradients] ? tensor.options[:stop_gradients].map(&:name).join('_') : ''
318
+ gradient_program_name = "grad_#{tensor.name}_#{xs.name}_#{stops}".to_sym
319
+
320
+ tensor_program = if tensor.graph.node_added?(gradient_program_name)
321
+ tensor.graph.get_node(gradient_program_name)
322
+ else
323
+ derivative_ops = TensorStream::MathGradients.derivative(a, xs, graph: tensor.graph, stop_gradients: tensor.options[:stop_gradients], target_shape: target_shape)
324
+ unit_matrix = op(:ones_like, xs)
325
+ tensor.graph.add_node!(gradient_program_name, unit_matrix * derivative_ops)
326
+ end
327
+
328
+ complete_eval(tensor_program, child_context)
329
+ end
330
+ when :identity
331
+ complete_eval(a, child_context)
332
+ when :print
333
+ a = complete_eval(a, child_context)
334
+ b = complete_eval(b, child_context)
335
+ puts "#{tensor.options.fetch(:message, '')} #{b}"
336
+ a
337
+ when :rank
338
+ a = complete_eval(a, child_context)
339
+ get_rank(a)
340
+ when :div
341
+ process_vector_math_op(a, b, child_context, ->(a,b) { a/b })
342
+ when :reshape
343
+ arr = complete_eval(a, child_context)
344
+ new_shape = complete_eval(b, child_context)
345
+
346
+ flat_arr = arr.flatten
347
+ return flat_arr[0] if new_shape.size == 0 && flat_arr.size == 1
348
+
349
+ new_shape = fix_inferred_elements(new_shape, flat_arr.size)
350
+
351
+ reshape(flat_arr, new_shape)
352
+ when :pad
353
+ a = complete_eval(a, child_context)
354
+ p = complete_eval(tensor.options[:paddings], child_context)
355
+
356
+ arr_pad(a, p, tensor.data_type)
357
+ when :max
358
+ a = complete_eval(a, child_context)
359
+ b = complete_eval(b, child_context)
360
+
361
+ call_vector_op(:max, a, b, child_context, ->(t, u) { [t, u].max })
362
+ else
363
+ fail "unknown op #{tensor.operation}"
364
+ end.tap do |result|
365
+ if tensor.breakpoint
366
+ a = complete_eval(a, child_context)
367
+ b = complete_eval(b, child_context)
368
+
369
+ tensor.breakpoint.call(tensor, a, b, complete_eval(result, child_context))
370
+ end
371
+ @context[tensor.name] = result
372
+ end
373
+ rescue EvaluatorExcecutionException => e
374
+ raise e
375
+ rescue StandardError => e
376
+ # a = complete_eval(a, child_context)
377
+ # b = complete_eval(b, child_context)
378
+ # puts "name: #{tensor.given_name}"
379
+ # puts "op: #{tensor.to_math(true, 1)}"
380
+ # puts "A: #{a}" if a
381
+ # puts "B: #{b}" if b
382
+ # binding.pry
383
+ puts e.backtrace.join("\n")
384
+ raise EvaluatorExcecutionException.new(e, tensor), "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math} defined at #{tensor.source}"
385
+ end
386
+
387
+ def eval_tensor(tensor, child_context)
388
+ return tensor unless tensor.is_a?(Tensor)
389
+ return @context[tensor.name] if @context.key?(tensor.name)
390
+
391
+ if tensor.value.is_a?(Array)
392
+ tensor.value.collect do |item|
393
+ item.is_a?(Tensor) ? run(item, child_context) : item
394
+ end
395
+ else
396
+ tensor.value.is_a?(Tensor) ? run(tensor.value, child_context) : tensor.value
397
+ end.tap do |result|
398
+ @context[tensor.name] = result
399
+ end
400
+ end
401
+
402
+ private
403
+
404
+ def get_max_with_axis(a, target_axis, current_axis, output_type)
405
+ if target_axis == current_axis
406
+ if a[0].is_a?(Array)
407
+ (0...a[0].size).each.collect do |column_index|
408
+ max = nil
409
+ max_index = 0
410
+ a.each_with_index do |row, row_index|
411
+ if max.nil? || row[column_index] > max
412
+ max = row[column_index]
413
+ max_index = row_index
414
+ end
415
+ end
416
+
417
+ Tensor.cast_dtype(max_index, output_type)
418
+ end
419
+ else
420
+ max = nil
421
+ max_index = 0
422
+ a.each_with_index do |a, index|
423
+ if max.nil? || a > max
424
+ max = a
425
+ max_index = index
426
+ end
427
+ end
428
+ Tensor.cast_dtype(max_index, output_type)
429
+ end
430
+ else
431
+ a.collect do |row|
432
+ get_max_with_axis(row, target_axis, current_axis + 1, output_type)
433
+ end
434
+ end
435
+ end
436
+
437
+ def reduction(child_context, tensor, func)
438
+ val = complete_eval(tensor.items[0], child_context)
439
+ axis = tensor.options[:axis]
440
+ keep_dims = tensor.options[:keepdims]
441
+
442
+ res = if axis.is_a?(Array)
443
+ axis.each do |x|
444
+ val = reduce_axis(x, val, keep_dims, child_context, func)
445
+ end
446
+
447
+ func.call(val.flatten)
448
+ else
449
+ reduce_axis(axis, val, keep_dims, child_context, func)
450
+ end
451
+ res
452
+ end
453
+
454
+ def arr_pad(arr, paddings, data_type = :float32, rank = 0)
455
+ fail "padding #{paddings[rank]} needs to have to elements [before, after]" if paddings[rank].size != 2
456
+
457
+ before = paddings[rank][0]
458
+ after = paddings[rank][1]
459
+
460
+ if arr[0].is_a?(Array)
461
+ next_dim_elem = arr.collect { |a| arr_pad(a, paddings, data_type, rank + 1) }
462
+ padding = deep_dup_array(next_dim_elem[0], data_type == :float32 ? 0.0 : 0)
463
+ before.times.map { padding } + next_dim_elem + after.times.map { padding }
464
+ else
465
+ before.times.map { data_type == :float32 ? 0.0 : 0 } + arr + after.times.map { data_type == :float32 ? 0.0 : 0 }
466
+ end
467
+ end
468
+
469
+ def deep_dup_array(arr, value = nil)
470
+ if arr.is_a?(Array)
471
+ arr.dup.collect do |a|
472
+ deep_dup_array(a, value)
473
+ end
474
+ else
475
+ value.nil? ? arr : value
476
+ end
477
+ end
478
+
479
+ def slice_tensor(input, start, size)
480
+ start_index = start.shift
481
+ dimen_size = start_index + size.shift
482
+
483
+ input[start_index...dimen_size].collect do |item|
484
+ if item.is_a?(Array)
485
+ slice_tensor(item, start.dup, size.dup)
486
+ else
487
+ item
488
+ end
489
+ end
490
+ end
491
+
492
+ def matmul_const_transform(mat, mat_b, tensor)
493
+ if !mat.is_a?(Array)
494
+ compat_shape = shape_eval(mat_b).reverse
495
+ func = ->() { tensor.data_type == :int32 ? mat.to_i : mat.to_f }
496
+
497
+ generate_vector(compat_shape, generator: func)
498
+ else
499
+ mat
500
+ end
501
+ end
502
+
503
+ def fix_inferred_elements(shape, total_size)
504
+ return shape if shape.empty?
505
+
506
+ current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product }
507
+ inferred_size = total_size / current_size
508
+ shape.map { |s| s == -1 ? inferred_size : s }
509
+ end
510
+
511
+ def reshape(arr, new_shape)
512
+ return arr if new_shape.empty?
513
+
514
+ s = new_shape.shift
515
+
516
+ if new_shape.size == 0
517
+ fail "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s
518
+ return arr
519
+ end
520
+
521
+ dim = (arr.size / s)
522
+ arr.each_slice(dim).collect do |slice|
523
+ reshape(slice, new_shape.dup)
524
+ end
525
+ end
526
+
527
+ def call_op(op, a, child_context, func)
528
+ a = complete_eval(a, child_context)
529
+ process_function_op(a, child_context, func)
530
+ rescue FullEvalNotPossible
531
+ TensorStream.send(op.to_sym, a)
532
+ end
533
+
534
+ def call_vector_op(op, a, b, child_context, func)
535
+ process_vector_math_op(a, b, child_context, func)
536
+ rescue FullEvalNotPossible
537
+ TensorStream.send(op.to_sym, a, b)
538
+ end
539
+
540
+ def process_vector_math_op(a, b, child_context, op)
541
+ eval_a = complete_eval(a, child_context) unless a.nil?
542
+ eval_b = complete_eval(b, child_context) unless b.nil?
543
+
544
+ fail FullEvalNotPossible.new, "full eval not possible for #{a.name}" if eval_a.is_a?(Tensor) || eval_b.kind_of?(Tensor)
545
+
546
+ # ruby scalar
547
+ if get_rank(eval_a) == 0
548
+ if (get_rank(eval_b)) == 0
549
+ op.call(eval_a,eval_b)
550
+ else
551
+ constant_op(eval_b, eval_a, child_context, op, true)
552
+ end
553
+ elsif get_rank(eval_a) > 0
554
+ if get_rank(eval_b) > 0
555
+ vector_op(eval_a, eval_b, child_context, op)
556
+ else
557
+ constant_op(eval_a, eval_b, child_context, op)
558
+ end
559
+ end
560
+ end
561
+
562
+ def get_rank(value, rank = 0)
563
+ return rank unless value.is_a?(Array)
564
+ return rank + 1 if value.size == 0
565
+
566
+ get_rank(value[0], rank + 1)
567
+ end
568
+
569
+ def concat_array(values, axis)
570
+ combined_array = values.shift
571
+ axis = get_rank(combined_array) - 1 if axis == -1
572
+
573
+ values.each do |v|
574
+ combined_array = concat(combined_array, v, axis)
575
+ end
576
+ combined_array
577
+ end
578
+
579
+ def concat(a, b, axis)
580
+ if axis == 0
581
+ a + b
582
+ else
583
+ a.each_with_index.collect do |i, index|
584
+ concat(i, b[index], axis - 1)
585
+ end
586
+ end
587
+ end
588
+
589
+ def process_function_op(a, child_context, op)
590
+ # ruby scalar
591
+ if (a.kind_of?(Tensor) && a.shape.rank > 0) || a.kind_of?(Array)
592
+ constant_op(a, 0, child_context, op)
593
+ elsif !a.kind_of?(Tensor) || a.shape.rank == 0
594
+ v = run(a, child_context)
595
+ fail FullEvalNotPossible.new, "full eval not possible for #{v.name}" if v.is_a?(Tensor) && !v.is_const
596
+
597
+ op.call(v, 0)
598
+ else
599
+ fail 'cannot be here'
600
+ end
601
+ end
602
+
603
+ def resolve_placeholder(placeholder, execution_context = {})
604
+ return nil if placeholder.nil?
605
+ return placeholder if retain.include?(placeholder)
606
+
607
+ var = if placeholder.kind_of?(Placeholder)
608
+ @context[placeholder.name.to_sym].tap do |c|
609
+ if c.nil?
610
+ raise "missing placeholder #{placeholder.name}"
611
+ end
612
+ end
613
+ else
614
+ placeholder
615
+ end
616
+
617
+ return var unless placeholder.kind_of?(Tensor)
618
+ Tensor.cast_dtype(var, placeholder.data_type)
619
+ end
620
+
621
+ def reduce_axis(axis, val, keep_dims, child_context, op = ->(v) { v.kind_of?(Array) ? v.reduce(:+) : v })
622
+ val = run(val, child_context)
623
+ return val.is_a?(Array) ? op.call(val.flatten) : val if axis.nil?
624
+ return val.transpose.collect { |v| keep_dims ? [op.call(v)] : op.call(v) } if axis == 0
625
+ return val.collect { |v| keep_dims ? [op.call(v)] : op.call(v) } if axis == 1
626
+
627
+ fail "can't handle with axis > 1 :("
628
+ end
629
+
630
+ def constant_add(vector, constant)
631
+ run(vector).collect do |item|
632
+ if item.is_a?(Array)
633
+ constant_add(item, constant)
634
+ else
635
+ if item.respond_to?(:value)
636
+ item.value + constant
637
+ else
638
+ item + constant
639
+ end
640
+ end
641
+ end
642
+ end
643
+
644
+ def constant_op(vector, constant, child_context, op = ->(a,b) { a + b }, switch = false)
645
+ eval_vector = complete_eval(vector, child_context)
646
+ constant = complete_eval(constant, child_context)
647
+
648
+ fail FullEvalNotPossible.new, "full eval not possible for #{eval_vector.name}" if eval_vector.kind_of?(Tensor) || constant.kind_of?(Tensor)
649
+
650
+ eval_vector.each_with_index.collect do |item, index|
651
+ c = constant.is_a?(Array) ? constant[index] : constant
652
+ if item.is_a?(Array)
653
+ constant_op(item, c, child_context, op, switch)
654
+ else
655
+ if item.respond_to?(:value)
656
+ switch ? op.(c, item.value) : op.(item.value, c)
657
+ else
658
+ switch ? op.(c, item) : op.(item, c)
659
+ end
660
+ end
661
+ end
662
+ end
663
+
664
+ def call_3way_vector_op(v_a, v_b, v_c, child_context, op = ->(a,b,c) { a + b + c})
665
+ return op.call(v_a, v_b, v_c) unless v_a.is_a?(Array)
666
+
667
+ v_a.each_with_index.collect do |v1, index|
668
+ v2 = v_b[index]
669
+ v3 = v_c[index]
670
+ if v1.is_a?(Array)
671
+ call_3way_vector_op(v1, v2, v3, child_context, op)
672
+ else
673
+ op.call(v1, v2, v3)
674
+ end
675
+ end
676
+ end
677
+
678
+ def vector_op(vector, vector2, child_context, op = ->(a,b) { a + b })
679
+ v_a = run(vector, child_context)
680
+ v_b = run(vector2, child_context)
681
+
682
+ if get_rank(v_a) < get_rank(v_b) # upgrade rank of A
683
+ duplicated = v_b.size.times.collect do
684
+ v_a
685
+ end
686
+ return vector_op(duplicated, v_b, child_context, op)
687
+ end
688
+
689
+ v_a.each_with_index.collect do |item, index|
690
+ next vector_op(item, v_b, child_context, op) if item.is_a?(Array) && get_rank(v_a) > get_rank(v_b)
691
+
692
+ z = index < v_b.size ? v_b[index] : v_b[0]
693
+
694
+ if item.is_a?(Array)
695
+ constant_op(item, z, child_context, op)
696
+ else
697
+ item.respond_to?(:value) ? op.call(item.value, z.value) : op.call(item, z)
698
+ end
699
+ end
700
+ end
701
+
702
+ def is_all_true(arr)
703
+ if arr.is_a?(Array)
704
+ arr.each do |a|
705
+ return false if !is_all_true(a)
706
+ end
707
+ return true
708
+ end
709
+
710
+ !!arr
711
+ end
712
+
713
+ def vector_add(vector, vector2, child_context)
714
+ v_a = run(vector, child_context)
715
+ v_b = run(vector2, child_context)
716
+
717
+ v_a.each_with_index.collect do |item, index|
718
+ if item.is_a?(Array)
719
+ constant_add(item, constant)
720
+ else
721
+ if item.respond_to?(:value)
722
+ item.value + v_b[index].value
723
+ else
724
+ item + v_b[index]
725
+ end
726
+ end
727
+ end
728
+ end
729
+
730
+ def generate_vector(shape, dtype: :float32, generator: )
731
+ if shape.is_a?(Integer)
732
+ shape.times.collect do
733
+ generator.call
734
+ end
735
+ elsif shape.size > 1
736
+ shape[0].times.collect do
737
+ generate_vector(shape[1..shape.size], generator: generator, dtype: dtype)
738
+ end
739
+ elsif shape.size == 1
740
+ shape[0].times.collect do
741
+ generator.call
742
+ end
743
+ elsif shape.size == 0
744
+ generator.call
745
+ end
746
+ end
747
+ end
748
+ end
749
+ end