tensor_stream 0.9.2 → 0.9.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/lib/tensor_stream/evaluator/base_evaluator.rb +3 -0
  3. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +25 -0
  4. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +24 -24
  5. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +8 -0
  6. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +16 -18
  7. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +20 -4
  8. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +9 -5
  9. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +4 -4
  10. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +16 -61
  11. data/lib/tensor_stream/graph_builder.rb +1 -0
  12. data/lib/tensor_stream/graph_serializers/graphml.rb +1 -1
  13. data/lib/tensor_stream/graph_serializers/pbtext.rb +1 -0
  14. data/lib/tensor_stream/helpers/infer_shape.rb +182 -0
  15. data/lib/tensor_stream/helpers/op_helper.rb +2 -2
  16. data/lib/tensor_stream/images.rb +1 -1
  17. data/lib/tensor_stream/math_gradients.rb +1 -1
  18. data/lib/tensor_stream/monkey_patches/array.rb +15 -0
  19. data/lib/tensor_stream/monkey_patches/float.rb +3 -0
  20. data/lib/tensor_stream/monkey_patches/integer.rb +3 -0
  21. data/lib/tensor_stream/monkey_patches/patch.rb +70 -0
  22. data/lib/tensor_stream/nn/nn_ops.rb +43 -9
  23. data/lib/tensor_stream/operation.rb +2 -153
  24. data/lib/tensor_stream/ops.rb +71 -56
  25. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  26. data/lib/tensor_stream/tensor_shape.rb +9 -6
  27. data/lib/tensor_stream/train/adadelta_optimizer.rb +1 -1
  28. data/lib/tensor_stream/train/adagrad_optimizer.rb +1 -1
  29. data/lib/tensor_stream/train/adam_optimizer.rb +2 -2
  30. data/lib/tensor_stream/train/learning_rate_decay.rb +29 -0
  31. data/lib/tensor_stream/train/optimizer.rb +7 -6
  32. data/lib/tensor_stream/train/saver.rb +1 -0
  33. data/lib/tensor_stream/train/slot_creator.rb +2 -2
  34. data/lib/tensor_stream/trainer.rb +3 -0
  35. data/lib/tensor_stream/utils.rb +2 -2
  36. data/lib/tensor_stream/version.rb +1 -1
  37. data/lib/tensor_stream.rb +5 -1
  38. data/samples/rnn.rb +108 -0
  39. metadata +8 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9601653d86556739c89b591768e9d54d13d6335dd2f953fff7d91f22636e8c7b
4
- data.tar.gz: 3db4119c9e752df77cbbf7f8f753050a012b03fd2bde04dd6bbfd01f8903af62
3
+ metadata.gz: 51cb6686663dece94714073ff12b13ad7e57de1aadbee44506897c69a2d5fd67
4
+ data.tar.gz: 47152a908b2cc966cba6721da8f0c3170e3198a70f416fd224caee7d35a1fca2
5
5
  SHA512:
6
- metadata.gz: 3ca3eec8ce6cc7e73a1d5e38ccaaf2f65f7a37bb84e0118766de05ec374557847a88f2f14b1e35b896010d7e68a0234589fcf3de95a9417840e2e77ab4a4ea58
7
- data.tar.gz: 6b109a2ee9b59e286a2d06d1c3166fc5bbda9a888f0c26aa07bc4f994753ce5dbd7cb0aebb433a7446ddef653f3096ae1b60dbc4610e0dc18c9aaca8d3175f2f
6
+ metadata.gz: 7fff67042fd35c651409e04e890feafab25e618d88d4c3f69c97b39e839a943c64a4e05be388a4464e32328a777f6ee56dc5ede0a6c69c977cc94bcb7cd38bad
7
+ data.tar.gz: a30230994e81062626eaae5ca7a13ece36d2a5ea6210cfe1e3e465b5e5e24adad525547223f27e15991f74728d53d8b02e7cac0710fe16bf3c602402f5b65b5c
@@ -140,6 +140,8 @@ module TensorStream
140
140
  @context[:profile][:operations][tensor.name] = { op: tensor.operation,
141
141
  step: @context[:profile][:step],
142
142
  eval_time: end_time - start_time,
143
+ shape: tensor.shape ? tensor.shape.shape : nil,
144
+ data_type: tensor.data_type,
143
145
  tensor: tensor }
