tensor_stream 0.4.1 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (62) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +38 -17
  4. data/benchmark/benchmark.rb +16 -20
  5. data/lib/tensor_stream/control_flow.rb +3 -3
  6. data/lib/tensor_stream/debugging/debugging.rb +4 -4
  7. data/lib/tensor_stream/device.rb +5 -2
  8. data/lib/tensor_stream/evaluator/base_evaluator.rb +138 -0
  9. data/lib/tensor_stream/evaluator/buffer.rb +7 -2
  10. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/_bool_operand.cl +3 -3
  11. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/_operand.cl +0 -0
  12. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/abs.cl +0 -0
  13. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/add.cl +1 -1
  14. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/argmax.cl +0 -0
  15. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/argmin.cl +0 -0
  16. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/cast.cl +0 -0
  17. data/lib/tensor_stream/evaluator/opencl/kernels/cond.cl.erb +6 -0
  18. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/cos.cl +0 -0
  19. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/div.cl.erb +1 -1
  20. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/exp.cl +0 -0
  21. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/gemm.cl +0 -0
  22. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/log.cl +0 -0
  23. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/log1p.cl +0 -0
  24. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/max.cl +3 -3
  25. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/mul.cl +1 -1
  26. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/negate.cl +0 -0
  27. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/pow.cl +3 -3
  28. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/reciprocal.cl +0 -0
  29. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/round.cl +0 -0
  30. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sigmoid.cl +0 -0
  31. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sigmoid_grad.cl +3 -3
  32. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sign.cl +1 -1
  33. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sin.cl +0 -0
  34. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/softmax.cl +0 -0
  35. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/softmax_grad.cl +0 -0
  36. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sqrt.cl +0 -0
  37. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/square.cl +0 -0
  38. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/sub.cl +1 -1
  39. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tan.cl +0 -0
  40. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tanh.cl +0 -0
  41. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/tanh_grad.cl +0 -0
  42. data/lib/tensor_stream/evaluator/{kernels → opencl/kernels}/where.cl +1 -1
  43. data/lib/tensor_stream/evaluator/{opencl_buffer.rb → opencl/opencl_buffer.rb} +1 -1
  44. data/lib/tensor_stream/evaluator/opencl/opencl_device.rb +5 -0
  45. data/lib/tensor_stream/evaluator/{opencl_evaluator.rb → opencl/opencl_evaluator.rb} +404 -452
  46. data/lib/tensor_stream/evaluator/{opencl_template_helper.rb → opencl/opencl_template_helper.rb} +6 -6
  47. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +21 -21
  48. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +492 -398
  49. data/lib/tensor_stream/graph.rb +21 -1
  50. data/lib/tensor_stream/graph_serializers/graphml.rb +59 -59
  51. data/lib/tensor_stream/graph_serializers/pbtext.rb +1 -1
  52. data/lib/tensor_stream/helpers/op_helper.rb +6 -2
  53. data/lib/tensor_stream/math_gradients.rb +7 -7
  54. data/lib/tensor_stream/operation.rb +100 -100
  55. data/lib/tensor_stream/session.rb +81 -8
  56. data/lib/tensor_stream/tensor.rb +7 -5
  57. data/lib/tensor_stream/utils.rb +32 -19
  58. data/lib/tensor_stream/version.rb +1 -1
  59. data/tensor_stream.gemspec +0 -1
  60. data/test_samples/raw_neural_net_sample.rb +7 -7
  61. metadata +41 -53
  62. data/lib/tensor_stream/evaluator/kernels/cond.cl.erb +0 -5
@@ -14,7 +14,7 @@ class OpenclTemplateHelper
14
14
  ERB.new(@source, nil, '%').result(current_scope)
15
15
  end
16
16
 
17
- def is_floating_point?(dtype)
17
+ def floating_point?(dtype)
18
18
  TensorStream::Ops::FLOATING_POINT_TYPES.include?(dtype)
19
19
  end
20
20
 
@@ -22,14 +22,14 @@ class OpenclTemplateHelper
22
22
  filename = File.join(File.dirname(__FILE__), 'kernels', "_#{template}")
23
23
  source = File.read(filename)
24
24
  current_scope = binding
25
- locals.each do |k,v|
25
+ locals.each do |k, v|
26
26
  current_scope.local_variable_set(k.to_sym, v)
27
27
  end
28
28
  ERB.new(source, nil, '%').result(current_scope)
29
29
  end
30
30
 
31
31
  def dtype_to_c_type(dtype)
32
- case(dtype.to_s)
32
+ case dtype.to_s
33
33
  when 'float64'
34
34
  'double'
35
35
  when 'float32', 'float'
@@ -39,14 +39,14 @@ class OpenclTemplateHelper
39
39
  when 'int16'
40
40
  'short'
41
41
  when 'boolean'
42
- 'int'
42
+ 'short'
43
43
  else
44
44
  raise "unknown dtype #{dtype}"
45
45
  end
46
46
  end
47
47
 
48
48
  def min_value_for(dtype)
49
- case(dtype.to_s)
49
+ case dtype.to_s
50
50
  when 'float64'
51
51
  'DBL_MIN'
52
52
  when 'float32', 'float'
@@ -63,7 +63,7 @@ class OpenclTemplateHelper
63
63
  end
64
64
 
65
65
  def operator_to_c(op)
66
- case(op)
66
+ case op
67
67
  when 'less'
68
68
  '<'
69
69
  when 'less_equal'
@@ -6,11 +6,11 @@ module TensorStream
6
6
  start_index = start.shift
7
7
  dimen_size = start_index + size.shift
8
8
 
9
- input[start_index...dimen_size].collect do |item|
10
- if item.is_a?(Array)
11
- slice_tensor(item, start.dup, size.dup)
9
+ input[start_index...dimen_size].collect do |input|
10
+ if input.is_a?(Array)
11
+ slice_tensor(input, start.dup, size.dup)
12
12
  else
13
- item
13
+ input
14
14
  end
15
15
  end
16
16
  end
@@ -72,8 +72,8 @@ module TensorStream
72
72
  d = dims.shift
73
73
 
74
74
  if input.is_a?(Array) && (get_rank(input) - 1) == dims.size
75
- row_to_dup = input.collect do |item|
76
- broadcast_dimensions(item, dims.dup)
75
+ row_to_dup = input.collect do |input|
76
+ broadcast_dimensions(input, dims.dup)
77
77
  end
78
78
 
79
79
  row_to_dup + Array.new(d) { row_to_dup }.flatten(1)
