tensor_stream 1.0.6 → 1.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/CHANGELOG.md +10 -3
- data/lib/tensor_stream.rb +1 -0
- data/lib/tensor_stream/evaluator/base_evaluator.rb +6 -0
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +60 -0
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +53 -1
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +42 -5
- data/lib/tensor_stream/generated_stub/ops.rb +61 -5
- data/lib/tensor_stream/helpers/tensor_mixins.rb +10 -1
- data/lib/tensor_stream/math/math_ops.rb +22 -0
- data/lib/tensor_stream/math_gradients.rb +15 -1
- data/lib/tensor_stream/nn/embedding_lookup.rb +114 -0
- data/lib/tensor_stream/nn/nn_ops.rb +3 -0
- data/lib/tensor_stream/op_maker.rb +15 -3
- data/lib/tensor_stream/ops.rb +12 -0
- data/lib/tensor_stream/ops/rsqrt.rb +11 -0
- data/lib/tensor_stream/ops/strided_slice.rb +24 -0
- data/lib/tensor_stream/ops/sum.rb +4 -2
- data/lib/tensor_stream/ops/top_k.rb +23 -0
- data/lib/tensor_stream/session.rb +3 -0
- data/lib/tensor_stream/tensor_shape.rb +32 -1
- data/lib/tensor_stream/train/saver.rb +2 -2
- data/lib/tensor_stream/utils.rb +8 -0
- data/lib/tensor_stream/utils/py_ports.rb +11 -0
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/word_embeddings/word_embedding_1.rb +192 -0
- data/samples/word_embeddings/word_embedding_2.rb +203 -0
- data/tensor_stream.gemspec +3 -0
- metadata +40 -4
- data/samples/neural_networks/lstm.rb +0 -22
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 667188a4a1ebc020c6c03c2b8530505b26e0ccb80885e03abba16c09665d6247
|
4
|
+
data.tar.gz: 06aebda0444eaa155d986324a0d4939bda66bbc6c6bc34762d83e46e1ba41fec
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 309595e1452075516bd003a3ca99b5537430cef182000521ea951d9e51459356e236a64ea98216aeb4ea507313541ad702be6d67e3536102632af7a3ed6ca6fa
|
7
|
+
data.tar.gz: 194174036aaee864f96cbbd7768f5b46a91a1592b0b9dd4a167c7071c063d92f2b62db26961a00ebace8eaece74caebe9232c04b4f741293b267dcdb215928e9
|
data/.gitignore
CHANGED
data/CHANGELOG.md
CHANGED
@@ -4,9 +4,16 @@ All notable changes to this project will be documented in this file.
|
|
4
4
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
|
5
5
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
6
6
|
|
7
|
+
## [1.0.7] - 2019-04-08
|
8
|
+
- [NEW] - Support for nn.embedding_lookup
|
9
|
+
- [NEW] - l2_normalize, dynamic_partition
|
10
|
+
- [NEW OP] - New Ops: rsqrt, top_k, strided_slice
|
11
|
+
- [NEW] - Support for ranges in tensors (e.g. t[0...2] via strided slice)
|
12
|
+
- [SAMPLES] - Add samples for handling word vectors
|
13
|
+
|
7
14
|
## [1.0.5] - 2019-03-20
|
8
15
|
- [BUG FIX] - Fix not wrapping a stack op on some arrays. Should fix rnn sample
|
9
|
-
|
16
|
+
|
10
17
|
## [0.9.10] - 2019-01-02
|
11
18
|
- [BUG FIX] - remove pry-byebug include (Thanks @samgooi4189)
|
12
19
|
- Update Changelog for 0.9.9
|
@@ -22,7 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|
22
29
|
- [NEW OP] Convolutional networks - conv2d, conv2d_backprop_filter, conv2d_backprop_input
|
23
30
|
- [IMAGE] Exposed image resampling options
|
24
31
|
- [BUG FIX] fix argmin, argmax handling of NaN values
|
25
|
-
|
32
|
+
|
26
33
|
## [0.9.5] - 2018-11-05
|
27
34
|
- [NEW OP] assert_equal, relu6
|
28
35
|
- [TRAINING] learning_rate_decay, dropout
|
@@ -137,4 +144,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|
137
144
|
- reworked auto differentiation, fix a number of bugs related to auto differentiation, smaller derivative programs
|
138
145
|
- alpha support for saving to pbtext format, added graphml generation
|
139
146
|
- significant number of ops added
|
140
|
-
- ops that support broadcasting now work better
|
147
|
+
- ops that support broadcasting now work better
|
data/lib/tensor_stream.rb
CHANGED
@@ -23,6 +23,7 @@ require "tensor_stream/operation"
|
|
23
23
|
require "tensor_stream/placeholder"
|
24
24
|
require "tensor_stream/control_flow"
|
25
25
|
require "tensor_stream/dynamic_stitch"
|
26
|
+
require "tensor_stream/math/math_ops"
|
26
27
|
require "tensor_stream/nn/nn_ops"
|
27
28
|
require "tensor_stream/evaluator/evaluator"
|
28
29
|
require "tensor_stream/graph_serializers/packer"
|
@@ -2,11 +2,17 @@ module TensorStream
|
|
2
2
|
# Evaluator base module
|
3
3
|
module Evaluator
|
4
4
|
class OutputGroup
|
5
|
+
include Enumerable
|
6
|
+
|
5
7
|
attr_accessor :outputs, :data_types
|
6
8
|
def initialize(outputs = [], data_types = [])
|
7
9
|
@outputs = outputs
|
8
10
|
@data_types = data_types
|
9
11
|
end
|
12
|
+
|
13
|
+
def each
|
14
|
+
@outputs.map { |output| yield output }
|
15
|
+
end
|
10
16
|
end
|
11
17
|
|
12
18
|
class UnsupportedOp < RuntimeError
|
@@ -30,6 +30,16 @@ module TensorStream
|
|
30
30
|
end
|
31
31
|
end
|
32
32
|
|
33
|
+
def array_set!(input, value)
|
34
|
+
input.each_with_index do |element, index|
|
35
|
+
if element.is_a?(Array)
|
36
|
+
array_set(element, value)
|
37
|
+
else
|
38
|
+
input[index] = value[index]
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
33
43
|
def truncate(input, target_shape)
|
34
44
|
rank = get_rank(input)
|
35
45
|
return input if rank.zero?
|
@@ -331,5 +341,55 @@ module TensorStream
|
|
331
341
|
value.nil? ? arr : value
|
332
342
|
end
|
333
343
|
end
|
344
|
+
|
345
|
+
def strided_slice(value, slices = [])
|
346
|
+
current_slice = slices.dup
|
347
|
+
selection = current_slice.shift
|
348
|
+
return value if selection.nil?
|
349
|
+
|
350
|
+
b, e, stride = selection
|
351
|
+
|
352
|
+
b = value.size + b if b < 0
|
353
|
+
e = value.size + e + 1 if e < 0
|
354
|
+
|
355
|
+
indexes = if stride < 0
|
356
|
+
b.downto(e).select.with_index { |elem, index| (index % stride.abs) == 0 }
|
357
|
+
else
|
358
|
+
(b...e).step(stride)
|
359
|
+
end
|
360
|
+
|
361
|
+
indexes.map do |index|
|
362
|
+
strided_slice(value[index], current_slice)
|
363
|
+
end
|
364
|
+
end
|
365
|
+
|
366
|
+
def strided_slice_grad(value, grad, x, slices)
|
367
|
+
current_slice = slices.dup
|
368
|
+
selection = current_slice.shift
|
369
|
+
current_shape = x.shift
|
370
|
+
|
371
|
+
if selection.nil?
|
372
|
+
array_set!(value, grad)
|
373
|
+
end
|
374
|
+
|
375
|
+
b, e, stride = selection
|
376
|
+
|
377
|
+
b = value.size + b if b < 0
|
378
|
+
e = value.size + e + 1 if e < 0
|
379
|
+
|
380
|
+
indexes = if stride < 0
|
381
|
+
b.downto(e).select.with_index { |elem, index| (index % stride.abs) == 0 }
|
382
|
+
else
|
383
|
+
(b...e).step(stride)
|
384
|
+
end
|
385
|
+
|
386
|
+
indexes.each_with_index do |index, grad_index|
|
387
|
+
if (value[index].is_a?(Array))
|
388
|
+
strided_slice_grad(value[index], grad[grad_index], x.dup, current_slice.dup)
|
389
|
+
else
|
390
|
+
value[index] = grad[grad_index]
|
391
|
+
end
|
392
|
+
end
|
393
|
+
end
|
334
394
|
end
|
335
395
|
end
|
@@ -22,8 +22,9 @@ module TensorStream
|
|
22
22
|
merged
|
23
23
|
end
|
24
24
|
|
25
|
-
register_op :gather do |_context,
|
25
|
+
register_op :gather do |_context, tensor, inputs|
|
26
26
|
params, indexes = inputs
|
27
|
+
raise "axis !=0 not supported" if tensor.options[:axis] != 0
|
27
28
|
gather(params, indexes)
|
28
29
|
end
|
29
30
|
|
@@ -216,7 +217,14 @@ module TensorStream
|
|
216
217
|
|
217
218
|
register_op :range do |_context, _tensor, inputs|
|
218
219
|
start, limit, delta = inputs
|
220
|
+
|
219
221
|
raise " delta !=0 " if delta.zero?
|
222
|
+
|
223
|
+
if limit.zero?
|
224
|
+
limit = start
|
225
|
+
start = 0
|
226
|
+
end
|
227
|
+
|
220
228
|
raise " Requires start <= limit when delta > 0" if (start > limit) && delta > 0
|
221
229
|
raise " Requires start >= limit when delta < 0" if (start < limit) && delta < 0
|
222
230
|
|
@@ -399,6 +407,50 @@ module TensorStream
|
|
399
407
|
end
|
400
408
|
end
|
401
409
|
|
410
|
+
register_op :dynamic_partition do |context, tensor, inputs|
|
411
|
+
data, partitions = inputs
|
412
|
+
num_partitions = tensor.options[:num_partitions]
|
413
|
+
output_arr = Array.new(num_partitions) { [] }
|
414
|
+
|
415
|
+
partitions.each_with_index do |part, index|
|
416
|
+
output_arr[part] << data[index]
|
417
|
+
end
|
418
|
+
TensorStream::Evaluator::OutputGroup.new(output_arr, num_partitions.times.map { tensor.data_type })
|
419
|
+
end
|
420
|
+
|
421
|
+
register_op :gather_grad do |context, tensor, inputs|
|
422
|
+
grad, indexes, input_shape = inputs
|
423
|
+
output = Array.new(input_shape.reduce(:*)) { fp_type?(tensor.data_type) ? 0.0 : 0 }
|
424
|
+
indexes.each_with_index.map do |x, index|
|
425
|
+
output[x] += grad[index]
|
426
|
+
end
|
427
|
+
TensorShape.reshape(output, input_shape)
|
428
|
+
end
|
429
|
+
|
430
|
+
register_op :strided_slice do |_context, _tensor, inputs|
|
431
|
+
value, b_index, e_index, stride = inputs
|
432
|
+
slices = b_index.zip(e_index).zip(stride).map do |params|
|
433
|
+
selection, stride = params
|
434
|
+
s, e = selection
|
435
|
+
[s, e, stride]
|
436
|
+
end
|
437
|
+
strided_slice(value, slices)
|
438
|
+
end
|
439
|
+
|
440
|
+
register_op :strided_slice_grad do |_context, tensor, inputs|
|
441
|
+
x, b_index, e_index, stride, grad = inputs
|
442
|
+
slices = b_index.zip(e_index).zip(stride).map do |params|
|
443
|
+
selection, stride = params
|
444
|
+
s, e = selection
|
445
|
+
[s, e, stride]
|
446
|
+
end
|
447
|
+
|
448
|
+
target_val = generate_vector(x, generator: ->() { fp_type?(tensor.data_type) ? 0.0 : 0 })
|
449
|
+
|
450
|
+
strided_slice_grad(target_val, grad, x.dup, slices.dup)
|
451
|
+
target_val
|
452
|
+
end
|
453
|
+
|
402
454
|
def merge_dynamic_stitch(merged, indexes, data, context)
|
403
455
|
indexes.each_with_index do |ind, m|
|
404
456
|
if ind.is_a?(Array)
|
@@ -129,6 +129,15 @@ module TensorStream
|
|
129
129
|
call_op(inputs[0], context) { |t, _b| Math.sqrt(t) }
|
130
130
|
end
|
131
131
|
|
132
|
+
register_op :rsqrt, no_eval: true do |context, _tensor, inputs|
|
133
|
+
call_op(inputs[0], context) { |t, _b| 1 / Math.sqrt(t) }
|
134
|
+
end
|
135
|
+
|
136
|
+
register_op :rsqrt_grad, no_eval: true do |context, tensor, inputs|
|
137
|
+
y, grad = inputs
|
138
|
+
call_vector_op(tensor, :rsqrt_grad, y, grad, context) { |_y, g| 0.5 * g * (_y ** 3) }
|
139
|
+
end
|
140
|
+
|
132
141
|
register_op :floor, no_eval: true do |context, _tensor, inputs|
|
133
142
|
call_op(inputs[0], context) { |t, _b| t.floor }
|
134
143
|
end
|
@@ -153,6 +162,25 @@ module TensorStream
|
|
153
162
|
call_op(inputs[0], context) { |t, _b| 1 - Math.tanh(t) * Math.tanh(t) }
|
154
163
|
end
|
155
164
|
|
165
|
+
register_op :top_k do |context, tensor, inputs|
|
166
|
+
values, k = inputs
|
167
|
+
v_shape = shape_eval(values)
|
168
|
+
|
169
|
+
sorted = tensor.options[:sorted]
|
170
|
+
work_values = TensorShape.reshape(values, [-1, v_shape.last])
|
171
|
+
work_values.map! do |row|
|
172
|
+
last_k = row.map.with_index { |r, index| [r, index] }.sort! { |a,b| a[0] <=> b[0] }.last(k) rescue binding.pry
|
173
|
+
last_k.reverse! if sorted
|
174
|
+
last_k
|
175
|
+
end
|
176
|
+
|
177
|
+
top_k = work_values.map { |row| row.map { |r| r[0] } }
|
178
|
+
top_indices = work_values.map { |row| row.map { |r| r[1] } }
|
179
|
+
v_shape[-1] = k
|
180
|
+
|
181
|
+
TensorStream::Evaluator::OutputGroup.new([TensorShape.reshape(top_k, v_shape), TensorShape.reshape(top_indices, v_shape)], [tensor.inputs[0].data_type, :int32])
|
182
|
+
end
|
183
|
+
|
156
184
|
register_op(%i[argmax arg_max]) do |_context, tensor, inputs|
|
157
185
|
axis = inputs[1] || 0
|
158
186
|
rank = get_rank(inputs[0])
|
@@ -259,13 +287,22 @@ module TensorStream
|
|
259
287
|
raise "#{tensor.inputs[0].name} rank must be greater than 1" if rank_a < 2
|
260
288
|
raise "#{tensor.inputs[1].name} rank must be greater than 1" if rank_b < 2
|
261
289
|
|
262
|
-
matrix_a = matrix_a.transpose if tensor.options[:transpose_a]
|
263
|
-
matrix_b = matrix_b.transpose if tensor.options[:transpose_b]
|
264
|
-
|
265
290
|
# check matrix dimensions
|
266
|
-
|
291
|
+
if rank_a >= 3
|
292
|
+
matrix_a.zip(matrix_b).map do |m_a, m_b|
|
293
|
+
matmul(m_a, m_b, tensor)
|
294
|
+
end
|
295
|
+
else
|
296
|
+
matmul(matrix_a, matrix_b, tensor)
|
297
|
+
end
|
298
|
+
end
|
299
|
+
|
300
|
+
def matmul(m_a, m_b, tensor)
|
301
|
+
m_a = m_a.transpose if tensor.options[:transpose_a]
|
302
|
+
m_b = m_b.transpose if tensor.options[:transpose_b]
|
303
|
+
raise TensorStream::ValueError, "incompatible shape sizes for matrix multiplication (#{m_a[0].size} != #{m_b.size}) #{shape_eval(m_a)} vs #{shape_eval(m_b)}" if m_a[0].size != m_b.size
|
267
304
|
|
268
|
-
(Matrix[*
|
305
|
+
(Matrix[*m_a] * Matrix[*m_b]).to_a
|
269
306
|
end
|
270
307
|
|
271
308
|
register_op %i[max maximum], noop: true do |context, tensor, inputs|
|
@@ -536,6 +536,21 @@ module TensorStream
|
|
536
536
|
end
|
537
537
|
|
538
538
|
|
539
|
+
##
|
540
|
+
# Computes reciprocal of square root of x element-wise.
|
541
|
+
#
|
542
|
+
#
|
543
|
+
# @param input_a tensor X (of type FLOATING_POINT_TYPES)
|
544
|
+
#
|
545
|
+
# Options:
|
546
|
+
# @option name Optional name
|
547
|
+
# @return Tensor
|
548
|
+
def rsqrt(input_a, name: nil)
|
549
|
+
check_allowed_types(input_a, TensorStream::Ops::FLOATING_POINT_TYPES)
|
550
|
+
_op(:rsqrt, input_a, name: name)
|
551
|
+
end
|
552
|
+
|
553
|
+
|
539
554
|
##
|
540
555
|
# This operation returns a 1-D integer tensor representing the shape of input
|
541
556
|
#
|
@@ -615,6 +630,28 @@ module TensorStream
|
|
615
630
|
end
|
616
631
|
|
617
632
|
|
633
|
+
##
|
634
|
+
# Extracts a strided slice of a tensor
|
635
|
+
# this op extracts a slice of size `(end-begin)/stride`
|
636
|
+
# from the given `input_` tensor. Starting at the location specified by `begin`
|
637
|
+
# the slice continues by adding `stride` to the index until all dimensions are
|
638
|
+
# not less than `end`.
|
639
|
+
# Note that a stride can be negative, which causes a reverse slice.
|
640
|
+
#
|
641
|
+
#
|
642
|
+
# @param input A tensor
|
643
|
+
# @param _begin start index
|
644
|
+
# @param _end end index
|
645
|
+
# @param strides end index
|
646
|
+
#
|
647
|
+
# Options:
|
648
|
+
# @option name Optional name
|
649
|
+
# @return Tensor
|
650
|
+
def strided_slice(input, _begin, _end, strides = nil, name: nil)
|
651
|
+
_op(:strided_slice, input, _begin, _end, strides, name: name)
|
652
|
+
end
|
653
|
+
|
654
|
+
|
618
655
|
##
|
619
656
|
# Returns x - y element-wise.
|
620
657
|
#
|
@@ -642,18 +679,20 @@ module TensorStream
|
|
642
679
|
#
|
643
680
|
#
|
644
681
|
# @param input_a tensor X
|
645
|
-
# @param
|
682
|
+
# @param axis_p tensor X (of type INTEGER_TYPES)
|
646
683
|
#
|
647
684
|
# Options:
|
685
|
+
# @option axis axis
|
648
686
|
# @option name Optional name
|
649
687
|
# @option keepdims If true, retains reduced dimensions with length 1. default (false)
|
650
688
|
# @return Tensor
|
651
|
-
def sum(input_a,
|
652
|
-
check_allowed_types(
|
689
|
+
def sum(input_a, axis_p = nil, axis: nil, name: nil, keepdims: false)
|
690
|
+
check_allowed_types(axis_p, TensorStream::Ops::INTEGER_TYPES)
|
653
691
|
input_a = TensorStream.convert_to_tensor(input_a)
|
654
692
|
return input_a if input_a.shape.scalar?
|
655
|
-
|
656
|
-
|
693
|
+
axis_p = axis_p || axis
|
694
|
+
axis_p = cast_axis(input_a, axis_p)
|
695
|
+
_op(:sum, input_a, axis_p, name: name, keepdims: keepdims)
|
657
696
|
end
|
658
697
|
|
659
698
|
alias_method :reduce_sum, :sum
|
@@ -706,6 +745,23 @@ module TensorStream
|
|
706
745
|
end
|
707
746
|
|
708
747
|
|
748
|
+
##
|
749
|
+
# Finds values and indices of the `k` largest entries for the last dimension.
|
750
|
+
#
|
751
|
+
#
|
752
|
+
# @param input 1-D or higher `Tensor` with last dimension at least `k`.
|
753
|
+
# @param k 0-D `int32` `Tensor`. Number of top elements to look for along the last dimension (along each row for matrices)
|
754
|
+
#
|
755
|
+
# Options:
|
756
|
+
# @option sorted If true the resulting `k` elements will be sorted by the values in descending order. default (true)
|
757
|
+
# @option name Optional name
|
758
|
+
# @return Tensor
|
759
|
+
def top_k(input, k = 1, sorted: true, name: nil)
|
760
|
+
result = _op(:top_k, input, k, sorted: sorted, name: name)
|
761
|
+
[result[0], result[1]]
|
762
|
+
end
|
763
|
+
|
764
|
+
|
709
765
|
##
|
710
766
|
# Creates a tensor with all elements set to zero
|
711
767
|
#
|
@@ -5,7 +5,16 @@ module TensorStream
|
|
5
5
|
end
|
6
6
|
|
7
7
|
def [](index)
|
8
|
-
|
8
|
+
if index.is_a?(Range)
|
9
|
+
last = if index.end.nil?
|
10
|
+
[TensorStream.shape(self)[0]]
|
11
|
+
else
|
12
|
+
[index.max + 1]
|
13
|
+
end
|
14
|
+
_op(:strided_slice, self, [index.min], last, [1])
|
15
|
+
else
|
16
|
+
_op(:index, self, index)
|
17
|
+
end
|
9
18
|
end
|
10
19
|
|
11
20
|
def *(other)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module TensorStream
|
2
|
+
# High level math functions
|
3
|
+
class Maths
|
4
|
+
extend TensorStream::OpHelper
|
5
|
+
|
6
|
+
module MathFunctions
|
7
|
+
|
8
|
+
##
|
9
|
+
# Normalizes along dimension axis using an L2 norm.
|
10
|
+
def l2_normalize(x, axis: nil, epsilon: 1e-12, name: nil)
|
11
|
+
TensorStream.name_scope(name, "l2_normalize", values: [x]) do |name|
|
12
|
+
x = TensorStream.convert_to_tensor(x, name: "x")
|
13
|
+
square_sum = TensorStream.reduce_sum(TensorStream.square(x), axis, keepdims: true)
|
14
|
+
x_inv_norm = TensorStream.rsqrt(TensorStream.maximum(square_sum, epsilon))
|
15
|
+
TensorStream.multiply(x, x_inv_norm, name: name)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
extend MathFunctions
|
21
|
+
end
|
22
|
+
end
|
@@ -136,7 +136,7 @@ module TensorStream
|
|
136
136
|
when :sparse_softmax_cross_entropy_with_logits
|
137
137
|
output = node
|
138
138
|
[_broadcast_mul(grad, output[1]), nil]
|
139
|
-
|
139
|
+
when :zeros_like
|
140
140
|
# non differentiable
|
141
141
|
nil
|
142
142
|
when :transpose
|
@@ -165,12 +165,26 @@ module TensorStream
|
|
165
165
|
ts.stack(grad, axis: node.options[:axis])
|
166
166
|
when :conv2d
|
167
167
|
_Conv2DGrad(node, grad)
|
168
|
+
when :flow_dynamic_stitch
|
169
|
+
num_values = node.inputs.size / 2
|
170
|
+
indices_grad = [nil] * num_values
|
171
|
+
|
172
|
+
inputs = (0...num_values).map { |i| _int32(node, node.inputs[i]) }
|
173
|
+
|
174
|
+
values_grad = inputs.map { |inp| TensorStream.gather(grad, inp) }
|
175
|
+
indices_grad + values_grad
|
176
|
+
when :gather
|
177
|
+
[_op(:gather_grad, grad, node.inputs[1], TensorStream.shape(node.inputs[0])), nil]
|
168
178
|
else
|
169
179
|
TensorStream::OpMaker.gradient_op(self, node, grad)
|
170
180
|
end
|
171
181
|
end
|
172
182
|
end
|
173
183
|
|
184
|
+
def self._int32(node, x)
|
185
|
+
(node.inputs[0].data_type == :int32 ? x : TensorStream.cast(x, :int32))
|
186
|
+
end
|
187
|
+
|
174
188
|
def self._reshape_to_input(node, grad)
|
175
189
|
ts.reshape(grad, ts.shape(node.inputs[0]))
|
176
190
|
end
|