tensor_stream-opencl 0.2.8 → 0.2.9

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9c305996e1d65ae1144ec6bbb874382ff4925b361b3128092c3b8a4c7bff5178
4
- data.tar.gz: 52d9d7d7fea60b1282012605bfa3dd320593f24500974a01f90a4ad605b01364
3
+ metadata.gz: acbe0eec553f30d1407ca6ab5c0401c7573450ebf08c10cac3fc0b1b579f9ca5
4
+ data.tar.gz: 160aa96505972cc291696a732d9021945483036fd6f4009b895ee09a12c04639
5
5
  SHA512:
6
- metadata.gz: be912476033f2be3fdc5b0bfe64cb826c4ec0fa3a6847a9b217cc402850a9b298c19eb3635a0aec34b7639a2fd50d11d258d4fb1cc4d2f75a744bd6eb8e4b118
7
- data.tar.gz: cb473fed480a78be78159c04a0bc794011213a692e80dd5286d21d80a2cc048f636117654ea77d4c03fa3a7040ff2cad274757726791f4d5b7ed22479cfa6e2a
6
+ metadata.gz: '0986e63d5fb000b36b66fc0b0bc4cbaf95d43d1d804347bf5f928d2444c631d32182a9b7893d84da5a2518a61444688a67debd82980029ef4d3642d921cf553c'
7
+ data.tar.gz: bd0033cbc8bc43bfcc29f948a13cc17e7895564280c1cf96afcfcdef91d7d121e1055383fff2e8520b50c1621f87aa19f03ebf10f1a05b9530ac5bc5d5ae3155
@@ -5,7 +5,7 @@ module TensorStream
5
5
  def MathOps.included(klass)
6
6
  klass.class_eval do
7
7
  %i[max min add real_div div sub floor_mod mod mul pow sigmoid_grad squared_difference].each do |op|
8
- register_op op do |context, tensor, inputs|
8
+ register_op op do |_context, tensor, inputs|
9
9
  execute_2_operand_func(op.to_s, tensor, inputs[0], inputs[1])
10
10
  end
11
11
  end
@@ -1,5 +1,5 @@
1
1
  module TensorStream
2
2
  module Opencl
3
- VERSION = "0.2.8"
3
+ VERSION = "0.2.9"
4
4
  end
5
5
  end
@@ -48,7 +48,7 @@ y_ = tf.placeholder(:float32, shape: [nil, 10])
48
48
  # step for variable learning rate
49
49
  step = tf.placeholder(:int32)
50
50
 
51
- pkeep = tf.placeholder(tf.float32)
51
+ pkeep = tf.placeholder(:float32)
52
52
 
53
53
  # three convolutional layers with their channel counts, and a
54
54
  # fully connected layer (tha last layer has 10 softmax neurons)