@@ -95,8 +95,8 @@ module TensorStream
95
95
 
96
96
  return op.call(vector, vector2) unless vector.is_a?(Array)
97
97
 
98
- vector.each_with_index.collect do |item, index|
99
- next vector_op(item, vector2, op, switch) if item.is_a?(Array) && get_rank(vector) > get_rank(vector2)
98
+ vector.each_with_index.collect do |input, index|
99
+ next vector_op(input, vector2, op, switch) if input.is_a?(Array) && get_rank(vector) > get_rank(vector2)
100
100
 
101
101
  if safe && vector2.is_a?(Array)
102
102
  next nil if vector2.size != 1 && index >= vector2.size
@@ -113,10 +113,10 @@ module TensorStream
113
113
  vector2
114
114
  end
115
115
 
116
- if item.is_a?(Array)
117
- vector_op(item, z, op, switch)
116
+ if input.is_a?(Array)
117
+ vector_op(input, z, op, switch)
118
118
  else
119
- switch ? op.call(z, item) : op.call(item, z)
119
+ switch ? op.call(z, input) : op.call(input, z)
120
120
  end
121
121
  end.compact
122
122
  end
@@ -173,11 +173,11 @@ module TensorStream
173
173
  arr.map { |a| Math.exp(a - arr.max) }.reduce(:+)
174
174
  end
175
175
 
176
- arr.collect do |item|
177
- if item.is_a?(Array)
178
- softmax(item)
176
+ arr.collect do |input|
177
+ if input.is_a?(Array)
178
+ softmax(input)
179
179
  else
180
- Math.exp(item - arr.max) / sum
180
+ Math.exp(input - arr.max) / sum
181
181
  end
182
182
  end
183
183
  end
@@ -185,15 +185,15 @@ module TensorStream
185
185
  def softmax_grad(arr)
186
186
  return arr if arr.empty?
187
187
 
188
- arr.each_with_index.collect do |item, index|
189
- if item.is_a?(Array)
190
- softmax_grad(item)
188
+ arr.each_with_index.collect do |input, index|
189
+ if input.is_a?(Array)
190
+ softmax_grad(input)
191
191
  else
192
- arr.each_with_index.collect do |item2, index2|
192
+ arr.each_with_index.collect do |input2, index2|
193
193
  if index != index2
194
- -item * item2
194
+ -input * input2
195
195
  else
196
- item * (1.0 - item)
196
+ input * (1.0 - input)
197
197
  end
198
198
  end
199
199
  end
@@ -1,7 +1,7 @@
1
1
  require 'tensor_stream/evaluator/operation_helpers/random_gaussian'
2
2
  require 'tensor_stream/evaluator/operation_helpers/array_ops_helper'
3
3
  require 'tensor_stream/evaluator/operation_helpers/math_helper'
4
- require 'distribution'
4
+ require 'tensor_stream/evaluator/base_evaluator'
5
5
 
6
6
  module TensorStream
7
7
  module Evaluator
@@ -23,30 +23,18 @@ module TensorStream
23
23
  end
24
24
 
25
25
  ## PURE ruby evaluator used for testing and development
26
- class RubyEvaluator
26
+ class RubyEvaluator < BaseEvaluator
27
27
  attr_accessor :retain
28
28
 
29
29
  include TensorStream::OpHelper
30
30
  include TensorStream::ArrayOpsHelper
31
31
  include TensorStream::MathHelper
32
32
 
33
- def initialize(session, context, thread_pool: nil, log_intermediates: false)
34
- @session = session
35
- @context = context
36
- @log_intermediates = log_intermediates
37
- @retain = context[:retain] || []
38
- @thread_pool = thread_pool || Concurrent::ImmediateExecutor.new
39
-
40
- @context[:compute_history] = [] if log_intermediates
41
- end
42
-
43
33
  def run(tensor, execution_context)
44
34
  if tensor.is_a?(Array) && tensor.size > 0 && tensor[0].is_a?(Tensor)
45
35
  return tensor.map { |t| run(t, execution_context) }
46
36
  end
47
37
 
48
- return tensor if retain.include?(tensor) # if var is in retain don't eval to value
49
-
50
38
  tensor = tensor.call if tensor.is_a?(Proc)
51
39
 
52
40
  child_context = execution_context.dup
@@ -63,6 +51,13 @@ module TensorStream
63
51
  res
64
52
  end
65
53
 
54
+ def run_with_buffer(tensor, context, execution_context)
55
+ @context = context
56
+ @context[:_cache][:_cl_buffers] ||= {} if context[:_cache]
57
+ result = run(tensor, execution_context)
58
+ TensorStream::Buffer.new(data_type: tensor.data_type, buffer: result)
59
+ end
60
+
66
61
  def complete_eval(tensor, context)
67
62
  Kernel.loop do
68
63
  old_tensor = tensor
@@ -77,6 +72,18 @@ module TensorStream
77
72
 
78
73
  protected
79
74
 
75
+ def prepare_input(tensor, context, options = {})
76
+ return nil unless tensor
77
+ tensor = resolve_placeholder(tensor)
78
+ if options[:noop]
79
+ tensor
80
+ elsif options[:no_eval]
81
+ run(tensor, context)
82
+ else
83
+ complete_eval(tensor, context)
84
+ end
85
+ end
86
+
80
87
  def eval_variable(tensor, child_context)
81
88
  value = tensor.read_value
82
89
  if value.nil?
@@ -89,404 +96,486 @@ module TensorStream
89
96
  end
90
97
  end
91
98
 
