tensor_stream 1.0.4 → 1.0.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/CHANGELOG.md +12 -2
  4. data/Dockerfile +1 -1
  5. data/USAGE_GUIDE.md +68 -0
  6. data/lib/tensor_stream.rb +1 -0
  7. data/lib/tensor_stream/evaluator/base_evaluator.rb +21 -1
  8. data/lib/tensor_stream/evaluator/evaluator.rb +1 -0
  9. data/lib/tensor_stream/evaluator/evaluator_utils.rb +20 -0
  10. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +60 -0
  11. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +53 -1
  12. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +26 -0
  13. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +60 -5
  14. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +25 -29
  15. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +7 -11
  16. data/lib/tensor_stream/evaluator/ruby/storage_manager.rb +40 -0
  17. data/lib/tensor_stream/evaluator/ruby/variable_ops.rb +74 -0
  18. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +31 -77
  19. data/lib/tensor_stream/generated_stub/ops.rb +256 -166
  20. data/lib/tensor_stream/generated_stub/stub_file.erb +4 -4
  21. data/lib/tensor_stream/graph.rb +3 -3
  22. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +4 -6
  23. data/lib/tensor_stream/helpers/infer_shape.rb +1 -7
  24. data/lib/tensor_stream/helpers/tensor_mixins.rb +10 -1
  25. data/lib/tensor_stream/images.rb +4 -0
  26. data/lib/tensor_stream/math/math_ops.rb +22 -0
  27. data/lib/tensor_stream/math_gradients.rb +15 -1
  28. data/lib/tensor_stream/nn/embedding_lookup.rb +114 -0
  29. data/lib/tensor_stream/nn/nn_ops.rb +16 -0
  30. data/lib/tensor_stream/op_maker.rb +36 -3
  31. data/lib/tensor_stream/operation.rb +8 -20
  32. data/lib/tensor_stream/ops.rb +14 -11
  33. data/lib/tensor_stream/ops/bias_add.rb +16 -0
  34. data/lib/tensor_stream/ops/equal.rb +4 -0
  35. data/lib/tensor_stream/ops/greater.rb +4 -0
  36. data/lib/tensor_stream/ops/greater_equal.rb +4 -0
  37. data/lib/tensor_stream/ops/less.rb +19 -0
  38. data/lib/tensor_stream/ops/less_equal.rb +4 -0
  39. data/lib/tensor_stream/ops/not_equal.rb +19 -0
  40. data/lib/tensor_stream/ops/rsqrt.rb +11 -0
  41. data/lib/tensor_stream/ops/strided_slice.rb +24 -0
  42. data/lib/tensor_stream/ops/sum.rb +4 -2
  43. data/lib/tensor_stream/ops/top_k.rb +23 -0
  44. data/lib/tensor_stream/session.rb +6 -12
  45. data/lib/tensor_stream/tensor.rb +1 -0
  46. data/lib/tensor_stream/tensor_shape.rb +32 -1
  47. data/lib/tensor_stream/train/saver.rb +2 -3
  48. data/lib/tensor_stream/utils.rb +18 -13
  49. data/lib/tensor_stream/utils/freezer.rb +5 -1
  50. data/lib/tensor_stream/utils/py_ports.rb +11 -0
  51. data/lib/tensor_stream/variable.rb +9 -6
  52. data/lib/tensor_stream/version.rb +1 -1
  53. data/samples/word_embeddings/word_embedding_1.rb +192 -0
  54. data/samples/word_embeddings/word_embedding_2.rb +203 -0
  55. data/tensor_stream.gemspec +7 -2
  56. metadata +67 -10
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a5d2223b3321554a5529fc7be45f6be485f66d0024d8b3ff4fadb8adfb20ba2a
4
- data.tar.gz: 229532234896767058fc8d3dff0897650504345aaed1ce3d57b13d09f017f19e
3
+ metadata.gz: 8f7d54f45a96ee2ed86af5916339747701171476e2b1cd6197f4f78d3f7f2eb3
4
+ data.tar.gz: f8a1c615ebf5f67de35e0e6ac84ac531fc5306b62787ac0d2d952f474eb97bad
5
5
  SHA512:
6
- metadata.gz: '084f3d8fa7e74fccdbcbc42e004716520e9dc7a0cf7568ecbab8182e21ea0930ee58da9ca262b245e28e38df99e20f7dca029007c4e11030a3f6d62b4dd2e089'
7
- data.tar.gz: a0c9e9987f557ed3cb301114c8f7bb6a91e1344dfeb96d72fbb4c61d850568e1d73ae8b060cabb655954a1d7b36379579eb221abfb955204454e28eae94cdd03
6
+ metadata.gz: 4607a3c117c98f21594bcbf12b98a1e927dab88c2847d4516c4f3dd3502b821e8d4b2b8e17e085c3ed122eec1e339baced66e3822fed2c5c6e3c8e10d0121e08
7
+ data.tar.gz: dd2f7b6c971a25b90a4404319231de6386aef0b536a1df9eca84a3a881e6e9e2faa903fd7fb72b9a05ba8588ec5ffdd153be0072e6c850960fff9ecaecf7b6bc
data/.gitignore CHANGED
@@ -7,6 +7,7 @@
7
7
  /pkg/
8
8
  /spec/reports/
9
9
  /tmp/
10
+ /embeddings/
10
11
  *.gem
11
12
  samples/.ipynb_checkpoints/
12
13
 
@@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file.
4
4
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
5
5
  and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
6
6
 
7
+ ## [1.0.7] - 2019-04-08
8
+ - [NEW] - Support for nn.embedding_lookup
9
+ - [NEW] - l2_normalize, dynamic_partition
10
+ - [NEW OP] - New Ops: rsqrt, top_k, strided_slice
11
+ - [NEW] - Support for ranges in tensors (e.g. t[0...2] via strided slice)
12
+ - [SAMPLES] - Add samples for handling word vectors
13
+
14
+ ## [1.0.5] - 2019-03-20
15
+ - [BUG FIX] - Fix not wrapping a stack op on some arrays. Should fix rnn sample
16
+
7
17
  ## [0.9.10] - 2019-01-02
8
18
  - [BUG FIX] - remove pry-byebug include (Thanks @samgooi4189)
9
19
  - Update Changelog for 0.9.9
@@ -19,7 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
19
29
  - [NEW OP] Convolutional networks - conv2d, conv2d_backprop_filter, conv2d_backprop_input
20
30
  - [IMAGE] Exposed image resampling options
21
31
  - [BUG FIX] fix argmin, argmax handling of NaN values
22
-
32
+
23
33
  ## [0.9.5] - 2018-11-05
24
34
  - [NEW OP] assert_equal, relu6
25
35
  - [TRAINING] learning_rate_decay, dropout
@@ -134,4 +144,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
134
144
  - reworked auto differentiation, fix a number of bugs related to auto differentiation, smaller derivative programs
135
145
  - alpha support for saving to pbtext format, added graphml generation
136
146
  - significant number of ops added
137
- - ops that support broadcasting now work better
147
+ - ops that support broadcasting now work better
data/Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM circleci/ruby:2.4.1-node-browsers
1
+ FROM circleci/ruby:2.6.1-node-browsers
2
2
  RUN sudo apt-get update -q && sudo apt-get install --no-install-recommends -yq alien wget unzip clinfo \
3
3
  && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
4
4
  RUN export DEVEL_URL="https://software.intel.com/file/531197/download" \
@@ -223,6 +223,74 @@ vars = graph.get_collection(TensorStream::GraphKeys::GLOBAL_VARIABLES)
223
223
  => [Variable(Variable:0 shape: TensorShape([]) data_type: float32)]
