tensor_stream 1.0.4 → 1.0.9

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/CHANGELOG.md +12 -2
  4. data/Dockerfile +1 -1
  5. data/USAGE_GUIDE.md +68 -0
  6. data/lib/tensor_stream.rb +1 -0
  7. data/lib/tensor_stream/evaluator/base_evaluator.rb +21 -1
  8. data/lib/tensor_stream/evaluator/evaluator.rb +1 -0
  9. data/lib/tensor_stream/evaluator/evaluator_utils.rb +20 -0
  10. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +60 -0
  11. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +53 -1
  12. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +26 -0
  13. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +60 -5
  14. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +25 -29
  15. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +7 -11
  16. data/lib/tensor_stream/evaluator/ruby/storage_manager.rb +40 -0
  17. data/lib/tensor_stream/evaluator/ruby/variable_ops.rb +74 -0
  18. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +31 -77
  19. data/lib/tensor_stream/generated_stub/ops.rb +256 -166
  20. data/lib/tensor_stream/generated_stub/stub_file.erb +4 -4
  21. data/lib/tensor_stream/graph.rb +3 -3
  22. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +4 -6
  23. data/lib/tensor_stream/helpers/infer_shape.rb +1 -7
  24. data/lib/tensor_stream/helpers/tensor_mixins.rb +10 -1
  25. data/lib/tensor_stream/images.rb +4 -0
  26. data/lib/tensor_stream/math/math_ops.rb +22 -0
  27. data/lib/tensor_stream/math_gradients.rb +15 -1
  28. data/lib/tensor_stream/nn/embedding_lookup.rb +114 -0
  29. data/lib/tensor_stream/nn/nn_ops.rb +16 -0
  30. data/lib/tensor_stream/op_maker.rb +36 -3
  31. data/lib/tensor_stream/operation.rb +8 -20
  32. data/lib/tensor_stream/ops.rb +14 -11
  33. data/lib/tensor_stream/ops/bias_add.rb +16 -0
  34. data/lib/tensor_stream/ops/equal.rb +4 -0
  35. data/lib/tensor_stream/ops/greater.rb +4 -0
  36. data/lib/tensor_stream/ops/greater_equal.rb +4 -0
  37. data/lib/tensor_stream/ops/less.rb +19 -0
  38. data/lib/tensor_stream/ops/less_equal.rb +4 -0
  39. data/lib/tensor_stream/ops/not_equal.rb +19 -0
  40. data/lib/tensor_stream/ops/rsqrt.rb +11 -0
  41. data/lib/tensor_stream/ops/strided_slice.rb +24 -0
  42. data/lib/tensor_stream/ops/sum.rb +4 -2
  43. data/lib/tensor_stream/ops/top_k.rb +23 -0
  44. data/lib/tensor_stream/session.rb +6 -12
  45. data/lib/tensor_stream/tensor.rb +1 -0
  46. data/lib/tensor_stream/tensor_shape.rb +32 -1
  47. data/lib/tensor_stream/train/saver.rb +2 -3
  48. data/lib/tensor_stream/utils.rb +18 -13
  49. data/lib/tensor_stream/utils/freezer.rb +5 -1
  50. data/lib/tensor_stream/utils/py_ports.rb +11 -0
  51. data/lib/tensor_stream/variable.rb +9 -6
  52. data/lib/tensor_stream/version.rb +1 -1
  53. data/samples/word_embeddings/word_embedding_1.rb +192 -0
  54. data/samples/word_embeddings/word_embedding_2.rb +203 -0
  55. data/tensor_stream.gemspec +7 -2
  56. metadata +67 -10
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a5d2223b3321554a5529fc7be45f6be485f66d0024d8b3ff4fadb8adfb20ba2a
4
- data.tar.gz: 229532234896767058fc8d3dff0897650504345aaed1ce3d57b13d09f017f19e
3
+ metadata.gz: 8f7d54f45a96ee2ed86af5916339747701171476e2b1cd6197f4f78d3f7f2eb3
4
+ data.tar.gz: f8a1c615ebf5f67de35e0e6ac84ac531fc5306b62787ac0d2d952f474eb97bad
5
5
  SHA512:
6
- metadata.gz: '084f3d8fa7e74fccdbcbc42e004716520e9dc7a0cf7568ecbab8182e21ea0930ee58da9ca262b245e28e38df99e20f7dca029007c4e11030a3f6d62b4dd2e089'
7
- data.tar.gz: a0c9e9987f557ed3cb301114c8f7bb6a91e1344dfeb96d72fbb4c61d850568e1d73ae8b060cabb655954a1d7b36379579eb221abfb955204454e28eae94cdd03
6
+ metadata.gz: 4607a3c117c98f21594bcbf12b98a1e927dab88c2847d4516c4f3dd3502b821e8d4b2b8e17e085c3ed122eec1e339baced66e3822fed2c5c6e3c8e10d0121e08
7
+ data.tar.gz: dd2f7b6c971a25b90a4404319231de6386aef0b536a1df9eca84a3a881e6e9e2faa903fd7fb72b9a05ba8588ec5ffdd153be0072e6c850960fff9ecaecf7b6bc
data/.gitignore CHANGED
@@ -7,6 +7,7 @@
7
7
  /pkg/
8
8
  /spec/reports/
9
9
  /tmp/
10
+ /embeddings/
10
11
  *.gem
11
12
  samples/.ipynb_checkpoints/
12
13
 
@@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file.
4
4
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
5
5
  and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
6
6
 
7
+ ## [1.0.7] - 2019-04-08
8
+ - [NEW] - Support for nn.embedding_lookup
9
+ - [NEW] - l2_normalize, dynamic_partition
10
+ - [NEW OP] - New Ops: rsqrt, top_k, strided_slice
11
+ - [NEW] - Support for ranges in tensors (e.g. t[0...2] via strided slice)
12
+ - [SAMPLES] - Add samples for handling word vectors
13
+
14
+ ## [1.0.5] - 2019-03-20
15
+ - [BUG FIX] - Fix not wrapping a stack op on some arrays. Should fix rnn sample
16
+
7
17
  ## [0.9.10] - 2019-01-02
8
18
  - [BUG FIX] - remove pry-byebug include (Thanks @samgooi4189)
9
19
  - Update Changelog for 0.9.9
@@ -19,7 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
19
29
  - [NEW OP] Convolutional networks - conv2d, conv2d_backprop_filter, conv2d_backprop_input
20
30
  - [IMAGE] Exposed image resampling options
21
31
  - [BUG FIX] fix argmin, argmax handling of NaN values
22
-
32
+
23
33
  ## [0.9.5] - 2018-11-05
24
34
  - [NEW OP] assert_equal, relu6
25
35
  - [TRAINING] learning_rate_decay, dropout
@@ -134,4 +144,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
134
144
  - reworked auto differentiation, fix a number of bugs related to auto differentiation, smaller derivative programs
135
145
  - alpha support for saving to pbtext format, added graphml generation
136
146
  - significant number of ops added
137
- - ops that support broadcasting now work better
147
+ - ops that support broadcasting now work better
data/Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM circleci/ruby:2.4.1-node-browsers
1
+ FROM circleci/ruby:2.6.1-node-browsers
2
2
  RUN sudo apt-get update -q && sudo apt-get install --no-install-recommends -yq alien wget unzip clinfo \
3
3
  && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
4
4
  RUN export DEVEL_URL="https://software.intel.com/file/531197/download" \
@@ -223,6 +223,74 @@ vars = graph.get_collection(TensorStream::GraphKeys::GLOBAL_VARIABLES)
223
223
  => [Variable(Variable:0 shape: TensorShape([]) data_type: float32)]