92
- def eval_operation(tensor, child_context)
93
- return @context[tensor.name] if @context.key?(tensor.name)
94
- a = resolve_placeholder(tensor.items[0], child_context) if tensor.items && tensor.items[0]
95
- b = resolve_placeholder(tensor.items[1], child_context) if tensor.items && tensor.items[1]
96
- # puts tensor.name
97
- case tensor.operation
98
- when :const
99
- complete_eval(a, child_context)
100
- when :argmax
101
- a = complete_eval(a, child_context)
102
- axis = tensor.options[:axis] || 0
103
-
104
- get_op_with_axis(a, axis, 0, tensor.data_type)
105
- when :argmin
106
- a = complete_eval(a, child_context)
107
- axis = tensor.options[:axis] || 0
108
-
109
- get_op_with_axis(a, axis, 0, tensor.data_type, ->(a, b) { a < b })
110
- when :cast
111
- a = complete_eval(a, child_context)
112
-
113
- call_op(:cast, a, child_context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
114
- when :sign
115
- a = complete_eval(a, child_context)
116
-
117
- func = lambda { |x, _b|
118
- if x.zero? || (x.is_a?(Float) && x.nan?)
119
- 0
120
- elsif x < 0
121
- -1
122
- elsif x > 0
123
- 1
124
- else
125
- raise 'assert: cannot be here'
126
- end
127
- }
128
-
129
- call_op(:sign, a, child_context, func)
130
- when :logical_and
131
- a = complete_eval(a, child_context)
132
- b = complete_eval(b, child_context)
133
-
134
- call_vector_op(:greater, a, b, child_context, ->(t, u) { t && u })
135
- when :equal
136
- a = complete_eval(a, child_context)
137
- b = complete_eval(b, child_context)
138
-
139
- call_vector_op(:greater, a, b, child_context, ->(t, u) { t == u })
140
- when :not_equal
141
- a = complete_eval(a, child_context)
142
- b = complete_eval(b, child_context)
143
-
144
- call_vector_op(:not_equal, a, b, child_context, ->(t, u) { t != u })
145
- when :index
146
- f = run(a, child_context)
147
- index = run(b, child_context)
148
-
149
- f[index]
150
- when :slice
151
- input = complete_eval(a, child_context)
152
- start = complete_eval(b, child_context)
153
- size = complete_eval(tensor.options[:size], child_context)
154
- raise "start index and size not of the same shape #{start.size} != #{size.size}" if start.size != size.size
155
- slice_tensor(input, start, size)
156
- when :negate
157
- call_vector_op(:negate, a, nil, child_context, ->(t, _u) { -t })
158
- when :add
159
- call_vector_op(:add, a, b, child_context, ->(t, u) { t + u })
160
- when :sub
161
- call_vector_op(:sub, a, b, child_context, ->(t, u) { t - u })
162
- when :mul
163
- call_vector_op(:mul, a, b, child_context, ->(t, u) { t * u })
164
- when :pow
165
- call_vector_op(:pow, a, b, child_context, ->(t, u) { t**u })
166
- when :concat
167
- values = complete_eval(a, child_context)
168
- concat_array(values, tensor.options[:axis])
169
- when :round
170
- call_op(:round, a, child_context, ->(t, _b) { t.round })
171
- when :abs
172
- call_op(:abs, a, child_context, ->(t, _b) { t.abs })
173
- when :tanh
174
- call_op(:tanh, a, child_context, ->(t, _b) { Math.tanh(t) })
175
- when :tan
176
- call_op(:tan, a, child_context, ->(t, _b) { Math.tan(t) })
177
- when :sec
178
- call_op(:sec, a, child_context, ->(t, _b) { Math.sec(t) })
179
- when :sin
180
- call_op(:sin, a, child_context, ->(t, _b) { Math.sin(t) })
181
- when :cos
182
- call_op(:cos, a, child_context, ->(t, _b) { Math.cos(t) })
183
- when :log1p
184
- call_op(:log1p, a, child_context, ->(t, _b) { Distribution::MathExtension::Log.log1p(t) })
185
- when :log
186
- call_op(:log, a, child_context, ->(t, _b) { t < 0 ? Float::NAN : Math.log(t) })
187
- when :exp
188
- call_op(:exp, a, child_context, ->(t, _b) { Math.exp(t) })
189
- when :sigmoid
190
- call_op(:sigmoid, a, child_context, ->(t, _b) { sigmoid(t) })
191
- when :sigmoid_grad
192
- call_vector_op(:sigmoid_grad, a, b, child_context, ->(t, u) { u * sigmoid(t) * (1 - sigmoid(t)) })
193
- when :sqrt
194
- call_op(:exp, a, child_context, ->(t, _b) { Math.sqrt(t) })
195
- when :square
196
- call_op(:square, a, child_context, ->(t, _b) { t * t })
197
- when :reciprocal
198
- call_op(:square, a, child_context, ->(t, _b) { 1 / t })
199
- when :stop_gradient
200
- run(a, child_context)
201
- when :random_uniform
202
- maxval = tensor.options.fetch(:maxval, 1)
203
- minval = tensor.options.fetch(:minval, 0)
204
- seed = tensor.options[:seed]
205
-
206
- random = _get_randomizer(tensor, seed)
207
- generator = -> { random.rand * (maxval - minval) + minval }
208
- shape = tensor.options[:shape] || tensor.shape.shape
209
- generate_vector(shape, generator: generator)
210
- when :random_normal
211
- random = _get_randomizer(tensor, seed)
212
- r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev), -> { random.rand })
213
- random = _get_randomizer(tensor, seed)
214
- generator = -> { r.rand }
215
- shape = tensor.options[:shape] || tensor.shape.shape
216
- generate_vector(shape, generator: generator)
217
- when :glorot_uniform
218
- random = _get_randomizer(tensor, seed)
219
-
220
- shape = tensor.options[:shape] || tensor.shape.shape
221
- fan_in, fan_out = if shape.size.zero?
222
- [1, 1]
223
- elsif shape.size == 1
224
- [1, shape[0]]
225
- else
226
- [shape[0], shape.last]
227
- end
228
-
229
- limit = Math.sqrt(6.0 / (fan_in + fan_out))
230
-
231
- minval = -limit
232
- maxval = limit
233
-
234
- generator = -> { random.rand * (maxval - minval) + minval }
235
- generate_vector(shape, generator: generator)
236
- when :flow_group
237
- tensor.items.collect { |item| run(item, child_context) }
238
- when :assign
239
- assign = tensor.items[0] || tensor
240
- assign.value = complete_eval(tensor.items[1], child_context)
241
- assign.value
242
- when :assign_add
243
- tensor.items[0].value = process_vector_math_op(tensor.items[0], tensor.items[1], child_context, ->(t, u) { t + u })
244
- tensor.items[0].value
245
- when :assign_sub
246
- tensor.items[0].value = process_vector_math_op(tensor.items[0], tensor.items[1], child_context, ->(t, u) { t - u })
247
- tensor.items[0].value
248
- when :mean
249
- c = fp_type?(tensor.data_type) ? 0.0 : 0
250
- func = lambda do |arr|
251
- return c if arr.nil?
252
-
253
- reduced_val = arr[0]
254
- arr[1..arr.size].each do |v|
255
- reduced_val = vector_op(reduced_val, v, ->(a, b) { a + b })
256
- end
99
+ register_op(:const) do |context, _tensor, inputs|
100
+ inputs[0]
101
+ end
257
102
 
