tensor_stream 0.9.2 → 0.9.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/lib/tensor_stream/evaluator/base_evaluator.rb +3 -0
  3. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +25 -0
  4. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +24 -24
  5. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +8 -0
  6. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +16 -18
  7. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +20 -4
  8. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +9 -5
  9. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +4 -4
  10. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +16 -61
  11. data/lib/tensor_stream/graph_builder.rb +1 -0
  12. data/lib/tensor_stream/graph_serializers/graphml.rb +1 -1
  13. data/lib/tensor_stream/graph_serializers/pbtext.rb +1 -0
  14. data/lib/tensor_stream/helpers/infer_shape.rb +182 -0
  15. data/lib/tensor_stream/helpers/op_helper.rb +2 -2
  16. data/lib/tensor_stream/images.rb +1 -1
  17. data/lib/tensor_stream/math_gradients.rb +1 -1
  18. data/lib/tensor_stream/monkey_patches/array.rb +15 -0
  19. data/lib/tensor_stream/monkey_patches/float.rb +3 -0
  20. data/lib/tensor_stream/monkey_patches/integer.rb +3 -0
  21. data/lib/tensor_stream/monkey_patches/patch.rb +70 -0
  22. data/lib/tensor_stream/nn/nn_ops.rb +43 -9
  23. data/lib/tensor_stream/operation.rb +2 -153
  24. data/lib/tensor_stream/ops.rb +71 -56
  25. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  26. data/lib/tensor_stream/tensor_shape.rb +9 -6
  27. data/lib/tensor_stream/train/adadelta_optimizer.rb +1 -1
  28. data/lib/tensor_stream/train/adagrad_optimizer.rb +1 -1
  29. data/lib/tensor_stream/train/adam_optimizer.rb +2 -2
  30. data/lib/tensor_stream/train/learning_rate_decay.rb +29 -0
  31. data/lib/tensor_stream/train/optimizer.rb +7 -6
  32. data/lib/tensor_stream/train/saver.rb +1 -0
  33. data/lib/tensor_stream/train/slot_creator.rb +2 -2
  34. data/lib/tensor_stream/trainer.rb +3 -0
  35. data/lib/tensor_stream/utils.rb +2 -2
  36. data/lib/tensor_stream/version.rb +1 -1
  37. data/lib/tensor_stream.rb +5 -1
  38. data/samples/rnn.rb +108 -0
  39. metadata +8 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9601653d86556739c89b591768e9d54d13d6335dd2f953fff7d91f22636e8c7b
4
- data.tar.gz: 3db4119c9e752df77cbbf7f8f753050a012b03fd2bde04dd6bbfd01f8903af62
3
+ metadata.gz: 51cb6686663dece94714073ff12b13ad7e57de1aadbee44506897c69a2d5fd67
4
+ data.tar.gz: 47152a908b2cc966cba6721da8f0c3170e3198a70f416fd224caee7d35a1fca2
5
5
  SHA512:
6
- metadata.gz: 3ca3eec8ce6cc7e73a1d5e38ccaaf2f65f7a37bb84e0118766de05ec374557847a88f2f14b1e35b896010d7e68a0234589fcf3de95a9417840e2e77ab4a4ea58
7
- data.tar.gz: 6b109a2ee9b59e286a2d06d1c3166fc5bbda9a888f0c26aa07bc4f994753ce5dbd7cb0aebb433a7446ddef653f3096ae1b60dbc4610e0dc18c9aaca8d3175f2f
6
+ metadata.gz: 7fff67042fd35c651409e04e890feafab25e618d88d4c3f69c97b39e839a943c64a4e05be388a4464e32328a777f6ee56dc5ede0a6c69c977cc94bcb7cd38bad
7
+ data.tar.gz: a30230994e81062626eaae5ca7a13ece36d2a5ea6210cfe1e3e465b5e5e24adad525547223f27e15991f74728d53d8b02e7cac0710fe16bf3c602402f5b65b5c
@@ -140,6 +140,8 @@ module TensorStream
140
140
  @context[:profile][:operations][tensor.name] = { op: tensor.operation,
141
141
  step: @context[:profile][:step],
142
142
  eval_time: end_time - start_time,
143
+ shape: tensor.shape ? tensor.shape.shape : nil,
144
+ data_type: tensor.data_type,
143
145
  tensor: tensor }
