tensor_stream 0.2.0 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (58) hide show
  1. checksums.yaml +5 -5
  2. data/.circleci/config.yml +2 -1
  3. data/CHANGELOG.md +5 -0
  4. data/README.md +28 -1
  5. data/benchmark/benchmark.rb +129 -0
  6. data/lib/tensor_stream.rb +7 -4
  7. data/lib/tensor_stream/evaluator/buffer.rb +10 -0
  8. data/lib/tensor_stream/evaluator/evaluator.rb +1 -0
  9. data/lib/tensor_stream/evaluator/kernels/_bool_operand.cl +45 -0
  10. data/lib/tensor_stream/evaluator/kernels/_operand.cl +45 -0
  11. data/lib/tensor_stream/evaluator/kernels/abs.cl +16 -0
  12. data/lib/tensor_stream/evaluator/kernels/add.cl +5 -0
  13. data/lib/tensor_stream/evaluator/kernels/argmax.cl +15 -0
  14. data/lib/tensor_stream/evaluator/kernels/argmin.cl +15 -0
  15. data/lib/tensor_stream/evaluator/kernels/cast.cl +15 -0
  16. data/lib/tensor_stream/evaluator/kernels/cond.cl.erb +5 -0
  17. data/lib/tensor_stream/evaluator/kernels/cos.cl +7 -0
  18. data/lib/tensor_stream/evaluator/kernels/div.cl.erb +5 -0
  19. data/lib/tensor_stream/evaluator/kernels/exp.cl +7 -0
  20. data/lib/tensor_stream/evaluator/kernels/gemm.cl +63 -0
  21. data/lib/tensor_stream/evaluator/kernels/log.cl +7 -0
  22. data/lib/tensor_stream/evaluator/kernels/log1p.cl +7 -0
  23. data/lib/tensor_stream/evaluator/kernels/max.cl +91 -0
  24. data/lib/tensor_stream/evaluator/kernels/mul.cl +5 -0
  25. data/lib/tensor_stream/evaluator/kernels/negate.cl +15 -0
  26. data/lib/tensor_stream/evaluator/kernels/pow.cl +130 -0
  27. data/lib/tensor_stream/evaluator/kernels/reciprocal.cl +15 -0
  28. data/lib/tensor_stream/evaluator/kernels/round.cl +7 -0
  29. data/lib/tensor_stream/evaluator/kernels/sigmoid.cl +8 -0
  30. data/lib/tensor_stream/evaluator/kernels/sigmoid_grad.cl +54 -0
  31. data/lib/tensor_stream/evaluator/kernels/sign.cl +23 -0
  32. data/lib/tensor_stream/evaluator/kernels/sin.cl +8 -0
  33. data/lib/tensor_stream/evaluator/kernels/sqrt.cl +8 -0
  34. data/lib/tensor_stream/evaluator/kernels/square.cl +15 -0
  35. data/lib/tensor_stream/evaluator/kernels/sub.cl +5 -0
  36. data/lib/tensor_stream/evaluator/kernels/tan.cl +7 -0
  37. data/lib/tensor_stream/evaluator/kernels/tanh.cl +7 -0
  38. data/lib/tensor_stream/evaluator/kernels/tanh_grad.cl +6 -0
  39. data/lib/tensor_stream/evaluator/kernels/where.cl +15 -0
  40. data/lib/tensor_stream/evaluator/opencl_buffer.rb +30 -0
  41. data/lib/tensor_stream/evaluator/opencl_evaluator.rb +1095 -0
  42. data/lib/tensor_stream/evaluator/opencl_template_helper.rb +58 -0
  43. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +27 -0
  44. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +20 -31
  45. data/lib/tensor_stream/graph.rb +4 -2
  46. data/lib/tensor_stream/math_gradients.rb +3 -0
  47. data/lib/tensor_stream/operation.rb +29 -2
  48. data/lib/tensor_stream/ops.rb +14 -2
  49. data/lib/tensor_stream/placeholder.rb +1 -1
  50. data/lib/tensor_stream/session.rb +10 -3
  51. data/lib/tensor_stream/tensor_shape.rb +1 -1
  52. data/lib/tensor_stream/train/saver.rb +1 -1
  53. data/lib/tensor_stream/variable.rb +7 -1
  54. data/lib/tensor_stream/version.rb +1 -1
  55. data/samples/logistic_regression.rb +2 -1
  56. data/samples/nearest_neighbor.rb +54 -0
  57. data/tensor_stream.gemspec +3 -1
  58. metadata +107 -28