258
- vector_op(reduced_val, nil, ->(a, _b) { a / arr.size })
259
- end
103
+ register_op(:argmax) do |context, tensor, inputs|
104
+ axis = tensor.options[:axis] || 0
105
+ get_op_with_axis(inputs[0], axis, 0, tensor.data_type)
106
+ end
260
107
 
261
- reduction(child_context, tensor, func)
262
- when :sum
263
- c = fp_type?(tensor.data_type) ? 0.0 : 0
264
- func = lambda do |arr|
265
- reduced_val = arr[0]
266
- arr[1..arr.size].each do |v|
267
- reduced_val = vector_op(reduced_val, v, ->(t, u) { t + u })
268
- end
269
- reduced_val
108
+ register_op(:argmin) do |context, tensor, inputs|
109
+ axis = tensor.options[:axis] || 0
110
+ get_op_with_axis(inputs[0], axis, 0, tensor.data_type, ->(a, b) { a < b })
111
+ end
112
+
113
+ register_op(:cast) do |context, tensor, inputs|
114
+ call_op(:cast, inputs[0], context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
115
+ end
116
+
117
+ register_op(:sign) do |context, tensor, inputs|
118
+ func = lambda { |x, _b|
119
+ if x.zero? || (x.is_a?(Float) && x.nan?)
120
+ 0
121
+ elsif x < 0
122
+ -1
123
+ elsif x > 0
124
+ 1
125
+ else
126
+ raise 'assert: cannot be here'
270
127
  end
128
+ }
271
129
 
272
- reduction(child_context, tensor, func)
273
- when :tanh_grad
274
- x = complete_eval(a, child_context)
275
- call_op(:tanh_grad, x, child_context, ->(t, _b) { 1 - Math.tanh(t) * Math.tanh(t) })
276
- when :prod
277
- c = fp_type?(tensor.data_type) ? 1.0 : 1
278
- func = lambda do |arr|
279
- return c if arr.nil?
280
-
281
- reduced_val = arr[0]
282
- arr[1..arr.size].each do |v|
283
- reduced_val = vector_op(reduced_val, v, ->(a, b) { a * b })
284
- end
285
- reduced_val
130
+ call_op(:sign, inputs[0], context, func)
131
+ end
132
+
133
+ register_op(:logical_and) do |context, tensor, inputs|
134
+ call_vector_op(:logical_and, inputs[0], inputs[1], context, ->(t, u) { t && u })
135
+ end
136
+
137
+ register_op(:equal) do |context, tensor, inputs|
138
+ call_vector_op(:equal, inputs[0], inputs[1], context, ->(t, u) { t == u })
139
+ end
140
+
141
+ register_op(:not_equal) do |context, tensor, inputs|
142
+ call_vector_op(:not_equal, inputs[0], inputs[1], context, ->(t, u) { t != u })
143
+ end
144
+
145
+ register_op :index, no_eval: true do |context, tensor, inputs|
146
+ f = inputs[0]
147
+ index = inputs[1]
148
+ f[index]
149
+ end
150
+
151
+ register_op :slice do |context, tensor, inputs|
152
+ input = inputs[0]
153
+ start = inputs[1]
154
+ size = complete_eval(tensor.options[:size], context)
155
+ raise "start index and size not of the same shape #{start.size} != #{size.size}" if start.size != size.size
156
+ slice_tensor(input, start, size)
157
+ end
158
+
159
+ register_op :negate, no_eval: true do |context, _tensor, inputs|
160
+ call_vector_op(:negate, inputs[0], nil, context, ->(t, _u) { -t })
161
+ end
162
+
163
+ register_op :add, no_eval: true do |context, _tensor, inputs|
164
+ a, b = inputs
165
+ call_vector_op(:add, a, b, context, ->(t, u) { t + u })
166
+ end
167
+
168
+ register_op :sub, no_eval: true do |context, _tensor, inputs|
169
+ a, b = inputs
170
+ call_vector_op(:sub, a, b, context, ->(t, u) { t - u })
171
+ end
172
+
173
+ register_op :mul, no_eval: true do |context, _tensor, inputs|
174
+ a, b = inputs
175
+ call_vector_op(:mul, a, b, context, ->(t, u) { t * u })
176
+ end
177
+
178
+ register_op :pow, no_eval: true do |context, _tensor, inputs|
179
+ a, b = inputs
180
+ call_vector_op(:pow, a, b, context, ->(t, u) { t**u })
181
+ end
182
+
183
+ register_op :concat do |_context, tensor, inputs|
184
+ concat_array(inputs[0], tensor.options[:axis])
185
+ end
186
+
187
+ register_op :round, no_eval: true do |context, tensor, inputs|
188
+ call_op(:round, inputs[0], context, ->(t, _b) { t.round })
189
+ end
190
+
191
+ register_op :abs, no_eval: true do |context, tensor, inputs|
192
+ call_op(:abs, inputs[0], context, ->(t, _b) { t.abs })
193
+ end
194
+
195
+ register_op :tanh, no_eval: true do |context, tensor, inputs|
196
+ call_op(:tanh, inputs[0], context, ->(t, _b) { Math.tanh(t) })
197
+ end
198
+
199
+ register_op :tan, no_eval: true do |context, tensor, inputs|
200
+ call_op(:tan, inputs[0], context, ->(t, _b) { Math.tan(t) })
201
+ end
202
+
203
+ register_op :sec, no_eval: true do |context, tensor, inputs|
204
+ call_op(:sec, inputs[0], context, ->(t, _b) { Math.sec(t) })
205
+ end
206
+
207
+ register_op :sin, no_eval: true do |context, tensor, inputs|
208
+ call_op(:sin, inputs[0], context, ->(t, _b) { Math.sin(t) })
209
+ end
210
+
211
+ register_op :cos, no_eval: true do |context, tensor, inputs|
212
+ call_op(:cos, inputs[0], context, ->(t, _b) { Math.cos(t) })
213
+ end
214
+
215
+ register_op :log1p, no_eval: true do |context, tensor, inputs|
216
+ call_op(:log1p, inputs[0], context, ->(t, _b) { Math.log(1 + t) })
217
+ end
218
+
219
+ register_op :log, no_eval: true do |context, tensor, inputs|
220
+ call_op(:log, inputs[0], context, ->(t, _b) { t < 0 ? Float::NAN : Math.log(t) })
221
+ end
222
+
223
+ register_op :exp, no_eval: true do |context, tensor, inputs|
224
+ call_op(:exp, inputs[0], context, ->(t, _b) { Math.exp(t) })
225
+ end
226
+
227
+ register_op :sigmoid, no_eval: true do |context, tensor, inputs|
228
+ call_op(:sigmoid, inputs[0], context, ->(t, _b) { sigmoid(t) })
229
+ end
230
+
231
+ register_op :sqrt, no_eval: true do |context, tensor, inputs|
232
+ call_op(:sqrt, inputs[0], context, ->(t, _b) { Math.sqrt(t) })
233
+ end
234
+
235
+ register_op :square, no_eval: true do |context, tensor, inputs|
236
+ call_op(:square, inputs[0], context, ->(t, _b) { t * t })
237
+ end
238
+
239
+ register_op :reciprocal, no_eval: true do |context, tensor, inputs|
240
+ call_op(:reciprocal, inputs[0], context, ->(t, _b) { 1 / t })
241
+ end
242
+
243
+ register_op :stop_gradient, no_eval: true do |_context, _tensor, inputs|
244
+ inputs[0]
245
+ end
246
+
247
+ register_op :sigmoid_grad, no_eval: true do |context, _tensor, inputs|
248
+ a, b = inputs
249
+ call_vector_op(:sigmoid_grad, a, b, context, ->(t, u) { u * sigmoid(t) * (1 - sigmoid(t))} )
250
+ end
251
+
252
+ register_op :random_uniform, no_eval: true do |_context, tensor, _inputs|
253
+ maxval = tensor.options.fetch(:maxval, 1)
254
+ minval = tensor.options.fetch(:minval, 0)
255
+ seed = tensor.options[:seed]
256
+
257
+ random = _get_randomizer(tensor, seed)
258
+ generator = -> { random.rand * (maxval - minval) + minval }
259
+ shape = tensor.options[:shape] || tensor.shape.shape
260
+ generate_vector(shape, generator: generator)
261
+ end
262
+
263
+ register_op :random_normal, no_eval: true do |_context, tensor, _inputs|
264
+ seed = tensor.options[:seed]
265
+ random = _get_randomizer(tensor, seed)
266
+ r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev), -> { random.rand })
267
+ random = _get_randomizer(tensor, seed)
268
+ generator = -> { r.rand }
269
+ shape = tensor.options[:shape] || tensor.shape.shape
270
+ generate_vector(shape, generator: generator)
271
+ end
272
+
273
+ register_op :glorot_uniform, no_eval: true do |_context, tensor, _inputs|
274
+ seed = tensor.options[:seed]
275
+ random = _get_randomizer(tensor, seed)
276
+
277
+ shape = tensor.options[:shape] || tensor.shape.shape
278
+ fan_in, fan_out = if shape.size.zero?
279
+ [1, 1]
280
+ elsif shape.size == 1
281
+ [1, shape[0]]
282
+ else
283
+ [shape[0], shape.last]
284
+ end
285
+
286
+ limit = Math.sqrt(6.0 / (fan_in + fan_out))
287
+
288
+ minval = -limit
289
+ maxval = limit
290
+
291
+ generator = -> { random.rand * (maxval - minval) + minval }
292
+ generate_vector(shape, generator: generator)
293
+ end
294
+
295
+ register_op :assign, noop: true do |context, tensor, inputs|
296
+ assign = tensor.inputs[0] || tensor
297
+ assign.value = complete_eval(tensor.inputs[1], context)
298
+ assign.value
299
+ end
300
+
301
+ register_op :assign_add, noop: true do |context, tensor, inputs|
302
+ tensor.inputs[0].value = process_vector_math_op(tensor.inputs[0], tensor.inputs[1], context, ->(t, u) { t + u })
303
+ tensor.inputs[0].value
304
+ end
305
+
306
+ register_op :assign_sub, noop: true do |context, tensor, inputs|
307
+ tensor.inputs[0].value = process_vector_math_op(tensor.inputs[0], tensor.inputs[1], context, ->(t, u) { t - u })
308
+ tensor.inputs[0].value
309
+ end
310
+
311
+ register_op :mean, noop: true do |context, tensor, _inputs|
312
+ c = fp_type?(tensor.data_type) ? 0.0 : 0
313
+ func = lambda do |arr|
314
+ return c if arr.nil?
315
+
316
+ reduced_val = arr[0]
317
+ arr[1..arr.size].each do |v|
318
+ reduced_val = vector_op(reduced_val, v, ->(a, b) { a + b })
286
319
  end