144
146
  end
145
147
  end
@@ -166,6 +168,7 @@ module TensorStream
166
168
  def global_eval(tensor, input, execution_context, op_options = {})
167
169
  return nil unless input
168
170
  return input unless input.is_a?(Tensor)
171
+
169
172
  # puts "global eval #{tensor.name}"
170
173
  @context[:_cache][:placement][input.name] = @session.assign_evaluator(input) if @context[:_cache][:placement][input.name].nil?
171
174
  if !on_same_device?(input) # tensor is on another device or evaluator
@@ -310,5 +310,30 @@ module TensorStream
310
310
 
311
311
  reduce_axis(0, axis, val, keep_dims, func)
312
312
  end
313
+
314
+ def arr_pad(arr, paddings, data_type = :float32, rank = 0)
315
+ raise "padding #{paddings[rank]} needs to have to elements [before, after]" if paddings[rank].size != 2
316
+
317
+ before = paddings[rank][0]
318
+ after = paddings[rank][1]
319
+ pad_value = fp_type?(data_type) ? 0.0 : 0
320
+ if arr[0].is_a?(Array)
321
+ next_dim_elem = arr.collect { |a| arr_pad(a, paddings, data_type, rank + 1) }
322
+ padding = deep_dup_array(next_dim_elem[0], pad_value)
323
+ Array.new(before) { padding } + next_dim_elem + Array.new(after) { padding }
324
+ else
325
+ Array.new(before) { pad_value } + arr + Array.new(after) { pad_value }
326
+ end
327
+ end
328
+
329
+ def deep_dup_array(arr, value = nil)
330
+ if arr.is_a?(Array)
331
+ arr.dup.collect do |a|
332
+ deep_dup_array(a, value)
333
+ end
334
+ else
335
+ value.nil? ? arr : value
336
+ end
337
+ end
313
338
  end
314
339
  end
@@ -132,9 +132,9 @@ module TensorStream
132
132
  axis = !tensor.options[:axis].is_a?(Array) ? [tensor.options[:axis]] : tensor.options[:axis]
133
133
 
134
134
  if !axis.empty?
135
- axis.each do |axis|
136
- if shape[axis] == 1
137
- shape[axis] = nil
135
+ axis.each do |x|
136
+ if shape[x] == 1
137
+ shape[x] = nil
138
138
  else
139
139
  raise TensorStream::ValueError, "unable to squeeze dimension that does not have a size of 1"
140
140
  end
@@ -201,6 +201,7 @@ module TensorStream
201
201
  out = []
202
202
  input.each_with_index do |x, index|
203
203
  next if remove.include?(x)
204
+
204
205
  out << x
205
206
  idx << index
206
207
  end
@@ -225,6 +226,7 @@ module TensorStream
225
226
  break if start == limit
226
227
  break if (start < limit) && (cur_step >= limit)
227
228
  break if (start > limit) && (cur_step <= limit)
229
+
228
230
  r << cur_step
229
231
  cur_step += delta
230
232
  end
@@ -281,34 +283,34 @@ module TensorStream
281
283
  get_rank(inputs[0])
282
284
  end
283
285
 
284
- register_op :split do |context, tensor, inputs|
286
+ register_op :split do |_context, tensor, inputs|
285
287
  value, num_split, axis = inputs
286
288
 
287
289
  value_shape = shape_eval(value)
288
290
  res = if num_split.is_a?(Array)