@@ -1,5 +1,5 @@
1
1
  module TensorStream
2
- VERSION = '0.2.0'.freeze
2
+ VERSION = '0.3.0'.freeze
3
3
 
4
4
  def self.version
5
5
  VERSION
@@ -2,6 +2,7 @@
2
2
 
3
3
  require "bundler/setup"
4
4
  require 'tensor_stream'
5
+ require 'pry-byebug'
5
6
 
6
7
  tf = TensorStream
7
8
 
@@ -61,7 +62,7 @@ optimizer = TensorStream::Train::GradientDescentOptimizer.new(learning_rate)
61
62
  goal = optimizer.minimize(loss)
62
63
  prediction = tf.round(tf.sigmoid(mod))
63
64
  # Bool into float32 type
64
- correct = tf.cast(tf.equal(prediction, target), dtype: :float32)
65
+ correct = tf.cast(tf.equal(prediction, target), :float32)
65
66
  # Average
66
67
  accuracy = tf.reduce_mean(correct)
67
68
 
@@ -0,0 +1,54 @@
1
+ '''
2
+ A nearest neighbor learning algorithm example using TensorFlow library.
3
+ This example is using the MNIST database of handwritten digits
4
+ (http://yann.lecun.com/exdb/mnist/)
5
+
6
+ Author: Aymeric Damien
7
+ Project: https://github.com/aymericdamien/TensorFlow-Examples/
8
+ '''
9
+ require "bundler/setup"
10
+ require 'tensor_stream'
11
+ require 'mnist-learn'
12
+ require 'tensor_stream/evaluator/opencl_evaluator'
13
+
14
+ tf = TensorStream
15
+
16
+ # Import MNIST data
17
+ mnist = Mnist.read_data_sets('/tmp/data', one_hot: true)
18
+
19
+ # In this example, we limit mnist data
20
+ Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)
21
+ Xte, Yte = mnist.test.next_batch(200) #200 for testing
22
+
23
+ # tf Graph Input
24
+ xtr = tf.placeholder(:float, shape: [nil, 784])
25
+ xte = tf.placeholder(:float, shape: [784])
26
+
27
+ # Nearest Neighbor calculation using L1 Distance
28
+ # Calculate L1 Distance
29
+ distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), 1)
30
+ # Prediction: Get min distance index (Nearest neighbor)
31
+ pred = tf.argmin(distance, 0)
32
+
33
+ accuracy = 0.0
34
+
35
+ # Initialize the variables (i.e. assign their default value)
36
+ init = tf.global_variables_initializer()
37
+
38
+ # Start training
39
+ tf.session(:opencl_evaluator) do |sess|
40
+ # Run the initializer
41
+ sess.run(init)
42
+ Xte.size.times do |i|
43
+ # Get nearest neighbor
44
+ nn_index = sess.run(pred, feed_dict: {xtr => Xtr, xte => Xte[i]})
45
+ print("Test", i, "Prediction:",Ytr[nn_index].max, \
46
+ "True Class:", Yte[i].max)
47
+ if Ytr[nn_index].max == Yte[i].max
48
+ accuracy += 1.0/ Xte.size
49
+ end
50
+ end
51
+
52
+ print("Done!")
53
+ print("Accuracy:", accuracy)
54
+ end
@@ -35,10 +35,12 @@ Gem::Specification.new do |spec|
35
35
  spec.add_development_dependency "rspec", "~> 3.0"
36
36
  spec.add_development_dependency "awesome_print"
37
37
  spec.add_development_dependency "rubocop"
38
- # spec.add_development_dependency "pry-byebug"
38
+ spec.add_development_dependency "pry-byebug"
39
39
  spec.add_development_dependency "byepry"
40
40
  spec.add_development_dependency "colorize"
41
41
  spec.add_development_dependency "rspec_junit_formatter"
42
+ spec.add_development_dependency "mnist-learn"
43
+ spec.add_development_dependency "opencl_ruby_ffi"
42
44
  spec.add_dependency "deep_merge"
43
45
  spec.add_dependency "concurrent-ruby"