144
146
  end
145
147
  end
@@ -166,6 +168,7 @@ module TensorStream
166
168
  def global_eval(tensor, input, execution_context, op_options = {})
167
169
  return nil unless input
168
170
  return input unless input.is_a?(Tensor)
171
+
169
172
  # puts "global eval #{tensor.name}"
170
173
  @context[:_cache][:placement][input.name] = @session.assign_evaluator(input) if @context[:_cache][:placement][input.name].nil?
171
174
  if !on_same_device?(input) # tensor is on another device or evaluator
@@ -310,5 +310,30 @@ module TensorStream
310
310
 
311
311
  reduce_axis(0, axis, val, keep_dims, func)
312
312
  end
313
+
314
+ def arr_pad(arr, paddings, data_type = :float32, rank = 0)
315
+ raise "padding #{paddings[rank]} needs to have to elements [before, after]" if paddings[rank].size != 2
316
+
317
+ before = paddings[rank][0]
318
+ after = paddings[rank][1]
319
+ pad_value = fp_type?(data_type) ? 0.0 : 0
320
+ if arr[0].is_a?(Array)
321
+ next_dim_elem = arr.collect { |a| arr_pad(a, paddings, data_type, rank + 1) }
322
+ padding = deep_dup_array(next_dim_elem[0], pad_value)
323
+ Array.new(before) { padding } + next_dim_elem + Array.new(after) { padding }
324
+ else
325
+ Array.new(before) { pad_value } + arr + Array.new(after) { pad_value }
326
+ end
327
+ end
328
+
329
+ def deep_dup_array(arr, value = nil)
330
+ if arr.is_a?(Array)
331
+ arr.dup.collect do |a|
332
+ deep_dup_array(a, value)
333
+ end
334
+ else
335
+ value.nil? ? arr : value
336
+ end
337
+ end
313
338
  end
314
339
  end
@@ -132,9 +132,9 @@ module TensorStream
132
132
  axis = !tensor.options[:axis].is_a?(Array) ? [tensor.options[:axis]] : tensor.options[:axis]
133
133
 
134
134
  if !axis.empty?
135
- axis.each do |axis|
136
- if shape[axis] == 1
137
- shape[axis] = nil
135
+ axis.each do |x|
136
+ if shape[x] == 1
137
+ shape[x] = nil
138
138
  else
139
139
  raise TensorStream::ValueError, "unable to squeeze dimension that does not have a size of 1"
140
140
  end
@@ -201,6 +201,7 @@ module TensorStream
201
201
  out = []
202
202
  input.each_with_index do |x, index|
203
203
  next if remove.include?(x)
204
+
204
205
  out << x
205
206
  idx << index
206
207
  end
@@ -225,6 +226,7 @@ module TensorStream
225
226
  break if start == limit
226
227
  break if (start < limit) && (cur_step >= limit)
227
228
  break if (start > limit) && (cur_step <= limit)
229
+
228
230
  r << cur_step
229
231
  cur_step += delta
230
232
  end
@@ -281,34 +283,34 @@ module TensorStream
281
283
  get_rank(inputs[0])
282
284
  end
283
285
 
284
- register_op :split do |context, tensor, inputs|
286
+ register_op :split do |_context, tensor, inputs|
285
287
  value, num_split, axis = inputs
286
288
 
287
289
  value_shape = shape_eval(value)
288
290
  res = if num_split.is_a?(Array)