289
- begin_index = 0
290
- num_split.collect do |num|
291
- end_index = begin_index + num
292
- arr = split_tensor(value, begin_index, end_index, axis)
293
- begin_index = end_index
294
- arr
295
- end
296
- else
297
- raise TensorStream::ValueError, "#{num_split} does not divide #{value_shape[axis]} evenly" if value_shape[axis] % num_split != 0
298
- piece_sizes = value_shape[axis] / num_split
299
- Array.new(num_split) do |num|
300
- begin_index = num * piece_sizes
301
- end_index = begin_index + piece_sizes
302
- split_tensor(value, begin_index, end_index, axis)
303
- end
304
- end
291
+ begin_index = 0
292
+ num_split.collect do |num|
293
+ end_index = begin_index + num
294
+ arr = split_tensor(value, begin_index, end_index, axis)
295
+ begin_index = end_index
296
+ arr
297
+ end
298
+ else
299
+ raise TensorStream::ValueError, "#{num_split} does not divide #{value_shape[axis]} evenly" if value_shape[axis] % num_split != 0
300
+
301
+ piece_sizes = value_shape[axis] / num_split
302
+ Array.new(num_split) do |num|
303
+ begin_index = num * piece_sizes
304
+ end_index = begin_index + piece_sizes
305
+ split_tensor(value, begin_index, end_index, axis)
306
+ end
307
+ end
305
308
  TensorStream::Evaluator::OutputGroup.new(res, res.map { tensor.inputs[0].data_type })
306
309
  end
307
310
 
308
311
  register_op :reshape do |_context, _tensor, inputs|
309
312
  arr, new_shape = inputs
310
313
  arr = [arr] unless arr.is_a?(Array)
311
-
312
314
  flat_arr = arr.flatten
313
315
  if new_shape.size.zero? && flat_arr.size == 1
314
316
  flat_arr[0]
@@ -318,9 +320,7 @@ module TensorStream
318
320
  end
319
321
 
320
322
  register_op :pad do |context, tensor, inputs|
321
- p = complete_eval(tensor.options[:paddings], context)
322
-
323
- arr_pad(inputs[0], p, tensor.data_type)
323
+ arr_pad(inputs[0], inputs[1], tensor.data_type)
324
324
  end
325
325
 
326
326
  register_op :tile do |_context, _tensor, inputs|
@@ -2,7 +2,15 @@ module TensorStream
2
2
  module CheckOps
3
3
  def CheckOps.included(klass)
4
4
  klass.class_eval do
5
+ register_op :assert_equal do |context, tensor, inputs|
6
+ result = call_vector_op(tensor, :equal, inputs[0], inputs[1], context, ->(t, u) { t == u })
5
7
 
8
+ result = result.is_a?(Array) ? result.flatten.uniq : [result]
9
+ prefix = tensor.options[:message] || ""
10
+ raise TensorStream::InvalidArgumentError, "#{prefix} #{tensor.inputs[0].name} != #{tensor.inputs[1].name}" if result != [true]
11
+
12
+ nil
13
+ end
6
14
  end
7
15
  end
8
16
  end
@@ -14,30 +14,28 @@ module TensorStream
14
14
  image.grayscale! if channels == 1
15
15
  image_data = image.pixels.collect do |pixel|
16
16
  color_values = if channels == 4
17
- [ ChunkyPNG::Color.r(pixel),
18
- ChunkyPNG::Color.g(pixel),
19
- ChunkyPNG::Color.b(pixel),
20
- ChunkyPNG::Color.a(pixel) ]
21
- elsif channels == 3
22
- [ ChunkyPNG::Color.r(pixel),
23
- ChunkyPNG::Color.g(pixel),
24
- ChunkyPNG::Color.b(pixel)]
25
- elsif channels == 1
26
- [ ChunkyPNG::Color.r(pixel) ]
27
- else
28
- raise "Invalid channel value #{channels}"
29
- end
30
-
31
- if fp_type?(tensor.data_type)
32
- color_values.map! { |v| v.to_f }
33
- end
17
+ [ChunkyPNG::Color.r(pixel),
18
+ ChunkyPNG::Color.g(pixel),
19
+ ChunkyPNG::Color.b(pixel),
20
+ ChunkyPNG::Color.a(pixel)]
21
+ elsif channels == 3
22
+ [ChunkyPNG::Color.r(pixel),
23
+ ChunkyPNG::Color.g(pixel),
24
+ ChunkyPNG::Color.b(pixel)]
25
+ elsif channels == 1
26
+ [ChunkyPNG::Color.r(pixel)]
27
+ else
28
+ raise "Invalid channel value #{channels}"
29
+ end
30
+
31
+ color_values.map!(&:to_f) if fp_type?(tensor.data_type)
34
32
 