287
320
 
288
- reduction(child_context, tensor, func)
289
- when :transpose
290
- matrix_a = complete_eval(a, child_context)
291
- matrix_a.transpose
292
- when :eye
293
- rows = complete_eval(a, child_context)
294
- columns = complete_eval(b, child_context)
295
-
296
- Array.new(rows) do |i|
297
- Array.new(columns) do |col|
298
- if fp_type?(tensor.data_type)
299
- i == col ? 1.0 : 0.0
300
- else
301
- i == col ? 1 : 0
302
- end
303
- end
321
+ vector_op(reduced_val, nil, ->(a, _b) { a / arr.size })
322
+ end
323
+
324
+ reduction(context, tensor, func)
325
+ end
326
+
327
+ register_op :sum, noop: true do |context, tensor, _inputs|
328
+ c = fp_type?(tensor.data_type) ? 0.0 : 0
329
+ func = lambda do |arr|
330
+ reduced_val = arr[0]
331
+ arr[1..arr.size].each do |v|
332
+ reduced_val = vector_op(reduced_val, v, ->(t, u) { t + u })
304
333
  end
305
- when :cond
306
- pred = complete_eval(tensor.options[:pred], child_context)
334
+ reduced_val
335
+ end
307
336
 
308
- if all_true?(pred)
309
- complete_eval(a, child_context)
310
- else
311
- complete_eval(b, child_context)
337
+ reduction(context, tensor, func)
338
+ end
339
+
340
+ register_op :prod, noop: true do |context, tensor, _inputs|
341
+ c = fp_type?(tensor.data_type) ? 1.0 : 1
342
+ func = lambda do |arr|
343
+ return c if arr.nil?
344
+
345
+ reduced_val = arr[0]
346
+ arr[1..arr.size].each do |v|
347
+ reduced_val = vector_op(reduced_val, v, ->(a, b) { a * b })
312
348
  end