44
46
  spec.add_dependency "sciruby"
metadata CHANGED
@@ -1,178 +1,220 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: tensor_stream
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Joseph Emmanuel Dayo
8
- autorequire:
8
+ autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-05-26 00:00:00.000000000 Z
11
+ date: 2018-06-04 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
+ name: bundler
14
15
  requirement: !ruby/object:Gem::Requirement
15
16
  requirements:
16
17
  - - "~>"
17
18
  - !ruby/object:Gem::Version
18
19
  version: '1.14'
19
- name: bundler
20
- prerelease: false
21
20
  type: :development
21
+ prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - "~>"
25
25
  - !ruby/object:Gem::Version
26
26
  version: '1.14'
27
27
  - !ruby/object:Gem::Dependency
28
+ name: rake
28
29
  requirement: !ruby/object:Gem::Requirement
29
30
  requirements:
30
31
  - - "~>"
31
32
  - !ruby/object:Gem::Version
32
33
  version: '10.0'
33
- name: rake
34
- prerelease: false
35
34
  type: :development
35
+ prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - "~>"
39
39
  - !ruby/object:Gem::Version
40
40
  version: '10.0'
41
41
  - !ruby/object:Gem::Dependency
42
+ name: rspec
42
43
  requirement: !ruby/object:Gem::Requirement
43
44
  requirements:
44
45
  - - "~>"
45
46
  - !ruby/object:Gem::Version
46
47
  version: '3.0'
47
- name: rspec
48
- prerelease: false
49
48
  type: :development
49
+ prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
52
  - - "~>"
53
53
  - !ruby/object:Gem::Version
54
54
  version: '3.0'
55
55
  - !ruby/object:Gem::Dependency
56
+ name: awesome_print
56
57
  requirement: !ruby/object:Gem::Requirement
57
58
  requirements:
58
59
  - - ">="
59
60
  - !ruby/object:Gem::Version
60
61
  version: '0'
61
- name: awesome_print
62
- prerelease: false
63
62
  type: :development
63
+ prerelease: false
64
64
  version_requirements: !ruby/object:Gem::Requirement
65
65
  requirements:
66
66
  - - ">="
67
67
  - !ruby/object:Gem::Version
68
68
  version: '0'
69
69
  - !ruby/object:Gem::Dependency
70
+ name: rubocop
70
71
  requirement: !ruby/object:Gem::Requirement
71
72
  requirements:
72
73
  - - ">="
73
74
  - !ruby/object:Gem::Version
74
75
  version: '0'
75
- name: rubocop
76
- prerelease: false
77
76
  type: :development
77
+ prerelease: false
78
78
  version_requirements: !ruby/object:Gem::Requirement
79
79
  requirements:
80
80
  - - ">="
81
81
  - !ruby/object:Gem::Version
82
82
  version: '0'
83
83
  - !ruby/object:Gem::Dependency
84
+ name: pry-byebug
84
85
  requirement: !ruby/object:Gem::Requirement
85
86
  requirements:
86
87
  - - ">="
87
88
  - !ruby/object:Gem::Version
88
89
  version: '0'
89
- name: byepry
90
- prerelease: false
91
90
  type: :development
91
+ prerelease: false
92
92
  version_requirements: !ruby/object:Gem::Requirement
93
93
  requirements:
94
94
  - - ">="
95
95
  - !ruby/object:Gem::Version
96
96
  version: '0'
97
97
  - !ruby/object:Gem::Dependency
98
+ name: byepry
98
99
  requirement: !ruby/object:Gem::Requirement
99
100
  requirements:
100
101
  - - ">="
101
102
  - !ruby/object:Gem::Version
102
103
  version: '0'
103
- name: colorize
104
- prerelease: false
105
104
  type: :development
105
+ prerelease: false
106
106
  version_requirements: !ruby/object:Gem::Requirement
107
107
  requirements:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
110
  version: '0'
111
111
  - !ruby/object:Gem::Dependency
112
+ name: colorize
112
113
  requirement: !ruby/object:Gem::Requirement
113
114
  requirements:
114
115
  - - ">="
115
116
  - !ruby/object:Gem::Version
116
117
  version: '0'
117
- name: rspec_junit_formatter
118
+ type: :development
118
119
  prerelease: false