35
33
  color_values
36
34
  end
37
35
  TensorShape.reshape(image_data, [image.height, image.width, channels])
38
36
  end
39
37
 
40
- register_op :encode_png do |_context, tensor, inputs|
38
+ register_op :encode_png do |_context, _tensor, inputs|
41
39
  image_data = inputs[0]
42
40
  height, width, channels = shape_eval(image_data)
43
41
 
@@ -140,17 +140,33 @@ module TensorStream
140
140
  end
141
141
 
142
142
  register_op(%i[argmax arg_max]) do |_context, tensor, inputs|
143
- axis = tensor.options[:axis] || 0
143
+ axis = inputs[1] || 0
144
144
  rank = get_rank(inputs[0])
145
145
  raise TensorStream::InvalidArgumentError, "Expected dimension in the range [#{-rank},#{rank}) but got #{axis}" if axis < -rank || axis >= rank
146
- get_op_with_axis(inputs[0], axis, 0, tensor.data_type)
146
+
147
+ new_shape = shape_eval(inputs[0])
148
+ ns = new_shape.each_with_index.collect do |shape, index|
149
+ next nil if index == axis
150
+
151
+ shape
152
+ end.compact
153
+
154
+ Tensor.cast_dtype(TensorShape.reshape(get_op_with_axis(inputs[0], axis, 0, :max), ns), tensor.data_type)
147
155
  end
148
156
 
149
157
  register_op(%i[argmin arg_min]) do |_context, tensor, inputs|
150
- axis = tensor.options[:axis] || 0
158
+ axis = inputs[1] || 0
151
159
  rank = get_rank(inputs[0])
152
160
  raise TensorStream::InvalidArgumentError, "Expected dimension in the range [#{-rank},#{rank}) but got #{axis}" if axis < -rank || axis >= rank
153
- get_op_with_axis(inputs[0], axis, 0, tensor.data_type, ->(a, b) { a < b })
161
+
162
+ new_shape = shape_eval(inputs[0])
163
+ ns = new_shape.each_with_index.collect do |shape, index|
164
+ next nil if index == axis
165
+
166
+ shape
167
+ end.compact
168
+
169
+ Tensor.cast_dtype(TensorShape.reshape(get_op_with_axis(inputs[0], axis, 0, :min), ns), tensor.data_type)
154
170
  end
155
171
 
156
172
  register_op :cumprod do |context, tensor, inputs|
@@ -16,11 +16,11 @@ module TensorStream
16
16
  assign = tensor.inputs[0] || tensor
17
17
  assign_acc = tensor.inputs[1]
18
18
  assign_acc.value = multi_array_op(->(t, u) { t * momentum + u }, momentum_var, grad)
19
- if tensor.options[:use_nesterov]
20
- assign.value = multi_array_op(->(v, g, acc) { v - (g * learning_rate + acc * momentum * learning_rate) }, target_var, grad, momentum_var)
21
- else
22
- assign.value = multi_array_op(->(v, acc) { v - acc * learning_rate }, target_var, momentum_var)
23
- end
19
+ assign.value = if tensor.options[:use_nesterov]
20
+ multi_array_op(->(v, g, acc) { v - (g * learning_rate + acc * momentum * learning_rate) }, target_var, grad, momentum_var)
21
+ else
22
+ multi_array_op(->(v, acc) { v - acc * learning_rate }, target_var, momentum_var)
23
+ end
24
24
  assign.value
25
25
  end
26
26
 
@@ -207,6 +207,10 @@ module TensorStream
207
207
  TensorShape.reshape(arr, input_shape)
208
208
  end
209
209
  end
210
+
211
+ register_op :relu6 do |context, tensor, inputs|
212
+ call_vector_op(tensor, :relu6, inputs[0], inputs[1], context, ->(t, u) { [[t, 0].max, 6].min })
213
+ end
210
214
  end
211
215
  end
212
216
  end
@@ -25,24 +25,24 @@ module TensorStream
25
25
  generate_vector(shape, generator: generator)
26
26
  end
27
27
 