313
- when :where
314
- pred = complete_eval(tensor.options[:pred], child_context)
315
- a = complete_eval(a, child_context)
316
- b = complete_eval(b, child_context)
317
-
318
- call_3way_vector_op(pred, a, b, child_context, ->(t, u, v) { t ? u : v })
319
- when :less
320
- a = complete_eval(a, child_context)
321
- b = complete_eval(b, child_context)
322
-
323
- call_vector_op(:greater, a, b, child_context, ->(t, u) { t < u })
324
- when :greater
325
- a = complete_eval(a, child_context)
326
- b = complete_eval(b, child_context)
327
-
328
- call_vector_op(:greater, a, b, child_context, ->(t, u) { t > u })
329
- when :greater_equal
330
- a = complete_eval(a, child_context)
331
- b = complete_eval(b, child_context)
332
-
333
- call_vector_op(:greater_equal, a, b, child_context, ->(t, u) { t >= u })
334
- when :less_equal
335
- a = complete_eval(a, child_context)
336
- b = complete_eval(b, child_context)
337
-
338
- call_vector_op(:less_equal, a, b, child_context, ->(t, u) { t <= u })
339
- when :zeros, :ones, :zeros_like, :ones_like
340
-
341
- shape = if %i[zeros_like ones_like].include?(tensor.operation)
342
- a = complete_eval(a, child_context)
343
- shape_eval(a)
344
- else
345
- complete_eval(a, child_context) || tensor.shape.shape
346
- end
347
-
348
- func = if %i[zeros zeros_like].include?(tensor.operation)
349
- -> { tensor.data_type == :int32 ? 0 : 0.0 }
350
- else
351
- -> { tensor.data_type == :int32 ? 1 : 1.0 }
352
- end
353
-
354
- if shape.is_a?(Array) && shape.size.zero?
355
- func.call
356
- else
357
- shape = [shape.to_i] unless shape.is_a?(Array)
349
+ reduced_val
350
+ end
351
+
352
+ reduction(context, tensor, func)
353
+ end
354
+
355
+ register_op :tanh_grad, no_eval: true do |context, _tensor, inputs|
356
+ call_op(:tanh_grad, inputs[0], context, ->(t, _b) { 1 - Math.tanh(t) * Math.tanh(t) })
357
+ end
358
358
 
359
- cache_key = "#{tensor.operation}_#{shape.to_s}"
360
- if @context[:_cache].key?(cache_key)
361
- return @context[:_cache][cache_key]
359
+ register_op :transpose do |_context, _tensor, inputs|
360
+ inputs[0].transpose
361
+ end
362
+
363
+ register_op :eye do |_context, tensor, inputs|
364
+ rows, columns = inputs
365
+
366
+ Array.new(rows) do |i|
367
+ Array.new(columns) do |col|
368
+ if fp_type?(tensor.data_type)
369
+ i == col ? 1.0 : 0.0
362
370
  else
363
- generate_vector(shape, generator: func).tap do |v|
364
- @context[:_cache][cache_key] = v
365
- end
371
+ i == col ? 1 : 0
366
372
  end
367
373
  end
368
- when :shape
369
- input = complete_eval(a, child_context)
370
- shape_eval(input, tensor.options[:out_type])
371
- when :matmul
372
- matrix_a = complete_eval(a, child_context)
373
- matrix_b = complete_eval(b, child_context)
374
-
375
- rank_a = get_rank(matrix_a)
376
- rank_b = get_rank(matrix_b)
377
-
378
- raise "#{tensor.items[0].name} rank must be greater than 1" if rank_a < 2
379
- raise "#{tensor.items[1].name} rank must be greater than 1" if rank_b < 2
380
-
381
- matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
382
- matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
383
-
384
- # handle matrix multiplication with constants like 1 or 0
385
- matrix_a = matmul_const_transform(matrix_a, matrix_b, tensor)
386
- matrix_b = matmul_const_transform(matrix_b, matrix_a, tensor)
387
-
388
- # check matrix dimensions
389
- raise "incompatible shape sizes for matrix multiplication (#{matrix_a[0].size} != #{matrix_b.size}) #{shape_eval(matrix_a)} vs #{shape_eval(matrix_b)}" if matrix_a[0].size != matrix_b.size
390
-
391
- (Matrix[*matrix_a] * Matrix[*matrix_b]).to_a
392
- when :gradients
393
- raise 'not implemented in evaluator' # see TensorStream.gradients instead.
394
- when :broadcast_transform
395
- a = complete_eval(a, child_context)
396
- b = complete_eval(b, child_context)
397
- broadcast(a, b)
398
- when :truncate
399
- a = complete_eval(a, child_context)
400
- b = complete_eval(b, child_context)
401
- truncate(a, b)
402
- when :identity
403
- complete_eval(a, child_context)
404
- when :print
405
- a = complete_eval(a, child_context)
406
- b = complete_eval(b, child_context)
407
- puts "#{tensor.options.fetch(:message, '')} #{b}"
408
- a
409
- when :rank
410
- a = complete_eval(a, child_context)
411
- get_rank(a)
412
- when :div
413
- process_vector_math_op(a, b, child_context, ->(t, u) { t / u })
414
- when :reshape
415
- arr = complete_eval(a, child_context)
416
- new_shape = complete_eval(b, child_context)
417
-
418
- arr = [arr] unless arr.is_a?(Array)
419
-
420
- flat_arr = arr.flatten
421
- return flat_arr[0] if new_shape.size.zero? && flat_arr.size == 1
374
+ end
375
+ end
422
376
 
423
- new_shape = TensorShape.fix_inferred_elements(new_shape, flat_arr.size)
377
+ register_op :cond do |context, tensor, inputs|
378
+ pred = complete_eval(tensor.options[:pred], context)
379
+ if all_true?(pred)
380
+ inputs[0]
381
+ else
382
+ inputs[1]
383
+ end
384
+ end
424
385
 
