tensor_stream 0.2.0 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +5 -5
- data/.circleci/config.yml +2 -1
- data/CHANGELOG.md +5 -0
- data/README.md +28 -1
- data/benchmark/benchmark.rb +129 -0
- data/lib/tensor_stream.rb +7 -4
- data/lib/tensor_stream/evaluator/buffer.rb +10 -0
- data/lib/tensor_stream/evaluator/evaluator.rb +1 -0
- data/lib/tensor_stream/evaluator/kernels/_bool_operand.cl +45 -0
- data/lib/tensor_stream/evaluator/kernels/_operand.cl +45 -0
- data/lib/tensor_stream/evaluator/kernels/abs.cl +16 -0
- data/lib/tensor_stream/evaluator/kernels/add.cl +5 -0
- data/lib/tensor_stream/evaluator/kernels/argmax.cl +15 -0
- data/lib/tensor_stream/evaluator/kernels/argmin.cl +15 -0
- data/lib/tensor_stream/evaluator/kernels/cast.cl +15 -0
- data/lib/tensor_stream/evaluator/kernels/cond.cl.erb +5 -0
- data/lib/tensor_stream/evaluator/kernels/cos.cl +7 -0
- data/lib/tensor_stream/evaluator/kernels/div.cl.erb +5 -0
- data/lib/tensor_stream/evaluator/kernels/exp.cl +7 -0
- data/lib/tensor_stream/evaluator/kernels/gemm.cl +63 -0
- data/lib/tensor_stream/evaluator/kernels/log.cl +7 -0
- data/lib/tensor_stream/evaluator/kernels/log1p.cl +7 -0
- data/lib/tensor_stream/evaluator/kernels/max.cl +91 -0
- data/lib/tensor_stream/evaluator/kernels/mul.cl +5 -0
- data/lib/tensor_stream/evaluator/kernels/negate.cl +15 -0
- data/lib/tensor_stream/evaluator/kernels/pow.cl +130 -0
- data/lib/tensor_stream/evaluator/kernels/reciprocal.cl +15 -0
- data/lib/tensor_stream/evaluator/kernels/round.cl +7 -0
- data/lib/tensor_stream/evaluator/kernels/sigmoid.cl +8 -0
- data/lib/tensor_stream/evaluator/kernels/sigmoid_grad.cl +54 -0
- data/lib/tensor_stream/evaluator/kernels/sign.cl +23 -0
- data/lib/tensor_stream/evaluator/kernels/sin.cl +8 -0
- data/lib/tensor_stream/evaluator/kernels/sqrt.cl +8 -0
- data/lib/tensor_stream/evaluator/kernels/square.cl +15 -0
- data/lib/tensor_stream/evaluator/kernels/sub.cl +5 -0
- data/lib/tensor_stream/evaluator/kernels/tan.cl +7 -0
- data/lib/tensor_stream/evaluator/kernels/tanh.cl +7 -0
- data/lib/tensor_stream/evaluator/kernels/tanh_grad.cl +6 -0
- data/lib/tensor_stream/evaluator/kernels/where.cl +15 -0
- data/lib/tensor_stream/evaluator/opencl_buffer.rb +30 -0
- data/lib/tensor_stream/evaluator/opencl_evaluator.rb +1095 -0
- data/lib/tensor_stream/evaluator/opencl_template_helper.rb +58 -0
- data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +27 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +20 -31
- data/lib/tensor_stream/graph.rb +4 -2
- data/lib/tensor_stream/math_gradients.rb +3 -0
- data/lib/tensor_stream/operation.rb +29 -2
- data/lib/tensor_stream/ops.rb +14 -2
- data/lib/tensor_stream/placeholder.rb +1 -1
- data/lib/tensor_stream/session.rb +10 -3
- data/lib/tensor_stream/tensor_shape.rb +1 -1
- data/lib/tensor_stream/train/saver.rb +1 -1
- data/lib/tensor_stream/variable.rb +7 -1
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/logistic_regression.rb +2 -1
- data/samples/nearest_neighbor.rb +54 -0
- data/tensor_stream.gemspec +3 -1
- metadata +107 -28
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
require "bundler/setup"
|
4
4
|
require 'tensor_stream'
|
5
|
+
require 'pry-byebug'
|
5
6
|
|
6
7
|
tf = TensorStream
|
7
8
|
|
@@ -61,7 +62,7 @@ optimizer = TensorStream::Train::GradientDescentOptimizer.new(learning_rate)
|
|
61
62
|
goal = optimizer.minimize(loss)
|
62
63
|
prediction = tf.round(tf.sigmoid(mod))
|
63
64
|
# Bool into float32 type
|
64
|
-
correct = tf.cast(tf.equal(prediction, target),
|
65
|
+
correct = tf.cast(tf.equal(prediction, target), :float32)
|
65
66
|
# Average
|
66
67
|
accuracy = tf.reduce_mean(correct)
|
67
68
|
|
@@ -0,0 +1,54 @@
|
|
1
|
+
'''
|
2
|
+
A nearest neighbor learning algorithm example using TensorFlow library.
|
3
|
+
This example is using the MNIST database of handwritten digits
|
4
|
+
(http://yann.lecun.com/exdb/mnist/)
|
5
|
+
|
6
|
+
Author: Aymeric Damien
|
7
|
+
Project: https://github.com/aymericdamien/TensorFlow-Examples/
|
8
|
+
'''
|
9
|
+
require "bundler/setup"
|
10
|
+
require 'tensor_stream'
|
11
|
+
require 'mnist-learn'
|
12
|
+
require 'tensor_stream/evaluator/opencl_evaluator'
|
13
|
+
|
14
|
+
tf = TensorStream
|
15
|
+
|
16
|
+
# Import MNIST data
|
17
|
+
mnist = Mnist.read_data_sets('/tmp/data', one_hot: true)
|
18
|
+
|
19
|
+
# In this example, we limit mnist data
|
20
|
+
Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)
|
21
|
+
Xte, Yte = mnist.test.next_batch(200) #200 for testing
|
22
|
+
|
23
|
+
# tf Graph Input
|
24
|
+
xtr = tf.placeholder(:float, shape: [nil, 784])
|
25
|
+
xte = tf.placeholder(:float, shape: [784])
|
26
|
+
|
27
|
+
# Nearest Neighbor calculation using L1 Distance
|
28
|
+
# Calculate L1 Distance
|
29
|
+
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), 1)
|
30
|
+
# Prediction: Get min distance index (Nearest neighbor)
|
31
|
+
pred = tf.argmin(distance, 0)
|
32
|
+
|
33
|
+
accuracy = 0.0
|
34
|
+
|
35
|
+
# Initialize the variables (i.e. assign their default value)
|
36
|
+
init = tf.global_variables_initializer()
|
37
|
+
|
38
|
+
# Start training
|
39
|
+
tf.session(:opencl_evaluator) do |sess|
|
40
|
+
# Run the initializer
|
41
|
+
sess.run(init)
|
42
|
+
Xte.size.times do |i|
|
43
|
+
# Get nearest neighbor
|
44
|
+
nn_index = sess.run(pred, feed_dict: {xtr => Xtr, xte => Xte[i]})
|
45
|
+
print("Test", i, "Prediction:",Ytr[nn_index].max, \
|
46
|
+
"True Class:", Yte[i].max)
|
47
|
+
if Ytr[nn_index].max == Yte[i].max
|
48
|
+
accuracy += 1.0/ Xte.size
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
print("Done!")
|
53
|
+
print("Accuracy:", accuracy)
|
54
|
+
end
|
data/tensor_stream.gemspec
CHANGED
@@ -35,10 +35,12 @@ Gem::Specification.new do |spec|
|
|
35
35
|
spec.add_development_dependency "rspec", "~> 3.0"
|
36
36
|
spec.add_development_dependency "awesome_print"
|
37
37
|
spec.add_development_dependency "rubocop"
|
38
|
-
|
38
|
+
spec.add_development_dependency "pry-byebug"
|
39
39
|
spec.add_development_dependency "byepry"
|
40
40
|
spec.add_development_dependency "colorize"
|
41
41
|
spec.add_development_dependency "rspec_junit_formatter"
|
42
|
+
spec.add_development_dependency "mnist-learn"
|
43
|
+
spec.add_development_dependency "opencl_ruby_ffi"
|
42
44
|
spec.add_dependency "deep_merge"
|
43
45
|
spec.add_dependency "concurrent-ruby"
|
44
46
|
spec.add_dependency "sciruby"
|
metadata
CHANGED
@@ -1,178 +1,220 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: tensor_stream
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Joseph Emmanuel Dayo
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-06-04 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
|
+
name: bundler
|
14
15
|
requirement: !ruby/object:Gem::Requirement
|
15
16
|
requirements:
|
16
17
|
- - "~>"
|
17
18
|
- !ruby/object:Gem::Version
|
18
19
|
version: '1.14'
|
19
|
-
name: bundler
|
20
|
-
prerelease: false
|
21
20
|
type: :development
|
21
|
+
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - "~>"
|
25
25
|
- !ruby/object:Gem::Version
|
26
26
|
version: '1.14'
|
27
27
|
- !ruby/object:Gem::Dependency
|
28
|
+
name: rake
|
28
29
|
requirement: !ruby/object:Gem::Requirement
|
29
30
|
requirements:
|
30
31
|
- - "~>"
|
31
32
|
- !ruby/object:Gem::Version
|
32
33
|
version: '10.0'
|
33
|
-
name: rake
|
34
|
-
prerelease: false
|
35
34
|
type: :development
|
35
|
+
prerelease: false
|
36
36
|
version_requirements: !ruby/object:Gem::Requirement
|
37
37
|
requirements:
|
38
38
|
- - "~>"
|
39
39
|
- !ruby/object:Gem::Version
|
40
40
|
version: '10.0'
|
41
41
|
- !ruby/object:Gem::Dependency
|
42
|
+
name: rspec
|
42
43
|
requirement: !ruby/object:Gem::Requirement
|
43
44
|
requirements:
|
44
45
|
- - "~>"
|
45
46
|
- !ruby/object:Gem::Version
|
46
47
|
version: '3.0'
|
47
|
-
name: rspec
|
48
|
-
prerelease: false
|
49
48
|
type: :development
|
49
|
+
prerelease: false
|
50
50
|
version_requirements: !ruby/object:Gem::Requirement
|
51
51
|
requirements:
|
52
52
|
- - "~>"
|
53
53
|
- !ruby/object:Gem::Version
|
54
54
|
version: '3.0'
|
55
55
|
- !ruby/object:Gem::Dependency
|
56
|
+
name: awesome_print
|
56
57
|
requirement: !ruby/object:Gem::Requirement
|
57
58
|
requirements:
|
58
59
|
- - ">="
|
59
60
|
- !ruby/object:Gem::Version
|
60
61
|
version: '0'
|
61
|
-
name: awesome_print
|
62
|
-
prerelease: false
|
63
62
|
type: :development
|
63
|
+
prerelease: false
|
64
64
|
version_requirements: !ruby/object:Gem::Requirement
|
65
65
|
requirements:
|
66
66
|
- - ">="
|
67
67
|
- !ruby/object:Gem::Version
|
68
68
|
version: '0'
|
69
69
|
- !ruby/object:Gem::Dependency
|
70
|
+
name: rubocop
|
70
71
|
requirement: !ruby/object:Gem::Requirement
|
71
72
|
requirements:
|
72
73
|
- - ">="
|
73
74
|
- !ruby/object:Gem::Version
|
74
75
|
version: '0'
|
75
|
-
name: rubocop
|
76
|
-
prerelease: false
|
77
76
|
type: :development
|
77
|
+
prerelease: false
|
78
78
|
version_requirements: !ruby/object:Gem::Requirement
|
79
79
|
requirements:
|
80
80
|
- - ">="
|
81
81
|
- !ruby/object:Gem::Version
|
82
82
|
version: '0'
|
83
83
|
- !ruby/object:Gem::Dependency
|
84
|
+
name: pry-byebug
|
84
85
|
requirement: !ruby/object:Gem::Requirement
|
85
86
|
requirements:
|
86
87
|
- - ">="
|
87
88
|
- !ruby/object:Gem::Version
|
88
89
|
version: '0'
|
89
|
-
name: byepry
|
90
|
-
prerelease: false
|
91
90
|
type: :development
|
91
|
+
prerelease: false
|
92
92
|
version_requirements: !ruby/object:Gem::Requirement
|
93
93
|
requirements:
|
94
94
|
- - ">="
|
95
95
|
- !ruby/object:Gem::Version
|
96
96
|
version: '0'
|
97
97
|
- !ruby/object:Gem::Dependency
|
98
|
+
name: byepry
|
98
99
|
requirement: !ruby/object:Gem::Requirement
|
99
100
|
requirements:
|
100
101
|
- - ">="
|
101
102
|
- !ruby/object:Gem::Version
|
102
103
|
version: '0'
|
103
|
-
name: colorize
|
104
|
-
prerelease: false
|
105
104
|
type: :development
|
105
|
+
prerelease: false
|
106
106
|
version_requirements: !ruby/object:Gem::Requirement
|
107
107
|
requirements:
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
110
|
version: '0'
|
111
111
|
- !ruby/object:Gem::Dependency
|
112
|
+
name: colorize
|
112
113
|
requirement: !ruby/object:Gem::Requirement
|
113
114
|
requirements:
|
114
115
|
- - ">="
|
115
116
|
- !ruby/object:Gem::Version
|
116
117
|
version: '0'
|
117
|
-
|
118
|
+
type: :development
|
118
119
|
prerelease: false
|
120
|
+
version_requirements: !ruby/object:Gem::Requirement
|
121
|
+
requirements:
|
122
|
+
- - ">="
|
123
|
+
- !ruby/object:Gem::Version
|
124
|
+
version: '0'
|
125
|
+
- !ruby/object:Gem::Dependency
|
126
|
+
name: rspec_junit_formatter
|
127
|
+
requirement: !ruby/object:Gem::Requirement
|
128
|
+
requirements:
|
129
|
+
- - ">="
|
130
|
+
- !ruby/object:Gem::Version
|
131
|
+
version: '0'
|
119
132
|
type: :development
|
133
|
+
prerelease: false
|
120
134
|
version_requirements: !ruby/object:Gem::Requirement
|
121
135
|
requirements:
|
122
136
|
- - ">="
|
123
137
|
- !ruby/object:Gem::Version
|
124
138
|
version: '0'
|
125
139
|
- !ruby/object:Gem::Dependency
|
140
|
+
name: mnist-learn
|
126
141
|
requirement: !ruby/object:Gem::Requirement
|
127
142
|
requirements:
|
128
143
|
- - ">="
|
129
144
|
- !ruby/object:Gem::Version
|
130
145
|
version: '0'
|
131
|
-
|
146
|
+
type: :development
|
132
147
|
prerelease: false
|
133
|
-
type: :runtime
|
134
148
|
version_requirements: !ruby/object:Gem::Requirement
|
135
149
|
requirements:
|
136
150
|
- - ">="
|
137
151
|
- !ruby/object:Gem::Version
|
138
152
|
version: '0'
|
139
153
|
- !ruby/object:Gem::Dependency
|
154
|
+
name: opencl_ruby_ffi
|
140
155
|
requirement: !ruby/object:Gem::Requirement
|
141
156
|
requirements:
|
142
157
|
- - ">="
|
143
158
|
- !ruby/object:Gem::Version
|
144
159
|
version: '0'
|
145
|
-
|
160
|
+
type: :development
|
146
161
|
prerelease: false
|
147
|
-
type: :runtime
|
148
162
|
version_requirements: !ruby/object:Gem::Requirement
|
149
163
|
requirements:
|
150
164
|
- - ">="
|
151
165
|
- !ruby/object:Gem::Version
|
152
166
|
version: '0'
|
153
167
|
- !ruby/object:Gem::Dependency
|
168
|
+
name: deep_merge
|
154
169
|
requirement: !ruby/object:Gem::Requirement
|
155
170
|
requirements:
|
156
171
|
- - ">="
|
157
172
|
- !ruby/object:Gem::Version
|
158
173
|
version: '0'
|
159
|
-
|
174
|
+
type: :runtime
|
160
175
|
prerelease: false
|
176
|
+
version_requirements: !ruby/object:Gem::Requirement
|
177
|
+
requirements:
|
178
|
+
- - ">="
|
179
|
+
- !ruby/object:Gem::Version
|
180
|
+
version: '0'
|
181
|
+
- !ruby/object:Gem::Dependency
|
182
|
+
name: concurrent-ruby
|
183
|
+
requirement: !ruby/object:Gem::Requirement
|
184
|
+
requirements:
|
185
|
+
- - ">="
|
186
|
+
- !ruby/object:Gem::Version
|
187
|
+
version: '0'
|
161
188
|
type: :runtime
|
189
|
+
prerelease: false
|
162
190
|
version_requirements: !ruby/object:Gem::Requirement
|
163
191
|
requirements:
|
164
192
|
- - ">="
|
165
193
|
- !ruby/object:Gem::Version
|
166
194
|
version: '0'
|
167
195
|
- !ruby/object:Gem::Dependency
|
196
|
+
name: sciruby
|
168
197
|
requirement: !ruby/object:Gem::Requirement
|
169
198
|
requirements:
|
170
199
|
- - ">="
|
171
200
|
- !ruby/object:Gem::Version
|
172
201
|
version: '0'
|
173
|
-
|
202
|
+
type: :runtime
|
174
203
|
prerelease: false
|
204
|
+
version_requirements: !ruby/object:Gem::Requirement
|
205
|
+
requirements:
|
206
|
+
- - ">="
|
207
|
+
- !ruby/object:Gem::Version
|
208
|
+
version: '0'
|
209
|
+
- !ruby/object:Gem::Dependency
|
210
|
+
name: distribution
|
211
|
+
requirement: !ruby/object:Gem::Requirement
|
212
|
+
requirements:
|
213
|
+
- - ">="
|
214
|
+
- !ruby/object:Gem::Version
|
215
|
+
version: '0'
|
175
216
|
type: :runtime
|
217
|
+
prerelease: false
|
176
218
|
version_requirements: !ruby/object:Gem::Requirement
|
177
219
|
requirements:
|
178
220
|
- - ">="
|
@@ -200,12 +242,48 @@ files:
|
|
200
242
|
- LICENSE.txt
|
201
243
|
- README.md
|
202
244
|
- Rakefile
|
245
|
+
- benchmark/benchmark.rb
|
203
246
|
- bin/console
|
204
247
|
- bin/setup
|
205
248
|
- lib/tensor_stream.rb
|
206
249
|
- lib/tensor_stream/control_flow.rb
|
207
250
|
- lib/tensor_stream/device.rb
|
251
|
+
- lib/tensor_stream/evaluator/buffer.rb
|
208
252
|
- lib/tensor_stream/evaluator/evaluator.rb
|
253
|
+
- lib/tensor_stream/evaluator/kernels/_bool_operand.cl
|
254
|
+
- lib/tensor_stream/evaluator/kernels/_operand.cl
|
255
|
+
- lib/tensor_stream/evaluator/kernels/abs.cl
|
256
|
+
- lib/tensor_stream/evaluator/kernels/add.cl
|
257
|
+
- lib/tensor_stream/evaluator/kernels/argmax.cl
|
258
|
+
- lib/tensor_stream/evaluator/kernels/argmin.cl
|
259
|
+
- lib/tensor_stream/evaluator/kernels/cast.cl
|
260
|
+
- lib/tensor_stream/evaluator/kernels/cond.cl.erb
|
261
|
+
- lib/tensor_stream/evaluator/kernels/cos.cl
|
262
|
+
- lib/tensor_stream/evaluator/kernels/div.cl.erb
|
263
|
+
- lib/tensor_stream/evaluator/kernels/exp.cl
|
264
|
+
- lib/tensor_stream/evaluator/kernels/gemm.cl
|
265
|
+
- lib/tensor_stream/evaluator/kernels/log.cl
|
266
|
+
- lib/tensor_stream/evaluator/kernels/log1p.cl
|
267
|
+
- lib/tensor_stream/evaluator/kernels/max.cl
|
268
|
+
- lib/tensor_stream/evaluator/kernels/mul.cl
|
269
|
+
- lib/tensor_stream/evaluator/kernels/negate.cl
|
270
|
+
- lib/tensor_stream/evaluator/kernels/pow.cl
|
271
|
+
- lib/tensor_stream/evaluator/kernels/reciprocal.cl
|
272
|
+
- lib/tensor_stream/evaluator/kernels/round.cl
|
273
|
+
- lib/tensor_stream/evaluator/kernels/sigmoid.cl
|
274
|
+
- lib/tensor_stream/evaluator/kernels/sigmoid_grad.cl
|
275
|
+
- lib/tensor_stream/evaluator/kernels/sign.cl
|
276
|
+
- lib/tensor_stream/evaluator/kernels/sin.cl
|
277
|
+
- lib/tensor_stream/evaluator/kernels/sqrt.cl
|
278
|
+
- lib/tensor_stream/evaluator/kernels/square.cl
|
279
|
+
- lib/tensor_stream/evaluator/kernels/sub.cl
|
280
|
+
- lib/tensor_stream/evaluator/kernels/tan.cl
|
281
|
+
- lib/tensor_stream/evaluator/kernels/tanh.cl
|
282
|
+
- lib/tensor_stream/evaluator/kernels/tanh_grad.cl
|
283
|
+
- lib/tensor_stream/evaluator/kernels/where.cl
|
284
|
+
- lib/tensor_stream/evaluator/opencl_buffer.rb
|
285
|
+
- lib/tensor_stream/evaluator/opencl_evaluator.rb
|
286
|
+
- lib/tensor_stream/evaluator/opencl_template_helper.rb
|
209
287
|
- lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb
|
210
288
|
- lib/tensor_stream/evaluator/operation_helpers/math_helper.rb
|
211
289
|
- lib/tensor_stream/evaluator/operation_helpers/random_gaussian.rb
|
@@ -237,6 +315,7 @@ files:
|
|
237
315
|
- samples/iris.data
|
238
316
|
- samples/linear_regression.rb
|
239
317
|
- samples/logistic_regression.rb
|
318
|
+
- samples/nearest_neighbor.rb
|
240
319
|
- tensor_stream.gemspec
|
241
320
|
- test_samples/error.graphml
|
242
321
|
- test_samples/gradient_sample.graphml
|
@@ -249,7 +328,7 @@ licenses:
|
|
249
328
|
- MIT
|
250
329
|
metadata:
|
251
330
|
allowed_push_host: https://rubygems.org
|
252
|
-
post_install_message:
|
331
|
+
post_install_message:
|
253
332
|
rdoc_options: []
|
254
333
|
require_paths:
|
255
334
|
- lib
|
@@ -264,9 +343,9 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
264
343
|
- !ruby/object:Gem::Version
|
265
344
|
version: '0'
|
266
345
|
requirements: []
|
267
|
-
rubyforge_project:
|
268
|
-
rubygems_version: 2.6.
|
269
|
-
signing_key:
|
346
|
+
rubyforge_project:
|
347
|
+
rubygems_version: 2.6.11
|
348
|
+
signing_key:
|
270
349
|
specification_version: 4
|
271
350
|
summary: A Pure ruby tensorflow implementation
|
272
351
|
test_files: []
|