289
- begin_index = 0
290
- num_split.collect do |num|
291
- end_index = begin_index + num
292
- arr = split_tensor(value, begin_index, end_index, axis)
293
- begin_index = end_index
294
- arr
295
- end
296
- else
297
- raise TensorStream::ValueError, "#{num_split} does not divide #{value_shape[axis]} evenly" if value_shape[axis] % num_split != 0
298
- piece_sizes = value_shape[axis] / num_split
299
- Array.new(num_split) do |num|
300
- begin_index = num * piece_sizes
301
- end_index = begin_index + piece_sizes
302
- split_tensor(value, begin_index, end_index, axis)
303
- end
304
- end
291
+ begin_index = 0
292
+ num_split.collect do |num|
293
+ end_index = begin_index + num
294
+ arr = split_tensor(value, begin_index, end_index, axis)
295
+ begin_index = end_index
296
+ arr
297
+ end
298
+ else
299
+ raise TensorStream::ValueError, "#{num_split} does not divide #{value_shape[axis]} evenly" if value_shape[axis] % num_split != 0
300
+
301
+ piece_sizes = value_shape[axis] / num_split
302
+ Array.new(num_split) do |num|
303
+ begin_index = num * piece_sizes
304
+ end_index = begin_index + piece_sizes
305
+ split_tensor(value, begin_index, end_index, axis)
306
+ end
307
+ end
305
308
  TensorStream::Evaluator::OutputGroup.new(res, res.map { tensor.inputs[0].data_type })
306
309
  end
307
310
 
308
311
  register_op :reshape do |_context, _tensor, inputs|
309
312
  arr, new_shape = inputs
310
313
  arr = [arr] unless arr.is_a?(Array)
311
-
312
314
  flat_arr = arr.flatten
313
315
  if new_shape.size.zero? && flat_arr.size == 1
314
316
  flat_arr[0]
@@ -318,9 +320,7 @@ module TensorStream
318
320
  end
319
321
 
320
322
  register_op :pad do |context, tensor, inputs|
321
- p = complete_eval(tensor.options[:paddings], context)
322
-
323
- arr_pad(inputs[0], p, tensor.data_type)
323
+ arr_pad(inputs[0], inputs[1], tensor.data_type)
324
324
  end
325
325
 
326
326
  register_op :tile do |_context, _tensor, inputs|
@@ -2,7 +2,15 @@ module TensorStream
2
2
  module CheckOps
3
3
  def CheckOps.included(klass)
4
4
  klass.class_eval do
5
+ register_op :assert_equal do |context, tensor, inputs|
6
+ result = call_vector_op(tensor, :equal, inputs[0], inputs[1], context, ->(t, u) { t == u })
5
7
 
8
+ result = result.is_a?(Array) ? result.flatten.uniq : [result]
9
+ prefix = tensor.options[:message] || ""
10
+ raise TensorStream::InvalidArgumentError, "#{prefix} #{tensor.inputs[0].name} != #{tensor.inputs[1].name}" if result != [true]
11
+
12
+ nil
13
+ end
6
14
  end
7
15
  end
8
16
  end
@@ -14,30 +14,28 @@ module TensorStream
14
14
  image.grayscale! if channels == 1
15
15
  image_data = image.pixels.collect do |pixel|
16
16
  color_values = if channels == 4
17
- [ ChunkyPNG::Color.r(pixel),
18
- ChunkyPNG::Color.g(pixel),
19
- ChunkyPNG::Color.b(pixel),
20
- ChunkyPNG::Color.a(pixel) ]
21
- elsif channels == 3
22
- [ ChunkyPNG::Color.r(pixel),
23
- ChunkyPNG::Color.g(pixel),
24
- ChunkyPNG::Color.b(pixel)]
25
- elsif channels == 1
26
- [ ChunkyPNG::Color.r(pixel) ]
27
- else
28
- raise "Invalid channel value #{channels}"
29
- end
30
-
31
- if fp_type?(tensor.data_type)
32
- color_values.map! { |v| v.to_f }
33
- end
17
+ [ChunkyPNG::Color.r(pixel),
18
+ ChunkyPNG::Color.g(pixel),
19
+ ChunkyPNG::Color.b(pixel),
20
+ ChunkyPNG::Color.a(pixel)]
21
+ elsif channels == 3
22
+ [ChunkyPNG::Color.r(pixel),
23
+ ChunkyPNG::Color.g(pixel),
24
+ ChunkyPNG::Color.b(pixel)]
25
+ elsif channels == 1
26
+ [ChunkyPNG::Color.r(pixel)]
27
+ else
28
+ raise "Invalid channel value #{channels}"
29
+ end
30
+
31
+ color_values.map!(&:to_f) if fp_type?(tensor.data_type)
34
32
 