120
+ version_requirements: !ruby/object:Gem::Requirement
121
+ requirements:
122
+ - - ">="
123
+ - !ruby/object:Gem::Version
124
+ version: '0'
125
+ - !ruby/object:Gem::Dependency
126
+ name: rspec_junit_formatter
127
+ requirement: !ruby/object:Gem::Requirement
128
+ requirements:
129
+ - - ">="
130
+ - !ruby/object:Gem::Version
131
+ version: '0'
119
132
  type: :development
133
+ prerelease: false
120
134
  version_requirements: !ruby/object:Gem::Requirement
121
135
  requirements:
122
136
  - - ">="
123
137
  - !ruby/object:Gem::Version
124
138
  version: '0'
125
139
  - !ruby/object:Gem::Dependency
140
+ name: mnist-learn
126
141
  requirement: !ruby/object:Gem::Requirement
127
142
  requirements:
128
143
  - - ">="
129
144
  - !ruby/object:Gem::Version
130
145
  version: '0'
131
- name: deep_merge
146
+ type: :development
132
147
  prerelease: false
133
- type: :runtime
134
148
  version_requirements: !ruby/object:Gem::Requirement
135
149
  requirements:
136
150
  - - ">="
137
151
  - !ruby/object:Gem::Version
138
152
  version: '0'
139
153
  - !ruby/object:Gem::Dependency
154
+ name: opencl_ruby_ffi
140
155
  requirement: !ruby/object:Gem::Requirement
141
156
  requirements:
142
157
  - - ">="
143
158
  - !ruby/object:Gem::Version
144
159
  version: '0'
145
- name: concurrent-ruby
160
+ type: :development
146
161
  prerelease: false
147
- type: :runtime
148
162
  version_requirements: !ruby/object:Gem::Requirement
149
163
  requirements:
150
164
  - - ">="
151
165
  - !ruby/object:Gem::Version
152
166
  version: '0'
153
167
  - !ruby/object:Gem::Dependency
168
+ name: deep_merge
154
169
  requirement: !ruby/object:Gem::Requirement
155
170
  requirements:
156
171
  - - ">="
157
172
  - !ruby/object:Gem::Version
158
173
  version: '0'
159
- name: sciruby
174
+ type: :runtime
160
175
  prerelease: false
176
+ version_requirements: !ruby/object:Gem::Requirement
177
+ requirements:
178
+ - - ">="
179
+ - !ruby/object:Gem::Version
180
+ version: '0'
181
+ - !ruby/object:Gem::Dependency
182
+ name: concurrent-ruby
183
+ requirement: !ruby/object:Gem::Requirement
184
+ requirements:
185
+ - - ">="
186
+ - !ruby/object:Gem::Version
187
+ version: '0'
161
188
  type: :runtime
189
+ prerelease: false
162
190
  version_requirements: !ruby/object:Gem::Requirement
163
191
  requirements:
164
192
  - - ">="
165
193
  - !ruby/object:Gem::Version
166
194
  version: '0'
167
195
  - !ruby/object:Gem::Dependency
196
+ name: sciruby
168
197
  requirement: !ruby/object:Gem::Requirement
169
198
  requirements:
170
199
  - - ">="
171
200
  - !ruby/object:Gem::Version
172
201
  version: '0'
173
- name: distribution
202
+ type: :runtime
174
203
  prerelease: false
204
+ version_requirements: !ruby/object:Gem::Requirement
205
+ requirements:
206
+ - - ">="
207
+ - !ruby/object:Gem::Version
208
+ version: '0'
209
+ - !ruby/object:Gem::Dependency
210
+ name: distribution
211
+ requirement: !ruby/object:Gem::Requirement
212
+ requirements:
213
+ - - ">="
214
+ - !ruby/object:Gem::Version
215
+ version: '0'
175
216
  type: :runtime
217
+ prerelease: false
176
218
  version_requirements: !ruby/object:Gem::Requirement
177
219
  requirements:
178
220
  - - ">="
@@ -200,12 +242,48 @@ files:
200
242
  - LICENSE.txt
201
243
  - README.md
202
244
  - Rakefile
245
+ - benchmark/benchmark.rb
203
246
  - bin/console
204
247
  - bin/setup
205
248
  - lib/tensor_stream.rb
