tensor_stream 0.8.1 → 0.8.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (84) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/CHANGELOG.md +8 -0
  4. data/README.md +12 -6
  5. data/lib/tensor_stream.rb +1 -0
  6. data/lib/tensor_stream/evaluator/base_evaluator.rb +1 -1
  7. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +282 -0
  8. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +61 -0
  9. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +111 -0
  10. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +48 -9
  11. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +51 -0
  12. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +20 -433
  13. data/lib/tensor_stream/images.rb +16 -0
  14. data/lib/tensor_stream/ops.rb +5 -1
  15. data/lib/tensor_stream/session.rb +15 -15
  16. data/lib/tensor_stream/tensor.rb +1 -1
  17. data/lib/tensor_stream/train/adadelta_optimizer.rb +52 -0
  18. data/lib/tensor_stream/train/adam_optimizer.rb +17 -2
  19. data/lib/tensor_stream/train/gradient_descent_optimizer.rb +7 -1
  20. data/lib/tensor_stream/trainer.rb +1 -0
  21. data/lib/tensor_stream/types.rb +4 -0
  22. data/lib/tensor_stream/utils.rb +4 -0
  23. data/lib/tensor_stream/variable_scope.rb +1 -0
  24. data/lib/tensor_stream/version.rb +1 -1
  25. data/samples/linear_regression.rb +4 -1
  26. data/samples/mnist_data.rb +64 -0
  27. data/samples/nearest_neighbor.rb +1 -2
  28. data/samples/raw_neural_net_sample.rb +1 -1
  29. data/tensor_stream.gemspec +1 -0
  30. metadata +23 -57
  31. data/lib/tensor_stream/evaluator/opencl/kernels/_bool_operand.cl +0 -45
  32. data/lib/tensor_stream/evaluator/opencl/kernels/_operand.cl +0 -45
  33. data/lib/tensor_stream/evaluator/opencl/kernels/abs.cl +0 -20
  34. data/lib/tensor_stream/evaluator/opencl/kernels/acos.cl +0 -8
  35. data/lib/tensor_stream/evaluator/opencl/kernels/add.cl +0 -3
  36. data/lib/tensor_stream/evaluator/opencl/kernels/apply_adam.cl +0 -23
  37. data/lib/tensor_stream/evaluator/opencl/kernels/apply_gradient.cl +0 -9
  38. data/lib/tensor_stream/evaluator/opencl/kernels/apply_momentum.cl +0 -16
  39. data/lib/tensor_stream/evaluator/opencl/kernels/argmax.cl +0 -8
  40. data/lib/tensor_stream/evaluator/opencl/kernels/argmin.cl +0 -8
  41. data/lib/tensor_stream/evaluator/opencl/kernels/asin.cl +0 -9
  42. data/lib/tensor_stream/evaluator/opencl/kernels/cast.cl +0 -10
  43. data/lib/tensor_stream/evaluator/opencl/kernels/ceil.cl +0 -8
  44. data/lib/tensor_stream/evaluator/opencl/kernels/cond.cl.erb +0 -6
  45. data/lib/tensor_stream/evaluator/opencl/kernels/cos.cl +0 -8
  46. data/lib/tensor_stream/evaluator/opencl/kernels/div.cl.erb +0 -3
  47. data/lib/tensor_stream/evaluator/opencl/kernels/exp.cl +0 -8
  48. data/lib/tensor_stream/evaluator/opencl/kernels/floor.cl +0 -8
  49. data/lib/tensor_stream/evaluator/opencl/kernels/floor_div.cl +0 -48
  50. data/lib/tensor_stream/evaluator/opencl/kernels/floor_mod.cl +0 -3
  51. data/lib/tensor_stream/evaluator/opencl/kernels/gemm.cl +0 -32
  52. data/lib/tensor_stream/evaluator/opencl/kernels/log.cl +0 -8
  53. data/lib/tensor_stream/evaluator/opencl/kernels/log1p.cl +0 -8
  54. data/lib/tensor_stream/evaluator/opencl/kernels/log_softmax.cl +0 -26
  55. data/lib/tensor_stream/evaluator/opencl/kernels/max.cl +0 -46
  56. data/lib/tensor_stream/evaluator/opencl/kernels/min.cl +0 -46
  57. data/lib/tensor_stream/evaluator/opencl/kernels/mod.cl +0 -3
  58. data/lib/tensor_stream/evaluator/opencl/kernels/mul.cl +0 -3
  59. data/lib/tensor_stream/evaluator/opencl/kernels/negate.cl +0 -8
  60. data/lib/tensor_stream/evaluator/opencl/kernels/pack.cl +0 -24
  61. data/lib/tensor_stream/evaluator/opencl/kernels/pow.cl +0 -46
  62. data/lib/tensor_stream/evaluator/opencl/kernels/real_div.cl +0 -3
  63. data/lib/tensor_stream/evaluator/opencl/kernels/reciprocal.cl +0 -8
  64. data/lib/tensor_stream/evaluator/opencl/kernels/round.cl +0 -8
  65. data/lib/tensor_stream/evaluator/opencl/kernels/sigmoid.cl +0 -9
  66. data/lib/tensor_stream/evaluator/opencl/kernels/sigmoid_grad.cl +0 -55
  67. data/lib/tensor_stream/evaluator/opencl/kernels/sign.cl +0 -21
  68. data/lib/tensor_stream/evaluator/opencl/kernels/sin.cl +0 -9
  69. data/lib/tensor_stream/evaluator/opencl/kernels/softmax.cl +0 -26
  70. data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross.cl +0 -32
  71. data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross_grad.cl +0 -28
  72. data/lib/tensor_stream/evaluator/opencl/kernels/softmax_grad.cl +0 -46
  73. data/lib/tensor_stream/evaluator/opencl/kernels/sqrt.cl +0 -9
  74. data/lib/tensor_stream/evaluator/opencl/kernels/square.cl +0 -9
  75. data/lib/tensor_stream/evaluator/opencl/kernels/squared_difference.cl +0 -53
  76. data/lib/tensor_stream/evaluator/opencl/kernels/sub.cl +0 -3
  77. data/lib/tensor_stream/evaluator/opencl/kernels/tan.cl +0 -8
  78. data/lib/tensor_stream/evaluator/opencl/kernels/tanh.cl +0 -8
  79. data/lib/tensor_stream/evaluator/opencl/kernels/tanh_grad.cl +0 -7
  80. data/lib/tensor_stream/evaluator/opencl/kernels/where.cl +0 -8
  81. data/lib/tensor_stream/evaluator/opencl/opencl_buffer.rb +0 -35
  82. data/lib/tensor_stream/evaluator/opencl/opencl_device.rb +0 -5
  83. data/lib/tensor_stream/evaluator/opencl/opencl_evaluator.rb +0 -1230
  84. data/lib/tensor_stream/evaluator/opencl/opencl_template_helper.rb +0 -95
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a3c7d0a810a79ceedc0237379b105d7a9b598ba2513ef2d59ba3cec78d7b0da0
4
- data.tar.gz: 7c0d90b27e548b72a86e88e7181f3d3b131a5fa3c6800000c743dd6e47d47b3b
3
+ metadata.gz: 1d6e9e482de719b709bfe554085718cb01a3bbfd089983eef51ace418b1b7d2d
4
+ data.tar.gz: 287e63d7a7269e7143ef1016677297c44235a4ac6afb03234f72bb2b6774d348
5
5
  SHA512:
6
- metadata.gz: a7a8a5607883d868da3ceaa9f870ac5c6d6809d45992a9ae0e4f26d5648bda2c4a26ea3abec506abd92362dcf3b63935b3f6acbcc70b605600307313f8c69f49
7
- data.tar.gz: db8917bee53f91e1017b5fdb8b9ece8b50a1096d249cee40eed8c335f597f32e5ce81056230b8ce1b9acff4df728497fa13e896212cbdaccc023dbdb7ed3591e
6
+ metadata.gz: 3ec2af3376d4cc671eb7bfc549290145d3e14167fe3b472f1c4dd1a05a6ab78ddeddc608d6c1c26a93ddc542ce605341172212e7d5d2f5beeb2542ef3790b03f
7
+ data.tar.gz: 961e74c0acac0179affca04974f7933fb1162a1a88b3bcdde6c57f679db6d5e23c2e92dd4fff8454bda7ba87dc2eaad7f9424458440348151ea4ac90258a8acc
data/.gitignore CHANGED
@@ -11,3 +11,4 @@
11
11
 
12
12
  # rspec failure tracking
13
13
  .rspec_status
14
+ .DS_Store
data/CHANGELOG.md CHANGED
@@ -4,6 +4,14 @@ All notable changes to this project will be documented in this file.
4
4
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
5
5
  and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