425
- TensorShape.reshape(flat_arr, new_shape)
426
- when :pad
427
- a = complete_eval(a, child_context)
428
- p = complete_eval(tensor.options[:paddings], child_context)
429
-
430
- arr_pad(a, p, tensor.data_type)
431
- when :max
432
- a = complete_eval(a, child_context)
433
- b = complete_eval(b, child_context)
434
-
435
- call_vector_op(:max, a, b, child_context, ->(t, u) { [t, u].max })
436
- when :broadcast_gradient_args
437
- a = complete_eval(a, child_context)
438
- b = complete_eval(b, child_context)
439
-
440
- get_broadcast_gradient_args(a, b)
441
- when :reduced_shape
442
- input_shape = complete_eval(a, child_context)
443
- axes = complete_eval(b, child_context)
444
-
445
- return [] if axes.nil? # reduce to scalar
446
- axes = [ axes ] unless axes.is_a?(Array)
447
- return input_shape if axes.empty?
448
-
449
- axes.each do |dimen|
450
- input_shape[dimen] = 1
451
- end
452
- input_shape
453
- when :tile
454
- input = complete_eval(a, child_context)
455
- multiples = complete_eval(b, child_context)
456
-
457
- rank = get_rank(input)
458
- raise '1D or higher tensor required' if rank.zero?
459
- raise "invalid multiple size passed #{rank} != #{multiples.size}" if rank != multiples.size
460
-
461
- tile = tile_arr(input, 0, multiples)
462
- tile.nil? ? [] : tile
463
- when :softmax
464
- input = complete_eval(a, child_context)
465
- softmax(input)
466
- when :softmax_grad
467
- input = complete_eval(a, child_context)
468
- grad = complete_eval(b, child_context)
469
- softmax_input = softmax(input)
470
- f_grad = softmax_grad(softmax_input)
471
- f_grad.transpose.each_with_index.collect do |row, index|
472
- sum = 0.0
473
- row.each_with_index do |r, g_index|
474
- sum += r * grad[g_index]
386
+ register_op :where do |context, tensor, inputs|
387
+ pred = complete_eval(tensor.options[:pred], context)
388
+ call_3way_vector_op(pred, inputs[0], inputs[1], context, ->(t, u, v) { t ? u : v })
389
+ end
390
+
391
+ register_op :less do |context, _tensor, inputs|
392
+ a, b = inputs
393
+ call_vector_op(:less, a, b, context, ->(t, u) { t < u })
394
+ end
395
+
396
+ register_op :greater do |context, _tensor, inputs|
397
+ a, b = inputs
398
+ call_vector_op(:greater, a, b, context, ->(t, u) { t > u })
399
+ end
400
+
401
+ register_op :greater_equal do |context, _tensor, inputs|
402
+ a, b = inputs
403
+ call_vector_op(:greater_equal, a, b, context, ->(t, u) { t >= u })
404
+ end
405
+
406
+ register_op :less_equal do |context, _tensor, inputs|
407
+ a, b = inputs
408
+ call_vector_op(:greater_equal, a, b, context, ->(t, u) { t <= u })
409
+ end
410
+
411
+ register_op %i[zeros ones zeros_like ones_like] do |_context, tensor, inputs|
412
+ shape = if %i[zeros_like ones_like].include?(tensor.operation)
413
+ shape_eval(inputs[0])
414
+ else
415
+ inputs[0] || tensor.shape.shape
416
+ end
417
+
418
+ func = if %i[zeros zeros_like].include?(tensor.operation)
419
+ -> { tensor.data_type == :int32 ? 0 : 0.0 }
420
+ else
421
+ -> { tensor.data_type == :int32 ? 1 : 1.0 }
422
+ end
423
+
424
+ if shape.is_a?(Array) && shape.size.zero?
425
+ func.call
426
+ else
427
+ shape = [shape.to_i] unless shape.is_a?(Array)
428
+
429
+ cache_key = "#{tensor.operation}_#{shape.to_s}"
430
+ if @context[:_cache].key?(cache_key)
431
+ @context[:_cache][cache_key]
432
+ else
433
+ generate_vector(shape, generator: func).tap do |v|
434
+ @context[:_cache][cache_key] = v
475
435
  end
476
- sum
477
436
  end
478
- when :check_numerics
479
- a = complete_eval(a, child_context)
480
- message = tensor.options[:message]
481
- f = ->(t, _b) { raise "#{message} Invalid argument" if t.nan? || t.infinite?; t }
482
- call_op(:check_numerics, a, child_context, f)
437
+ end
438
+ end
439
+
440
+ register_op :shape do |_context, tensor, inputs|
441
+ shape_eval(inputs[0], tensor.options[:out_type])
442
+ end
443
+
444
+ register_op :matmul do |_context, tensor, inputs|
445
+ matrix_a, matrix_b = inputs
446
+ rank_a = get_rank(matrix_a)
447
+ rank_b = get_rank(matrix_b)
448
+
449
+ raise "#{tensor.inputs[0].name} rank must be greater than 1" if rank_a < 2
450
+ raise "#{tensor.inputs[1].name} rank must be greater than 1" if rank_b < 2
451
+
452
+ matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
453
+ matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
454
+
455
+ # handle matrix multiplication with constants like 1 or 0
456
+ matrix_a = matmul_const_transform(matrix_a, matrix_b, tensor)
457
+ matrix_b = matmul_const_transform(matrix_b, matrix_a, tensor)
458
+
459
+ # check matrix dimensions
460
+ raise "incompatible shape sizes for matrix multiplication (#{matrix_a[0].size} != #{matrix_b.size}) #{shape_eval(matrix_a)} vs #{shape_eval(matrix_b)}" if matrix_a[0].size != matrix_b.size
461
+
462
+ (Matrix[*matrix_a] * Matrix[*matrix_b]).to_a
463
+ end
464
+
465
+ register_op :broadcast_transform do |_context, _tensor, inputs|
466
+ broadcast(inputs[0], inputs[1])
467
+ end
468
+
469
+ register_op :truncate do |_context, _tensor, inputs|
470
+ truncate(inputs[0], inputs[1])
471
+ end
472
+
473
+ register_op :identity do |_context, _tensor, inputs|
474
+ inputs[0]
475
+ end
476
+
477
+ register_op :print do |_context, tensor, inputs|
478
+ puts "#{tensor.options.fetch(:message, '')} #{inputs[1]}"
479
+ inputs[0]
480
+ end
481
+
482
+ register_op :rank do |_context, _tensor, inputs|
483
+ get_rank(inputs[0])
484
+ end
485
+
486
+ register_op :div, noop: true do |context, _tensor, inputs|
487
+ process_vector_math_op(inputs[0], inputs[1], context, ->(t, u) { t / u })
488
+ end
489
+
490
+ register_op :reshape do |_context, _tensor, inputs|
491
+ arr, new_shape = inputs
492
+
493
+ arr = [arr] unless arr.is_a?(Array)
494
+
495
+ flat_arr = arr.flatten
496
+ if new_shape.size.zero? && flat_arr.size == 1
497
+ flat_arr[0]
483
498
  else