28
- register_op :random_uniform, no_eval: true do |_context, tensor, _inputs|
28
+ register_op :random_uniform, no_eval: true do |_context, tensor, inputs|
29
29
  maxval = tensor.options.fetch(:maxval, 1)
30
30
  minval = tensor.options.fetch(:minval, 0)
31
31
  seed = tensor.options[:seed]
32
32
 
33
33
  random = _get_randomizer(tensor, seed)
34
34
  generator = -> { random.rand * (maxval - minval) + minval }
35
- shape = tensor.options[:shape] || tensor.shape.shape
35
+ shape = inputs[0] || tensor.shape.shape
36
36
  generate_vector(shape, generator: generator)
37
37
  end
38
38
 
39
- register_op :random_standard_normal, no_eval: true do |_context, tensor, _inputs|
39
+ register_op :random_standard_normal, no_eval: true do |_context, tensor, inputs|
40
40
  seed = tensor.options[:seed]
41
41
  random = _get_randomizer(tensor, seed)
42
42
  r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev), -> { random.rand })
43
43
  random = _get_randomizer(tensor, seed)
44
44
  generator = -> { r.rand }
45
- shape = tensor.options[:shape] || tensor.shape.shape
45
+ shape = inputs[0] || tensor.shape.shape
46
46
  generate_vector(shape, generator: generator)
47
47
  end
48
48
  end
@@ -289,7 +289,7 @@ module TensorStream
289
289
  # assertions to make sure inferred shapes == actual evaluated shapes
290
290
  if tensor.shape.known? && (result.is_a?(Array) || result.is_a?(Float) || result.is_a?(Integer))
291
291
  if shape_eval(result) != tensor.shape.shape
292
- raise "assert error #{tensor.name} #{shape_eval(result)} != #{tensor.shape.shape}"
292
+ # raise "assert error #{tensor.name} #{shape_eval(result)} != #{tensor.shape.shape}"
293
293
  end
294
294
  end
295
295
 
@@ -313,21 +313,20 @@ module TensorStream
313
313
  @context[tensor.name] = result
314
314
  end
315
315
  rescue EvaluatorExcecutionException => e
316
- raise e, "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math(true, 1)} defined at #{tensor.source}"
316
+ raise e, "error #{e.message} while evaluating #{tensor.name} defined at #{tensor.source}"
317
317
  rescue TensorStreamError => e
318
- raise e, "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math(true, 1)} defined at #{tensor.source}"
318
+ raise e, "error #{e.message} while evaluating #{tensor.name} defined at #{tensor.source}"
319
319
  rescue StandardError => e
320
-
321
- a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
322
- b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
320
+ # a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
321
+ # b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
323
322
  puts e.message
324
323
  puts e.backtrace.join("\n")
325
324
  # shape_a = a.shape.shape if a
326
325
  # shape_b = b.shape.shape if b
327
326
  # dtype_a = a.data_type if a
328
327
  # dtype_b = b.data_type if b
329
- a = complete_eval(a, child_context)
330
- b = complete_eval(b, child_context)
328
+ # a = complete_eval(a, child_context)
329
+ # b = complete_eval(b, child_context)
331
330
  # puts "name: #{tensor.given_name}"
332
331
  # # puts "op: #{tensor.to_math(true, 1)}"
333
332
  # puts "A #{shape_a} #{dtype_a}: #{a}" if a
@@ -363,36 +362,17 @@ module TensorStream
363
362
 
364
363
  private
365
364
 
366
- def get_op_with_axis(a, target_axis, current_axis, output_type, op = ->(t, u) { t > u })
367
- if target_axis == current_axis
368
- if a[0].is_a?(Array)
369
- (0...a[0].size).each.collect do |column_index|
370
- max = nil
371
- max_index = 0
372
- a.each_with_index do |row, row_index|
373
- if max.nil? || op.call(row[column_index], max)
374
- max = row[column_index]
375
- max_index = row_index
376
- end
377
- end
365
+ def get_op_with_axis(a, target_axis, current_axis, op)
366
+ rank = get_rank(a)
367
+ return a.index(a.send(:"#{op}")) if rank == 1
378
368
 