224
224
  ```
225
225
 
226
+ High Performance Computing
227
+ --------------------------
228
+
229
+ TensorStream has been designed from the ground up to support multiple execution backends.
230
+
231
+ What this means is you can build your models once and then be able to execute them later on specialized hardware when available like GPUs.
232
+
233
+ An OpenCL backend is available that you can use for compute intensive taks like machine learning, especially those that use convolutional networks.
234
+
235
+ Using OpenCL is as simple as installing the tensorstream-opencl gem
236
+
237
+ ```
238
+ gem install tensor_stream-opencl
239
+ ```
240
+
241
+ You can then require the library in your programs and it will get used automatically (assuming you also installed OpenCL drivers for your system)
242
+
243
+ ```ruby
244
+ require 'tensor_stream'
245
+
246
+ # enable OpenCL
247
+ require 'tensor_stream/opencl'
248
+
249
+ tf = TensorStream
250
+
251
+ srand(5)
252
+ seed = 5
253
+ tf.set_random_seed(seed)
254
+
255
+ SHAPES = [32, 32]
256
+ tf = TensorStream
257
+ sess = tf.session
258
+ large_tensor = tf.constant(sess.run(tf.random_uniform([256, 256])))
259
+
260
+ sum_axis_1 = tf.reduce_sum(large_tensor, 1)
261
+ sess.run(sum_axis_1)
262
+ ```
263
+
264
+ Using OpenCL can improve performance dramatically in scenarios involving large tensors:
265
+
266
+ ```
267
+ Linux 4.15.0-46-generic #49-Ubuntu SMP
268
+ model name : AMD Ryzen 3 1300X Quad-Core Processor
269
+ OpenCL device NVIDIA CUDA GeForce GTX 1060 6GB
270
+ ruby 2.6.2p47 (2019-03-13 revision 67232) [x86_64-linux]
271
+
272
+ user system total real
273
+ pure ruby softmax : 0.024724 0.000000 0.024724 ( 0.024731)
274
+ opencl softmax : 0.006237 0.003945 0.010182 ( 0.009005)
275
+ pure ruby matmul : 0.679538 0.000000 0.679538 ( 0.680048)
276
+ opencl matmul : 0.003456 0.007965 0.011421 ( 0.008568)
277
+ pure ruby sum : 3.210619 0.000000 3.210619 ( 3.210064)
278
+ opencl sum : 0.002431 0.008030 0.010461 ( 0.007522)
279
+ pure ruby sum axis 1 : 3.208789 0.000000 3.208789 ( 3.208125)
280
+ opencl sum axis 1 : 0.006075 0.003963 0.010038 ( 0.007679)
281
+ pure ruby conv2d_backprop : 3.738167 0.000000 3.738167 ( 3.737946)
282
+ opencl conv2d_backprop : 0.031267 0.003958 0.035225 ( 0.030381)
283
+ pure ruby conv2d : 0.794182 0.000000 0.794182 ( 0.794100)
284
+ opencl conv2d : 0.015865 0.004020 0.019885 ( 0.016878)
285
+ ```
286
+
287
+ A quick glance shows not a marginal increase but an order of magnitude performance increase in most operations.
288
+ In fact we are looking at almost a 200x faster compute on operations like matmul and softmax (essential operations in machine learning). This is not a surprise because of the "embarrasingly" parallel nature of machine learning computation. Because of this, GPUs are basically a requirement in most machine learning tasks.
289
+
290
+ The code containing these benchmarks can be found at:
291
+
292
+ tensor_stream-opencl/benchmark/benchmark.rb
293
+
226
294
  Limitations
227
295
  -----------
228
296
 
@@ -23,6 +23,7 @@ require "tensor_stream/operation"
23
23
  require "tensor_stream/placeholder"
24
24
  require "tensor_stream/control_flow"
25
25
  require "tensor_stream/dynamic_stitch"
26
+ require "tensor_stream/math/math_ops"
26
27
  require "tensor_stream/nn/nn_ops"
27
28
  require "tensor_stream/evaluator/evaluator"
28
29
  require "tensor_stream/graph_serializers/packer"
@@ -2,11 +2,17 @@ module TensorStream
2
2
  # Evaluator base module
3
3
  module Evaluator
4
4
  class OutputGroup
5
+ include Enumerable
6
+
5
7
  attr_accessor :outputs, :data_types
6
8
  def initialize(outputs = [], data_types = [])
7
9
  @outputs = outputs
8
10
  @data_types = data_types
9
11
  end
12
+
13
+ def each
14
+ @outputs.map { |output| yield output }
15
+ end
10
16
  end
11
17
 
12
18
  class UnsupportedOp < RuntimeError
@@ -131,7 +137,7 @@ module TensorStream
131
137
  time.to_i * (10**9) + time.nsec
132
138
  end
133
139
 
134
- instance_exec(execution_context, tensor, resolved_inputs, &op[:block]).tap do
140
+ instance_exec(execution_context, tensor, resolved_inputs, &op[:block]).tap do |result|
135
141
  if profile_enabled?
136
142
  time = Time.now
137
143
  end_time = time.to_i * (10**9) + time.nsec
@@ -222,11 +228,25 @@ module TensorStream
222
228
 
223
229
  def self.register_evaluator(klass, name, index = 0)
224
230
  @evaluators ||= {}
231
+ @storage_managers ||= {}
225
232
  @evaluators[name] = {name: name, class: klass, index: index}
233
+ @storage_managers[klass] = klass.get_storage_manager
226
234
  end
227
235
 
228
236
  def self.default_evaluators
229
237
  evaluators.values.sort { |v| v[:index] }.reverse.map { |v| v[:class] }
230
238
  end
239
+
240
+ def self.clear_storages(graph)
241
+ @storage_managers.values.each { |manager| manager.clear_variables(graph) }
242
+ end
243
+
244
+ def self.read_variable(graph, name)
245
+ @storage_managers.values.each do |manager|
246
+ return manager.read_value(graph, name) if manager.exists?(graph, name)
247
+ end
248
+
249
+ nil
250
+ end
231
251
  end
232
252
  end
@@ -1,5 +1,6 @@
1
1
  require "tensor_stream/evaluator/ruby_evaluator"
2
2
  require "tensor_stream/evaluator/buffer"
3
+ require "tensor_stream/evaluator/evaluator_utils"
3
4
 
4
5
  module TensorStream
5
6
  module Evaluator
@@ -0,0 +1,20 @@
1
+ module TensorStream
2
+ class EvaluatorUtils
3
+ extend TensorStream::StringHelper
4
+
5
+ def self.get_evaluator_classes(evaluators)
6
+ @evaluator_classes ||= if evaluators.is_a?(Array)
7
+ if evaluators.empty?
8
+ TensorStream::Evaluator.default_evaluators
9
+ else
10
+ evaluators.collect { |name| Object.const_get("TensorStream::Evaluator::#{camelize(name.to_s)}") }
11
+ end
12
+ elsif evaluators.nil?
13
+ TensorStream::Evaluator.default_evaluators
14
+ else
15
+ [Object.const_get("TensorStream::Evaluator::#{camelize(evaluators.to_s)}")]
16
+ end
17
+ @evaluator_classes
18
+ end
19
+ end
20
+ end
@@ -30,6 +30,16 @@ module TensorStream
30
30
  end
31
31
  end
32
32
 
33
+ def array_set!(input, value)
34
+ input.each_with_index do |element, index|
35
+ if element.is_a?(Array)
36
+ array_set(element, value)
37
+ else
38
+ input[index] = value[index]
39
+ end
40
+ end
41
+ end
42
+
33
43
  def truncate(input, target_shape)
34
44
  rank = get_rank(input)
35
45
  return input if rank.zero?
@@ -331,5 +341,55 @@ module TensorStream
331
341
  value.nil? ? arr : value
332
342
  end
333
343
  end
344
+
345
+ def strided_slice(value, slices = [])
346
+ current_slice = slices.dup
347
+ selection = current_slice.shift
348
+ return value if selection.nil?
349
+
350
+ b, e, stride = selection
351
+
352
+ b = value.size + b if b < 0
353
+ e = value.size + e + 1 if e < 0
354
+
355
+ indexes = if stride < 0
356
+ b.downto(e).select.with_index { |elem, index| (index % stride.abs) == 0 }
357
+ else
358
+ (b...e).step(stride)
359
+ end
360
+
361
+ indexes.map do |index|
362
+ strided_slice(value[index], current_slice)
363
+ end
364
+ end
365
+
366
+ def strided_slice_grad(value, grad, x, slices)
367
+ current_slice = slices.dup
368
+ selection = current_slice.shift
369
+ current_shape = x.shift
370
+
371
+ if selection.nil?
372
+ array_set!(value, grad)
373
+ end
374
+
375
+ b, e, stride = selection
376
+
377
+ b = value.size + b if b < 0
378
+ e = value.size + e + 1 if e < 0
379
+
380
+ indexes = if stride < 0
381
+ b.downto(e).select.with_index { |elem, index| (index % stride.abs) == 0 }
382
+ else
383
+ (b...e).step(stride)
384
+ end
385
+
386
+ indexes.each_with_index do |index, grad_index|
387
+ if (value[index].is_a?(Array))
388
+ strided_slice_grad(value[index], grad[grad_index], x.dup, current_slice.dup)
389
+ else
390
+ value[index] = grad[grad_index]
391
+ end
392
+ end
393
+ end
334
394
  end
335
395
  end
@@ -22,8 +22,9 @@ module TensorStream
22
22
  merged
23
23
  end
24
24
 
25
- register_op :gather do |_context, _tensor, inputs|
25
+ register_op :gather do |_context, tensor, inputs|
26
26
  params, indexes = inputs
27
+ raise "axis !=0 not supported" if tensor.options[:axis] != 0
27
28
  gather(params, indexes)
28
29
  end
29
30
 
@@ -216,7 +217,14 @@ module TensorStream
216
217
 
217
218
  register_op :range do |_context, _tensor, inputs|
218
219
  start, limit, delta = inputs
220
+
219
221
  raise " delta !=0 " if delta.zero?
222
+
223
+ if limit.zero?
224
+ limit = start
225
+ start = 0
226
+ end
227
+
220
228
  raise " Requires start <= limit when delta > 0" if (start > limit) && delta > 0
221
229
  raise " Requires start >= limit when delta < 0" if (start < limit) && delta < 0
222
230
 
@@ -399,6 +407,50 @@ module TensorStream
399
407
  end
400
408
  end
401
409
 
410
+ register_op :dynamic_partition do |context, tensor, inputs|
411
+ data, partitions = inputs
412
+ num_partitions = tensor.options[:num_partitions]
413
+ output_arr = Array.new(num_partitions) { [] }
414
+
415
+ partitions.each_with_index do |part, index|
416
+ output_arr[part] << data[index]
417
+ end
418
+ TensorStream::Evaluator::OutputGroup.new(output_arr, num_partitions.times.map { tensor.data_type })
419
+ end
420
+
421
+ register_op :gather_grad do |context, tensor, inputs|
422
+ grad, indexes, input_shape = inputs
423
+ output = Array.new(input_shape.reduce(:*)) { fp_type?(tensor.data_type) ? 0.0 : 0 }
424
+ indexes.each_with_index.map do |x, index|
425
+ output[x] += grad[index]
426
+ end
427
+ TensorShape.reshape(output, input_shape)
428
+ end
429
+
430
+ register_op :strided_slice do |_context, _tensor, inputs|
431
+ value, b_index, e_index, stride = inputs
432
+ slices = b_index.zip(e_index).zip(stride).map do |params|
433
+ selection, stride = params
434
+ s, e = selection
435
+ [s, e, stride]
436
+ end
437
+ strided_slice(value, slices)
438
+ end
439
+
440
+ register_op :strided_slice_grad do |_context, tensor, inputs|
441
+ x, b_index, e_index, stride, grad = inputs
442
+ slices = b_index.zip(e_index).zip(stride).map do |params|
443
+ selection, stride = params
444
+ s, e = selection
445
+ [s, e, stride]
446
+ end
447
+
448
+ target_val = generate_vector(x, generator: ->() { fp_type?(tensor.data_type) ? 0.0 : 0 })
449
+
450
+ strided_slice_grad(target_val, grad, x.dup, slices.dup)
451
+ target_val
452
+ end
453
+
402
454
  def merge_dynamic_stitch(merged, indexes, data, context)
403
455
  indexes.each_with_index do |ind, m|
404
456
  if ind.is_a?(Array)
@@ -1,5 +1,6 @@
1
1
  require "chunky_png"
2
2
 
3
+
3
4
  module TensorStream
4
5
  module ImagesOps
5
6
  def self.included(klass)
@@ -49,6 +50,31 @@ module TensorStream
49
50
  TensorShape.reshape(image_data, [image.height, image.width, channels])
50
51
  end
51
52
 
53
+ register_op :decode_jpg do |_context, tensor, inputs|
54
+ require "jpeg"
55
+
56
+ content = inputs[0]
57
+ channels = tensor.options[:channels]
58
+ channels = 3 if channels.zero?
59
+
60
+ image = Jpeg::Image.open_buffer(content)
61
+ source_channels = image.color_info == :gray ? 1 : 3
62
+
63
+ image_data = image.raw_data.map do |pixel|
64
+ if source_channels == channels
65
+ pixel
66
+ elsif source_channels = 1 && channels == 3
67
+ [pixel, pixel, pixel]
68
+ elsif source_channels = 3 && channels == 1
69
+ raise TensorStream::ValueError, "color to grayscale not supported for jpg"
70
+ end
71
+ end.flatten
72
+
73
+ image_data.map!(&:to_f) if fp_type?(tensor.data_type)
74
+
75
+ TensorShape.reshape(image_data, [image.height, image.width, channels])
76
+ end
77
+
52
78
  register_op :encode_png do |_context, tensor, inputs|
53
79
  image_data = inputs[0]
54
80
  height, width, channels = shape_eval(image_data)
@@ -37,6 +37,24 @@ module TensorStream
37
37
  end
38
38
  end
39
39
 
40
+ register_op :bias_add do |_context, _tensor, inputs|
41
+ value, bias = inputs
42
+ arr = value.flatten.each_slice(bias.size).map do |slice|
43
+ slice.each_with_index.map { |elem, index| elem + bias[index] }
44
+ end
45
+ TensorShape.reshape(arr, shape_eval(value))
46
+ end
47
+
48
+ register_op :bias_add_grad do |_context, _tensor, inputs|
49
+ received_grad = inputs[0]
50
+ bias_size = shape_eval(received_grad).last
51
+ grad_sum = Array.new(bias_size) { 0.0 }
52
+ received_grad.flatten.each_slice(bias_size) do |slice|
53
+ slice.each_with_index.map { |elem, index| grad_sum[index] += elem }
54
+ end
55
+ grad_sum
56
+ end
57
+
40
58
  register_op :sub, no_eval: true do |context, tensor, inputs|
41
59
  a, b = inputs
42
60
  call_vector_op(tensor, :sub, a, b, context) { |t, u| t - u }
@@ -111,6 +129,15 @@ module TensorStream
111
129
  call_op(inputs[0], context) { |t, _b| Math.sqrt(t) }
112
130
  end
113
131
 
132
+ register_op :rsqrt, no_eval: true do |context, _tensor, inputs|
133
+ call_op(inputs[0], context) { |t, _b| 1 / Math.sqrt(t) }
134
+ end
135
+
136
+ register_op :rsqrt_grad, no_eval: true do |context, tensor, inputs|
137
+ y, grad = inputs
138
+ call_vector_op(tensor, :rsqrt_grad, y, grad, context) { |_y, g| 0.5 * g * (_y ** 3) }
139
+ end
140
+
114
141
  register_op :floor, no_eval: true do |context, _tensor, inputs|
115
142
  call_op(inputs[0], context) { |t, _b| t.floor }
116
143
  end
@@ -135,6 +162,25 @@ module TensorStream
135
162
  call_op(inputs[0], context) { |t, _b| 1 - Math.tanh(t) * Math.tanh(t) }
136
163
  end
137
164
 
165
+ register_op :top_k do |context, tensor, inputs|
166
+ values, k = inputs
167
+ v_shape = shape_eval(values)
168
+
169
+ sorted = tensor.options[:sorted]
170
+ work_values = TensorShape.reshape(values, [-1, v_shape.last])
171
+ work_values.map! do |row|
172
+ last_k = row.map.with_index { |r, index| [r, index] }.sort! { |a,b| a[0] <=> b[0] }.last(k)
173
+ last_k.reverse! if sorted
174
+ last_k
175
+ end
176
+
177
+ top_k = work_values.map { |row| row.map { |r| r[0] } }
178
+ top_indices = work_values.map { |row| row.map { |r| r[1] } }
179
+ v_shape[-1] = k
180
+
181
+ TensorStream::Evaluator::OutputGroup.new([TensorShape.reshape(top_k, v_shape), TensorShape.reshape(top_indices, v_shape)], [tensor.inputs[0].data_type, :int32])
182
+ end
183
+
138
184
  register_op(%i[argmax arg_max]) do |_context, tensor, inputs|
139
185
  axis = inputs[1] || 0
140
186
  rank = get_rank(inputs[0])
@@ -241,13 +287,22 @@ module TensorStream
241
287
  raise "#{tensor.inputs[0].name} rank must be greater than 1" if rank_a < 2
242
288
  raise "#{tensor.inputs[1].name} rank must be greater than 1" if rank_b < 2
243
289
 
244
- matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
245
- matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
246
-
247
290
  # check matrix dimensions
248
- raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{matrix_a[0].size} != #{matrix_b.size}) #{shape_eval(matrix_a)} vs #{shape_eval(matrix_b)}" if matrix_a[0].size != matrix_b.size
291
+ if rank_a >= 3
292
+ matrix_a.zip(matrix_b).map do |m_a, m_b|
293
+ matmul(m_a, m_b, tensor)
294
+ end
295
+ else
296
+ matmul(matrix_a, matrix_b, tensor)
297
+ end
298
+ end
299
+
300
+ def matmul(m_a, m_b, tensor)
301
+ m_a = m_a.transpose if tensor.options[:transpose_a]
302
+ m_b = m_b.transpose if tensor.options[:transpose_b]
303
+ raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{m_a[0].size} != #{m_b.size}) #{shape_eval(m_a)} vs #{shape_eval(m_b)}" if m_a[0].size != m_b.size
249
304
 
250
- (Matrix[*matrix_a] * Matrix[*matrix_b]).to_a
305
+ (Matrix[*m_a] * Matrix[*m_b]).to_a
251
306
  end
252
307
 
253
308
  register_op %i[max maximum], noop: true do |context, tensor, inputs|