206
249
  - lib/tensor_stream/control_flow.rb
207
250
  - lib/tensor_stream/device.rb
251
+ - lib/tensor_stream/evaluator/buffer.rb
208
252
  - lib/tensor_stream/evaluator/evaluator.rb
253
+ - lib/tensor_stream/evaluator/kernels/_bool_operand.cl
254
+ - lib/tensor_stream/evaluator/kernels/_operand.cl
255
+ - lib/tensor_stream/evaluator/kernels/abs.cl
256
+ - lib/tensor_stream/evaluator/kernels/add.cl
257
+ - lib/tensor_stream/evaluator/kernels/argmax.cl
258
+ - lib/tensor_stream/evaluator/kernels/argmin.cl
259
+ - lib/tensor_stream/evaluator/kernels/cast.cl
260
+ - lib/tensor_stream/evaluator/kernels/cond.cl.erb
261
+ - lib/tensor_stream/evaluator/kernels/cos.cl
262
+ - lib/tensor_stream/evaluator/kernels/div.cl.erb
263
+ - lib/tensor_stream/evaluator/kernels/exp.cl
264
+ - lib/tensor_stream/evaluator/kernels/gemm.cl
265
+ - lib/tensor_stream/evaluator/kernels/log.cl
266
+ - lib/tensor_stream/evaluator/kernels/log1p.cl
267
+ - lib/tensor_stream/evaluator/kernels/max.cl
268
+ - lib/tensor_stream/evaluator/kernels/mul.cl
269
+ - lib/tensor_stream/evaluator/kernels/negate.cl
270
+ - lib/tensor_stream/evaluator/kernels/pow.cl
271
+ - lib/tensor_stream/evaluator/kernels/reciprocal.cl
272
+ - lib/tensor_stream/evaluator/kernels/round.cl
273
+ - lib/tensor_stream/evaluator/kernels/sigmoid.cl
274
+ - lib/tensor_stream/evaluator/kernels/sigmoid_grad.cl
275
+ - lib/tensor_stream/evaluator/kernels/sign.cl
276
+ - lib/tensor_stream/evaluator/kernels/sin.cl
277
+ - lib/tensor_stream/evaluator/kernels/sqrt.cl
278
+ - lib/tensor_stream/evaluator/kernels/square.cl
279
+ - lib/tensor_stream/evaluator/kernels/sub.cl
280
+ - lib/tensor_stream/evaluator/kernels/tan.cl
281
+ - lib/tensor_stream/evaluator/kernels/tanh.cl
282
+ - lib/tensor_stream/evaluator/kernels/tanh_grad.cl
283
+ - lib/tensor_stream/evaluator/kernels/where.cl
284
+ - lib/tensor_stream/evaluator/opencl_buffer.rb
285
+ - lib/tensor_stream/evaluator/opencl_evaluator.rb
286
+ - lib/tensor_stream/evaluator/opencl_template_helper.rb
209
287
  - lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb
210
288
  - lib/tensor_stream/evaluator/operation_helpers/math_helper.rb
211
289
  - lib/tensor_stream/evaluator/operation_helpers/random_gaussian.rb
@@ -237,6 +315,7 @@ files:
237
315
  - samples/iris.data
238
316
  - samples/linear_regression.rb
239
317
  - samples/logistic_regression.rb
318
+ - samples/nearest_neighbor.rb
240
319
  - tensor_stream.gemspec
241
320
  - test_samples/error.graphml
242
321
  - test_samples/gradient_sample.graphml
@@ -249,7 +328,7 @@ licenses:
249
328
  - MIT
250
329
  metadata:
251
330
  allowed_push_host: https://rubygems.org
252
- post_install_message:
331
+ post_install_message:
253
332
  rdoc_options: []
254
333
  require_paths:
255
334
  - lib
@@ -264,9 +343,9 @@ required_rubygems_version: !ruby/object:Gem::Requirement
264
343
  - !ruby/object:Gem::Version
265
344
  version: '0'
266
345
  requirements: []
267
- rubyforge_project:
268
- rubygems_version: 2.6.13
269
- signing_key:
346
+ rubyforge_project:
347
+ rubygems_version: 2.6.11
348
+ signing_key:
270
349
  specification_version: 4
271
350
  summary: A Pure ruby tensorflow implementation
272
351
  test_files: []