379
- Tensor.cast_dtype(max_index, output_type)
380
- end
381
- else
382
- max = nil
383
- max_index = 0
384
- a.each_with_index do |x, index|
385
- if max.nil? || op.call(x, max)
386
- max = x
387
- max_index = index
388
- end
389
- end
390
- Tensor.cast_dtype(max_index, output_type)
391
- end
369
+ if current_axis == target_axis
370
+ compare_items = a.collect(&:flatten).transpose
371
+ compare_items.map { |item| item.index(item.send(:"#{op}")) }
372
+ elsif a[0].is_a?(Array)
373
+ a.map { |item| get_op_with_axis(item, target_axis, current_axis + 1, op) }
392
374
  else
393
- a.collect do |row|
394
- get_op_with_axis(row, target_axis, current_axis + 1, output_type, op)
395
- end
375
+ return a.index(a.send(:"#{op}"))
396
376
  end
397
377
  end
398
378
 
@@ -403,31 +383,6 @@ module TensorStream
403
383
  reduce(val, axis, keep_dims, func)
404
384
  end
405
385
 
406
- def arr_pad(arr, paddings, data_type = :float32, rank = 0)
407
- raise "padding #{paddings[rank]} needs to have to elements [before, after]" if paddings[rank].size != 2
408
-
409
- before = paddings[rank][0]
410
- after = paddings[rank][1]
411
- pad_value = fp_type?(data_type) ? 0.0 : 0
412
- if arr[0].is_a?(Array)
413
- next_dim_elem = arr.collect { |a| arr_pad(a, paddings, data_type, rank + 1) }
414
- padding = deep_dup_array(next_dim_elem[0], pad_value)
415
- Array.new(before) { padding } + next_dim_elem + Array.new(after) { padding }
416
- else
417
- Array.new(before) { pad_value } + arr + Array.new(after) { pad_value }
418
- end
419
- end
420
-
421
- def deep_dup_array(arr, value = nil)
422
- if arr.is_a?(Array)
423
- arr.dup.collect do |a|
424
- deep_dup_array(a, value)
425
- end
426
- else
427
- value.nil? ? arr : value
428
- end
429
- end
430
-
431
386
  def call_op(op, a, child_context, func)
432
387
  a = complete_eval(a, child_context)
433
388
  process_function_op(a, func)
@@ -14,6 +14,7 @@ module TensorStream
14
14
  parsed_tree = protobuf.load_from_string(buffer)
15
15
  parsed_tree.each do |node|
16
16
  next unless node['type'] == 'node'
17
+
17
18
  # puts "build #{node['name']}"
18
19
  options = protobuf.options_evaluator(node)
19
20
  options[:name] = node['name']
@@ -34,7 +34,7 @@ module TensorStream
34
34
  arr_buf << "</node>"
35
35
 
36
36
  to_graph_ml(tensor, arr_buf, {}, groups)
37
- #dump groups
37
+ # dump groups
38
38
  groups.each do |k, g|
39
39
  arr_buf << create_group(k, k, g)
40
40
  end
@@ -15,6 +15,7 @@ module TensorStream
15
15
  @lines << " op: #{camelize(node.operation.to_s).to_json}"
16
16
  node.inputs.each do |input|
17
17
  next unless input
18
+
18
19
  @lines << " input: #{input.name.to_json}"
19
20
  end
20
21
  # type