35
33
  color_values
36
34
  end
37
35
  TensorShape.reshape(image_data, [image.height, image.width, channels])
38
36
  end
39
37
 
40
- register_op :encode_png do |_context, tensor, inputs|
38
+ register_op :encode_png do |_context, _tensor, inputs|
41
39
  image_data = inputs[0]
42
40
  height, width, channels = shape_eval(image_data)
43
41
 
@@ -140,17 +140,33 @@ module TensorStream
140
140
  end
141
141
 
142
142
  register_op(%i[argmax arg_max]) do |_context, tensor, inputs|
143
- axis = tensor.options[:axis] || 0
143
+ axis = inputs[1] || 0
144
144
  rank = get_rank(inputs[0])
145
145
  raise TensorStream::InvalidArgumentError, "Expected dimension in the range [#{-rank},#{rank}) but got #{axis}" if axis < -rank || axis >= rank
146
- get_op_with_axis(inputs[0], axis, 0, tensor.data_type)
146
+
147
+ new_shape = shape_eval(inputs[0])
148
+ ns = new_shape.each_with_index.collect do |shape, index|
149
+ next nil if index == axis
150
+
151
+ shape
152
+ end.compact
153
+
154
+ Tensor.cast_dtype(TensorShape.reshape(get_op_with_axis(inputs[0], axis, 0, :max), ns), tensor.data_type)
147
155
  end
148
156
 
149
157
  register_op(%i[argmin arg_min]) do |_context, tensor, inputs|
150
- axis = tensor.options[:axis] || 0
158
+ axis = inputs[1] || 0
151
159
  rank = get_rank(inputs[0])
152
160
  raise TensorStream::InvalidArgumentError, "Expected dimension in the range [#{-rank},#{rank}) but got #{axis}" if axis < -rank || axis >= rank
153
- get_op_with_axis(inputs[0], axis, 0, tensor.data_type, ->(a, b) { a < b })
161
+
162
+ new_shape = shape_eval(inputs[0])
163
+ ns = new_shape.each_with_index.collect do |shape, index|
164
+ next nil if index == axis
165
+
166
+ shape
167
+ end.compact
168
+
169
+ Tensor.cast_dtype(TensorShape.reshape(get_op_with_axis(inputs[0], axis, 0, :min), ns), tensor.data_type)
154
170
  end
155
171
 
156
172
  register_op :cumprod do |context, tensor, inputs|
@@ -16,11 +16,11 @@ module TensorStream
16
16
  assign = tensor.inputs[0] || tensor
17
17
  assign_acc = tensor.inputs[1]
18
18
  assign_acc.value = multi_array_op(->(t, u) { t * momentum + u }, momentum_var, grad)
19
- if tensor.options[:use_nesterov]
20
- assign.value = multi_array_op(->(v, g, acc) { v - (g * learning_rate + acc * momentum * learning_rate) }, target_var, grad, momentum_var)
21
- else
22
- assign.value = multi_array_op(->(v, acc) { v - acc * learning_rate }, target_var, momentum_var)
23
- end
19
+ assign.value = if tensor.options[:use_nesterov]
20
+ multi_array_op(->(v, g, acc) { v - (g * learning_rate + acc * momentum * learning_rate) }, target_var, grad, momentum_var)
21
+ else
22
+ multi_array_op(->(v, acc) { v - acc * learning_rate }, target_var, momentum_var)
23
+ end
24
24
  assign.value
25
25
  end
26
26
 
@@ -207,6 +207,10 @@ module TensorStream
207
207
  TensorShape.reshape(arr, input_shape)
208
208
  end
209
209
  end
210
+
211
+ register_op :relu6 do |context, tensor, inputs|
212
+ call_vector_op(tensor, :relu6, inputs[0], inputs[1], context, ->(t, u) { [[t, 0].max, 6].min })
213
+ end
210
214
  end
211
215
  end
212
216
  end
@@ -25,24 +25,24 @@ module TensorStream
25
25
  generate_vector(shape, generator: generator)
26
26
  end
