tensor_stream 1.0.6 → 1.0.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/CHANGELOG.md +10 -3
- data/lib/tensor_stream.rb +1 -0
- data/lib/tensor_stream/evaluator/base_evaluator.rb +6 -0
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +60 -0
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +53 -1
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +42 -5
- data/lib/tensor_stream/generated_stub/ops.rb +61 -5
- data/lib/tensor_stream/helpers/tensor_mixins.rb +10 -1
- data/lib/tensor_stream/math/math_ops.rb +22 -0
- data/lib/tensor_stream/math_gradients.rb +15 -1
- data/lib/tensor_stream/nn/embedding_lookup.rb +114 -0
- data/lib/tensor_stream/nn/nn_ops.rb +3 -0
- data/lib/tensor_stream/op_maker.rb +15 -3
- data/lib/tensor_stream/ops.rb +12 -0
- data/lib/tensor_stream/ops/rsqrt.rb +11 -0
- data/lib/tensor_stream/ops/strided_slice.rb +24 -0
- data/lib/tensor_stream/ops/sum.rb +4 -2
- data/lib/tensor_stream/ops/top_k.rb +23 -0
- data/lib/tensor_stream/session.rb +3 -0
- data/lib/tensor_stream/tensor_shape.rb +32 -1
- data/lib/tensor_stream/train/saver.rb +2 -2
- data/lib/tensor_stream/utils.rb +8 -0
- data/lib/tensor_stream/utils/py_ports.rb +11 -0
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/word_embeddings/word_embedding_1.rb +192 -0
- data/samples/word_embeddings/word_embedding_2.rb +203 -0
- data/tensor_stream.gemspec +3 -0
- metadata +40 -4
- data/samples/neural_networks/lstm.rb +0 -22
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 667188a4a1ebc020c6c03c2b8530505b26e0ccb80885e03abba16c09665d6247
|
4
|
+
data.tar.gz: 06aebda0444eaa155d986324a0d4939bda66bbc6c6bc34762d83e46e1ba41fec
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 309595e1452075516bd003a3ca99b5537430cef182000521ea951d9e51459356e236a64ea98216aeb4ea507313541ad702be6d67e3536102632af7a3ed6ca6fa
|
7
|
+
data.tar.gz: 194174036aaee864f96cbbd7768f5b46a91a1592b0b9dd4a167c7071c063d92f2b62db26961a00ebace8eaece74caebe9232c04b4f741293b267dcdb215928e9
|
data/.gitignore
CHANGED
data/CHANGELOG.md
CHANGED
@@ -4,9 +4,16 @@ All notable changes to this project will be documented in this file.
|
|
4
4
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
|
5
5
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
6
6
|
|
7
|
+
## [1.0.7] - 2019-04-08
|
8
|
+
- [NEW] - Support for nn.embedding_lookup
|
9
|
+
- [NEW] - l2_normalize, dynamic_partition
|
10
|
+
- [NEW OP] - New Ops: rsqrt, top_k, strided_slice
|
11
|
+
- [NEW] - Support for ranges in tensors (e.g. t[0...2] via strided slice)
|
12
|
+
- [SAMPLES] - Add samples for handling word vectors
|
13
|
+
|
7
14
|
## [1.0.5] - 2019-03-20
|
8
15
|
- [BUG FIX] - Fix not wrapping a stack op on some arrays. Should fix rnn sample
|
9
|
-
|
16
|
+
|
10
17
|
## [0.9.10] - 2019-01-02
|
11
18
|
- [BUG FIX] - remove pry-byebug include (Thanks @samgooi4189)
|
12
19
|
- Update Changelog for 0.9.9
|
@@ -22,7 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|
22
29
|
- [NEW OP] Convolutional networks - conv2d, conv2d_backprop_filter, conv2d_backprop_input
|
23
30
|
- [IMAGE] Exposed image resampling options
|
24
31
|
- [BUG FIX] fix argmin, argmax handling of NaN values
|
25
|
-
|
32
|
+
|
26
33
|
## [0.9.5] - 2018-11-05
|
27
34
|
- [NEW OP] assert_equal, relu6
|
28
35
|
- [TRAINING] learning_rate_decay, dropout
|
@@ -137,4 +144,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|
137
144
|
- reworked auto differentiation, fix a number of bugs related to auto differentiation, smaller derivative programs
|
138
145
|
- alpha support for saving to pbtext format, added graphml generation
|
139
146
|
- significant number of ops added
|
140
|
-
- ops that support broadcasting now work better
|
147
|
+
- ops that support broadcasting now work better
|
data/lib/tensor_stream.rb
CHANGED
@@ -23,6 +23,7 @@ require "tensor_stream/operation"
|
|
23
23
|
require "tensor_stream/placeholder"
|
24
24
|
require "tensor_stream/control_flow"
|
25
25
|
require "tensor_stream/dynamic_stitch"
|
26
|
+
require "tensor_stream/math/math_ops"
|
26
27
|
require "tensor_stream/nn/nn_ops"
|
27
28
|
require "tensor_stream/evaluator/evaluator"
|
28
29
|
require "tensor_stream/graph_serializers/packer"
|
@@ -2,11 +2,17 @@ module TensorStream
|
|
2
2
|
# Evaluator base module
|
3
3
|
module Evaluator
|
4
4
|
class OutputGroup
|
5
|
+
include Enumerable
|
6
|
+
|
5
7
|
attr_accessor :outputs, :data_types
|
6
8
|
def initialize(outputs = [], data_types = [])
|
7
9
|
@outputs = outputs
|
8
10
|
@data_types = data_types
|
9
11
|
end
|
12
|
+
|
13
|
+
def each
|
14
|
+
@outputs.map { |output| yield output }
|
15
|
+
end
|
10
16
|
end
|
11
17
|
|
12
18
|
class UnsupportedOp < RuntimeError
|
@@ -30,6 +30,16 @@ module TensorStream
|
|
30
30
|
end
|
31
31
|
end
|
32
32
|
|
33
|
+
def array_set!(input, value)
|
34
|
+
input.each_with_index do |element, index|
|
35
|
+
if element.is_a?(Array)
|
36
|
+
array_set(element, value)
|
37
|
+
else
|
38
|
+
input[index] = value[index]
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
33
43
|
def truncate(input, target_shape)
|
34
44
|
rank = get_rank(input)
|
35
45
|
return input if rank.zero?
|
@@ -331,5 +341,55 @@ module TensorStream
|
|
331
341
|
value.nil? ? arr : value
|
332
342
|
end
|
333
343
|
end
|
344
|
+
|
345
|
+
def strided_slice(value, slices = [])
|
346
|
+
current_slice = slices.dup
|
347
|
+
selection = current_slice.shift
|
348
|
+
return value if selection.nil?
|
349
|
+
|
350
|
+
b, e, stride = selection
|
351
|
+
|
352
|
+
b = value.size + b if b < 0
|
353
|
+
e = value.size + e + 1 if e < 0
|
354
|
+
|
355
|
+
indexes = if stride < 0
|
356
|
+
b.downto(e).select.with_index { |elem, index| (index % stride.abs) == 0 }
|
357
|
+
else
|
358
|
+
(b...e).step(stride)
|
359
|
+
end
|
360
|
+
|
361
|
+
indexes.map do |index|
|
362
|
+
strided_slice(value[index], current_slice)
|
363
|
+
end
|
364
|
+
end
|
365
|
+
|
366
|
+
def strided_slice_grad(value, grad, x, slices)
|
367
|
+
current_slice = slices.dup
|
368
|
+
selection = current_slice.shift
|
369
|
+
current_shape = x.shift
|
370
|
+
|
371
|
+
if selection.nil?
|
372
|
+
array_set!(value, grad)
|
373
|
+
end
|
374
|
+
|
375
|
+
b, e, stride = selection
|
376
|
+
|
377
|
+
b = value.size + b if b < 0
|
378
|
+
e = value.size + e + 1 if e < 0
|
379
|
+
|
380
|
+
indexes = if stride < 0
|
381
|
+
b.downto(e).select.with_index { |elem, index| (index % stride.abs) == 0 }
|
382
|
+
else
|
383
|
+
(b...e).step(stride)
|
384
|
+
end
|
385
|
+
|
386
|
+
indexes.each_with_index do |index, grad_index|
|
387
|
+
if (value[index].is_a?(Array))
|
388
|
+
strided_slice_grad(value[index], grad[grad_index], x.dup, current_slice.dup)
|
389
|
+
else
|
390
|
+
value[index] = grad[grad_index]
|
391
|
+
end
|
392
|
+
end
|
393
|
+
end
|
334
394
|
end
|
335
395
|
end
|
@@ -22,8 +22,9 @@ module TensorStream
|
|
22
22
|
merged
|
23
23
|
end
|
24
24
|
|
25
|
-
register_op :gather do |_context,
|
25
|
+
register_op :gather do |_context, tensor, inputs|
|
26
26
|
params, indexes = inputs
|
27
|
+
raise "axis !=0 not supported" if tensor.options[:axis] != 0
|
27
28
|
gather(params, indexes)
|
28
29
|
end
|
29
30
|
|
@@ -216,7 +217,14 @@ module TensorStream
|
|
216
217
|
|
217
218
|
register_op :range do |_context, _tensor, inputs|
|
218
219
|
start, limit, delta = inputs
|
220
|
+
|
219
221
|
raise " delta !=0 " if delta.zero?
|
222
|
+
|
223
|
+
if limit.zero?
|
224
|
+
limit = start
|
225
|
+
start = 0
|
226
|
+
end
|
227
|
+
|
220
228
|
raise " Requires start <= limit when delta > 0" if (start > limit) && delta > 0
|
221
229
|
raise " Requires start >= limit when delta < 0" if (start < limit) && delta < 0
|
222
230
|
|
@@ -399,6 +407,50 @@ module TensorStream
|
|
399
407
|
end
|
400
408
|
end
|
401
409
|
|
410
|
+
register_op :dynamic_partition do |context, tensor, inputs|
|
411
|
+
data, partitions = inputs
|
412
|
+
num_partitions = tensor.options[:num_partitions]
|
413
|
+
output_arr = Array.new(num_partitions) { [] }
|
414
|
+
|
415
|
+
partitions.each_with_index do |part, index|
|
416
|
+
output_arr[part] << data[index]
|
417
|
+
end
|
418
|
+
TensorStream::Evaluator::OutputGroup.new(output_arr, num_partitions.times.map { tensor.data_type })
|
419
|
+
end
|
420
|
+
|
421
|
+
register_op :gather_grad do |context, tensor, inputs|
|
422
|
+
grad, indexes, input_shape = inputs
|
423
|
+
output = Array.new(input_shape.reduce(:*)) { fp_type?(tensor.data_type) ? 0.0 : 0 }
|
424
|
+
indexes.each_with_index.map do |x, index|
|
425
|
+
output[x] += grad[index]
|
426
|
+
end
|
427
|
+
TensorShape.reshape(output, input_shape)
|
428
|
+
end
|
429
|
+
|
430
|
+
register_op :strided_slice do |_context, _tensor, inputs|
|
431
|
+
value, b_index, e_index, stride = inputs
|
432
|
+
slices = b_index.zip(e_index).zip(stride).map do |params|
|
433
|
+
selection, stride = params
|
434
|
+
s, e = selection
|
435
|
+
[s, e, stride]
|
436
|
+
end
|
437
|
+
strided_slice(value, slices)
|
438
|
+
end
|
439
|
+
|
440
|
+
register_op :strided_slice_grad do |_context, tensor, inputs|
|
441
|
+
x, b_index, e_index, stride, grad = inputs
|
442
|
+
slices = b_index.zip(e_index).zip(stride).map do |params|
|
443
|
+
selection, stride = params
|
444
|
+
s, e = selection
|
445
|
+
[s, e, stride]
|
446
|
+
end
|
447
|
+
|
448
|
+
target_val = generate_vector(x, generator: ->() { fp_type?(tensor.data_type) ? 0.0 : 0 })
|
449
|
+
|
450
|
+
strided_slice_grad(target_val, grad, x.dup, slices.dup)
|
451
|
+
target_val
|
452
|
+
end
|
453
|
+
|
402
454
|
def merge_dynamic_stitch(merged, indexes, data, context)
|
403
455
|
indexes.each_with_index do |ind, m|
|
404
456
|
if ind.is_a?(Array)
|
@@ -129,6 +129,15 @@ module TensorStream
|
|
129
129
|
call_op(inputs[0], context) { |t, _b| Math.sqrt(t) }
|
130
130
|
end
|
131
131
|
|
132
|
+
register_op :rsqrt, no_eval: true do |context, _tensor, inputs|
|
133
|
+
call_op(inputs[0], context) { |t, _b| 1 / Math.sqrt(t) }
|
134
|
+
end
|
135
|
+
|
136
|
+
register_op :rsqrt_grad, no_eval: true do |context, tensor, inputs|
|
137
|
+
y, grad = inputs
|
138
|
+
call_vector_op(tensor, :rsqrt_grad, y, grad, context) { |_y, g| 0.5 * g * (_y ** 3) }
|
139
|
+
end
|
140
|
+
|
132
141
|
register_op :floor, no_eval: true do |context, _tensor, inputs|
|
133
142
|
call_op(inputs[0], context) { |t, _b| t.floor }
|
134
143
|
end
|
@@ -153,6 +162,25 @@ module TensorStream
|
|
153
162
|
call_op(inputs[0], context) { |t, _b| 1 - Math.tanh(t) * Math.tanh(t) }
|
154
163
|
end
|
155
164
|
|
165
|
+
register_op :top_k do |context, tensor, inputs|
|
166
|
+
values, k = inputs
|
167
|
+
v_shape = shape_eval(values)
|
168
|
+
|
169
|
+
sorted = tensor.options[:sorted]
|
170
|
+
work_values = TensorShape.reshape(values, [-1, v_shape.last])
|
171
|
+
work_values.map! do |row|
|
172
|
+
last_k = row.map.with_index { |r, index| [r, index] }.sort! { |a,b| a[0] <=> b[0] }.last(k) rescue binding.pry
|
173
|
+
last_k.reverse! if sorted
|
174
|
+
last_k
|
175
|
+
end
|
176
|
+
|
177
|
+
top_k = work_values.map { |row| row.map { |r| r[0] } }
|
178
|
+
top_indices = work_values.map { |row| row.map { |r| r[1] } }
|
179
|
+
v_shape[-1] = k
|
180
|
+
|
181
|
+
TensorStream::Evaluator::OutputGroup.new([TensorShape.reshape(top_k, v_shape), TensorShape.reshape(top_indices, v_shape)], [tensor.inputs[0].data_type, :int32])
|
182
|
+
end
|
183
|
+
|
156
184
|
register_op(%i[argmax arg_max]) do |_context, tensor, inputs|
|
157
185
|
axis = inputs[1] || 0
|
158
186
|
rank = get_rank(inputs[0])
|
@@ -259,13 +287,22 @@ module TensorStream
|
|
259
287
|
raise "#{tensor.inputs[0].name} rank must be greater than 1" if rank_a < 2
|
260
288
|
raise "#{tensor.inputs[1].name} rank must be greater than 1" if rank_b < 2
|
261
289
|
|
262
|
-
matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
|
263
|
-
matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
|
264
|
-
|
265
290
|
# check matrix dimensions
|
266
|
-
|
291
|
+
if rank_a >= 3
|
292
|
+
matrix_a.zip(matrix_b).map do |m_a, m_b|
|
293
|
+
matmul(m_a, m_b, tensor)
|
294
|
+
end
|
295
|
+
else
|
296
|
+
matmul(matrix_a, matrix_b, tensor)
|
297
|
+
end
|
298
|
+
end
|
299
|
+
|
300
|
+
def matmul(m_a, m_b, tensor)
|
301
|
+
m_a = m_a.transpose if tensor.options[:transpose_a]
|
302
|
+
m_b = m_b.transpose if tensor.options[:transpose_b]
|
303
|
+
raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{m_a[0].size} != #{m_b.size}) #{shape_eval(m_a)} vs #{shape_eval(m_b)}" if m_a[0].size != m_b.size
|
267
304
|
|
268
|
-
(Matrix[*
|
305
|
+
(Matrix[*m_a] * Matrix[*m_b]).to_a
|
269
306
|
end
|
270
307
|
|
271
308
|
register_op %i[max maximum], noop: true do |context, tensor, inputs|
|
@@ -536,6 +536,21 @@ module TensorStream
|
|
536
536
|
end
|
537
537
|
|
538
538
|
|
539
|
+
##
|
540
|
+
# Computes reciprocal of square root of x element-wise.
|
541
|
+
#
|
542
|
+
#
|
543
|
+
# @param input_a tensor X (of type FLOATING_POINT_TYPES)
|
544
|
+
#
|
545
|
+
# Options:
|
546
|
+
# @option name Optional name
|
547
|
+
# @return Tensor
|
548
|
+
def rsqrt(input_a, name: nil)
|
549
|
+
check_allowed_types(input_a, TensorStream::Ops::FLOATING_POINT_TYPES)
|
550
|
+
_op(:rsqrt, input_a, name: name)
|
551
|
+
end
|
552
|
+
|
553
|
+
|
539
554
|
##
|
540
555
|
# This operation returns a 1-D integer tensor representing the shape of input
|
541
556
|
#
|
@@ -615,6 +630,28 @@ module TensorStream
|
|
615
630
|
end
|
616
631
|
|
617
632
|
|
633
|
+
##
|
634
|
+
# Extracts a strided slice of a tensor
|
635
|
+
# this op extracts a slice of size `(end-begin)/stride`
|
636
|
+
# from the given `input_` tensor. Starting at the location specified by `begin`
|
637
|
+
# the slice continues by adding `stride` to the index until all dimensions are
|
638
|
+
# not less than `end`.
|
639
|
+
# Note that a stride can be negative, which causes a reverse slice.
|
640
|
+
#
|
641
|
+
#
|
642
|
+
# @param input A tensor
|
643
|
+
# @param _begin start index
|
644
|
+
# @param _end end index
|
645
|
+
# @param strides end index
|
646
|
+
#
|
647
|
+
# Options:
|
648
|
+
# @option name Optional name
|
649
|
+
# @return Tensor
|
650
|
+
def strided_slice(input, _begin, _end, strides = nil, name: nil)
|
651
|
+
_op(:strided_slice, input, _begin, _end, strides, name: name)
|
652
|
+
end
|
653
|
+
|
654
|
+
|
618
655
|
##
|
619
656
|
# Returns x - y element-wise.
|
620
657
|
#
|
@@ -642,18 +679,20 @@ module TensorStream
|
|
642
679
|
#
|
643
680
|
#
|
644
681
|
# @param input_a tensor X
|
645
|
-
# @param
|
682
|
+
# @param axis_p tensor X (of type INTEGER_TYPES)
|
646
683
|
#
|
647
684
|
# Options:
|
685
|
+
# @option axis axis
|
648
686
|
# @option name Optional name
|
649
687
|
# @option keepdims If true, retains reduced dimensions with length 1. default (false)
|
650
688
|
# @return Tensor
|
651
|
-
def sum(input_a,
|
652
|
-
check_allowed_types(
|
689
|
+
def sum(input_a, axis_p = nil, axis: nil, name: nil, keepdims: false)
|
690
|
+
check_allowed_types(axis_p, TensorStream::Ops::INTEGER_TYPES)
|
653
691
|
input_a = TensorStream.convert_to_tensor(input_a)
|
654
692
|
return input_a if input_a.shape.scalar?
|
655
|
-
|
656
|
-
|
693
|
+
axis_p = axis_p || axis
|
694
|
+
axis_p = cast_axis(input_a, axis_p)
|
695
|
+
_op(:sum, input_a, axis_p, name: name, keepdims: keepdims)
|
657
696
|
end
|
658
697
|
|
659
698
|
alias_method :reduce_sum, :sum
|
@@ -706,6 +745,23 @@ module TensorStream
|
|
706
745
|
end
|
707
746
|
|
708
747
|
|
748
|
+
##
|
749
|
+
# Finds values and indices of the `k` largest entries for the last dimension.
|
750
|
+
#
|
751
|
+
#
|
752
|
+
# @param input 1-D or higher `Tensor` with last dimension at least `k`.
|
753
|
+
# @param k 0-D `int32` `Tensor`. Number of top elements to look for along the last dimension (along each row for matrices)
|
754
|
+
#
|
755
|
+
# Options:
|
756
|
+
# @option sorted If true the resulting `k` elements will be sorted by the values in descending order. default (true)
|
757
|
+
# @option name Optional name
|
758
|
+
# @return Tensor
|
759
|
+
def top_k(input, k = 1, sorted: true, name: nil)
|
760
|
+
result = _op(:top_k, input, k, sorted: sorted, name: name)
|
761
|
+
[result[0], result[1]]
|
762
|
+
end
|
763
|
+
|
764
|
+
|
709
765
|
##
|
710
766
|
# Creates a tensor with all elements set to zero
|
711
767
|
#
|
@@ -5,7 +5,16 @@ module TensorStream
|
|
5
5
|
end
|
6
6
|
|
7
7
|
def [](index)
|
8
|
-
|
8
|
+
if index.is_a?(Range)
|
9
|
+
last = if index.end.nil?
|
10
|
+
[TensorStream.shape(self)[0]]
|
11
|
+
else
|
12
|
+
[index.max + 1]
|
13
|
+
end
|
14
|
+
_op(:strided_slice, self, [index.min], last, [1])
|
15
|
+
else
|
16
|
+
_op(:index, self, index)
|
17
|
+
end
|
9
18
|
end
|
10
19
|
|
11
20
|
def *(other)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module TensorStream
|
2
|
+
# High level math functions
|
3
|
+
class Maths
|
4
|
+
extend TensorStream::OpHelper
|
5
|
+
|
6
|
+
module MathFunctions
|
7
|
+
|
8
|
+
##
|
9
|
+
# Normalizes along dimension axis using an L2 norm.
|
10
|
+
def l2_normalize(x, axis: nil, epsilon: 1e-12, name: nil)
|
11
|
+
TensorStream.name_scope(name, "l2_normalize", values: [x]) do |name|
|
12
|
+
x = TensorStream.convert_to_tensor(x, name: "x")
|
13
|
+
square_sum = TensorStream.reduce_sum(TensorStream.square(x), axis, keepdims: true)
|
14
|
+
x_inv_norm = TensorStream.rsqrt(TensorStream.maximum(square_sum, epsilon))
|
15
|
+
TensorStream.multiply(x, x_inv_norm, name: name)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
extend MathFunctions
|
21
|
+
end
|
22
|
+
end
|
@@ -136,7 +136,7 @@ module TensorStream
|
|
136
136
|
when :sparse_softmax_cross_entropy_with_logits
|
137
137
|
output = node
|
138
138
|
[_broadcast_mul(grad, output[1]), nil]
|
139
|
-
|
139
|
+
when :zeros_like
|
140
140
|
# non differentiable
|
141
141
|
nil
|
142
142
|
when :transpose
|
@@ -165,12 +165,26 @@ module TensorStream
|
|
165
165
|
ts.stack(grad, axis: node.options[:axis])
|
166
166
|
when :conv2d
|
167
167
|
_Conv2DGrad(node, grad)
|
168
|
+
when :flow_dynamic_stitch
|
169
|
+
num_values = node.inputs.size / 2
|
170
|
+
indices_grad = [nil] * num_values
|
171
|
+
|
172
|
+
inputs = (0...num_values).map { |i| _int32(node, node.inputs[i]) }
|
173
|
+
|
174
|
+
values_grad = inputs.map { |inp| TensorStream.gather(grad, inp) }
|
175
|
+
indices_grad + values_grad
|
176
|
+
when :gather
|
177
|
+
[_op(:gather_grad, grad, node.inputs[1], TensorStream.shape(node.inputs[0])), nil]
|
168
178
|
else
|
169
179
|
TensorStream::OpMaker.gradient_op(self, node, grad)
|
170
180
|
end
|
171
181
|
end
|
172
182
|
end
|
173
183
|
|
184
|
+
def self._int32(node, x)
|
185
|
+
(node.inputs[0].data_type == :int32 ? x : TensorStream.cast(x, :int32))
|
186
|
+
end
|
187
|
+
|
174
188
|
def self._reshape_to_input(node, grad)
|
175
189
|
ts.reshape(grad, ts.shape(node.inputs[0]))
|
176
190
|
end
|