@@ -0,0 +1,623 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "In this notebook will be demonstrating how to train using convolutional networks using TensorStream and its OpenCL backend.\n",
8
+ "\n",
9
+ "Note that code here is based on Martin Gorner's talk in \"TensorFlow and Deep Learning without a PhD, Part 1 (Google Cloud Next '17)\"\"\n",
10
+ "\n",
11
+ "https://www.youtube.com/watch?v=u4alGiomYP4"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "Include the OpenCL backend as working with images is compute intensive. Note that this requires OpenCL to be configured properly on your machine."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "metadata": {},
25
+ "outputs": [
26
+ {
27
+ "name": "stderr",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "Warning OpenCL 2.0 loader detected!\n"
31
+ ]
32
+ },
33
+ {
34
+ "name": "stdout",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "Tensorstream version 1.0.0 with OpenCL lib 0.2.8\n"
38
+ ]
39
+ }
40
+ ],
41
+ "source": [
42
+ "require 'tensor_stream'\n",
43
+ "require 'mnist-learn'\n",
44
+ "require 'csv'\n",
45
+ "\n",
46
+ "require 'tensor_stream/opencl'\n",
47
+ "\n",
48
+ "ts = TensorStream\n",
49
+ "puts \"Tensorstream version #{ts.__version__} with OpenCL lib #{TensorStream::Opencl::VERSION}\""
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {},
55
+ "source": [
56
+ "Download the MNIST data set which we will use for training the network"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 2,
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "downloading minst data\n",
69
+ "downloading finished\n"
70
+ ]
71
+ }
72
+ ],
73
+ "source": [
74
+ "# Import MNIST data\n",
75
+ "puts \"downloading minst data\"\n",
76
+ "# Download images and labels into mnist.test (10K images+labels) and mnist.train (60K images+labels)\n",
77
+ "mnist = Mnist.read_data_sets('/tmp/data', one_hot: true)\n",
78
+ "\n",
79
+ "puts \"downloading finished\""
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "metadata": {},
85
+ "source": [
86
+ "Setup parameters that we will use for the network:"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 3,
92
+ "metadata": {},
93
+ "outputs": [
94
+ {
95
+ "data": {
96
+ "text/plain": [
97
+ "2000"
98
+ ]
99
+ },
100
+ "execution_count": 3,
101
+ "metadata": {},
102
+ "output_type": "execute_result"
103
+ }
104
+ ],
105
+ "source": [
106
+ "K = 4 # first convolutional layer output depth\n",
107
+ "L = 8 # second convolutional layer output depth\n",
108
+ "M = 12 # third convolutional layer\n",
109
+ "N = 200 # fully connected layer\n",
110
+ "EPOCH = 2000"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "markdown",
115
+ "metadata": {},
116
+ "source": [
117
+ "Setup placeholders. Placeholders are like input parameters that your model can accept and has no definite value until you give it one during sess.run"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 4,
123
+ "metadata": {},
124
+ "outputs": [
125
+ {
126
+ "data": {
127
+ "text/plain": [
128
+ "Placeholder(Placeholder_4 shape: ? data_type: float32)"
129
+ ]
130
+ },
131
+ "execution_count": 4,
132
+ "metadata": {},
133
+ "output_type": "execute_result"
134
+ }
135
+ ],
136
+ "source": [
137
+ "# input X: 28x28 grayscale images, the first dimension (None) will index the images in the mini-batch\n",
138
+ "x = ts.placeholder(:float32, shape: [nil, 28, 28, 1])\n",
139
+ "\n",
140
+ "# correct answers will go here\n",
141
+ "y_ = ts.placeholder(:float32, shape: [nil, 10])\n",
142
+ "\n",
143
+ "# step for variable learning rate\n",
144
+ "step_ = ts.placeholder(:int32)\n",
145
+ "\n",
146
+ "pkeep = ts.placeholder(:float32)"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "markdown",
151
+ "metadata": {},
152
+ "source": [
153
+ "Here we declare variables. The contents of these variables are randomized initially, however this get updated automatically during training. Variables contain the weights of the network and the values serves as neural connections that make the system learn."
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 5,
159
+ "metadata": {},
160
+ "outputs": [
161
+ {
162
+ "data": {
163
+ "text/plain": [
164
+ "Variable(Variable_10:0 shape: TensorShape([Dimension(10)]) data_type: float32)"
165
+ ]
166
+ },
167
+ "execution_count": 5,
168
+ "metadata": {},
169
+ "output_type": "execute_result"
170
+ }
171
+ ],
172
+ "source": [
173
+ "w1 = ts.variable(ts.truncated_normal([6, 6, 1, K], stddev: 0.1))\n",
174
+ "b1 = ts.variable(ts.ones([K])/10)\n",
175
+ "\n",
176
+ "w2 = ts.variable(ts.truncated_normal([5, 5, K, L], stddev: 0.1))\n",
177
+ "b2 = ts.variable(ts.ones([L])/10)\n",
178
+ "\n",
179
+ "w3 = ts.variable(ts.truncated_normal([4, 4, L, M], stddev: 0.1))\n",
180
+ "b3 = ts.variable(ts.ones([M])/10)\n",
181
+ "\n",
182
+ "w4 = ts.variable(ts.truncated_normal([7 * 7 * M, N], stddev: 0.1))\n",
183
+ "b4 = ts.variable(ts.ones([N])/10)\n",
184
+ "\n",
185
+ "w5 = ts.variable(ts.truncated_normal([N, 10], stddev: 0.1))\n",
186
+ "b5 = ts.variable(ts.ones([10])/10)\n"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {},
192
+ "source": [
193
+ "Here we declare the model itself. These define the computation that make up the structure of the neural network. In this case we are setting up 3 convolutional layers and 2 fully connected layers. We are also using relu as the activiation functions. The kinds of functions to use are based on decades of research and this can change depending on new findings."
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": 6,
199
+ "metadata": {},
200
+ "outputs": [
201
+ {
202
+ "data": {
203
+ "text/plain": [
204
+ "Op(softmax name: out shape: ? data_type: float32)"
205
+ ]
206
+ },
207
+ "execution_count": 6,
208
+ "metadata": {},
209
+ "output_type": "execute_result"
210
+ }
211
+ ],
212
+ "source": [
213
+ "# The model\n",
214
+ "stride = 1 # output is 28x28\n",
215
+ "y1 = ts.nn.relu(ts.nn.conv2d(x.reshape([-1, 28, 28, 1]), w1, [1, stride, stride, 1], 'SAME') + b1)\n",
216
+ "stride = 2 # output is 14x14\n",
217
+ "y2 = ts.nn.relu(ts.nn.conv2d(y1, w2, [1, stride, stride, 1], 'SAME') + b2)\n",
218
+ "stride = 2 # output is 7x7\n",
219
+ "y3 = ts.nn.relu(ts.nn.conv2d(y2, w3, [1, stride, stride, 1], 'SAME') + b3)\n",
220
+ "\n",
221
+ "# reshape the output from the third convolution for the fully connected layer\n",
222
+ "yy = y3.reshape([-1, 7 * 7 * M])\n",
223
+ "y4 = ts.nn.relu(ts.matmul(yy, w4) + b4)\n",
224
+ "\n",
225
+ "ylogits = ts.matmul(y4, w5) + b5\n",
226
+ "\n",
227
+ "# model\n",
228
+ "y = ts.nn.softmax(ylogits, name: 'out')"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "markdown",
233
+ "metadata": {},
234
+ "source": [
235
+ "Now we define the error function to use and the optimization algorithm. There are various error functions to choose from as well as optimization algorithms, most have their pros and cons. However for this type of neural network the softmax cross entropy and the Adam optimizer seems the most appropriate."
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 7,
241
+ "metadata": {},
242
+ "outputs": [
243
+ {
244
+ "data": {
245
+ "text/plain": [
246
+ "Op(flow_group name: Adam/flow_group shape: TensorShape([Dimension(12)]) data_type: )"
247
+ ]
248
+ },
249
+ "execution_count": 7,
250
+ "metadata": {},
251
+ "output_type": "execute_result"
252
+ }
253
+ ],
254
+ "source": [
255
+ "cross_entropy = ts.nn.softmax_cross_entropy_with_logits(logits: ylogits, labels: y_)\n",
256
+ "cross_entropy = ts.reduce_mean(cross_entropy)*100\n",
257
+ "\n",
258
+ "is_correct = ts.equal(ts.argmax(y, 1), ts.argmax(y_, 1))\n",
259
+ "accuracy = ts.reduce_mean(is_correct.cast(:float32))\n",
260
+ "\n",
261
+ "# training step, learning rate = 0.003\n",
262
+ "lr = 0.0001.t + ts.train.exponential_decay(0.003, step_, 2000, 1/Math::E)\n",
263
+ "train_step = TensorStream::Train::AdamOptimizer.new(lr).minimize(cross_entropy)"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "markdown",
268
+ "metadata": {},
269
+ "source": [
270
+ "Setup test data and use a saver so that progress can be continued on the next run. Here we also\n",
271
+ "initialize the variables, otherwise they will contain null values and cause errors during the next sess.run"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": 10,
277
+ "metadata": {},
278
+ "outputs": [
279
+ {
280
+ "ename": "Interrupt",
281
+ "evalue": "",
282
+ "output_type": "error",
283
+ "traceback": [
284
+ "\u001b[31mInterrupt\u001b[0m: ",
285
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/czmq-ffi-gen-0.15.0/lib/czmq-ffi-gen/errors.rb:10:in `zmq_errno'\u001b[0m",
286
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/czmq-ffi-gen-0.15.0/lib/czmq-ffi-gen/errors.rb:10:in `strerror'\u001b[0m",
287
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/cztop-0.13.1/lib/cztop/has_ffi_delegate.rb:48:in `raise_zmq_err'\u001b[0m",
288
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/cztop-0.13.1/lib/cztop/message.rb:76:in `receive_from'\u001b[0m",
289
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/cztop-0.13.1/lib/cztop/send_receive_methods.rb:32:in `receive'\u001b[0m",
290
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/iruby-0.3/lib/iruby/session/cztop.rb:59:in `recv'\u001b[0m",
291
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/iruby-0.3/lib/iruby/kernel.rb:42:in `dispatch'\u001b[0m",
292
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/iruby-0.3/lib/iruby/kernel.rb:37:in `run'\u001b[0m",
293
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/iruby-0.3/lib/iruby/command.rb:70:in `run_kernel'\u001b[0m",
294
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/iruby-0.3/lib/iruby/command.rb:34:in `run'\u001b[0m",
295
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/gems/iruby-0.3/bin/iruby:5:in `<top (required)>'\u001b[0m",
296
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/bin/iruby:23:in `load'\u001b[0m",
297
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/bin/iruby:23:in `<main>'\u001b[0m",
298
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/bin/ruby_executable_hooks:24:in `eval'\u001b[0m",
299
+ "\u001b[37m/home/jedld/.rvm/gems/ruby-2.5.1/bin/ruby_executable_hooks:24:in `<main>'\u001b[0m"
300
+ ]
301
+ }
302
+ ],
303
+ "source": [
304
+ "sess = ts.session\n",
305
+ "# Add ops to save and restore all the variables.\n",
306
+ "\n",
307
+ "init = ts.global_variables_initializer\n",
308
+ "\n",
309
+ "sess.run(init)\n",
310
+ "\n",
311
+ "#Setup save and restore\n",
312
+ "model_save_path = \"test_models/mnist_data_3.0\"\n",
313
+ "saver = TensorStream::Train::Saver.new\n",
314
+ "saver.restore(sess, model_save_path)\n",
315
+ "\n",
316
+ "mnist_train = mnist.train\n",
317
+ "test_data = { x => mnist.test.images, y_ => mnist.test.labels, pkeep => 1.0 }\n",
318
+ "\n",
319
+ "nil"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": 9,
325
+ "metadata": {},
326
+ "outputs": [
327
+ {
328
+ "name": "stdout",
329
+ "output_type": "stream",
330
+ "text": [
331
+ "0: accuracy:0.9599999785423279 loss:15.933753967285156 (lr:0.003100000089034438)\n",
332
+ "0: ******** test accuracy: 0.9283999800682068 test loss: 2699.62841796875\n",
333
+ "10: accuracy:0.949999988079071 loss:11.798951148986816 (lr:0.003085037227720022)\n",
334
+ "20: accuracy:0.9399999976158142 loss:13.785243034362793 (lr:0.003070149337872863)\n",
335
+ "30: accuracy:0.9300000071525574 loss:29.51731300354004 (lr:0.003055335721001029)\n",
336
+ "40: accuracy:0.9700000286102295 loss:7.979790687561035 (lr:0.0030405959114432335)\n",
337
+ "50: accuracy:0.9300000071525574 loss:22.561899185180664 (lr:0.0030259296763688326)\n",
338
+ "60: accuracy:0.9700000286102295 loss:17.135129928588867 (lr:0.003011336550116539)\n",
339
+ "70: accuracy:0.949999988079071 loss:12.989716529846191 (lr:0.002996816299855709)\n",
340
+ "80: accuracy:0.9599999785423279 loss:23.976905822753906 (lr:0.002982368227094412)\n",
341
+ "90: accuracy:0.8999999761581421 loss:25.88566017150879 (lr:0.002967992564663291)\n",
342
+ "100: accuracy:0.9700000286102295 loss:9.034774780273438 (lr:0.002953688381239772)\n",
343
+ "100: ******** test accuracy: 0.9628000259399414 test loss: 1461.940185546875\n",
344
+ "110: accuracy:0.9599999785423279 loss:11.933794975280762 (lr:0.002939455443993211)\n",
345
+ "120: accuracy:0.9300000071525574 loss:25.46670150756836 (lr:0.002925293752923608)\n",
346
+ "130: accuracy:0.9399999976158142 loss:21.21111297607422 (lr:0.0029112021438777447)\n",
347
+ "140: accuracy:1.0 loss:4.727890968322754 (lr:0.002897181548178196)\n",
348
+ "150: accuracy:0.9700000286102295 loss:6.854841232299805 (lr:0.002883230336010456)\n",
349
+ "160: accuracy:0.9200000166893005 loss:29.101383209228516 (lr:0.0028693489730358124)\n",
350
+ "170: accuracy:1.0 loss:3.9741010665893555 (lr:0.002855536760762334)\n",
351
+ "180: accuracy:0.9700000286102295 loss:6.293881893157959 (lr:0.0028417936991900206)\n",
352
+ "190: accuracy:1.0 loss:4.598720550537109 (lr:0.002828118856996298)\n",
353
+ "200: accuracy:0.949999988079071 loss:11.564444541931152 (lr:0.0028145122341811657)\n",
354
+ "200: ******** test accuracy: 0.9682000279426575 test loss: 1092.38916015625\n",
355
+ "210: accuracy:0.9700000286102295 loss:9.555869102478027 (lr:0.0028009735979139805)\n",
356
+ "220: accuracy:0.9800000190734863 loss:5.317028999328613 (lr:0.002787502482533455)\n",
357
+ "230: accuracy:0.9700000286102295 loss:8.040860176086426 (lr:0.0027740984223783016)\n",
358
+ "240: accuracy:0.9700000286102295 loss:10.434488296508789 (lr:0.002760761184617877)\n",
359
+ "250: accuracy:0.9800000190734863 loss:5.38808012008667 (lr:0.002747490769252181)\n",
360
+ "260: accuracy:0.9599999785423279 loss:12.20152473449707 (lr:0.002734286244958639)\n",
361
+ "270: accuracy:0.9900000095367432 loss:3.077124834060669 (lr:0.002721147844567895)\n",
362
+ "280: accuracy:0.9700000286102295 loss:7.027272701263428 (lr:0.0027080748695880175)\n",
363
+ "290: accuracy:0.9599999785423279 loss:16.264310836791992 (lr:0.002695067087188363)\n",
364
+ "300: accuracy:0.9700000286102295 loss:8.991308212280273 (lr:0.002682123798877001)\n",
365
+ "300: ******** test accuracy: 0.9660000205039978 test loss: 1412.0006103515625\n",
366
+ "310: accuracy:0.9900000095367432 loss:6.156497001647949 (lr:0.002669245470315218)\n",
367
+ "320: accuracy:0.9800000190734863 loss:3.826826810836792 (lr:0.0026564314030110836)\n",
368
+ "330: accuracy:0.9700000286102295 loss:10.53378677368164 (lr:0.0026436811313033104)\n",
369
+ "340: accuracy:0.9900000095367432 loss:4.218532085418701 (lr:0.0026309946551918983)\n",
370
+ "350: accuracy:0.9900000095367432 loss:4.505715847015381 (lr:0.0026183712761849165)\n",
371
+ "360: accuracy:0.9900000095367432 loss:5.784036159515381 (lr:0.002605810761451721)\n",
372
+ "370: accuracy:0.9900000095367432 loss:4.932862281799316 (lr:0.0025933128781616688)\n",
373
+ "380: accuracy:0.9700000286102295 loss:10.328872680664062 (lr:0.0025808773934841156)\n",
374
+ "390: accuracy:0.9900000095367432 loss:3.35548996925354 (lr:0.002568504074588418)\n",
375
+ "400: accuracy:0.9700000286102295 loss:5.675891876220703 (lr:0.0025561924558132887)\n",
376
+ "400: ******** test accuracy: 0.9735000133514404 test loss: 1068.105224609375\n",
377
+ "410: accuracy:0.9900000095367432 loss:3.3151674270629883 (lr:0.0025439420714974403)\n",
378
+ "420: accuracy:0.9900000095367432 loss:2.646207571029663 (lr:0.002531752921640873)\n",
379
+ "430: accuracy:0.949999988079071 loss:15.31964111328125 (lr:0.0025196243077516556)\n",
380
+ "440: accuracy:0.949999988079071 loss:15.334563255310059 (lr:0.002507556462660432)\n",
381
+ "450: accuracy:0.9900000095367432 loss:5.680765151977539 (lr:0.002495548687875271)\n",
382
+ "460: accuracy:0.9599999785423279 loss:6.461648464202881 (lr:0.0024836009833961725)\n",
383
+ "470: accuracy:0.9900000095367432 loss:4.117606163024902 (lr:0.0024717124179005623)\n",
384
+ "480: accuracy:0.9800000190734863 loss:4.43185567855835 (lr:0.002459883689880371)\n",
385
+ "490: accuracy:0.9599999785423279 loss:11.783940315246582 (lr:0.0024481136351823807)\n",
386
+ "500: accuracy:1.0 loss:1.5767478942871094 (lr:0.0024364024866372347)\n",
387
+ "500: ******** test accuracy: 0.9782999753952026 test loss: 843.5567016601562\n",
388
+ "510: accuracy:0.9800000190734863 loss:3.927738666534424 (lr:0.002424749545753002)\n",
389
+ "520: accuracy:1.0 loss:1.8157836198806763 (lr:0.0024131545796990395)\n",
390
+ "530: accuracy:0.9700000286102295 loss:4.323674201965332 (lr:0.0024016178213059902)\n",
391
+ "540: accuracy:0.9700000286102295 loss:8.708285331726074 (lr:0.00239013833925128)\n",
392
+ "550: accuracy:1.0 loss:1.5547138452529907 (lr:0.002378716366365552)\n",
393
+ "560: accuracy:0.9900000095367432 loss:5.274013042449951 (lr:0.0023673512041568756)\n",
394
+ "570: accuracy:0.9900000095367432 loss:4.134397029876709 (lr:0.002356042852625251)\n",
395
+ "580: accuracy:0.9900000095367432 loss:4.712724685668945 (lr:0.0023447906132787466)\n",
396
+ "590: accuracy:1.0 loss:2.3797760009765625 (lr:0.0023335947189480066)\n",
397
+ "600: accuracy:0.9900000095367432 loss:2.464310884475708 (lr:0.0023224544711411)\n",
398
+ "600: ******** test accuracy: 0.9768999814987183 test loss: 969.0513916015625\n",
399
+ "610: accuracy:0.9800000190734863 loss:4.269423484802246 (lr:0.00231137010268867)\n",
400
+ "620: accuracy:0.9900000095367432 loss:1.96079683303833 (lr:0.0023003409150987864)\n",
401
+ "630: accuracy:0.9900000095367432 loss:3.370682954788208 (lr:0.002289366675540805)\n",
402
+ "640: accuracy:0.9900000095367432 loss:3.629244804382324 (lr:0.002278447151184082)\n",
403
+ "650: accuracy:0.9900000095367432 loss:2.6084139347076416 (lr:0.002267582109197974)\n",
404
+ "660: accuracy:0.9700000286102295 loss:10.114944458007812 (lr:0.0022567713167518377)\n",
405
+ "670: accuracy:0.9800000190734863 loss:7.03450870513916 (lr:0.0022460143081843853)\n",
406
+ "680: accuracy:0.9900000095367432 loss:2.3672995567321777 (lr:0.002235311083495617)\n",
407
+ "690: accuracy:0.9900000095367432 loss:2.9936716556549072 (lr:0.0022246609441936016)\n",
408
+ "700: accuracy:0.9900000095367432 loss:3.0124707221984863 (lr:0.0022140643559396267)\n",
409
+ "700: ******** test accuracy: 0.9732999801635742 test loss: 1081.31103515625\n",
410
+ "710: accuracy:0.9800000190734863 loss:4.998941421508789 (lr:0.0022035203874111176)\n",
411
+ "720: accuracy:0.9900000095367432 loss:3.2429025173187256 (lr:0.002193029038608074)\n",
412
+ "730: accuracy:0.9900000095367432 loss:5.626744270324707 (lr:0.0021825898438692093)\n",
413
+ "740: accuracy:0.9900000095367432 loss:1.948408603668213 (lr:0.0021722030360251665)\n",
414
+ "750: accuracy:0.9900000095367432 loss:6.581701278686523 (lr:0.0021618676837533712)\n",
415
+ "760: accuracy:0.9800000190734863 loss:6.1198039054870605 (lr:0.0021515842527151108)\n",
416
+ "770: accuracy:1.0 loss:2.308049201965332 (lr:0.002141352044418454)\n",
417
+ "780: accuracy:1.0 loss:1.1454635858535767 (lr:0.002131170593202114)\n",
418
+ "790: accuracy:0.9900000095367432 loss:3.5103325843811035 (lr:0.002121040364727378)\n",
419
+ "800: accuracy:1.0 loss:0.9047347903251648 (lr:0.0021109601948410273)\n",
420
+ "800: ******** test accuracy: 0.9776999950408936 test loss: 1062.073974609375\n",
421
+ "810: accuracy:0.9800000190734863 loss:3.2006778717041016 (lr:0.0021009305492043495)\n",
422
+ "820: accuracy:0.9700000286102295 loss:8.762024879455566 (lr:0.0020909507293254137)\n",
423
+ "830: accuracy:1.0 loss:1.0636988878250122 (lr:0.00208102073520422)\n",
424
+ "840: accuracy:0.9800000190734863 loss:3.2064943313598633 (lr:0.002071140566840768)\n",
425
+ "850: accuracy:1.0 loss:0.9395076632499695 (lr:0.002061309525743127)\n",
426
+ "860: accuracy:1.0 loss:3.0294976234436035 (lr:0.0020515271462500095)\n",
427
+ "870: accuracy:1.0 loss:1.2438673973083496 (lr:0.002041793894022703)\n",
428
+ "880: accuracy:0.9800000190734863 loss:5.025793075561523 (lr:0.0020321093033999205)\n",
429
+ "890: accuracy:1.0 loss:1.1006050109863281 (lr:0.002022472908720374)\n",
430
+ "900: accuracy:0.9700000286102295 loss:7.187697410583496 (lr:0.0020128844771534204)\n",
431
+ "900: ******** test accuracy: 0.9746999740600586 test loss: 1035.771240234375\n",
432
+ "910: accuracy:1.0 loss:1.6806761026382446 (lr:0.0020033440086990595)\n",
433
+ "920: accuracy:1.0 loss:1.3582149744033813 (lr:0.001993851037696004)\n"
434
+ ]
435
+ },
436
+ {
437
+ "name": "stdout",
438
+ "output_type": "stream",
439
+ "text": [
440
+ "930: accuracy:1.0 loss:0.5913202166557312 (lr:0.0019844050984829664)\n",
441
+ "940: accuracy:1.0 loss:1.3453789949417114 (lr:0.0019750066567212343)\n",
442
+ "950: accuracy:0.9900000095367432 loss:1.7747830152511597 (lr:0.0019656552467495203)\n",
443
+ "960: accuracy:1.0 loss:1.236012578010559 (lr:0.0019563501700758934)\n",
444
+ "970: accuracy:1.0 loss:1.853969693183899 (lr:0.0019470915431156754)\n",
445
+ "980: accuracy:1.0 loss:0.845487117767334 (lr:0.0019378792494535446)\n",
446
+ "990: accuracy:0.9900000095367432 loss:1.9294251203536987 (lr:0.00192871259059757)\n",
447
+ "1000: accuracy:1.0 loss:0.9650002121925354 (lr:0.001919591915793717)\n",
448
+ "1000: ******** test accuracy: 0.9814000129699707 test loss: 1020.56982421875\n",
449
+ "1010: accuracy:0.9599999785423279 loss:16.393333435058594 (lr:0.0019105166429653764)\n",
450
+ "1020: accuracy:0.9900000095367432 loss:3.5150673389434814 (lr:0.0019014865392819047)\n",
451
+ "1030: accuracy:1.0 loss:0.7717859745025635 (lr:0.0018925016047433019)\n",
452
+ "1040: accuracy:0.9700000286102295 loss:6.396111965179443 (lr:0.0018835614901036024)\n",
453
+ "1050: accuracy:0.9900000095367432 loss:6.6839680671691895 (lr:0.0018746658461168408)\n",
454
+ "1060: accuracy:1.0 loss:1.2511272430419922 (lr:0.001865814789198339)\n",
455
+ "1070: accuracy:0.9599999785423279 loss:6.292105674743652 (lr:0.0018570076208561659)\n",
456
+ "1080: accuracy:1.0 loss:2.3918440341949463 (lr:0.001848244690336287)\n",
457
+ "1090: accuracy:0.9800000190734863 loss:5.536012649536133 (lr:0.0018395251827314496)\n",
458
+ "1100: accuracy:0.9900000095367432 loss:2.6255433559417725 (lr:0.0018308493308722973)\n",
459
+ "1100: ******** test accuracy: 0.978600025177002 test loss: 959.9609375\n",
460
+ "1110: accuracy:0.9900000095367432 loss:2.0644290447235107 (lr:0.0018222166690975428)\n",
461
+ "1120: accuracy:0.9900000095367432 loss:2.614457368850708 (lr:0.001813627197407186)\n",
462
+ "1130: accuracy:1.0 loss:1.4448753595352173 (lr:0.0018050804501399398)\n",
463
+ "1140: accuracy:1.0 loss:3.2684671878814697 (lr:0.0017965761944651604)\n",
464
+ "1150: accuracy:1.0 loss:1.832627534866333 (lr:0.0017881145467981696)\n",
465
+ "1160: accuracy:0.9900000095367432 loss:1.4943240880966187 (lr:0.0017796949250623584)\n",
466
+ "1170: accuracy:1.0 loss:1.4607481956481934 (lr:0.0017713174456730485)\n",
467
+ "1180: accuracy:0.9800000190734863 loss:4.0511040687561035 (lr:0.0017629817593842745)\n",
468
+ "1190: accuracy:0.9900000095367432 loss:5.078634738922119 (lr:0.0017546876333653927)\n",
469
+ "1200: accuracy:0.9900000095367432 loss:2.203090190887451 (lr:0.0017464348347857594)\n",
470
+ "1200: ******** test accuracy: 0.9796000123023987 test loss: 1017.877685546875\n",
471
+ "1210: accuracy:0.9900000095367432 loss:3.175906181335449 (lr:0.0017382231308147311)\n",
472
+ "1220: accuracy:0.9800000190734863 loss:8.9501314163208 (lr:0.0017300526378676295)\n",
473
+ "1230: accuracy:1.0 loss:1.6223065853118896 (lr:0.0017219226574525237)\n",
474
+ "1240: accuracy:1.0 loss:1.0438904762268066 (lr:0.0017138333059847355)\n",
475
+ "1250: accuracy:0.9700000286102295 loss:6.996125221252441 (lr:0.0017057841178029776)\n",
476
+ "1260: accuracy:0.9900000095367432 loss:2.394702911376953 (lr:0.0016977754421532154)\n",
477
+ "1270: accuracy:1.0 loss:1.5201783180236816 (lr:0.0016898063477128744)\n",
478
+ "1280: accuracy:0.9900000095367432 loss:2.346707344055176 (lr:0.0016818773001432419)\n",
479
+ "1290: accuracy:0.9900000095367432 loss:2.6199910640716553 (lr:0.001673987484537065)\n",
480
+ "1300: accuracy:0.9900000095367432 loss:3.752833127975464 (lr:0.0016661373665556312)\n",
481
+ "1300: ******** test accuracy: 0.9790999889373779 test loss: 931.9585571289062\n",
482
+ "1310: accuracy:1.0 loss:1.3101643323898315 (lr:0.0016583260148763657)\n",
483
+ "1320: accuracy:1.0 loss:0.6096019148826599 (lr:0.0016505538951605558)\n",
484
+ "1330: accuracy:0.9900000095367432 loss:3.4166314601898193 (lr:0.0016428205417469144)\n",
485
+ "1340: accuracy:0.9900000095367432 loss:10.107922554016113 (lr:0.0016351256053894758)\n",
486
+ "1350: accuracy:0.9900000095367432 loss:2.611990213394165 (lr:0.001627469202503562)\n",
487
+ "1360: accuracy:1.0 loss:1.1654564142227173 (lr:0.0016198508674278855)\n",
488
+ "1370: accuracy:1.0 loss:2.004689931869507 (lr:0.0016122704837471247)\n",
489
+ "1380: accuracy:1.0 loss:0.6705881357192993 (lr:0.0016047279350459576)\n",
490
+ "1390: accuracy:1.0 loss:0.29386502504348755 (lr:0.0015972232213243842)\n",
491
+ "1400: accuracy:0.9900000095367432 loss:1.863643765449524 (lr:0.0015897557605057955)\n",
492
+ "1400: ******** test accuracy: 0.9807000160217285 test loss: 836.1314086914062\n",
493
+ "1410: accuracy:1.0 loss:0.421527624130249 (lr:0.0015823256690055132)\n",
494
+ "1420: accuracy:1.0 loss:1.3263750076293945 (lr:0.00157493248116225)\n",
495
+ "1430: accuracy:1.0 loss:0.2853686809539795 (lr:0.001567576196976006)\n",
496
+ "1440: accuracy:0.9900000095367432 loss:2.1042652130126953 (lr:0.0015602567000314593)\n",
497
+ "1450: accuracy:0.9800000190734863 loss:4.364713191986084 (lr:0.0015529736410826445)\n",
498
+ "1460: accuracy:1.0 loss:0.7905405163764954 (lr:0.0015457269037142396)\n",
499
+ "1470: accuracy:1.0 loss:1.9527945518493652 (lr:0.001538516255095601)\n",
500
+ "1480: accuracy:1.0 loss:0.7064335942268372 (lr:0.001531341695226729)\n",
501
+ "1490: accuracy:1.0 loss:0.493858277797699 (lr:0.001524202642031014)\n",
502
+ "1500: accuracy:1.0 loss:1.2484592199325562 (lr:0.0015170994447544217)\n",
503
+ "1500: ******** test accuracy: 0.9805999994277954 test loss: 1007.0768432617188\n",
504
+ "1510: accuracy:0.9900000095367432 loss:3.031127691268921 (lr:0.0015100316377356648)\n",
505
+ "1520: accuracy:1.0 loss:0.6348403692245483 (lr:0.0015029989881440997)\n",
506
+ "1530: accuracy:1.0 loss:0.6699919104576111 (lr:0.00149600172881037)\n",
507
+ "1540: accuracy:0.9900000095367432 loss:1.3584109544754028 (lr:0.0014890391612425447)\n",
508
+ "1550: accuracy:1.0 loss:1.6908869743347168 (lr:0.0014821112854406238)\n",
509
+ "1560: accuracy:0.9900000095367432 loss:2.1196155548095703 (lr:0.0014752179849892855)\n",
510
+ "1570: accuracy:0.9900000095367432 loss:1.5407307147979736 (lr:0.001468359143473208)\n",
511
+ "1580: accuracy:1.0 loss:0.14396381378173828 (lr:0.001461534295231104)\n",
512
+ "1590: accuracy:1.0 loss:1.1674902439117432 (lr:0.001454743673093617)\n",
513
+ "1600: accuracy:1.0 loss:1.0761510133743286 (lr:0.0014479868113994598)\n",
514
+ "1600: ******** test accuracy: 0.9836999773979187 test loss: 809.2349243164062\n",
515
+ "1610: accuracy:0.9900000095367432 loss:1.558372974395752 (lr:0.0014412635937333107)\n",
516
+ "1620: accuracy:1.0 loss:1.0245006084442139 (lr:0.0014345741365104914)\n",
517
+ "1630: accuracy:0.9900000095367432 loss:1.3378311395645142 (lr:0.0014279178576543927)\n",
518
+ "1640: accuracy:1.0 loss:0.7805245518684387 (lr:0.0014212947571650147)\n",
519
+ "1650: accuracy:1.0 loss:0.48196879029273987 (lr:0.0014147048350423574)\n",
520
+ "1660: accuracy:1.0 loss:0.9297369122505188 (lr:0.0014081477420404553)\n",
521
+ "1670: accuracy:0.9900000095367432 loss:1.3147106170654297 (lr:0.0014016233617439866)\n",
522
+ "1680: accuracy:0.9900000095367432 loss:1.4754124879837036 (lr:0.0013951313449069858)\n",
523
+ "1690: accuracy:0.9900000095367432 loss:2.6645889282226562 (lr:0.0013886719243600965)\n",
524
+ "1700: accuracy:1.0 loss:2.6606810092926025 (lr:0.0013822447508573532)\n",
525
+ "1700: ******** test accuracy: 0.9786999821662903 test loss: 955.8195190429688\n",
526
+ "1710: accuracy:0.9800000190734863 loss:8.284337997436523 (lr:0.0013758495915681124)\n",
527
+ "1720: accuracy:0.9900000095367432 loss:4.359131336212158 (lr:0.0013694862136617303)\n",
528
+ "1730: accuracy:1.0 loss:0.9269058704376221 (lr:0.0013631545007228851)\n",
529
+ "1740: accuracy:0.9900000095367432 loss:3.155236005783081 (lr:0.001356854452751577)\n",
530
+ "1750: accuracy:1.0 loss:0.8599931597709656 (lr:0.0013505859533324838)\n",
531
+ "1760: accuracy:1.0 loss:0.5180749297142029 (lr:0.0013443486532196403)\n",
532
+ "1770: accuracy:0.9800000190734863 loss:2.8685896396636963 (lr:0.0013381424359977245)\n",
533
+ "1780: accuracy:1.0 loss:1.209244728088379 (lr:0.0013319671852514148)\n",
534
+ "1790: accuracy:1.0 loss:0.3605708181858063 (lr:0.0013258226681500673)\n",
535
+ "1800: accuracy:1.0 loss:1.2704296112060547 (lr:0.0013197088846936822)\n",
536
+ "1800: ******** test accuracy: 0.982200026512146 test loss: 896.2050170898438\n",
537
+ "1810: accuracy:1.0 loss:1.8222142457962036 (lr:0.0013136254856362939)\n",
538
+ "1820: accuracy:1.0 loss:0.24535450339317322 (lr:0.0013075725873932242)\n",
539
+ "1830: accuracy:1.0 loss:0.6276077032089233 (lr:0.001301549724303186)\n",
540
+ "1840: accuracy:0.9900000095367432 loss:5.3721489906311035 (lr:0.0012955570127815008)\n",
541
+ "1850: accuracy:0.9800000190734863 loss:9.77694320678711 (lr:0.0012895942199975252)\n",
542
+ "1860: accuracy:1.0 loss:0.6020952463150024 (lr:0.0012836609967052937)\n",
543
+ "1870: accuracy:1.0 loss:1.3472830057144165 (lr:0.001277757459320128)\n",
544
+ "1880: accuracy:1.0 loss:0.9408558011054993 (lr:0.0012718833750113845)\n",
545
+ "1890: accuracy:1.0 loss:0.6806813478469849 (lr:0.0012660386273637414)\n",
546
+ "1900: accuracy:1.0 loss:0.7370650768280029 (lr:0.001260222983546555)\n",
547
+ "1900: ******** test accuracy: 0.9839000105857849 test loss: 833.6243896484375\n"
548
+ ]
549
+ },
550
+ {
551
+ "name": "stdout",
552
+ "output_type": "stream",
553
+ "text": [
554
+ "1910: accuracy:1.0 loss:0.2219347357749939 (lr:0.0012544363271445036)\n",
555
+ "1920: accuracy:0.9900000095367432 loss:5.134439468383789 (lr:0.0012486785417422652)\n",
556
+ "1930: accuracy:1.0 loss:0.9207942485809326 (lr:0.001242949510924518)\n",
557
+ "1940: accuracy:1.0 loss:0.47080865502357483 (lr:0.0012372490018606186)\n",
558
+ "1950: accuracy:1.0 loss:1.2955037355422974 (lr:0.0012315770145505667)\n",
559
+ "1960: accuracy:0.9900000095367432 loss:2.417428493499756 (lr:0.0012259331997483969)\n",
560
+ "1970: accuracy:1.0 loss:0.6553899049758911 (lr:0.0012203175574541092)\n",
561
+ "1980: accuracy:1.0 loss:0.49434810876846313 (lr:0.00121472985483706)\n",
562
+ "1990: accuracy:0.9900000095367432 loss:1.8686177730560303 (lr:0.0012091703247278929)\n",
563
+ "2000: accuracy:1.0 loss:0.46403393149375916 (lr:0.001203638268634677)\n",
564
+ "2000: ******** test accuracy: 0.9843000173568726 test loss: 821.3646240234375\n"
565
+ ]
566
+ },
567
+ {
568
+ "data": {
569
+ "text/plain": [
570
+ "0..2000"
571
+ ]
572
+ },
573
+ "execution_count": 9,
574
+ "metadata": {},
575
+ "output_type": "execute_result"
576
+ }
577
+ ],
578
+ "source": [
579
+ "(0..EPOCH).each do |i|\n",
580
+ " # load batch of images and correct answers\n",
581
+ " batch_x, batch_y = mnist_train.next_batch(100)\n",
582
+ " train_data = { x => batch_x, y_ => batch_y, step_ => i, pkeep => 0.75 }\n",
583
+ "\n",
584
+ " # train\n",
585
+ " sess.run(train_step, feed_dict: train_data)\n",
586
+ "\n",
587
+ " if (i % 10 == 0)\n",
588
+ " # result = TensorStream::ReportTool.profile_for(sess)\n",
589
+ " # File.write(\"profile.csv\", result.map(&:to_csv).join(\"\\n\"))\n",
590
+ " # success? add code to print it\n",
591
+ " a_train, c_train, l = sess.run([accuracy, cross_entropy, lr], feed_dict: { x => batch_x, y_ => batch_y, step_ => i, pkeep => 1.0})\n",
592
+ " puts \"#{i}: accuracy:#{a_train} loss:#{c_train} (lr:#{l})\"\n",
593
+ " end\n",
594
+ "\n",
595
+ " if (i % 100 == 0)\n",
596
+ " # success on test data?\n",
597
+ " a_test, c_test = sess.run([accuracy, cross_entropy], feed_dict: test_data, pkeep => 1.0)\n",
598
+ " puts(\"#{i}: ******** test accuracy: #{a_test} test loss: #{c_test}\")\n",
599
+ "\n",
600
+ " # save current state of the model\n",
601
+ " save_path = saver.save(sess, model_save_path)\n",
602
+ " end\n",
603
+ "end\n",
604
+ "\n"
605
+ ]
606
+ }
607
+ ],
608
+ "metadata": {
609
+ "kernelspec": {
610
+ "display_name": "Ruby 2.5.1",
611
+ "language": "ruby",
612
+ "name": "ruby"
613
+ },
614
+ "language_info": {
615
+ "file_extension": ".rb",
616
+ "mimetype": "application/x-ruby",
617
+ "name": "ruby",
618
+ "version": "2.5.1"
619
+ }
620
+ },
621
+ "nbformat": 4,
622
+ "nbformat_minor": 2
623
+ }
@@ -39,7 +39,7 @@ Gem::Specification.new do |spec|
39
39
  spec.add_development_dependency "awesome_print"
40
40
  spec.add_development_dependency "mnist-learn"
41
41
  spec.add_development_dependency "simplecov"
42
- spec.add_dependency "tensor_stream", "1.0.0"
42
+ spec.add_dependency "tensor_stream", "1.0.4"
43
43
  spec.add_dependency "opencl_ruby_ffi"
44
44
  spec.add_dependency "oily_png"
45
45
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: tensor_stream-opencl
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.8
4
+ version: 0.2.9
5
5
  platform: ruby
6
6
  authors:
7
7
  - Joseph Dayo
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-01-06 00:00:00.000000000 Z
11
+ date: 2019-03-04 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: bundler
@@ -114,14 +114,14 @@ dependencies:
114
114
  requirements:
115
115
  - - '='
116
116
  - !ruby/object:Gem::Version
117
- version: 1.0.0
117
+ version: 1.0.4
118
118
  type: :runtime
119
119
  prerelease: false
120
120
  version_requirements: !ruby/object:Gem::Requirement
121
121
  requirements:
122
122
  - - '='
123
123
  - !ruby/object:Gem::Version
124
- version: 1.0.0
124
+ version: 1.0.4
125
125
  - !ruby/object:Gem::Dependency
126
126
  name: opencl_ruby_ffi
127
127
  requirement: !ruby/object:Gem::Requirement
@@ -256,6 +256,7 @@ files:
256
256
  - samples/mnist_data_2.2.rb
257
257
  - samples/mnist_data_2.3.rb
258
258
  - samples/mnist_data_3.0.rb
259
+ - samples/mnist_image.ipynb
259
260
  - samples/multigpu.rb
260
261
  - samples/nearest_neighbor.rb
261
262
  - samples/rnn.rb