27
27
 
28
- register_op :random_uniform, no_eval: true do |_context, tensor, _inputs|
28
+ register_op :random_uniform, no_eval: true do |_context, tensor, inputs|
29
29
  maxval = tensor.options.fetch(:maxval, 1)
30
30
  minval = tensor.options.fetch(:minval, 0)
31
31
  seed = tensor.options[:seed]
32
32
 
33
33
  random = _get_randomizer(tensor, seed)
34
34
  generator = -> { random.rand * (maxval - minval) + minval }
35
- shape = tensor.options[:shape] || tensor.shape.shape
35
+ shape = inputs[0] || tensor.shape.shape
36
36
  generate_vector(shape, generator: generator)
37
37
  end
38
38
 
39
- register_op :random_standard_normal, no_eval: true do |_context, tensor, _inputs|
39
+ register_op :random_standard_normal, no_eval: true do |_context, tensor, inputs|
40
40
  seed = tensor.options[:seed]
41
41
  random = _get_randomizer(tensor, seed)
42
42
  r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev), -> { random.rand })
43
43
  random = _get_randomizer(tensor, seed)
44
44
  generator = -> { r.rand }
45
- shape = tensor.options[:shape] || tensor.shape.shape
45
+ shape = inputs[0] || tensor.shape.shape
46
46
  generate_vector(shape, generator: generator)
47
47
  end
48
48
  end
@@ -289,7 +289,7 @@ module TensorStream
289
289
  # assertions to make sure inferred shapes == actual evaluated shapes
290
290
  if tensor.shape.known? && (result.is_a?(Array) || result.is_a?(Float) || result.is_a?(Integer))
291
291
  if shape_eval(result) != tensor.shape.shape
292
- raise "assert error #{tensor.name} #{shape_eval(result)} != #{tensor.shape.shape}"
292
+ # raise "assert error #{tensor.name} #{shape_eval(result)} != #{tensor.shape.shape}"
293
293
  end
294
294
  end
295
295
 
@@ -313,21 +313,20 @@ module TensorStream
313
313
  @context[tensor.name] = result
314
314
  end
315
315
  rescue EvaluatorExcecutionException => e
316
- raise e, "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math(true, 1)} defined at #{tensor.source}"
316
+ raise e, "error #{e.message} while evaluating #{tensor.name} defined at #{tensor.source}"
317
317
  rescue TensorStreamError => e
318
- raise e, "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math(true, 1)} defined at #{tensor.source}"
318
+ raise e, "error #{e.message} while evaluating #{tensor.name} defined at #{tensor.source}"
319
319
  rescue StandardError => e
320
-
321
- a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
322
- b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
320
+ # a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
321
+ # b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
323
322
  puts e.message
324
323
  puts e.backtrace.join("\n")
325
324
  # shape_a = a.shape.shape if a
326
325
  # shape_b = b.shape.shape if b
327
326
  # dtype_a = a.data_type if a
328
327
  # dtype_b = b.data_type if b
329
- a = complete_eval(a, child_context)
330
- b = complete_eval(b, child_context)
328
+ # a = complete_eval(a, child_context)
329
+ # b = complete_eval(b, child_context)
331
330
  # puts "name: #{tensor.given_name}"
332
331
  # # puts "op: #{tensor.to_math(true, 1)}"
333
332
  # puts "A #{shape_a} #{dtype_a}: #{a}" if a
@@ -363,36 +362,17 @@ module TensorStream
363
362
 
364
363
  private
365
364
 
366
- def get_op_with_axis(a, target_axis, current_axis, output_type, op = ->(t, u) { t > u })
367
- if target_axis == current_axis
368
- if a[0].is_a?(Array)
369
- (0...a[0].size).each.collect do |column_index|
370
- max = nil
371
- max_index = 0
372
- a.each_with_index do |row, row_index|
373
- if max.nil? || op.call(row[column_index], max)
374
- max = row[column_index]
375
- max_index = row_index
376
- end
377
- end
365
+ def get_op_with_axis(a, target_axis, current_axis, op)
366
+ rank = get_rank(a)
367
+ return a.index(a.send(:"#{op}")) if rank == 1
378
368
 