6
6
 
7
+ ## [0.8.5] - 2018-09-06
8
+
9
+ ### Added
10
+ - [TRAINING] Added AdadeltaOptimizer
11
+ - [NEW OP] squeeze, encode_png, decode_png
12
+
13
+ ### Others
14
+ - The OpenCL evaluator has now been decoupled and is not on its own gem (tensor_stream-opencl)
7
15
 
8
16
  ## [0.8.1] - 2018-08-30
9
17
  - [TRAINING] Added AdamOptimizer
data/README.md CHANGED
@@ -22,10 +22,16 @@ The goal of this gem is to have a high performance machine learning and compute
22
22
  TensorStream comes with a pure ruby and OpenCL implementation out of the box. The pure ruby implementation
23
23
  is known to work with most ruby implementations including TruffleRuby, JRuby as well as jit enabled versions of mri (ruby-2.6.0).
24
24
 
25
- OpenCL is supported only on mri implementations of ruby. This can be enabled by including the OpenCL evaluator (Make sure you have OpenCL drivers installed correctly on your system):
25
+ OpenCL is supported only on mri implementations of ruby. This can be enabled by adding OpenCL evaluator gem (Make sure you have OpenCL drivers installed correctly on your system):
26
+
27
+ ```Gemfile
28
+ gem 'tensor_stream-opencl'
29
+ ```
30
+
31
+ and then (without bundler)
26
32
 
27
33
  ```ruby
28
- require 'tensor_stream/evaluator/opencl/opencl_evaluator'
34
+ require 'tensor_stream-opencl'
29
35
  ```
30
36
 
31
37
  OpenCL is basically a requirement for deep learning and image processing tasks as the ruby implementation is too slow even with jit speedups using latest ruby implementations.
@@ -199,14 +205,14 @@ Also OpenCL only supports ruby-mri at the moment.
199
205
 
200
206
  Also include the following gem in your project:
201
207
 
202
- ```
203
- gem 'opencl_ruby_ffi'
208
+ ```Gemfile
209
+ gem 'tensor_stream-opencl'
204
210
  ```
205
211
 
206
- To use the opencl evaluator instead of the ruby evaluator simply add require it.
212
+ To use the opencl evaluator instead of the ruby evaluator simply require it (if using rails this should be loaded automatically).
207
213
 
208
214
  ```ruby
