tensor_stream 0.8.6 → 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +1 -3
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +4 -4
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +1 -0
- data/lib/tensor_stream/operation.rb +14 -16
- data/lib/tensor_stream/ops.rb +1 -1
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/mnist_data.rb +5 -4
- metadata +3 -5
- data/benchmark/benchmark.rb +0 -88
- data/samples/rnn.rb +0 -105
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d5c4d8c8c2e586ef8c9cce0b57d31b68d9b30ddba682c79e0de5e0c905c76a22
|
4
|
+
data.tar.gz: c503cbef1dd51f563479a248b209f62b035fc076a531269b1d26656521a928a0
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 8a9fb88f5ac1c4f8ec47809405370f28009630155b56e6f1f9e0b1cf0d975188bcfe522216226d425c135a43551e08ddc9603c6ccf2c37e6d78ea3321c52ff5d
|
7
|
+
data.tar.gz: '080d79da1a32290f9ee5b411270835a80f9dd9d1b9d1d10ee28d67ad82c856f2bb3b31bfd122812ba6ff3206b3ff25ccf5294a5a2837d22cd038940229161ce4'
|
data/README.md
CHANGED
@@ -1,6 +1,4 @@
|
|
1
|
-
[![Gem Version](https://badge.fury.io/rb/tensor_stream.svg)](https://badge.fury.io/rb/tensor_stream)
|
2
|
-
|
3
|
-
[![CircleCI](https://circleci.com/gh/jedld/tensor_stream.svg?style=svg)](https://circleci.com/gh/jedld/tensor_stream)
|
1
|
+
[![Gem Version](https://badge.fury.io/rb/tensor_stream.svg)](https://badge.fury.io/rb/tensor_stream)[![CircleCI](https://circleci.com/gh/jedld/tensor_stream.svg?style=svg)](https://circleci.com/gh/jedld/tensor_stream) [![Join the chat at https://gitter.im/tensor_stream/Lobby](https://badges.gitter.im/tensor_stream/Lobby.svg)](https://gitter.im/tensor_stream/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
4
2
|
|
5
3
|
# TensorStream
|
6
4
|
|
@@ -122,7 +122,7 @@ module TensorStream
|
|
122
122
|
|
123
123
|
res = TensorShape.reshape(output_buffer, new_shape)
|
124
124
|
|
125
|
-
TensorStream::Evaluator::OutputGroup.new(res)
|
125
|
+
TensorStream::Evaluator::OutputGroup.new(res, res.map { tensor.inputs[0].data_type })
|
126
126
|
end
|
127
127
|
|
128
128
|
register_op :squeeze do |_context, tensor, inputs|
|
@@ -302,7 +302,7 @@ module TensorStream
|
|
302
302
|
split_tensor(value, begin_index, end_index, axis)
|
303
303
|
end
|
304
304
|
end
|
305
|
-
TensorStream::Evaluator::OutputGroup.new(res)
|
305
|
+
TensorStream::Evaluator::OutputGroup.new(res, res.map { tensor.inputs[0].data_type })
|
306
306
|
end
|
307
307
|
|
308
308
|
register_op :reshape do |_context, _tensor, inputs|
|
@@ -352,11 +352,11 @@ module TensorStream
|
|
352
352
|
shape_eval(inputs[0], tensor.options[:out_type])
|
353
353
|
end
|
354
354
|
|
355
|
-
register_op :shape_n do |_context,
|
355
|
+
register_op :shape_n do |_context, tensor, inputs|
|
356
356
|
shapes = inputs.collect do |input|
|
357
357
|
shape_eval(input)
|
358
358
|
end
|
359
|
-
TensorStream::Evaluator::OutputGroup.new(shapes)
|
359
|
+
TensorStream::Evaluator::OutputGroup.new(shapes, shapes.map { tensor.options[:out_type] })
|
360
360
|
end
|
361
361
|
end
|
362
362
|
end
|
@@ -40,6 +40,7 @@ module TensorStream
|
|
40
40
|
register_op :apply_adagrad do |_context, tensor, inputs|
|
41
41
|
target_var, accum, lr, grad = inputs
|
42
42
|
assign = tensor.inputs[0] || tensor
|
43
|
+
|
43
44
|
assign.value = multi_array_op(->(v, a, g) { v - (g * lr * (1.0 / Math.sqrt(a))) }, target_var, accum, grad)
|
44
45
|
assign.value
|
45
46
|
end
|
@@ -224,7 +224,7 @@ module TensorStream
|
|
224
224
|
when :index
|
225
225
|
input_shape = inputs[0].shape.shape
|
226
226
|
return nil if input_shape.nil?
|
227
|
-
|
227
|
+
input_shape[1, input_shape.size]
|
228
228
|
when :mean, :prod, :sum
|
229
229
|
return [] if inputs[1].nil?
|
230
230
|
return nil if inputs[0].nil?
|
@@ -235,7 +235,7 @@ module TensorStream
|
|
235
235
|
axis = inputs[1].is_a?(Tensor) ? inputs[1].value : inputs[1]
|
236
236
|
|
237
237
|
axis = [axis] unless axis.is_a?(Array)
|
238
|
-
|
238
|
+
input_shape.each_with_index.map do |s, index|
|
239
239
|
next nil if axis.include?(index)
|
240
240
|
s
|
241
241
|
end.compact
|
@@ -246,27 +246,27 @@ module TensorStream
|
|
246
246
|
|
247
247
|
input_shape = inputs[0].shape.shape
|
248
248
|
return new_shape if input_shape.nil?
|
249
|
-
|
250
|
-
|
249
|
+
return nil if input_shape.include?(nil)
|
250
|
+
TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
|
251
251
|
when :flow_group
|
252
|
-
|
252
|
+
[]
|
253
253
|
when :zeros, :ones, :fill
|
254
|
-
|
254
|
+
inputs[0] ? inputs[0].value : options[:shape]
|
255
255
|
when :zeros_like, :ones_like
|
256
256
|
inputs[0].shape.shape
|
257
257
|
when :shape
|
258
|
-
|
258
|
+
inputs[0].shape.shape ? [inputs[0].shape.shape.size] : nil
|
259
259
|
when :mat_mul
|
260
260
|
shape1 = inputs[0].shape.shape.nil? ? nil : inputs[0].shape.shape[0]
|
261
261
|
shape2 = inputs[1].shape.shape.nil? ? nil : inputs[1].shape.shape[1]
|
262
|
-
|
262
|
+
[shape1, shape2]
|
263
263
|
when :transpose
|
264
264
|
return nil unless shape_full_specified(inputs[0])
|
265
265
|
return nil if inputs[1].is_a?(Tensor)
|
266
266
|
|
267
267
|
rank = inputs[0].shape.shape.size
|
268
268
|
perm = inputs[1] || (0...rank).to_a.reverse
|
269
|
-
|
269
|
+
perm.map { |p| inputs[0].shape.shape[p] }
|
270
270
|
when :stack
|
271
271
|
return nil unless shape_full_specified(inputs[0])
|
272
272
|
|
@@ -276,7 +276,7 @@ module TensorStream
|
|
276
276
|
rank = inputs[0].shape.shape.size + 1
|
277
277
|
axis = rank + axis if axis < 0
|
278
278
|
rotated_shape = Array.new(axis + 1) { new_shape.shift }
|
279
|
-
|
279
|
+
rotated_shape.rotate! + new_shape
|
280
280
|
when :concat
|
281
281
|
return nil if inputs[0].value.nil?
|
282
282
|
|
@@ -293,18 +293,16 @@ module TensorStream
|
|
293
293
|
|
294
294
|
new_shape = inputs[1].shape.shape.dup
|
295
295
|
new_shape[axis] = axis_size
|
296
|
-
|
296
|
+
new_shape
|
297
297
|
when :slice, :squeeze
|
298
|
-
|
298
|
+
nil
|
299
299
|
when :tile
|
300
|
-
|
300
|
+
nil
|
301
301
|
else
|
302
302
|
return nil if inputs[0].nil?
|
303
303
|
return inputs[0].shape.shape if inputs.size == 1
|
304
|
-
|
304
|
+
TensorShape.infer_shape(inputs[0].shape.shape, inputs[1].shape.shape) if inputs.size == 2 && inputs[0] && inputs[1]
|
305
305
|
end
|
306
|
-
|
307
|
-
nil
|
308
306
|
end
|
309
307
|
|
310
308
|
def propagate_consumer(consumer)
|
data/lib/tensor_stream/ops.rb
CHANGED
@@ -115,7 +115,7 @@ module TensorStream
|
|
115
115
|
end
|
116
116
|
|
117
117
|
if shapes_known
|
118
|
-
inputs.collect { |input| cons(input.shape.shape dtype: out_type) }
|
118
|
+
inputs.collect { |input| cons(input.shape.shape, dtype: out_type) }
|
119
119
|
else
|
120
120
|
res = _op(:shape_n, *inputs, out_type: out_type, name: name)
|
121
121
|
Array.new(inputs.size) do |index|
|
data/samples/mnist_data.rb
CHANGED
@@ -11,7 +11,7 @@ require 'tensor_stream'
|
|
11
11
|
require 'mnist-learn'
|
12
12
|
|
13
13
|
# Enable OpenCL hardware accelerated computation, not using OpenCL can be very slow
|
14
|
-
require 'tensor_stream/
|
14
|
+
# require 'tensor_stream/opencl'
|
15
15
|
|
16
16
|
tf = TensorStream
|
17
17
|
|
@@ -20,11 +20,11 @@ puts "downloading minst data"
|
|
20
20
|
mnist = Mnist.read_data_sets('/tmp/data', one_hot: true)
|
21
21
|
puts "downloading finished"
|
22
22
|
|
23
|
-
x = tf.placeholder(:float32, shape: [nil,
|
23
|
+
x = tf.placeholder(:float32, shape: [nil, 784])
|
24
24
|
w = tf.variable(tf.zeros([784, 10]))
|
25
25
|
b = tf.variable(tf.zeros([10]))
|
26
26
|
|
27
|
-
|
27
|
+
|
28
28
|
|
29
29
|
# model
|
30
30
|
y = tf.nn.softmax(tf.matmul(tf.reshape(x, [-1, 784]), w) + b)
|
@@ -37,10 +37,11 @@ cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
|
|
37
37
|
is_correct = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
|
38
38
|
accuracy = tf.reduce_mean(tf.cast(is_correct, :float32))
|
39
39
|
|
40
|
-
optimizer = TensorStream::Train::
|
40
|
+
optimizer = TensorStream::Train::AdamOptimizer.new
|
41
41
|
train_step = optimizer.minimize(cross_entropy)
|
42
42
|
|
43
43
|
sess = tf.session
|
44
|
+
init = tf.global_variables_initializer
|
44
45
|
sess.run(init)
|
45
46
|
|
46
47
|
(0...1000).each do |i|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: tensor_stream
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Joseph Emmanuel Dayo
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-10-05 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: bundler
|
@@ -243,7 +243,6 @@ files:
|
|
243
243
|
- LICENSE.txt
|
244
244
|
- README.md
|
245
245
|
- Rakefile
|
246
|
-
- benchmark/benchmark.rb
|
247
246
|
- benchmark_intel.txt
|
248
247
|
- benchmark_nvidia.txt
|
249
248
|
- benchmark_ryzen_amd.txt
|
@@ -314,7 +313,6 @@ files:
|
|
314
313
|
- samples/multigpu.rb
|
315
314
|
- samples/nearest_neighbor.rb
|
316
315
|
- samples/raw_neural_net_sample.rb
|
317
|
-
- samples/rnn.rb
|
318
316
|
- tensor_stream.gemspec
|
319
317
|
homepage: http://www.github.com/jedld/tensor_stream
|
320
318
|
licenses:
|
@@ -337,7 +335,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
337
335
|
version: '0'
|
338
336
|
requirements: []
|
339
337
|
rubyforge_project:
|
340
|
-
rubygems_version:
|
338
|
+
rubygems_version: 2.7.7
|
341
339
|
signing_key:
|
342
340
|
specification_version: 4
|
343
341
|
summary: A Pure ruby tensorflow implementation
|
data/benchmark/benchmark.rb
DELETED
@@ -1,88 +0,0 @@
|
|
1
|
-
require "bundler/setup"
|
2
|
-
require 'tensor_stream'
|
3
|
-
require 'benchmark'
|
4
|
-
require 'pry-byebug'
|
5
|
-
require 'awesome_print'
|
6
|
-
require 'tensor_stream/evaluator/opencl/opencl_evaluator'
|
7
|
-
|
8
|
-
def tr(t, places = 1)
|
9
|
-
if t.is_a?(Array)
|
10
|
-
return t.collect do |v|
|
11
|
-
tr(v, places)
|
12
|
-
end
|
13
|
-
end
|
14
|
-
|
15
|
-
return t unless t.is_a?(Float)
|
16
|
-
|
17
|
-
t.round(places)
|
18
|
-
end
|
19
|
-
|
20
|
-
tf = TensorStream
|
21
|
-
|
22
|
-
srand(5)
|
23
|
-
seed = 5
|
24
|
-
tf.set_random_seed(seed)
|
25
|
-
|
26
|
-
SHAPES = [32, 32]
|
27
|
-
|
28
|
-
sess = tf.session(:ruby_evaluator)
|
29
|
-
|
30
|
-
a = tf.constant(sess.run(tf.random_uniform(SHAPES)))
|
31
|
-
a_int = tf.constant([
|
32
|
-
[1, 2, 3, 4, 4, 1, 4, 8, 3, 4, 1, 1],
|
33
|
-
[2, 2, 3, 4, 4, 1, 1, 1, 1, 4, 1, 1],
|
34
|
-
[3, 2, 3, 4, 0, 1, 1, 2, 1, 1, 2, 1],
|
35
|
-
[4, 2, 3, 4, 0, 1, 1, 0, 1, 1, 3, 1],
|
36
|
-
[4, 2, 3, 4, 0, 1, 1, 0, 1, 1, 4, 1],
|
37
|
-
[4, 2, 3, 4, 0, 1, 1, 0, 0, 1, 5, 1],
|
38
|
-
[4, 2, 3, 4, 0, 1, 1, 0, 0, 1, 6, 1],
|
39
|
-
[4, 2, 3, 4, 0, 1, 1, 0, 0, 0, 0, 1],
|
40
|
-
[4, 2, 3, 4, 0, 1, 1, 0, 0, 2, 6, 1],
|
41
|
-
[4, 2, 3, 4, 0, 1, 1, 0, 0, 2, 1, 1],
|
42
|
-
[4, 2, 3, 4, 0, 1, 1, 0, 0, 2, 1, 2],
|
43
|
-
[4, 2, 3, 4, 0, 1, 1, 0, 0, 2, 1, 2],
|
44
|
-
])
|
45
|
-
|
46
|
-
b = tf.constant(sess.run(tf.random_uniform(SHAPES)))
|
47
|
-
|
48
|
-
c = tf.constant(sess.run(tf.random_uniform(SHAPES)))
|
49
|
-
|
50
|
-
d = tf.constant(sess.run(tf.random_uniform(SHAPES)))
|
51
|
-
|
52
|
-
p = tf.placeholder('float')
|
53
|
-
q = tf.placeholder('float')
|
54
|
-
|
55
|
-
model = -tf.sin(a.dot(b + p) + c).dot(a) + tf.cos(a.dot(d + q))
|
56
|
-
single_function_test = (tf.sigmoid(a * p) * tf.sigmoid(b * q)) + c
|
57
|
-
pow_f = tf.pow(a, 3)
|
58
|
-
pow_i = tf.pow(a_int, 3)
|
59
|
-
matmul = tf.matmul(a, b)
|
60
|
-
out_of_order = tf.matmul(a, b) + tf.matmul(a, c)
|
61
|
-
softmax = tf.nn.softmax(a)
|
62
|
-
add_n = tf.add_n([a,b,c,d])
|
63
|
-
|
64
|
-
puts TensorStream::Evaluator.default_evaluators
|
65
|
-
|
66
|
-
sess2 = tf.session
|
67
|
-
|
68
|
-
puts `cat /proc/cpuinfo | grep "model name" | head -1`
|
69
|
-
device = TensorStream::Evaluator::OpenclEvaluator.default_device.native_device
|
70
|
-
puts "OpenCL device #{device.platform.to_s} #{device.name}"
|
71
|
-
Benchmark.bmbm do |x|
|
72
|
-
x.report("pure ruby add_n :") { 100.times do sess.run(add_n) end }
|
73
|
-
x.report("opencl ruby add_n :") { 100.times do sess2.run(add_n) end }
|
74
|
-
x.report("pure ruby ooo matmul :") { 100.times do sess.run(out_of_order) end }
|
75
|
-
x.report("opencl ooo matmul :") { 100.times do sess2.run(out_of_order) end }
|
76
|
-
x.report("pure ruby softmax :") { 100.times do sess.run(softmax) end }
|
77
|
-
x.report("opencl softmax :") { 100.times do sess2.run(softmax) end }
|
78
|
-
x.report("pure ruby matmul :") { 100.times do sess.run(matmul) end }
|
79
|
-
x.report("opencl matmul :") { 100.times do sess2.run(matmul) end }
|
80
|
-
x.report("pure ruby :") { 100.times do sess.run(model, feed_dict: { p => rand, q => rand }) end }
|
81
|
-
x.report("opencl :") { 100.times do sess2.run(model, feed_dict: { p => rand, q => rand }) end }
|
82
|
-
x.report("pure ruby single function:") { 100.times do sess.run(single_function_test, feed_dict: { p => rand, q => rand }) end }
|
83
|
-
x.report("opencl singlefunction:") { 100.times do sess2.run(single_function_test, feed_dict: { p => rand, q => rand }) end }
|
84
|
-
x.report("pure ruby pow float:") { 100.times do sess.run(pow_f, feed_dict: { p => rand, q => rand }) end }
|
85
|
-
x.report("opencl pow float:") { 100.times do sess2.run(pow_f, feed_dict: { p => rand, q => rand }) end }
|
86
|
-
x.report("pure ruby pow int:") { 100.times do sess.run(pow_i, feed_dict: { p => rand, q => rand }) end }
|
87
|
-
x.report("opencl pow int:") { 100.times do sess2.run(pow_i, feed_dict: { p => rand, q => rand }) end }
|
88
|
-
end
|
data/samples/rnn.rb
DELETED
@@ -1,105 +0,0 @@
|
|
1
|
-
# RNN sample
|
2
|
-
#
|
3
|
-
# Ruby port Example based on article by Erik Hallström
|
4
|
-
# https://medium.com/@erikhallstrm/hello-world-rnn-83cd7105b767
|
5
|
-
#
|
6
|
-
#
|
7
|
-
|
8
|
-
require "bundler/setup"
|
9
|
-
require 'tensor_stream'
|
10
|
-
|
11
|
-
tf = TensorStream
|
12
|
-
|
13
|
-
num_epochs = 100
|
14
|
-
total_series_length = 50000
|
15
|
-
truncated_backprop_length = 15
|
16
|
-
state_size = 4
|
17
|
-
num_classes = 2
|
18
|
-
echo_step = 3
|
19
|
-
batch_size = 5
|
20
|
-
num_batches = total_series_length / batch_size / truncated_backprop_length
|
21
|
-
randomizer = TensorStream.random_uniform([total_series_length], minval: 0, maxval: 2)
|
22
|
-
|
23
|
-
|
24
|
-
def generate_data(randomizer, total_series_length, batch_size, echo_step)
|
25
|
-
x = randomizer.eval
|
26
|
-
y = x.rotate(-echo_step)
|
27
|
-
|
28
|
-
y[echo_step] = 0
|
29
|
-
|
30
|
-
x = TensorStream::TensorShape.reshape(x, [batch_size, -1]) # The first index changing slowest, subseries as rows
|
31
|
-
y = TensorStream::TensorShape.reshape(y, [batch_size, -1])
|
32
|
-
[x, y]
|
33
|
-
end
|
34
|
-
|
35
|
-
batchX_placeholder = tf.placeholder(:float32, shape: [batch_size, truncated_backprop_length], name: 'batch_x')
|
36
|
-
batchY_placeholder = tf.placeholder(:int32, shape: [batch_size, truncated_backprop_length], name: 'batch_y')
|
37
|
-
|
38
|
-
init_state = tf.placeholder(:float32, shape: [batch_size, state_size], name: 'init_state')
|
39
|
-
|
40
|
-
|
41
|
-
W = tf.variable(tf.random_uniform([state_size+1, state_size]), dtype: :float32, name: 'W')
|
42
|
-
b = tf.variable(tf.zeros([state_size]), dtype: :float32, name: 'b')
|
43
|
-
|
44
|
-
W2 = tf.variable(tf.random_uniform([state_size, num_classes]), dtype: :float32, name: 'W2')
|
45
|
-
b2 = tf.variable(tf.zeros([num_classes]), dtype: :float32, name: 'b2')
|
46
|
-
|
47
|
-
|
48
|
-
inputs_series = tf.unpack(batchX_placeholder, axis: 1)
|
49
|
-
labels_series = tf.unpack(batchY_placeholder, axis: 1)
|
50
|
-
|
51
|
-
current_state = init_state
|
52
|
-
states_series = []
|
53
|
-
|
54
|
-
inputs_series.each do |current_input|
|
55
|
-
current_input = tf.reshape(current_input, [batch_size, 1])
|
56
|
-
input_and_state_concatenated = tf.concat([current_input, current_state], 1) # Increasing number of columns
|
57
|
-
next_state = tf.tanh(tf.matmul(input_and_state_concatenated, W) + b) # Broadcasted addition
|
58
|
-
states_series << next_state
|
59
|
-
current_state = next_state
|
60
|
-
end
|
61
|
-
|
62
|
-
logits_series = states_series.collect do |state|
|
63
|
-
tf.matmul(state, W2) + b2
|
64
|
-
end
|
65
|
-
|
66
|
-
predictions_series = logits_series.collect do |logits|
|
67
|
-
tf.nn.softmax(logits)
|
68
|
-
end
|
69
|
-
|
70
|
-
losses = logits_series.zip(labels_series).collect do |logits, labels|
|
71
|
-
tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: labels)
|
72
|
-
end
|
73
|
-
total_loss = tf.reduce_mean(losses)
|
74
|
-
|
75
|
-
train_step = TensorStream::Train::AdagradOptimizer.new(0.3).minimize(total_loss)
|
76
|
-
|
77
|
-
puts "#{tf.get_default_graph.nodes.keys.size} nodes created"
|
78
|
-
zeros_state = tf.zeros([batch_size, state_size]).eval
|
79
|
-
tf.session do |sess|
|
80
|
-
sess.run(tf.global_variables_initializer)
|
81
|
-
(0..num_epochs).each do |epoch_idx|
|
82
|
-
x, y = generate_data(randomizer, total_series_length, batch_size, echo_step)
|
83
|
-
_current_state = zeros_state
|
84
|
-
print("New data, epoch", epoch_idx, "\n")
|
85
|
-
(0..num_batches - 1).each do |batch_idx|
|
86
|
-
start_idx = batch_idx * truncated_backprop_length
|
87
|
-
end_idx = start_idx + truncated_backprop_length
|
88
|
-
|
89
|
-
batchX = x.map { |x| x[start_idx...end_idx] }
|
90
|
-
batchY = y.map { |y| y[start_idx...end_idx] }
|
91
|
-
|
92
|
-
_total_loss, _train_step, _current_state, _predictions_series = sess.run(
|
93
|
-
[total_loss, train_step, current_state, predictions_series],
|
94
|
-
feed_dict: {
|
95
|
-
batchX_placeholder => batchX,
|
96
|
-
batchY_placeholder => batchY,
|
97
|
-
init_state => _current_state
|
98
|
-
})
|
99
|
-
|
100
|
-
if batch_idx%100 == 0
|
101
|
-
print("Step",batch_idx, " Loss ", _total_loss, "\n")
|
102
|
-
end
|
103
|
-
end
|
104
|
-
end
|
105
|
-
end
|