379
- Tensor.cast_dtype(max_index, output_type)
380
- end
381
- else
382
- max = nil
383
- max_index = 0
384
- a.each_with_index do |x, index|
385
- if max.nil? || op.call(x, max)
386
- max = x
387
- max_index = index
388
- end
389
- end
390
- Tensor.cast_dtype(max_index, output_type)
391
- end
369
+ if current_axis == target_axis
370
+ compare_items = a.collect(&:flatten).transpose
371
+ compare_items.map { |item| item.index(item.send(:"#{op}")) }
372
+ elsif a[0].is_a?(Array)
373
+ a.map { |item| get_op_with_axis(item, target_axis, current_axis + 1, op) }
392
374
  else
393
- a.collect do |row|
394
- get_op_with_axis(row, target_axis, current_axis + 1, output_type, op)
395
- end
375
+ return a.index(a.send(:"#{op}"))
396
376
  end
397
377
  end
398
378
 
@@ -403,31 +383,6 @@ module TensorStream
403
383
  reduce(val, axis, keep_dims, func)
404
384
  end
405
385
 
406
- def arr_pad(arr, paddings, data_type = :float32, rank = 0)
407
- raise "padding #{paddings[rank]} needs to have to elements [before, after]" if paddings[rank].size != 2
408
-
409
- before = paddings[rank][0]
410
- after = paddings[rank][1]
411
- pad_value = fp_type?(data_type) ? 0.0 : 0
412
- if arr[0].is_a?(Array)
413
- next_dim_elem = arr.collect { |a| arr_pad(a, paddings, data_type, rank + 1) }
414
- padding = deep_dup_array(next_dim_elem[0], pad_value)
415
- Array.new(before) { padding } + next_dim_elem + Array.new(after) { padding }
416
- else
417
- Array.new(before) { pad_value } + arr + Array.new(after) { pad_value }
418
- end
419
- end
420
-
421
- def deep_dup_array(arr, value = nil)
422
- if arr.is_a?(Array)
423
- arr.dup.collect do |a|
424
- deep_dup_array(a, value)
425
- end
426
- else
427
- value.nil? ? arr : value
428
- end
429
- end
430
-
431
386
  def call_op(op, a, child_context, func)
432
387
  a = complete_eval(a, child_context)
433
388
  process_function_op(a, func)
@@ -14,6 +14,7 @@ module TensorStream
14
14
  parsed_tree = protobuf.load_from_string(buffer)
15
15
  parsed_tree.each do |node|
16
16
  next unless node['type'] == 'node'
17
+
17
18
  # puts "build #{node['name']}"
18
19
  options = protobuf.options_evaluator(node)
19
20
  options[:name] = node['name']
@@ -34,7 +34,7 @@ module TensorStream
34
34
  arr_buf << "</node>"
35
35
 
36
36
  to_graph_ml(tensor, arr_buf, {}, groups)
37
- #dump groups
37
+ # dump groups
38
38
  groups.each do |k, g|
39
39
  arr_buf << create_group(k, k, g)
40
40
  end
@@ -15,6 +15,7 @@ module TensorStream
15
15
  @lines << " op: #{camelize(node.operation.to_s).to_json}"
16
16
  node.inputs.each do |input|
17
17
  next unless input
18
+
18
19
  @lines << " input: #{input.name.to_json}"
19
20
  end
20
21
  # type
