tensor_stream 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rake_tasks~ +0 -0
  4. data/.rspec +2 -0
  5. data/.travis.yml +5 -0
  6. data/CODE_OF_CONDUCT.md +74 -0
  7. data/Gemfile +4 -0
  8. data/LICENSE.txt +21 -0
  9. data/README.md +123 -0
  10. data/Rakefile +6 -0
  11. data/bin/console +14 -0
  12. data/bin/setup +8 -0
  13. data/lib/tensor_stream.rb +138 -0
  14. data/lib/tensor_stream/control_flow.rb +23 -0
  15. data/lib/tensor_stream/evaluator/evaluator.rb +7 -0
  16. data/lib/tensor_stream/evaluator/operation_helpers/random_gaussian.rb +32 -0
  17. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +749 -0
  18. data/lib/tensor_stream/graph.rb +98 -0
  19. data/lib/tensor_stream/graph_keys.rb +5 -0
  20. data/lib/tensor_stream/helpers/op_helper.rb +58 -0
  21. data/lib/tensor_stream/math_gradients.rb +161 -0
  22. data/lib/tensor_stream/monkey_patches/integer.rb +0 -0
  23. data/lib/tensor_stream/nn/nn_ops.rb +17 -0
  24. data/lib/tensor_stream/operation.rb +195 -0
  25. data/lib/tensor_stream/ops.rb +225 -0
  26. data/lib/tensor_stream/placeholder.rb +21 -0
  27. data/lib/tensor_stream/session.rb +66 -0
  28. data/lib/tensor_stream/tensor.rb +317 -0
  29. data/lib/tensor_stream/tensor_shape.rb +25 -0
  30. data/lib/tensor_stream/train/gradient_descent_optimizer.rb +23 -0
  31. data/lib/tensor_stream/train/saver.rb +61 -0
  32. data/lib/tensor_stream/trainer.rb +7 -0
  33. data/lib/tensor_stream/types.rb +17 -0
  34. data/lib/tensor_stream/variable.rb +52 -0
  35. data/lib/tensor_stream/version.rb +7 -0
  36. data/samples/iris.data +150 -0
  37. data/samples/iris.rb +117 -0
  38. data/samples/linear_regression.rb +55 -0
  39. data/samples/raw_neural_net_sample.rb +54 -0
  40. data/tensor_stream.gemspec +40 -0
  41. metadata +185 -0