209
- require 'tensor_stream/evaluator/opencl/opencl_evaluator'
215
+ require 'tensor_stream-opencl'
210
216
  ```
211
217
 
212
218
  Adding the OpenCL evaluator should expose additional devices available to tensor_stream
data/lib/tensor_stream.rb CHANGED
@@ -30,6 +30,7 @@ require 'tensor_stream/math_gradients'
30
30
  require "tensor_stream/debugging/debugging"
31
31
  require 'tensor_stream/utils'
32
32
  require 'tensor_stream/train/utils'
33
+ require 'tensor_stream/images'
33
34
  require 'tensor_stream/trainer'
34
35
 
35
36
  # require 'tensor_stream/libraries/layers'
@@ -129,7 +129,7 @@ module TensorStream
129
129
  def global_eval(tensor, input, execution_context, op_options = {})
130
130
  return nil unless input
131
131
  return input unless input.is_a?(Tensor)
132
-
132
+ @context[:_cache][:placement][input.name] = @session.assign_evaluator(input) if @context[:_cache][:placement][input.name].nil?
133
133
  if object_id != @context[:_cache][:placement][input.name][1].object_id # tensor is on another device or evaluator
134
134
  cache_key = "#{tensor.graph.object_id}_#{input.name}:#{object_id}"
135
135
  return @context[:_cache][cache_key] if @context[:_cache].key?(cache_key)
@@ -0,0 +1,282 @@
1
+ module TensorStream
2
+ module ArrayOps
3
+ def ArrayOps.included(klass)
4
+ klass.class_eval do
5
+ register_op :slice do |context, tensor, inputs|
6
+ input = inputs[0]
7
+ start = inputs[1]
8
+ size = complete_eval(tensor.options[:size], context)
9
+ raise "start index and size not of the same shape #{start.size} != #{size.size}" if start.size != size.size
10
+ slice_tensor(input, start, size)
11
+ end
12
+
13
+ register_op %i[flow_dynamic_stitch dynamic_stitch] do |_context, _tensor, inputs|
14
+ indexes, data = inputs
15
+ merged = []
16
+ merge_dynamic_stitch(merged, indexes, data)
17
+ merged
18
+ end
19
+
20
+ register_op :gather do |_context, _tensor, inputs|
21
+ params, indexes = inputs
22
+ gather(params, indexes)
23
+ end
24
+
25
+ register_op %i[concat concat_v2] do |_context, tensor, inputs|
26
+ concat_array(inputs, tensor.options[:axis])
27
+ end
28
+
29
+ register_op :stack do |_context, tensor, inputs|
30
+ axis = tensor.options[:axis] || 0
31
+ shape = shape_eval(inputs[0])
32
+ rank = shape.size + 1
33
+ elem_size = shape.empty? ? 1 : shape.reduce(:*)
34
+ output_buffer = Array.new(inputs.size * elem_size) { 0 }
35
+ new_shape = [inputs.size]
36
+ shape.inject(new_shape) { |ns, s| ns << s }
37
+
38
+ divisors = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
39
+ a << s * a.last
40
+ end.reverse
41
+
42
+ axis = rank + axis if axis < 0
43
+ rotated_shape = Array.new(axis + 1) { new_shape.shift }
44
+ new_shape = rotated_shape.rotate! + new_shape
45
+
46
+ multipliers = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
47
+ a << s * a.last
48
+ end.reverse
49
+
50
+ inputs.each_with_index do |input, index|
51
+ raw_input = input.is_a?(Array) ? input.flatten : [input]
52
+ start = index * divisors.first
53
+
54
+ raw_input.each_with_index do |x, index2|
55
+ index_map = []
56
+ ptr = start + index2
57
+ divisors.each_with_object(index_map) do |div, a|
58
+ a << (ptr / div.to_f).floor
59
+ ptr = ptr % div
60
+ end
61
+
62
+ rotated_index = Array.new(axis + 1) { index_map.shift }
63
+ index_map = rotated_index.rotate! + index_map
64
+
65
+ ptr2 = 0
66
+ multipliers.each_with_index do |m, idx|
67
+ ptr2 += index_map[idx] * m
68
+ end
69
+
70
+ output_buffer[ptr2] = x
71
+ end
72
+ end
73
+
74
+ TensorShape.reshape(output_buffer, new_shape)
75
+ end
76
+
77
+ register_op :squeeze do |_context, tensor, inputs|
78
+ val = inputs[0]
79
+ shape = shape_eval(val)
80
+
81
+ axis = !tensor.options[:axis].is_a?(Array) ? [tensor.options[:axis]] : tensor.options[:axis]
82
+
83
+ if !axis.empty?
84
+
85
+ axis.each do |axis|
86
+ if shape[axis] == 1
87
+ shape[axis] = nil
88
+ else
89
+ raise TensorStream::ValueError, "unable to squeeze dimension that does not have a size of 1"
90
+ end
91
+ end
92
+ else
93
+ shape = shape.map { |s| s == 1 ? nil : s }
94
+ end
95
+
96
+ TensorShape.reshape(val.flatten, shape.compact)
97
+ end
98
+
99
+ register_op :expand_dims do |_context, _tensor, inputs|
100
+ val, axis = inputs
101
+ axis = axis.nil? ? 0 : axis
102
+
103
+ shape = shape_eval(val)
104
+ axis = -axis if axis == shape.size
105
+
106
+ new_shape = shape.dup.insert(axis, 1).compact
107
+
108
+ TensorShape.reshape([val].flatten, new_shape)
109
+ end
110
+
111
+ register_op :fill do |_context, _tensor, inputs|
112
+ shape = inputs[0]
113
+ value = inputs[1]
114
+
115
+ func = -> { value }
116
+
117
+ if shape.is_a?(Array) && shape.size.zero?
118
+ func.call
119
+ else
120
+ shape = [shape.to_i] unless shape.is_a?(Array)
121
+ generate_vector(shape, generator: func)
122
+ end
123
+ end
124
+
125
+ register_op :invert_permutation do |_context, _tensor, inputs|
126
+ input = inputs[0]
127
+ output = input.dup
128
+
129
+ unless input.nil?
130
+ input.size.times.each do |index|
131
+ output[input[index]] = index
132
+ end
133
+ end
134
+
135
+ output
136
+ end
137
+
138
+ register_op :index, no_eval: true do |_context, _tensor, inputs|
139
+ f = inputs[0]
140
+ index = inputs[1]
141
+ if f.is_a?(TensorStream::Evaluator::OutputGroup)
142
+ f.outputs[index]
143
+ else
144
+ f[index]
145
+ end
146
+ end
147
+
148
+ register_op :setdiff1d do |_context, tensor, inputs|
149
+ input, remove = inputs
150
+ idx = []
151
+ out = []
152
+ input.each_with_index do |x, index|
153
+ next if remove.include?(x)
154
+ out << x
155
+ idx << index
156
+ end
157
+ idx = idx.map { |i| Tensor.cast_dtype(i, tensor.options[:index_dtype]) } unless tensor.options[:index_dtype] == :int32
158
+ TensorStream::Evaluator::OutputGroup.new([out, idx], tensor.inputs.map(&:data_type))
159
+ end
160
+
161
+ register_op :size do |_context, tensor, inputs|
162
+ input = inputs[0]
163
+ Tensor.cast_dtype(input.flatten.size, tensor.options[:out_type])
164
+ end
165
+
166
+ register_op :range do |_context, _tensor, inputs|
167
+ start, limit, delta = inputs
168
+ raise " delta !=0 " if delta.zero?
169
+ raise " Requires start <= limit when delta > 0" if (start > limit) && delta > 0
170
+ raise " Requires start >= limit when delta < 0" if (start < limit) && delta < 0
171
+
172
+ cur_step = start
173
+ r = []
174
+ Kernel.loop do
175
+ break if start == limit
176
+ break if (start < limit) && (cur_step >= limit)
177
+ break if (start > limit) && (cur_step <= limit)
178
+ r << cur_step
179
+ cur_step += delta
180
+ end
181
+ r
182
+ end
183
+
184
+ register_op :eye do |_context, tensor, inputs|
185
+ rows, columns = inputs
186
+
187
+ Array.new(rows) do |i|
188
+ Array.new(columns) do |col|
189
+ if fp_type?(tensor.data_type)
190
+ i == col ? 1.0 : 0.0
191
+ else
192
+ i == col ? 1 : 0
193
+ end
194
+ end
195
+ end
196
+ end
197
+
198
+ register_op %i[zeros ones zeros_like ones_like] do |_context, tensor, inputs|
199
+ shape = if %i[zeros_like ones_like].include?(tensor.operation)
200
+ shape_eval(inputs[0])
201
+ else
202
+ inputs[0] || tensor.shape.shape
203
+ end
204
+
205
+ func = if %i[zeros zeros_like].include?(tensor.operation)
206
+ -> { int_type?(tensor.data_type) ? 0 : 0.0 }
207
+ else
208
+ -> { int_type?(tensor.data_type) ? 1 : 1.0 }
209
+ end
210
+
211
+ if shape.is_a?(Array) && shape.size.zero?
212
+ func.call
213
+ else
214
+ shape = [shape.to_i] unless shape.is_a?(Array)
215
+
216
+ cache_key = "#{tensor.operation}_#{shape}"
217
+ if @context[:_cache].key?(cache_key)
218
+ @context[:_cache][cache_key]
219
+ else
220
+ generate_vector(shape, generator: func).tap do |v|
221
+ @context[:_cache][cache_key] = v
222
+ end
223
+ end
224
+ end
225
+ end
226
+
227
+ register_op :truncate do |_context, _tensor, inputs|
228
+ truncate(inputs[0], inputs[1])
229
+ end
230
+
231
+ register_op :rank do |_context, _tensor, inputs|
232
+ get_rank(inputs[0])
233
+ end
234
+
235
+ register_op :reshape do |_context, _tensor, inputs|
236
+ arr, new_shape = inputs
237
+
238
+ arr = [arr] unless arr.is_a?(Array)
239
+
240
+ flat_arr = arr.flatten
241
+ if new_shape.size.zero? && flat_arr.size == 1
242
+ flat_arr[0]
243
+ else
244
+ new_shape = TensorShape.fix_inferred_elements(new_shape, flat_arr.size)
245
+ TensorShape.reshape(flat_arr, new_shape)
246
+ end
247
+ end
248
+
249
+ register_op :pad do |context, tensor, inputs|
250
+ p = complete_eval(tensor.options[:paddings], context)
251
+
252
+ arr_pad(inputs[0], p, tensor.data_type)
253
+ end
254
+
255
+ register_op :tile do |_context, _tensor, inputs|
256
+ input, multiples = inputs
257
+ rank = get_rank(input)
258
+ raise '1D or higher tensor required' if rank.zero?
259
+ raise "invalid multiple size passed #{rank} != #{multiples.size}" if rank != multiples.size
260
+
261
+ tile = tile_arr(input, 0, multiples)
262
+ tile.nil? ? [] : tile
263
+ end
264
+
265
+ register_op :cond, noop: true do |context, tensor, inputs|
266
+ pred = global_eval(tensor, tensor.options[:pred], context)
267
+
268
+ if all_true?(pred)
269
+ global_eval(tensor, inputs[0], context)
270
+ else
271
+ global_eval(tensor, inputs[1], context)
272
+ end
273
+ end
274
+
275
+ register_op %i[select where] do |context, tensor, inputs|
276
+ pred = complete_eval(tensor.options[:pred], context)
277
+ call_3way_vector_op(pred, inputs[0], inputs[1], context, ->(t, u, v) { t ? u : v })
278
+ end
279
+ end
280
+ end
281
+ end
282
+ end
@@ -0,0 +1,61 @@
1
+ require 'chunky_png'
2
+
3
+ module TensorStream
4
+ module ImagesOps
5
+ def ImagesOps.included(klass)
6
+ klass.class_eval do
7
+ register_op :decode_png do |_context, tensor, inputs|
8
+ content = inputs[0]
9
+ channels = tensor.options[:channels]
10
+ channels = 4 if channels.zero?
11
+
12
+ image = ChunkyPNG::Image.from_blob(content)
13
+
14
+ image.grayscale! if channels == 1
15
+ image_data = image.pixels.collect do |pixel|
16
+ color_values = if channels == 4
17
+ [ ChunkyPNG::Color.r(pixel),
18
+ ChunkyPNG::Color.g(pixel),
19
+ ChunkyPNG::Color.b(pixel),
20
+ ChunkyPNG::Color.a(pixel) ]
21
+ elsif channels == 3
22
+ [ ChunkyPNG::Color.r(pixel),
23
+ ChunkyPNG::Color.g(pixel),
24
+ ChunkyPNG::Color.b(pixel)]
25
+ elsif channels == 1
26
+ [ ChunkyPNG::Color.r(pixel) ]
27
+ else
28
+ raise "Invalid channel value #{channels}"
29
+ end
30
+
31
+ if fp_type?(tensor.data_type)
32
+ color_values.map! { |v| v.to_f }
33
+ end
34
+
35
+ color_values
36
+ end
37
+ TensorShape.reshape(image_data.flatten, [image.height, image.width, channels])
38
+ end
39
+
40
+ register_op :encode_png do |_context, tensor, inputs|
41
+ image_data = inputs[0]
42
+ height, width, channels = shape_eval(image_data)
43
+
44
+ png = ChunkyPNG::Image.new(width, height)
45
+ image_data.each_with_index do |rows, h_index|
46
+ rows.each_with_index do |p_data, w_index|
47
+ if channels == 4
48
+ png[w_index, h_index] = ChunkyPNG::Color.rgba(p_data[0], p_data[1], p_data[2], p_data[3])
49
+ elsif channels == 3
50
+ png[w_index, h_index] = ChunkyPNG::Color.rgb(p_data[0], p_data[1], p_data[2])
51
+ elsif channels == 1
52
+ png[w_index, h_index] = ChunkyPNG::Color.rgb(p_data[0], p_data[0], p_data[0])
53
+ end
54
+ end
55
+ end
56
+ png.to_s
57
+ end
58
+ end
59
+ end
60
+ end
61
+ end
@@ -138,6 +138,117 @@ module TensorStream
138
138
  register_op :tanh_grad, no_eval: true do |context, _tensor, inputs|
139
139
  call_op(:tanh_grad, inputs[0], context, ->(t, _b) { 1 - Math.tanh(t) * Math.tanh(t) })
140
140
  end
141
+
142
+ register_op(%i[argmax arg_max]) do |_context, tensor, inputs|
143
+ axis = tensor.options[:axis] || 0
144
+ rank = get_rank(inputs[0])
145
+ raise TensorStream::InvalidArgumentError, "Expected dimension in the range [#{-rank},#{rank}) but got #{axis}" if axis < -rank || axis >= rank
146
+ get_op_with_axis(inputs[0], axis, 0, tensor.data_type)
147
+ end
148
+
149
+ register_op(%i[argmin arg_min]) do |_context, tensor, inputs|
150
+ axis = tensor.options[:axis] || 0
151
+ rank = get_rank(inputs[0])
152
+ raise TensorStream::InvalidArgumentError, "Expected dimension in the range [#{-rank},#{rank}) but got #{axis}" if axis < -rank || axis >= rank
153
+ get_op_with_axis(inputs[0], axis, 0, tensor.data_type, ->(a, b) { a < b })
154
+ end
155
+
156
+ register_op :cumprod do |context, tensor, inputs|
157
+ x = inputs[0]
158
+ c = fp_type?(tensor.data_type) ? 1.0 : 1
159
+ reverse_option = tensor.options[:reverse]
160
+ exclusive = tensor.options[:exclusive]
161
+
162
+ func = lambda do |arr|
163
+ return c if arr.nil?
164
+ count = arr.size
165
+
166
+
167
+ arr = arr.reverse if reverse_option
168
+ arr = [1] + arr if exclusive
169
+
170
+ start_prod = arr[0]
171
+ mapped = arr[1...count].map do |v|
172
+ start_prod = vector_op(start_prod, v, ->(a, b) { a * b })
173
+ end
174
+
175
+ arr = [arr[0]] + mapped
176
+ reverse_option ? arr.reverse : arr
177
+ end
178
+ reduction(context, tensor, func)
179
+ end
180
+
181
+ register_op :sum, noop: true do |context, tensor, _inputs|
182
+ func = lambda do |arr|
183
+ reduced_val = arr[0]
184
+ arr[1..arr.size].each do |v|
185
+ reduced_val = vector_op(reduced_val, v, ->(t, u) { t + u })
186
+ end
187
+ reduced_val
188
+ end
189
+
190
+ reduction(context, tensor, func)
191
+ end
192
+
193
+ register_op :prod, noop: true do |context, tensor, _inputs|
194
+ c = fp_type?(tensor.data_type) ? 1.0 : 1
195
+ func = lambda do |arr|
196
+ return c if arr.nil?
197
+
198
+ reduced_val = arr[0]
199
+ arr[1..arr.size].each do |v|
200
+ reduced_val = vector_op(reduced_val, v, ->(a, b) { a * b })
201
+ end
202
+ reduced_val
203
+ end
204
+
205
+ reduction(context, tensor, func)
206
+ end
207
+
208
+ register_op :sigmoid_grad, no_eval: true do |context, tensor, inputs|
209
+ a, b = inputs
210
+ call_vector_op(tensor, :sigmoid_grad, a, b, context, ->(t, u) { u * sigmoid(t) * (1 - sigmoid(t)) })
211
+ end
212
+
213
+ register_op :mean, noop: true do |context, tensor, _inputs|
214
+ c = fp_type?(tensor.data_type) ? 0.0 : 0
215
+ func = lambda do |arr|
216
+ return c if arr.nil?
217
+
218
+ reduced_val = arr[0]
219
+ arr[1..arr.size].each do |v|
220
+ reduced_val = vector_op(reduced_val, v, ->(a, b) { a + b })
221
+ end
222
+
223
+ vector_op(reduced_val, nil, ->(a, _b) { a / arr.size })
224
+ end
225
+
226
+ reduction(context, tensor, func)
227
+ end
228
+
229
+ register_op :mat_mul do |_context, tensor, inputs|
230
+ matrix_a, matrix_b = inputs
231
+ rank_a = get_rank(matrix_a)
232
+ rank_b = get_rank(matrix_b)
233
+ raise "#{tensor.inputs[0].name} rank must be greater than 1" if rank_a < 2
234
+ raise "#{tensor.inputs[1].name} rank must be greater than 1" if rank_b < 2
235
+
236
+ matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
237
+ matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
238
+
239
+ # check matrix dimensions
240
+ raise "incompatible shape sizes for matrix multiplication (#{matrix_a[0].size} != #{matrix_b.size}) #{shape_eval(matrix_a)} vs #{shape_eval(matrix_b)}" if matrix_a[0].size != matrix_b.size
241
+
242
+ (Matrix[*matrix_a] * Matrix[*matrix_b]).to_a
243
+ end
244
+
245
+ register_op %i[max maximum], noop: true do |context, tensor, inputs|
246
+ call_vector_op(tensor, :max, inputs[0], inputs[1], context, ->(t, u) { [t, u].max })
247
+ end
248
+
249
+ register_op %i[min minimum], noop: true do |context, tensor, inputs|
250
+ call_vector_op(tensor, :min, inputs[0], inputs[1], context, ->(t, u) { [t, u].min })
251
+ end
141
252
  end
142
253
  end
143
254
  end