@@ -0,0 +1,182 @@
1
+ require 'tensor_stream/evaluator/operation_helpers/array_ops_helper'
2
+ module TensorStream
3
+ ##
4
+ # Convenience class for guessing the shape of a tensor
5
+ #
6
+ class InferShape
7
+ extend TensorStream::ArrayOpsHelper
8
+ extend TensorStream::OpHelper
9
+
10
+ def self.infer_shape(tensor)
11
+ case tensor.operation
12
+ when :assign
13
+ possible_shape = if tensor.inputs[0] && tensor.inputs[0].shape.shape
14
+ tensor.inputs[0].shape.shape
15
+ else
16
+ tensor.inputs[1].shape.shape
17
+ end
18
+
19
+ possible_shape
20
+ when :index
21
+ return nil unless tensor.inputs[0].is_a?(Tensor)
22
+ return nil unless tensor.inputs[0].const_value
23
+
24
+ input_shape = tensor.inputs[0].shape
25
+ return nil unless input_shape.known?
26
+
27
+ s = input_shape.shape.dup
28
+ s.shift
29
+ s
30
+ when :arg_min, :argmax, :argmin
31
+ return nil unless tensor.inputs[0].shape.known?
32
+ return nil if tensor.inputs[1] && tensor.inputs[1].value.nil?
33
+
34
+ axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].value
35
+ new_shape = tensor.inputs[0].shape.shape
36
+ new_shape.each_with_index.collect do |shape, index|
37
+ next nil if index == axis
38
+
39
+ shape
40
+ end.compact
41
+ when :mean, :prod, :sum, :arg_max
42
+ return [] if tensor.inputs[1].nil?
43
+ return nil if tensor.inputs[0].nil?
44
+ return nil unless tensor.inputs[0].shape.known?
45
+
46
+ input_shape = tensor.inputs[0].shape.shape
47
+ rank = input_shape.size
48
+
49
+ axis = tensor.inputs[1].const_value
50
+ return nil if axis.nil?
51
+
52
+ axis = [axis] unless axis.is_a?(Array)
53
+ axis = axis.map { |a| a < 0 ? rank - a.abs : a }
54
+
55
+ input_shape.each_with_index.map do |item, index|
56
+ if axis.include?(index)
57
+ next 1 if tensor.options[:keepdims]
58
+
59
+ next nil
60
+ end
61
+ item
62
+ end.compact
63
+ when :reshape
64
+ new_shape = tensor.inputs[1] && tensor.inputs[1].value ? tensor.inputs[1].value : nil
65
+ return nil if new_shape.nil?
66
+ return nil if tensor.inputs[0].shape.nil?
67
+
68
+ input_shape = tensor.inputs[0].shape.shape
69
+ return new_shape if input_shape.nil?
70
+ return nil if input_shape.include?(nil)
71
+ TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
72
+ when :flow_group
73
+ []
74
+ when :zeros, :ones, :fill, :random_standard_normal, :random_uniform
75
+ a_shape = tensor.inputs[0] ? tensor.inputs[0].const_value : tensor.options[:shape]
76
+ return nil if a_shape.nil?
77
+ a_shape.is_a?(Array) ? a_shape : [a_shape]
78
+ when :zeros_like, :ones_like
79
+ tensor.inputs[0].shape.shape
80
+ when :shape
81
+ tensor.inputs[0].shape.shape ? [tensor.inputs[0].shape.shape.size] : nil
82
+ when :pad
83
+ return nil unless tensor.inputs[0].shape.known?
84
+ return nil unless tensor.inputs[1].value
85
+
86
+ size = tensor.inputs[0].shape.shape.reduce(:*) || 1
87
+ dummy_tensor_for_shape = TensorShape.reshape(Array.new(size), tensor.inputs[0].shape)
88
+ shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].value))
89
+ when :mat_mul
90
+ return nil if tensor.inputs[0].shape.shape.nil? || tensor.inputs[1].shape.shape.nil?
91
+ return [] if tensor.inputs[0].shape.shape.empty? || tensor.inputs[1].shape.shape.empty?
92
+ return nil if tensor.inputs[0].shape.shape.size != 2 || tensor.inputs[1].shape.shape.size != 2
93
+
94
+ shape1, m = if tensor.options[:transpose_a]
95
+ [tensor.inputs[0].shape.shape[0], tensor.inputs[0].shape.shape[1]]
96
+ else
97
+ [tensor.inputs[0].shape.shape[1], tensor.inputs[0].shape.shape[0]]
98
+ end
99
+
100
+ shape2, n = if tensor.options[:transpose_b]
101
+ [tensor.inputs[1].shape.shape[1], tensor.inputs[1].shape.shape[0]]
102
+ else
103
+ [tensor.inputs[1].shape.shape[0], tensor.inputs[1].shape.shape[1]]
104
+ end
105
+
106
+ return nil if shape1.nil? || shape2.nil? || shape1 < 0 || shape2 < 0
107
+
108
+ raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{shape1} != #{shape2}) #{tensor.inputs[0].shape.shape} vs #{tensor.inputs[1].shape.shape}" if shape1 != shape2
109
+
110
+ [m, n]
111
+ when :transpose
112
+ return nil unless shape_full_specified(tensor.inputs[0])
113
+ return nil if tensor.inputs[1].is_a?(Tensor)
114
+
115
+ rank = tensor.inputs[0].shape.shape.size
116
+ perm = tensor.inputs[1] || (0...rank).to_a.reverse
117
+ perm.map { |p| tensor.inputs[0].shape.shape[p] }
118
+ when :stack
119
+ return nil unless shape_full_specified(tensor.inputs[0])
120
+
121
+ axis = tensor.options[:axis] || 0
122
+ new_shape = [tensor.inputs.size]
123
+ tensor.inputs[0].shape.shape.inject(new_shape) { |ns, i| ns << i }
124
+ rank = tensor.inputs[0].shape.shape.size + 1
125
+ axis = rank + axis if axis < 0
126
+ rotated_shape = Array.new(axis + 1) { new_shape.shift }
127
+ rotated_shape.rotate! + new_shape
128
+ when :concat
129
+ return nil if tensor.inputs[0].value.nil?
130
+
131
+ axis = tensor.inputs[0].value # get axis
132
+
133
+ axis_size = 0
134
+
135
+ tensor.inputs[1..tensor.inputs.size].each do |input_item|
136
+ return nil if input_item.shape.shape.nil?
137
+ return nil if input_item.shape.shape[axis].nil?
138
+
139
+ axis_size += input_item.shape.shape[axis]
140
+ end
141
+
142
+ new_shape = tensor.inputs[1].shape.shape.dup
143
+ new_shape[axis] = axis_size
144
+ new_shape
145
+ when :slice, :squeeze
146
+ nil
147
+ when :tile
148
+ nil
149
+ when :expand_dims
150
+ nil
151
+ when :broadcast_gradient_args
152
+ nil
153
+ when :no_op
154
+ nil
155
+ when :softmax_cross_entropy_with_logits_v2, :sparse_softmax_cross_entropy_with_logits
156
+ nil
157
+ when :decode_png, :flow_dynamic_stitch, :dynamic_stitch, :gather
158
+ nil
159
+ when :eye
160
+ return [tensor.inputs[0].const_value, tensor.inputs[1].const_value] if tensor.inputs[0].const_value && tensor.inputs[1].const_value
161
+
162
+ nil
163
+ when :size
164
+ []
165
+ when :unstack
166
+ return nil unless tensor.inputs[0].shape.known?
167
+
168
+ new_shape = tensor.inputs[0].shape.shape.dup
169
+ rank = new_shape.size - 1
170
+ axis = tensor.options[:axis] || 0
171
+ axis = rank + axis if axis < 0
172
+ rotated_shape = Array.new(axis + 1) { new_shape.shift }
173
+ rotated_shape.rotate!(-1) + new_shape
174
+ else
175
+ return nil if tensor.inputs[0].nil?
176
+ return tensor.inputs[0].shape.shape if tensor.inputs.size == 1
177
+
178
+ TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1]
179
+ end
180
+ end
181
+ end
182
+ end
@@ -56,8 +56,8 @@ module TensorStream
56
56
 
57
57
  def format_source(trace)
58
58
  grad_source = trace.select { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }.first
59
- source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
60
- [grad_source, source].compact.join("\n")
59
+ # source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
60
+ [grad_source, trace].compact.join("\n")
61
61
  end
62
62
 
63
63
  def shapes_fully_specified_and_equal(x, y)
@@ -8,7 +8,7 @@ module TensorStream
8
8
  end
9
9
 
10
10
  def self.encode_png(contents, compression: -1, name: nil)
11
- check_allowed_types(contents, [:uint8, :uint16])
11
+ check_allowed_types(contents, %i[uint8 uint16])
12
12
  contents = convert_to_tensor(contents, dtype: :uint16)
13
13
  _op(:encode_png, contents, compression: compression, name: name)
14
14
  end