@@ -0,0 +1,7 @@
1
+
2
+ require 'tensor_stream/evaluator/ruby_evaluator'
3
+
4
+ module TensorStream
5
+ module Evaluator
6
+ end
7
+ end
@@ -0,0 +1,32 @@
1
+ # http://creativecommons.org/publicdomain/zero/1.0/
2
+ class RandomGaussian
3
+ def initialize(mean, stddev, rand_helper = lambda { Kernel.rand })
4
+ @rand_helper = rand_helper
5
+ @mean = mean
6
+ @stddev = stddev
7
+ @valid = false
8
+ @next = 0
9
+ end
10
+
11
+ def rand
12
+ if @valid then
13
+ @valid = false
14
+ return @next
15
+ else
16
+ @valid = true
17
+ x, y = self.class.gaussian(@mean, @stddev, @rand_helper)
18
+ @next = y
19
+ return x
20
+ end
21
+ end
22
+
23
+ private
24
+ def self.gaussian(mean, stddev, rand)
25
+ theta = 2 * Math::PI * rand.call
26
+ rho = Math.sqrt(-2 * Math.log(1 - rand.call))
27
+ scale = stddev * rho
28
+ x = mean + scale * Math.cos(theta)
29
+ y = mean + scale * Math.sin(theta)
30
+ return x, y
31
+ end
32
+ end
@@ -0,0 +1,749 @@
1
+ require "tensor_stream/evaluator/operation_helpers/random_gaussian"
2
+ require 'tensor_stream/math_gradients'
3
+
4
+ module TensorStream
5
+ module Evaluator
6
+ class FullEvalNotPossible < RuntimeError
7
+ end
8
+
9
+ # Errors during graph evaluation
10
+ class EvaluatorExcecutionException < RuntimeError
11
+ attr_reader :tensor
12
+
13
+ def initialize(exception, tensor)
14
+ @exception = exception
15
+ @tensor = tensor
16
+ end
17
+
18
+ def wrapped_exception
19
+ @exception
20
+ end
21
+ end
22
+
23
+ ## PURE ruby evaluator used for testing and development
24
+ class RubyEvaluator
25
+ attr_accessor :retain
26
+
27
+ include TensorStream::OpHelper
28
+
29
+ def initialize(session, context, thread_pool: nil)
30
+ @session = session
31
+ @context = context
32
+ @retain = context[:retain] || []
33
+ @thread_pool = thread_pool || Concurrent::ImmediateExecutor.new
34
+ end
35
+
36
+ def run(tensor, execution_context)
37
+ return tensor.map { |t| run(t, execution_context) } if tensor.is_a?(Array)
38
+
39
+ return tensor if retain.include?(tensor) # if var is in retain don't eval to value
40
+
41
+ child_context = execution_context.dup
42
+ res = if tensor.is_a?(Operation)
43
+ eval_operation(tensor, child_context)
44
+ elsif tensor.is_a?(Variable)
45
+ eval_variable(tensor, child_context)
46
+ elsif tensor.is_a?(Placeholder)
47
+ resolve_placeholder(tensor, child_context)
48
+ else
49
+ eval_tensor(tensor, child_context)
50
+ end
51
+ execution_context.deep_merge!(returns: child_context[:returns])
52
+ res
53
+ end
54
+
55
+ def complete_eval(tensor, context)
56
+ Kernel.loop do
57
+ old_tensor = tensor
58
+ tensor = run(tensor, context)
59
+
60
+ if tensor.is_a?(Array) && !tensor.empty? && tensor[0].is_a?(Tensor)
61
+ tensor = tensor.map { |t| complete_eval(t, context) }
62
+ end
63
+
64
+ return tensor if old_tensor.equal?(tensor)
65
+ return tensor unless tensor.is_a?(Tensor)
66
+ end
67
+ end
68
+
69
+ protected
70
+
71
+ def eval_variable(tensor, child_context)
72
+ raise "variable #{tensor.name} not initalized" if tensor.value.nil?
73
+ eval_tensor(tensor.value, child_context).tap do |val|
74
+ child_context[:returns] ||= {}
75
+ child_context[:returns][:vars] ||= []
76
+ child_context[:returns][:vars] << { name: tensor.name, val: val }
77
+ end
78
+ end
79
+
80
+ def eval_operation(tensor, child_context)
81
+ return @context[tensor.name] if @context.key?(tensor.name)
82
+
83
+ a = resolve_placeholder(tensor.items[0], child_context) if tensor.items && tensor.items[0]
84
+ b = resolve_placeholder(tensor.items[1], child_context) if tensor.items && tensor.items[1]
85
+
86
+ case tensor.operation
87
+ when :argmax
88
+ a = complete_eval(a, child_context)
89
+ axis = tensor.options[:axis] || 0
90
+
91
+ get_max_with_axis(a, axis, 0, tensor.data_type)
92
+ when :cast
93
+ a = complete_eval(a, child_context)
94
+
95
+ call_op(:cast, a, child_context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
96
+ when :sign
97
+ a = complete_eval(a, child_context)
98
+
99
+ func = lambda { |x, _b|
100
+ if x.zero? || (x.is_a?(Float) && x.nan?)
101
+ 0
102
+ elsif x < 0
103
+ -1
104
+ elsif x > 0
105
+ 1
106
+ else
107
+ fail 'assert: cannot be here'
108
+ end
109
+ }
110
+
111
+ call_op(:sign, a, child_context, func)
112
+ when :equal
113
+ a = complete_eval(a, child_context)
114
+ b = complete_eval(b, child_context)
115
+
116
+ call_vector_op(:greater, a, b, child_context, ->(t, u) { t == u })
117
+ when :not_equal
118
+ a = complete_eval(a, child_context)
119
+ b = complete_eval(b, child_context)
120
+
121
+ call_vector_op(:not_equal, a, b, child_context, ->(t, u) { t != u })
122
+ when :index
123
+ f = run(a, child_context)
124
+ index = run(b, child_context)
125
+
126
+ f[index]
127
+ when :slice
128
+ input = complete_eval(a, child_context)
129
+ start = complete_eval(b, child_context)
130
+ size = complete_eval(tensor.options[:size], child_context)
131
+ fail "start index and size not of the same shape #{start.size} != #{size.size}" if start.size != size.size
132
+ slice_tensor(input, start, size)
133
+ when :negate
134
+ call_vector_op(:negate, a, nil, child_context, ->(t, _u) { -t })
135
+ when :add
136
+ call_vector_op(:add, a, b, child_context, ->(t, u) { t + u })
137
+ when :sub
138
+ call_vector_op(:sub, a, b, child_context, ->(t, u) { t - u })
139
+ when :mul
140
+ call_vector_op(:mul, a, b, child_context, ->(t, u) { binding.pry if t.nil? || u.nil?; t * u })
141
+ when :pow
142
+ call_vector_op(:pow, a, b, child_context, ->(t, u) { t**u })
143
+ when :concat
144
+ values = complete_eval(a, child_context)
145
+ concat_array(values, tensor.options[:axis])
146
+ when :abs
147
+ call_op(:abs, a, child_context, ->(t, _b) { t.abs })
148
+ when :tanh
149
+ call_op(:tanh, a, child_context, ->(t, _b) { Math.tanh(t) })
150
+ when :tan
151
+ call_op(:tan, a, child_context, ->(t, _b) { Math.tan(t) })
152
+ when :sec
153
+ call_op(:sec, a, child_context, ->(t, _b) { Math.sec(t) })
154
+ when :sin
155
+ call_op(:sin, a, child_context, ->(t, _b) { Math.sin(t) })
156
+ when :cos
157
+ call_op(:cos, a, child_context, ->(t, _b) { Math.cos(t) })
158
+ when :log
159
+ call_op(:log, a, child_context, ->(t, _b) { t < 0 ? Float::NAN : Math.log(t)} )
160
+ when :exp
161
+ call_op(:exp, a, child_context, ->(t, _b) { Math.exp(t) } )
162
+ when :sqrt
163
+ call_op(:exp, a, child_context, ->(t, _b) { Math.sqrt(t) } )
164
+ when :square
165
+ call_op(:square, a, child_context, ->(t, _b) { t * t } )
166
+ when :stop_gradient
167
+ run(a, child_context)
168
+ when :random_uniform
169
+ maxval = tensor.options.fetch(:maxval, 1)
170
+ minval = tensor.options.fetch(:minval, 0)
171
+
172
+ generator = -> { rand * (maxval - minval) + minval }
173
+ generate_vector(tensor.options[:shape], generator: generator)
174
+ when :random_normal
175
+ r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev))
176
+ generator = -> { r.rand }
177
+
178
+ generate_vector(tensor.options[:shape], generator: generator)
179
+ when :flow_group
180
+ threads = tensor.items.collect { |item| Concurrent::Future.execute(executor: @thread_pool) { run(item, child_context) } }
181
+ threads.collect(&:value)
182
+ when :assign
183
+ assign = tensor.items[0] || tensor
184
+ assign.value = complete_eval(tensor.items[1], child_context)
185
+ assign.value
186
+ when :assign_add
187
+ tensor.items[0].value = process_vector_math_op(tensor.items[0], tensor.items[1], child_context, ->(a,b) { a + b })
188
+ tensor.items[0].value
189
+ when :assign_sub
190
+ tensor.items[0].value = process_vector_math_op(tensor.items[0], tensor.items[1], child_context, ->(a,b) { a - b })
191
+ tensor.items[0].value
192
+ when :reduce_mean
193
+ c = tensor.data_type == :float ? 0.0 : 0
194
+ func = lambda { |v|
195
+ if v.is_a?(Array)
196
+ v.empty? ? c : (v.reduce(:+) / v.size)
197
+ else
198
+ v
199
+ end
200
+ }
201
+
202
+ reduction(child_context, tensor, func)
203
+ when :reduce_sum
204
+ c = tensor.data_type == :float ? 0.0 : 0
205
+ func = ->(v) {
206
+ if v.kind_of?(Array)
207
+ v.empty? ? c : v.reduce(:+)
208
+ else
209
+ v
210
+ end
211
+ }
212
+
213
+ reduction(child_context, tensor, func)
214
+ when :reduce_prod
215
+ c = tensor.data_type == :float ? 1.0 : 1
216
+ func = ->(v) {
217
+ if v.kind_of?(Array)
218
+ v.empty? ? c : v.reduce(:*)
219
+ else
220
+ v
221
+ end
222
+ }
223
+
224
+ reduction(child_context, tensor, func)
225
+ when :transpose
226
+ matrix_a = complete_eval(a, child_context)
227
+ matrix_a.transpose
228
+ when :eye
229
+ rows = complete_eval(a, child_context)
230
+ columns = complete_eval(b, child_context)
231
+
232
+ Array.new(rows) do |i|
233
+ Array.new(columns) do |col|
234
+ if tensor.data_type == :float32
235
+ i == col ? 1.0 : 0.0
236
+ else
237
+ i == col ? 1 : 0
238
+ end
239
+ end
240
+ end
241
+ when :cond
242
+ pred = complete_eval(tensor.options[:pred], child_context)
243
+
244
+ if is_all_true(pred)
245
+ complete_eval(a, child_context)
246
+ else
247
+ complete_eval(b, child_context)
248
+ end
249
+ when :where
250
+ pred = complete_eval(tensor.options[:pred], child_context)
251
+ a = complete_eval(a, child_context)
252
+ b = complete_eval(b, child_context)
253
+
254
+ call_3way_vector_op(pred, a, b, child_context, ->(t, u, v) { t ? u : v })
255
+ when :less
256
+ a = complete_eval(a, child_context)
257
+ b = complete_eval(b, child_context)
258
+
259
+ call_vector_op(:greater, a, b, child_context, ->(t, u) { t < u })
260
+ when :greater
261
+ a = complete_eval(a, child_context)
262
+ b = complete_eval(b, child_context)
263
+
264
+ call_vector_op(:greater, a, b, child_context, ->(t, u) { t > u })
265
+ when :zeros, :ones, :zeros_like, :ones_like
266
+
267
+ shape = if %i[zeros_like ones_like].include?(tensor.operation)
268
+ a = complete_eval(a, child_context)
269
+ shape_eval(a)
270
+ else
271
+ complete_eval(a, child_context) || tensor.shape.shape
272
+ end
273
+
274
+ func = if %i[zeros zeros_like].include?(tensor.operation)
275
+ -> { tensor.data_type == :int32 ? 0 : 0.0 }
276
+ else
277
+ -> { tensor.data_type == :int32 ? 1 : 1.0 }
278
+ end
279
+
280
+ if shape.is_a?(Array) && shape.size.zero?
281
+ func.call()
282
+ else
283
+ shape = [shape.to_i] unless shape.is_a?(Array)
284
+ generate_vector(shape, generator: func)
285
+ end
286
+ when :shape
287
+ input = complete_eval(a, child_context)
288
+
289
+ shape_eval(input)
290
+ when :matmul
291
+ matrix_a = complete_eval(a, child_context)
292
+ matrix_b = complete_eval(b, child_context)
293
+
294
+ rank_a = get_rank(matrix_a)
295
+ rank_b = get_rank(matrix_b)
296
+
297
+ raise "#{a.name} rank must be greater than 1" if rank_a < 2
298
+ raise "#{b.name} rank must be greater than 1" if rank_b < 2
299
+
300
+ matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
301
+ matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
302
+
303
+ # handle matrix multiplication with constants like 1 or 0
304
+ matrix_a = matmul_const_transform(matrix_a, matrix_b, tensor)
305
+ matrix_b = matmul_const_transform(matrix_b, matrix_a, tensor)
306
+
307
+ # check matrix dimensions
308
+ raise "incompatible shape sizes for matrix multiplication (#{matrix_a[0].size} != #{matrix_b.size}) #{shape_eval(matrix_a)} vs #{shape_eval(matrix_b)}" if matrix_a[0].size != matrix_b.size
309
+
310
+ (Matrix[*matrix_a] * Matrix[*matrix_b]).to_a
311
+ when :gradients
312
+ b.collect do |xs|
313
+ fail "#{xs} passed is not a tensor object" unless xs.is_a?(Tensor)
314
+ xs_val = complete_eval(xs, child_context)
315
+ target_shape = shape_eval(xs_val)
316
+
317
+ stops = tensor.options[:stop_gradients] ? tensor.options[:stop_gradients].map(&:name).join('_') : ''
318
+ gradient_program_name = "grad_#{tensor.name}_#{xs.name}_#{stops}".to_sym
319
+
320
+ tensor_program = if tensor.graph.node_added?(gradient_program_name)
321
+ tensor.graph.get_node(gradient_program_name)
322
+ else
323
+ derivative_ops = TensorStream::MathGradients.derivative(a, xs, graph: tensor.graph, stop_gradients: tensor.options[:stop_gradients], target_shape: target_shape)
324
+ unit_matrix = op(:ones_like, xs)
325
+ tensor.graph.add_node!(gradient_program_name, unit_matrix * derivative_ops)
326
+ end
327
+
328
+ complete_eval(tensor_program, child_context)
329
+ end
330
+ when :identity
331
+ complete_eval(a, child_context)
332
+ when :print
333
+ a = complete_eval(a, child_context)
334
+ b = complete_eval(b, child_context)
335
+ puts "#{tensor.options.fetch(:message, '')} #{b}"
336
+ a
337
+ when :rank
338
+ a = complete_eval(a, child_context)
339
+ get_rank(a)
340
+ when :div
341
+ process_vector_math_op(a, b, child_context, ->(a,b) { a/b })
342
+ when :reshape
343
+ arr = complete_eval(a, child_context)
344
+ new_shape = complete_eval(b, child_context)
345
+
346
+ flat_arr = arr.flatten
347
+ return flat_arr[0] if new_shape.size == 0 && flat_arr.size == 1
348
+
349
+ new_shape = fix_inferred_elements(new_shape, flat_arr.size)
350
+
351
+ reshape(flat_arr, new_shape)
352
+ when :pad
353
+ a = complete_eval(a, child_context)
354
+ p = complete_eval(tensor.options[:paddings], child_context)
355
+
356
+ arr_pad(a, p, tensor.data_type)
357
+ when :max
358
+ a = complete_eval(a, child_context)
359
+ b = complete_eval(b, child_context)
360
+
361
+ call_vector_op(:max, a, b, child_context, ->(t, u) { [t, u].max })
362
+ else
363
+ fail "unknown op #{tensor.operation}"
364
+ end.tap do |result|
365
+ if tensor.breakpoint
366
+ a = complete_eval(a, child_context)
367
+ b = complete_eval(b, child_context)
368
+
369
+ tensor.breakpoint.call(tensor, a, b, complete_eval(result, child_context))
370
+ end
371
+ @context[tensor.name] = result
372
+ end
373
+ rescue EvaluatorExcecutionException => e
374
+ raise e
375
+ rescue StandardError => e
376
+ # a = complete_eval(a, child_context)
377
+ # b = complete_eval(b, child_context)
378
+ # puts "name: #{tensor.given_name}"
379
+ # puts "op: #{tensor.to_math(true, 1)}"
380
+ # puts "A: #{a}" if a
381
+ # puts "B: #{b}" if b
382
+ # binding.pry
383
+ puts e.backtrace.join("\n")
384
+ raise EvaluatorExcecutionException.new(e, tensor), "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math} defined at #{tensor.source}"
385
+ end
386
+
387
+ def eval_tensor(tensor, child_context)
388
+ return tensor unless tensor.is_a?(Tensor)
389
+ return @context[tensor.name] if @context.key?(tensor.name)
390
+
391
+ if tensor.value.is_a?(Array)
392
+ tensor.value.collect do |item|
393
+ item.is_a?(Tensor) ? run(item, child_context) : item
394
+ end
395
+ else
396
+ tensor.value.is_a?(Tensor) ? run(tensor.value, child_context) : tensor.value
397
+ end.tap do |result|
398
+ @context[tensor.name] = result
399
+ end
400
+ end
401
+
402
+ private
403
+
404
+ def get_max_with_axis(a, target_axis, current_axis, output_type)
405
+ if target_axis == current_axis
406
+ if a[0].is_a?(Array)
407
+ (0...a[0].size).each.collect do |column_index|
408
+ max = nil
409
+ max_index = 0
410
+ a.each_with_index do |row, row_index|
411
+ if max.nil? || row[column_index] > max
412
+ max = row[column_index]
413
+ max_index = row_index
414
+ end
415
+ end
416
+
417
+ Tensor.cast_dtype(max_index, output_type)
418
+ end
419
+ else
420
+ max = nil
421
+ max_index = 0
422
+ a.each_with_index do |a, index|
423
+ if max.nil? || a > max
424
+ max = a
425
+ max_index = index
426
+ end
427
+ end
428
+ Tensor.cast_dtype(max_index, output_type)
429
+ end
430
+ else
431
+ a.collect do |row|
432
+ get_max_with_axis(row, target_axis, current_axis + 1, output_type)
433
+ end
434
+ end
435
+ end
436
+
437
+ def reduction(child_context, tensor, func)
438
+ val = complete_eval(tensor.items[0], child_context)
439
+ axis = tensor.options[:axis]
440
+ keep_dims = tensor.options[:keepdims]
441
+
442
+ res = if axis.is_a?(Array)
443
+ axis.each do |x|
444
+ val = reduce_axis(x, val, keep_dims, child_context, func)
445
+ end
446
+
447
+ func.call(val.flatten)
448
+ else
449
+ reduce_axis(axis, val, keep_dims, child_context, func)
450
+ end
451
+ res
452
+ end
453
+
454
+ def arr_pad(arr, paddings, data_type = :float32, rank = 0)
455
+ fail "padding #{paddings[rank]} needs to have to elements [before, after]" if paddings[rank].size != 2
456
+
457
+ before = paddings[rank][0]
458
+ after = paddings[rank][1]
459
+
460
+ if arr[0].is_a?(Array)
461
+ next_dim_elem = arr.collect { |a| arr_pad(a, paddings, data_type, rank + 1) }
462
+ padding = deep_dup_array(next_dim_elem[0], data_type == :float32 ? 0.0 : 0)
463
+ before.times.map { padding } + next_dim_elem + after.times.map { padding }
464
+ else
465
+ before.times.map { data_type == :float32 ? 0.0 : 0 } + arr + after.times.map { data_type == :float32 ? 0.0 : 0 }
466
+ end
467
+ end
468
+
469
+ def deep_dup_array(arr, value = nil)
470
+ if arr.is_a?(Array)
471
+ arr.dup.collect do |a|
472
+ deep_dup_array(a, value)
473
+ end
474
+ else
475
+ value.nil? ? arr : value
476
+ end
477
+ end
478
+
479
+ def slice_tensor(input, start, size)
480
+ start_index = start.shift
481
+ dimen_size = start_index + size.shift
482
+
483
+ input[start_index...dimen_size].collect do |item|
484
+ if item.is_a?(Array)
485
+ slice_tensor(item, start.dup, size.dup)
486
+ else
487
+ item
488
+ end
489
+ end
490
+ end
491
+
492
+ def matmul_const_transform(mat, mat_b, tensor)
493
+ if !mat.is_a?(Array)
494
+ compat_shape = shape_eval(mat_b).reverse
495
+ func = ->() { tensor.data_type == :int32 ? mat.to_i : mat.to_f }
496
+
497
+ generate_vector(compat_shape, generator: func)
498
+ else
499
+ mat
500
+ end
501
+ end
502
+
503
+ def fix_inferred_elements(shape, total_size)
504
+ return shape if shape.empty?
505
+
506
+ current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product }
507
+ inferred_size = total_size / current_size
508
+ shape.map { |s| s == -1 ? inferred_size : s }
509
+ end
510
+
511
+ def reshape(arr, new_shape)
512
+ return arr if new_shape.empty?
513
+
514
+ s = new_shape.shift
515
+
516
+ if new_shape.size == 0
517
+ fail "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s
518
+ return arr
519
+ end
520
+
521
+ dim = (arr.size / s)
522
+ arr.each_slice(dim).collect do |slice|
523
+ reshape(slice, new_shape.dup)
524
+ end
525
+ end
526
+
527
+ def call_op(op, a, child_context, func)
528
+ a = complete_eval(a, child_context)
529
+ process_function_op(a, child_context, func)
530
+ rescue FullEvalNotPossible
531
+ TensorStream.send(op.to_sym, a)
532
+ end
533
+
534
+ def call_vector_op(op, a, b, child_context, func)
535
+ process_vector_math_op(a, b, child_context, func)
536
+ rescue FullEvalNotPossible
537
+ TensorStream.send(op.to_sym, a, b)
538
+ end
539
+
540
+ def process_vector_math_op(a, b, child_context, op)
541
+ eval_a = complete_eval(a, child_context) unless a.nil?
542
+ eval_b = complete_eval(b, child_context) unless b.nil?
543
+
544
+ fail FullEvalNotPossible.new, "full eval not possible for #{a.name}" if eval_a.is_a?(Tensor) || eval_b.kind_of?(Tensor)
545
+
546
+ # ruby scalar
547
+ if get_rank(eval_a) == 0
548
+ if (get_rank(eval_b)) == 0
549
+ op.call(eval_a,eval_b)
550
+ else
551
+ constant_op(eval_b, eval_a, child_context, op, true)
552
+ end
553
+ elsif get_rank(eval_a) > 0
554
+ if get_rank(eval_b) > 0
555
+ vector_op(eval_a, eval_b, child_context, op)
556
+ else
557
+ constant_op(eval_a, eval_b, child_context, op)
558
+ end
559
+ end
560
+ end
561
+
562
+ def get_rank(value, rank = 0)
563
+ return rank unless value.is_a?(Array)
564
+ return rank + 1 if value.size == 0
565
+
566
+ get_rank(value[0], rank + 1)
567
+ end
568
+
569
+ def concat_array(values, axis)
570
+ combined_array = values.shift
571
+ axis = get_rank(combined_array) - 1 if axis == -1
572
+
573
+ values.each do |v|
574
+ combined_array = concat(combined_array, v, axis)
575
+ end
576
+ combined_array
577
+ end
578
+
579
+ def concat(a, b, axis)
580
+ if axis == 0
581
+ a + b
582
+ else
583
+ a.each_with_index.collect do |i, index|
584
+ concat(i, b[index], axis - 1)
585
+ end
586
+ end
587
+ end
588
+
589
+ def process_function_op(a, child_context, op)
590
+ # ruby scalar
591
+ if (a.kind_of?(Tensor) && a.shape.rank > 0) || a.kind_of?(Array)
592
+ constant_op(a, 0, child_context, op)
593
+ elsif !a.kind_of?(Tensor) || a.shape.rank == 0
594
+ v = run(a, child_context)
595
+ fail FullEvalNotPossible.new, "full eval not possible for #{v.name}" if v.is_a?(Tensor) && !v.is_const
596
+
597
+ op.call(v, 0)
598
+ else
599
+ fail 'cannot be here'
600
+ end
601
+ end
602
+
603
+ def resolve_placeholder(placeholder, execution_context = {})
604
+ return nil if placeholder.nil?
605
+ return placeholder if retain.include?(placeholder)
606
+
607
+ var = if placeholder.kind_of?(Placeholder)
608
+ @context[placeholder.name.to_sym].tap do |c|
609
+ if c.nil?
610
+ raise "missing placeholder #{placeholder.name}"
611
+ end
612
+ end
613
+ else
614
+ placeholder
615
+ end
616
+
617
+ return var unless placeholder.kind_of?(Tensor)
618
+ Tensor.cast_dtype(var, placeholder.data_type)
619
+ end
620
+
621
+ def reduce_axis(axis, val, keep_dims, child_context, op = ->(v) { v.kind_of?(Array) ? v.reduce(:+) : v })
622
+ val = run(val, child_context)
623
+ return val.is_a?(Array) ? op.call(val.flatten) : val if axis.nil?
624
+ return val.transpose.collect { |v| keep_dims ? [op.call(v)] : op.call(v) } if axis == 0
625
+ return val.collect { |v| keep_dims ? [op.call(v)] : op.call(v) } if axis == 1
626
+
627
+ fail "can't handle with axis > 1 :("
628
+ end
629
+
630
+ def constant_add(vector, constant)
631
+ run(vector).collect do |item|
632
+ if item.is_a?(Array)
633
+ constant_add(item, constant)
634
+ else
635
+ if item.respond_to?(:value)
636
+ item.value + constant
637
+ else
638
+ item + constant
639
+ end
640
+ end
641
+ end
642
+ end
643
+
644
+ def constant_op(vector, constant, child_context, op = ->(a,b) { a + b }, switch = false)
645
+ eval_vector = complete_eval(vector, child_context)
646
+ constant = complete_eval(constant, child_context)
647
+
648
+ fail FullEvalNotPossible.new, "full eval not possible for #{eval_vector.name}" if eval_vector.kind_of?(Tensor) || constant.kind_of?(Tensor)
649
+
650
+ eval_vector.each_with_index.collect do |item, index|
651
+ c = constant.is_a?(Array) ? constant[index] : constant
652
+ if item.is_a?(Array)
653
+ constant_op(item, c, child_context, op, switch)
654
+ else
655
+ if item.respond_to?(:value)
656
+ switch ? op.(c, item.value) : op.(item.value, c)
657
+ else
658
+ switch ? op.(c, item) : op.(item, c)
659
+ end
660
+ end
661
+ end
662
+ end
663
+
664
+ def call_3way_vector_op(v_a, v_b, v_c, child_context, op = ->(a,b,c) { a + b + c})
665
+ return op.call(v_a, v_b, v_c) unless v_a.is_a?(Array)
666
+
667
+ v_a.each_with_index.collect do |v1, index|
668
+ v2 = v_b[index]
669
+ v3 = v_c[index]
670
+ if v1.is_a?(Array)
671
+ call_3way_vector_op(v1, v2, v3, child_context, op)
672
+ else
673
+ op.call(v1, v2, v3)
674
+ end
675
+ end
676
+ end
677
+
678
+ def vector_op(vector, vector2, child_context, op = ->(a,b) { a + b })
679
+ v_a = run(vector, child_context)
680
+ v_b = run(vector2, child_context)
681
+
682
+ if get_rank(v_a) < get_rank(v_b) # upgrade rank of A
683
+ duplicated = v_b.size.times.collect do
684
+ v_a
685
+ end
686
+ return vector_op(duplicated, v_b, child_context, op)
687
+ end
688
+
689
+ v_a.each_with_index.collect do |item, index|
690
+ next vector_op(item, v_b, child_context, op) if item.is_a?(Array) && get_rank(v_a) > get_rank(v_b)
691
+
692
+ z = index < v_b.size ? v_b[index] : v_b[0]
693
+
694
+ if item.is_a?(Array)
695
+ constant_op(item, z, child_context, op)
696
+ else
697
+ item.respond_to?(:value) ? op.call(item.value, z.value) : op.call(item, z)
698
+ end
699
+ end
700
+ end
701
+
702
+ def is_all_true(arr)
703
+ if arr.is_a?(Array)
704
+ arr.each do |a|
705
+ return false if !is_all_true(a)
706
+ end
707
+ return true
708
+ end
709
+
710
+ !!arr
711
+ end
712
+
713
+ def vector_add(vector, vector2, child_context)
714
+ v_a = run(vector, child_context)
715
+ v_b = run(vector2, child_context)
716
+
717
+ v_a.each_with_index.collect do |item, index|
718
+ if item.is_a?(Array)
719
+ constant_add(item, constant)
720
+ else
721
+ if item.respond_to?(:value)
722
+ item.value + v_b[index].value
723
+ else
724
+ item + v_b[index]
725
+ end
726
+ end
727
+ end
728
+ end
729
+
730
+ def generate_vector(shape, dtype: :float32, generator: )
731
+ if shape.is_a?(Integer)
732
+ shape.times.collect do
733
+ generator.call
734
+ end
735
+ elsif shape.size > 1
736
+ shape[0].times.collect do
737
+ generate_vector(shape[1..shape.size], generator: generator, dtype: dtype)
738
+ end
739
+ elsif shape.size == 1
740
+ shape[0].times.collect do
741
+ generator.call
742
+ end
743
+ elsif shape.size == 0
744
+ generator.call
745
+ end
746
+ end
747
+ end
748
+ end
749
+ end