484
- raise "unknown op #{tensor.operation}"
485
- end.tap do |result|
499
+ new_shape = TensorShape.fix_inferred_elements(new_shape, flat_arr.size)
500
+ TensorShape.reshape(flat_arr, new_shape)
501
+ end
502
+ end
503
+
504
+ register_op :pad do |context, tensor, inputs|
505
+ p = complete_eval(tensor.options[:paddings], context)
506
+
507
+ arr_pad(inputs[0], p, tensor.data_type)
508
+ end
509
+
510
+ register_op :max, noop: true do |context, _tensor, inputs|
511
+ call_vector_op(:max, inputs[0], inputs[1], context, ->(t, u) { [t, u].max })
512
+ end
513
+
514
+ register_op :broadcast_gradient_args do |_context, _tensor, inputs|
515
+ get_broadcast_gradient_args(inputs[0], inputs[1])
516
+ end
517
+
518
+ register_op :reduced_shape do |context, _tensor, inputs|
519
+ input_shape, axes = inputs
520
+
521
+ return [] if axes.nil? # reduce to scalar
522
+ axes = [ axes ] unless axes.is_a?(Array)
523
+ return input_shape if axes.empty?
524
+
525
+ axes.each do |dimen|
526
+ input_shape[dimen] = 1
527
+ end
528
+ input_shape
529
+ end
530
+
531
+ register_op :tile do |context, _tensor, inputs|
532
+ input, multiples = inputs
533
+ rank = get_rank(input)
534
+ raise '1D or higher tensor required' if rank.zero?
535
+ raise "invalid multiple size passed #{rank} != #{multiples.size}" if rank != multiples.size
536
+
537
+ tile = tile_arr(input, 0, multiples)
538
+ tile.nil? ? [] : tile
539
+ end
540
+
541
+ register_op :flow_group, noop: true do |context, _tensor, inputs|
542
+ inputs.collect { |input| run(input, context) }
543
+ end
544
+
545
+ register_op :softmax do |context, _tensor, inputs|
546
+ softmax(inputs[0])
547
+ end
548
+
549
+ register_op :softmax_grad do |_context, _tensor, inputs|
550
+ input, grad = inputs
551
+
552
+ softmax_input = softmax(input)
553
+ f_grad = softmax_grad(softmax_input)
554
+ f_grad.transpose.each_with_index.collect do |row, index|
555
+ sum = 0.0
556
+ row.each_with_index do |r, g_index|
557
+ sum += r * grad[g_index]
558
+ end
559
+ sum
560
+ end
561
+ end
562
+
563
+ register_op :check_numerics do |context, tensor, inputs|
564
+ message = tensor.options[:message]
565
+ f = ->(t, _b) { raise "#{message} Invalid argument" if t.nan? || t.infinite?; t }
566
+ call_op(:check_numerics, inputs[0], context, f)
567
+ end
568
+
569
+ def eval_operation(tensor, child_context)
570
+ return @context[tensor.name] if @context.key?(tensor.name)
571
+
572
+ # puts tensor.name
573
+ invoke(tensor, child_context).tap do |result|
486
574
  if tensor.breakpoint
575
+ a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
576
+ b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
487
577
  a = complete_eval(a, child_context)
488
578
  b = complete_eval(b, child_context)
489
-
490
579
  tensor.breakpoint.call(tensor, a, b, complete_eval(result, child_context))
491
580
  end
492
581
  if @log_intermediates
@@ -504,16 +593,16 @@ module TensorStream
504
593
  rescue EvaluatorExcecutionException => e
505
594
  raise e
506
595
  rescue StandardError => e
596
+ a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
597
+ b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
507
598
  puts e.message
508
599
  puts e.backtrace.join("\n")
509
-
510
600
  # shape_a = a.shape.shape if a
511
601
  # shape_b = b.shape.shape if b
512
602
  # dtype_a = a.data_type if a
513
603
  # dtype_b = b.data_type if b
514
604
  a = complete_eval(a, child_context)
515
605
  b = complete_eval(b, child_context)
516
-
517
606
  # puts "name: #{tensor.given_name}"
518
607
  # # puts "op: #{tensor.to_math(true, 1)}"
519
608
  # puts "A #{shape_a} #{dtype_a}: #{a}" if a
@@ -532,8 +621,8 @@ module TensorStream
532
621
  return @context[:_cache][cache_key] if @context[:_cache] && @context[:_cache].key?(tensor.name)
533
622
 
534
623
  if tensor.value.is_a?(Array)
535
- tensor.value.collect do |item|
536
- item.is_a?(Tensor) ? run(item, child_context) : item
624
+ tensor.value.collect do |input|
625
+ input.is_a?(Tensor) ? run(input, child_context) : input
537
626
  end
538
627
  else
539
628
  tensor.value.is_a?(Tensor) ? run(tensor.value, child_context) : tensor.value
@@ -543,6 +632,10 @@ module TensorStream
543
632
  end
544
633
  end
545
634
 
635
+ def convert_from_buffer(tensor, result)
636
+ result.buffer
637
+ end
638
+
546
639
  private
547
640
 
548
641
  def get_op_with_axis(a, target_axis, current_axis, output_type, op = ->(t, u) { t > u })
@@ -579,8 +672,8 @@ module TensorStream
579
672
  end
580
673
 
581
674
  def reduction(child_context, tensor, func)
582
- val = complete_eval(tensor.items[0], child_context)
583
- axis = complete_eval(tensor.items[1], child_context)
675
+ val = complete_eval(tensor.inputs[0], child_context)
676
+ axis = complete_eval(tensor.inputs[1], child_context)
584
677
  keep_dims = complete_eval(tensor.options[:keepdims], child_context)
585
678
  rank = get_rank(val)
586
679
  return val if axis && axis.is_a?(Array) && axis.empty?
@@ -720,7 +813,6 @@ module TensorStream
720
813
 
721
814
  def resolve_placeholder(placeholder, _execution_context = {})
722
815
  return nil if placeholder.nil?
723
- return placeholder if retain.include?(placeholder)
724
816
 
725
817
  var = if placeholder.is_a?(Placeholder)
726
818
  @context[placeholder.name.to_sym].tap do |c|
@@ -834,3 +926,5 @@ module TensorStream
834
926
  end
835
927
  end
836
928
  end
929
+
930
+ TensorStream::Evaluator.register_evaluator(TensorStream::Evaluator::RubyEvaluator, "ruby")