@@ -0,0 +1,182 @@
1
+ require 'tensor_stream/evaluator/operation_helpers/array_ops_helper'
2
+ module TensorStream
3
+ ##
4
+ # Convenience class for guessing the shape of a tensor
5
+ #
6
+ class InferShape
7
+ extend TensorStream::ArrayOpsHelper
8
+ extend TensorStream::OpHelper
9
+
10
+ def self.infer_shape(tensor)
11
+ case tensor.operation
12
+ when :assign
13
+ possible_shape = if tensor.inputs[0] && tensor.inputs[0].shape.shape
14
+ tensor.inputs[0].shape.shape
15
+ else
16
+ tensor.inputs[1].shape.shape
17
+ end
18
+
19
+ possible_shape
20
+ when :index
21
+ return nil unless tensor.inputs[0].is_a?(Tensor)
22
+ return nil unless tensor.inputs[0].const_value
23
+
24
+ input_shape = tensor.inputs[0].shape
25
+ return nil unless input_shape.known?
26
+
27
+ s = input_shape.shape.dup
28
+ s.shift
29
+ s
30
+ when :arg_min, :argmax, :argmin
31
+ return nil unless tensor.inputs[0].shape.known?
32
+ return nil if tensor.inputs[1] && tensor.inputs[1].value.nil?
33
+
34
+ axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].value
35
+ new_shape = tensor.inputs[0].shape.shape
36
+ new_shape.each_with_index.collect do |shape, index|
37
+ next nil if index == axis
38
+
39
+ shape
40
+ end.compact
41
+ when :mean, :prod, :sum, :arg_max
42
+ return [] if tensor.inputs[1].nil?
43
+ return nil if tensor.inputs[0].nil?
44
+ return nil unless tensor.inputs[0].shape.known?
45
+
46
+ input_shape = tensor.inputs[0].shape.shape
47
+ rank = input_shape.size
48
+
49
+ axis = tensor.inputs[1].const_value
50
+ return nil if axis.nil?
51
+
52
+ axis = [axis] unless axis.is_a?(Array)
53
+ axis = axis.map { |a| a < 0 ? rank - a.abs : a }
54
+
55
+ input_shape.each_with_index.map do |item, index|
56
+ if axis.include?(index)
57
+ next 1 if tensor.options[:keepdims]
58
+
59
+ next nil
60
+ end
61
+ item
62
+ end.compact
63
+ when :reshape
64
+ new_shape = tensor.inputs[1] && tensor.inputs[1].value ? tensor.inputs[1].value : nil
65
+ return nil if new_shape.nil?
66
+ return nil if tensor.inputs[0].shape.nil?
67
+
68
+ input_shape = tensor.inputs[0].shape.shape
69
+ return new_shape if input_shape.nil?
70
+ return nil if input_shape.include?(nil)
71
+ TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
72
+ when :flow_group
73
+ []
74
+ when :zeros, :ones, :fill, :random_standard_normal, :random_uniform
75
+ a_shape = tensor.inputs[0] ? tensor.inputs[0].const_value : tensor.options[:shape]
76
+ return nil if a_shape.nil?
77
+ a_shape.is_a?(Array) ? a_shape : [a_shape]
78
+ when :zeros_like, :ones_like
79
+ tensor.inputs[0].shape.shape
80
+ when :shape
81
+ tensor.inputs[0].shape.shape ? [tensor.inputs[0].shape.shape.size] : nil
82
+ when :pad
83
+ return nil unless tensor.inputs[0].shape.known?
84
+ return nil unless tensor.inputs[1].value
85
+
86
+ size = tensor.inputs[0].shape.shape.reduce(:*) || 1
87
+ dummy_tensor_for_shape = TensorShape.reshape(Array.new(size), tensor.inputs[0].shape)
88
+ shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].value))
89
+ when :mat_mul
90
+ return nil if tensor.inputs[0].shape.shape.nil? || tensor.inputs[1].shape.shape.nil?
91
+ return [] if tensor.inputs[0].shape.shape.empty? || tensor.inputs[1].shape.shape.empty?
92
+ return nil if tensor.inputs[0].shape.shape.size != 2 || tensor.inputs[1].shape.shape.size != 2
93
+
94
+ shape1, m = if tensor.options[:transpose_a]
95
+ [tensor.inputs[0].shape.shape[0], tensor.inputs[0].shape.shape[1]]
96
+ else
97
+ [tensor.inputs[0].shape.shape[1], tensor.inputs[0].shape.shape[0]]
98
+ end
99
+
100
+ shape2, n = if tensor.options[:transpose_b]
101
+ [tensor.inputs[1].shape.shape[1], tensor.inputs[1].shape.shape[0]]
102
+ else
103
+ [tensor.inputs[1].shape.shape[0], tensor.inputs[1].shape.shape[1]]
104
+ end
105
+
106
+ return nil if shape1.nil? || shape2.nil? || shape1 < 0 || shape2 < 0
107
+
108
+ raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{shape1} != #{shape2}) #{tensor.inputs[0].shape.shape} vs #{tensor.inputs[1].shape.shape}" if shape1 != shape2
109
+
110
+ [m, n]
111
+ when :transpose
112
+ return nil unless shape_full_specified(tensor.inputs[0])
113
+ return nil if tensor.inputs[1].is_a?(Tensor)
114
+
115
+ rank = tensor.inputs[0].shape.shape.size
116
+ perm = tensor.inputs[1] || (0...rank).to_a.reverse
117
+ perm.map { |p| tensor.inputs[0].shape.shape[p] }
118
+ when :stack
119
+ return nil unless shape_full_specified(tensor.inputs[0])
120
+
121
+ axis = tensor.options[:axis] || 0
122
+ new_shape = [tensor.inputs.size]
123
+ tensor.inputs[0].shape.shape.inject(new_shape) { |ns, i| ns << i }
124
+ rank = tensor.inputs[0].shape.shape.size + 1
125
+ axis = rank + axis if axis < 0
126
+ rotated_shape = Array.new(axis + 1) { new_shape.shift }
127
+ rotated_shape.rotate! + new_shape
128
+ when :concat
129
+ return nil if tensor.inputs[0].value.nil?
130
+
131
+ axis = tensor.inputs[0].value # get axis
132
+
133
+ axis_size = 0
134
+
135
+ tensor.inputs[1..tensor.inputs.size].each do |input_item|
136
+ return nil if input_item.shape.shape.nil?
137
+ return nil if input_item.shape.shape[axis].nil?
138
+
139
+ axis_size += input_item.shape.shape[axis]
140
+ end
141
+
142
+ new_shape = tensor.inputs[1].shape.shape.dup
143
+ new_shape[axis] = axis_size
144
+ new_shape
145
+ when :slice, :squeeze
146
+ nil
147
+ when :tile
148
+ nil
149
+ when :expand_dims
150
+ nil
151
+ when :broadcast_gradient_args
152
+ nil
153
+ when :no_op
154
+ nil
155
+ when :softmax_cross_entropy_with_logits_v2, :sparse_softmax_cross_entropy_with_logits
156
+ nil
157
+ when :decode_png, :flow_dynamic_stitch, :dynamic_stitch, :gather
158
+ nil
159
+ when :eye
160
+ return [tensor.inputs[0].const_value, tensor.inputs[1].const_value] if tensor.inputs[0].const_value && tensor.inputs[1].const_value
161
+
162
+ nil
163
+ when :size
164
+ []
165
+ when :unstack
166
+ return nil unless tensor.inputs[0].shape.known?
167
+
168
+ new_shape = tensor.inputs[0].shape.shape.dup
169
+ rank = new_shape.size - 1
170
+ axis = tensor.options[:axis] || 0
171
+ axis = rank + axis if axis < 0
172
+ rotated_shape = Array.new(axis + 1) { new_shape.shift }
173
+ rotated_shape.rotate!(-1) + new_shape
174
+ else
175
+ return nil if tensor.inputs[0].nil?
176
+ return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
177
+
178
+ TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1]
179
+ end
180
+ end
181
+ end
182
+ end
@@ -56,8 +56,8 @@ module TensorStream
56
56
 
57
57
  def format_source(trace)
58
58
  grad_source = trace.select { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }.first
59
- source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
60
- [grad_source, source].compact.join("\n")
59
+ # source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
60
+ [grad_source, trace].compact.join("\n")
61
61
  end
62
62
 
63
63
  def shapes_fully_specified_and_equal(x, y)
@@ -8,7 +8,7 @@ module TensorStream
8
8
  end
9
9
 
10
10
  def self.encode_png(contents, compression: -1, name: nil)
11
- check_allowed_types(contents, [:uint8, :uint16])
11
+ check_allowed_types(contents, %i[uint8 uint16])
12
12
  contents = convert_to_tensor(contents, dtype: :uint16)
13
13
  _op(:encode_png, contents, compression: compression, name: name)
14
14
  end