224
224
  ```
225
225
 
226
+ High Performance Computing
227
+ --------------------------
228
+
229
+ TensorStream has been designed from the ground up to support multiple execution backends.
230
+
231
+ What this means is you can build your models once and then be able to execute them later on specialized hardware when available like GPUs.
232
+
233
+ An OpenCL backend is available that you can use for compute intensive taks like machine learning, especially those that use convolutional networks.
234
+
235
+ Using OpenCL is as simple as installing the tensorstream-opencl gem
236
+
237
+ ```
238
+ gem install tensor_stream-opencl
239
+ ```
240
+
241
+ You can then require the library in your programs and it will get used automatically (assuming you also installed OpenCL drivers for your system)
242
+
243
+ ```ruby
244
+ require 'tensor_stream'
245
+
246
+ # enable OpenCL
247
+ require 'tensor_stream/opencl'
248
+
249
+ tf = TensorStream
250
+
251
+ srand(5)
252
+ seed = 5
253
+ tf.set_random_seed(seed)
254
+
255
+ SHAPES = [32, 32]
256
+ tf = TensorStream
257
+ sess = tf.session
258
+ large_tensor = tf.constant(sess.run(tf.random_uniform([256, 256])))
259
+
260
+ sum_axis_1 = tf.reduce_sum(large_tensor, 1)
261
+ sess.run(sum_axis_1)
262
+ ```
263
+
264
+ Using OpenCL can improve performance dramatically in scenarios involving large tensors:
265
+
266
+ ```
267
+ Linux 4.15.0-46-generic #49-Ubuntu SMP
268
+ model name : AMD Ryzen 3 1300X Quad-Core Processor
269
+ OpenCL device NVIDIA CUDA GeForce GTX 1060 6GB
270
+ ruby 2.6.2p47 (2019-03-13 revision 67232) [x86_64-linux]
271
+
272
+ user system total real
273
+ pure ruby softmax : 0.024724 0.000000 0.024724 ( 0.024731)
274
+ opencl softmax : 0.006237 0.003945 0.010182 ( 0.009005)
275
+ pure ruby matmul : 0.679538 0.000000 0.679538 ( 0.680048)
276
+ opencl matmul : 0.003456 0.007965 0.011421 ( 0.008568)
277
+ pure ruby sum : 3.210619 0.000000 3.210619 ( 3.210064)
278
+ opencl sum : 0.002431 0.008030 0.010461 ( 0.007522)
279
+ pure ruby sum axis 1 : 3.208789 0.000000 3.208789 ( 3.208125)
280
+ opencl sum axis 1 : 0.006075 0.003963 0.010038 ( 0.007679)
281
+ pure ruby conv2d_backprop : 3.738167 0.000000 3.738167 ( 3.737946)
282
+ opencl conv2d_backprop : 0.031267 0.003958 0.035225 ( 0.030381)
283
+ pure ruby conv2d : 0.794182 0.000000 0.794182 ( 0.794100)
284
+ opencl conv2d : 0.015865 0.004020 0.019885 ( 0.016878)
285
+ ```
286
+
287
+ A quick glance shows not a marginal increase but an order of magnitude performance increase in most operations.
288
+ In fact we are looking at almost a 200x faster compute on operations like matmul and softmax (essential operations in machine learning). This is not a surprise because of the "embarrasingly" parallel nature of machine learning computation. Because of this, GPUs are basically a requirement in most machine learning tasks.
289
+
290
+ The code containing these benchmarks can be found at:
291
+
292
+ tensor_stream-opencl/benchmark/benchmark.rb
293
+
226
294
  Limitations
227
295
  -----------
228
296
 
@@ -23,6 +23,7 @@ require "tensor_stream/operation"
23
23
  require "tensor_stream/placeholder"
24
24
  require "tensor_stream/control_flow"
25
25
  require "tensor_stream/dynamic_stitch"
26
+ require "tensor_stream/math/math_ops"
26
27
  require "tensor_stream/nn/nn_ops"
27
28
  require "tensor_stream/evaluator/evaluator"
28
29
  require "tensor_stream/graph_serializers/packer"
@@ -2,11 +2,17 @@ module TensorStream
2
2
  # Evaluator base module
3
3
  module Evaluator
4
4
  class OutputGroup
5
+ include Enumerable
6
+
5
7
  attr_accessor :outputs, :data_types
6
8
  def initialize(outputs = [], data_types = [])
7
9
  @outputs = outputs
8
10
  @data_types = data_types
9
11
  end
12
+
13
+ def each
14
+ @outputs.map { |output| yield output }
15
+ end
10
16
  end
11
17
 
12
18
  class UnsupportedOp < RuntimeError
@@ -131,7 +137,7 @@ module TensorStream
131
137
  time.to_i * (10**9) + time.nsec
132
138
  end
133
139
 
134
- instance_exec(execution_context, tensor, resolved_inputs, &op[:block]).tap do
140
+ instance_exec(execution_context, tensor, resolved_inputs, &op[:block]).tap do |result|
135
141
  if profile_enabled?
136
142
  time = Time.now
137
143
  end_time = time.to_i * (10**9) + time.nsec
@@ -222,11 +228,25 @@ module TensorStream
222
228
 
223
229
  def self.register_evaluator(klass, name, index = 0)
224
230
  @evaluators ||= {}
231
+ @storage_managers ||= {}
225
232
  @evaluators[name] = {name: name, class: klass, index: index}
233
+ @storage_managers[klass] = klass.get_storage_manager
226
234
  end
227
235
 
228
236
  def self.default_evaluators
229
237
  evaluators.values.sort { |v| v[:index] }.reverse.map { |v| v[:class] }
230
238
  end
239
+
240
+ def self.clear_storages(graph)
241
+ @storage_managers.values.each { |manager| manager.clear_variables(graph) }
242
+ end
243
+
244
+ def self.read_variable(graph, name)
245
+ @storage_managers.values.each do |manager|
246
+ return manager.read_value(graph, name) if manager.exists?(graph, name)
247
+ end
248
+
249
+ nil
250
+ end
231
251
  end
232
252
  end
@@ -1,5 +1,6 @@
1
1
  require "tensor_stream/evaluator/ruby_evaluator"
2
2
  require "tensor_stream/evaluator/buffer"
3
+ require "tensor_stream/evaluator/evaluator_utils"
3
4
 
4
5
  module TensorStream
5
6
  module Evaluator
@@ -0,0 +1,20 @@
1
+ module TensorStream
2
+ class EvaluatorUtils
3
+ extend TensorStream::StringHelper
4
+
5
+ def self.get_evaluator_classes(evaluators)
6
+ @evaluator_classes ||= if evaluators.is_a?(Array)
7
+ if evaluators.empty?
8
+ TensorStream::Evaluator.default_evaluators
9
+ else
10
+ evaluators.collect { |name| Object.const_get("TensorStream::Evaluator::#{camelize(name.to_s)}") }
11
+ end
12
+ elsif evaluators.nil?
13
+ TensorStream::Evaluator.default_evaluators
14
+ else
15
+ [Object.const_get("TensorStream::Evaluator::#{camelize(evaluators.to_s)}")]
16
+ end
17
+ @evaluator_classes
18
+ end
19
+ end
20
+ end
@@ -30,6 +30,16 @@ module TensorStream
30
30
  end
31
31
  end
32
32
 
33
+ def array_set!(input, value)
34
+ input.each_with_index do |element, index|
35
+ if element.is_a?(Array)
36
+ array_set(element, value)
37
+ else
38
+ input[index] = value[index]
39
+ end
40
+ end
41
+ end
42
+
33
43
  def truncate(input, target_shape)
34
44
  rank = get_rank(input)
35
45
  return input if rank.zero?
@@ -331,5 +341,55 @@ module TensorStream
331
341
  value.nil? ? arr : value
332
342
  end
333
343
  end
344
+
345
+ def strided_slice(value, slices = [])
346
+ current_slice = slices.dup
347
+ selection = current_slice.shift
348
+ return value if selection.nil?
349
+
350
+ b, e, stride = selection
351
+
352
+ b = value.size + b if b < 0
353
+ e = value.size + e + 1 if e < 0
354
+
355
+ indexes = if stride < 0
356
+ b.downto(e).select.with_index { |elem, index| (index % stride.abs) == 0 }
357
+ else
358
+ (b...e).step(stride)
359
+ end
360
+
361
+ indexes.map do |index|
362
+ strided_slice(value[index], current_slice)
363
+ end
364
+ end
365
+
366
+ def strided_slice_grad(value, grad, x, slices)
367
+ current_slice = slices.dup
368
+ selection = current_slice.shift
369
+ current_shape = x.shift
370
+
371
+ if selection.nil?
372
+ array_set!(value, grad)
373
+ end
374
+
375
+ b, e, stride = selection
376
+
377
+ b = value.size + b if b < 0
378
+ e = value.size + e + 1 if e < 0
379
+
380
+ indexes = if stride < 0
381
+ b.downto(e).select.with_index { |elem, index| (index % stride.abs) == 0 }
382
+ else
383
+ (b...e).step(stride)
384
+ end
385
+
386
+ indexes.each_with_index do |index, grad_index|
387
+ if (value[index].is_a?(Array))
388
+ strided_slice_grad(value[index], grad[grad_index], x.dup, current_slice.dup)
389
+ else
390
+ value[index] = grad[grad_index]
391
+ end
392
+ end
393
+ end
334
394
  end
335
395
  end
@@ -22,8 +22,9 @@ module TensorStream
22
22
  merged
23
23
  end
24
24
 
25
- register_op :gather do |_context, _tensor, inputs|
25
+ register_op :gather do |_context, tensor, inputs|
26
26
  params, indexes = inputs
27
+ raise "axis !=0 not supported" if tensor.options[:axis] != 0
27
28
  gather(params, indexes)
28
29
  end
29
30
 
@@ -216,7 +217,14 @@ module TensorStream
216
217
 
217
218
  register_op :range do |_context, _tensor, inputs|
218
219
  start, limit, delta = inputs
220
+
219
221
  raise " delta !=0 " if delta.zero?
222
+
223
+ if limit.zero?
224
+ limit = start
225
+ start = 0
226
+ end
227
+
220
228
  raise " Requires start <= limit when delta > 0" if (start > limit) && delta > 0
221
229
  raise " Requires start >= limit when delta < 0" if (start < limit) && delta < 0
222
230
 
@@ -399,6 +407,50 @@ module TensorStream
399
407
  end
400
408
  end
401
409
 
410
+ register_op :dynamic_partition do |context, tensor, inputs|
411
+ data, partitions = inputs
412
+ num_partitions = tensor.options[:num_partitions]
413
+ output_arr = Array.new(num_partitions) { [] }
414
+
415
+ partitions.each_with_index do |part, index|
416
+ output_arr[part] << data[index]
417
+ end
418
+ TensorStream::Evaluator::OutputGroup.new(output_arr, num_partitions.times.map { tensor.data_type })
419
+ end
420
+
421
+ register_op :gather_grad do |context, tensor, inputs|
422
+ grad, indexes, input_shape = inputs
423
+ output = Array.new(input_shape.reduce(:*)) { fp_type?(tensor.data_type) ? 0.0 : 0 }
424
+ indexes.each_with_index.map do |x, index|
425
+ output[x] += grad[index]
426
+ end
427
+ TensorShape.reshape(output, input_shape)
428
+ end
429
+
430
+ register_op :strided_slice do |_context, _tensor, inputs|
431
+ value, b_index, e_index, stride = inputs
432
+ slices = b_index.zip(e_index).zip(stride).map do |params|
433
+ selection, stride = params
434
+ s, e = selection
435
+ [s, e, stride]
436
+ end
437
+ strided_slice(value, slices)
438
+ end
439
+
440
+ register_op :strided_slice_grad do |_context, tensor, inputs|
441
+ x, b_index, e_index, stride, grad = inputs
442
+ slices = b_index.zip(e_index).zip(stride).map do |params|
443
+ selection, stride = params
444
+ s, e = selection
445
+ [s, e, stride]
446
+ end
447
+
448
+ target_val = generate_vector(x, generator: ->() { fp_type?(tensor.data_type) ? 0.0 : 0 })
449
+
450
+ strided_slice_grad(target_val, grad, x.dup, slices.dup)
451
+ target_val
452
+ end
453
+
402
454
  def merge_dynamic_stitch(merged, indexes, data, context)
403
455
  indexes.each_with_index do |ind, m|
404
456
  if ind.is_a?(Array)
@@ -1,5 +1,6 @@
1
1
  require "chunky_png"
2
2
 
3
+
3
4
  module TensorStream
4
5
  module ImagesOps
5
6
  def self.included(klass)
@@ -49,6 +50,31 @@ module TensorStream
49
50
  TensorShape.reshape(image_data, [image.height, image.width, channels])
50
51
  end
51
52
 
53
+ register_op :decode_jpg do |_context, tensor, inputs|
54
+ require "jpeg"
55
+
56
+ content = inputs[0]
57
+ channels = tensor.options[:channels]
58
+ channels = 3 if channels.zero?
59
+
60
+ image = Jpeg::Image.open_buffer(content)
61
+ source_channels = image.color_info == :gray ? 1 : 3
62
+
63
+ image_data = image.raw_data.map do |pixel|
64
+ if source_channels == channels
65
+ pixel
66
+ elsif source_channels = 1 && channels == 3
67
+ [pixel, pixel, pixel]
68
+ elsif source_channels = 3 && channels == 1
69
+ raise TensorStream::ValueError, "color to grayscale not supported for jpg"
70
+ end
71
+ end.flatten
72
+
73
+ image_data.map!(&:to_f) if fp_type?(tensor.data_type)
74
+
75
+ TensorShape.reshape(image_data, [image.height, image.width, channels])
76
+ end
77
+
52
78
  register_op :encode_png do |_context, tensor, inputs|
53
79
  image_data = inputs[0]
54
80
  height, width, channels = shape_eval(image_data)
@@ -37,6 +37,24 @@ module TensorStream
37
37
  end
38
38
  end
39
39
 
40
+ register_op :bias_add do |_context, _tensor, inputs|
41
+ value, bias = inputs
42
+ arr = value.flatten.each_slice(bias.size).map do |slice|
43
+ slice.each_with_index.map { |elem, index| elem + bias[index] }
44
+ end
45
+ TensorShape.reshape(arr, shape_eval(value))
46
+ end
47
+
48
+ register_op :bias_add_grad do |_context, _tensor, inputs|
49
+ received_grad = inputs[0]
50
+ bias_size = shape_eval(received_grad).last
51
+ grad_sum = Array.new(bias_size) { 0.0 }
52
+ received_grad.flatten.each_slice(bias_size) do |slice|
53
+ slice.each_with_index.map { |elem, index| grad_sum[index] += elem }
54
+ end
55
+ grad_sum
56
+ end
57
+
40
58
  register_op :sub, no_eval: true do |context, tensor, inputs|
41
59
  a, b = inputs
42
60
  call_vector_op(tensor, :sub, a, b, context) { |t, u| t - u }
@@ -111,6 +129,15 @@ module TensorStream
111
129
  call_op(inputs[0], context) { |t, _b| Math.sqrt(t) }
112
130
  end
113
131
 
132
+ register_op :rsqrt, no_eval: true do |context, _tensor, inputs|
133
+ call_op(inputs[0], context) { |t, _b| 1 / Math.sqrt(t) }
134
+ end
135
+
136
+ register_op :rsqrt_grad, no_eval: true do |context, tensor, inputs|
137
+ y, grad = inputs
138
+ call_vector_op(tensor, :rsqrt_grad, y, grad, context) { |_y, g| 0.5 * g * (_y ** 3) }
139
+ end
140
+
114
141
  register_op :floor, no_eval: true do |context, _tensor, inputs|
115
142
  call_op(inputs[0], context) { |t, _b| t.floor }
116
143
  end
@@ -135,6 +162,25 @@ module TensorStream
135
162
  call_op(inputs[0], context) { |t, _b| 1 - Math.tanh(t) * Math.tanh(t) }
136
163
  end
137
164
 
165
+ register_op :top_k do |context, tensor, inputs|
166
+ values, k = inputs
167
+ v_shape = shape_eval(values)
168
+
169
+ sorted = tensor.options[:sorted]
170
+ work_values = TensorShape.reshape(values, [-1, v_shape.last])
171
+ work_values.map! do |row|
172
+ last_k = row.map.with_index { |r, index| [r, index] }.sort! { |a,b| a[0] <=> b[0] }.last(k)
173
+ last_k.reverse! if sorted
174
+ last_k
175
+ end
176
+
177
+ top_k = work_values.map { |row| row.map { |r| r[0] } }
178
+ top_indices = work_values.map { |row| row.map { |r| r[1] } }
179
+ v_shape[-1] = k
180
+
181
+ TensorStream::Evaluator::OutputGroup.new([TensorShape.reshape(top_k, v_shape), TensorShape.reshape(top_indices, v_shape)], [tensor.inputs[0].data_type, :int32])
182
+ end
183
+
138
184
  register_op(%i[argmax arg_max]) do |_context, tensor, inputs|
139
185
  axis = inputs[1] || 0
140
186
  rank = get_rank(inputs[0])
@@ -241,13 +287,22 @@ module TensorStream
241
287
  raise "#{tensor.inputs[0].name} rank must be greater than 1" if rank_a < 2
242
288
  raise "#{tensor.inputs[1].name} rank must be greater than 1" if rank_b < 2
243
289
 
244
- matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
245
- matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
246
-
247
290
  # check matrix dimensions
248
- raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{matrix_a[0].size} != #{matrix_b.size}) #{shape_eval(matrix_a)} vs #{shape_eval(matrix_b)}" if matrix_a[0].size != matrix_b.size
291
+ if rank_a >= 3
292
+ matrix_a.zip(matrix_b).map do |m_a, m_b|
293
+ matmul(m_a, m_b, tensor)
294
+ end
295
+ else
296
+ matmul(matrix_a, matrix_b, tensor)
297
+ end
298
+ end
299
+
300
+ def matmul(m_a, m_b, tensor)
301
+ m_a = m_a.transpose if tensor.options[:transpose_a]
302
+ m_b = m_b.transpose if tensor.options[:transpose_b]
303
+ raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{m_a[0].size} != #{m_b.size}) #{shape_eval(m_a)} vs #{shape_eval(m_b)}" if m_a[0].size != m_b.size
249
304
 
250
- (Matrix[*matrix_a] * Matrix[*matrix_b]).to_a
305
+ (Matrix[*m_a] * Matrix[*m_b]).to_a
251
306
  end
252
307
 
253
308
  register_op %i[max maximum], noop: true do |context, tensor, inputs|